#!/usr/bin/env python3

import struct
import sys
from pathlib import Path

class DNGVulnAnalyzer:
    def __init__(self, filepath):
        self.filepath = Path(filepath)
        if not self.filepath.exists():
            raise FileNotFoundError(f"File not found: {filepath}")
        
        self.data = self.filepath.read_bytes()
        self.tiff_offset = 0
        self.endian = '<'
        self.vulnerability_data = {}
        
    def find_tiff_header(self):
        tiff_le = b'II\x2a\x00'
        tiff_be = b'MM\x00\x2a'
        
        le_pos = self.data.find(tiff_le)
        be_pos = self.data.find(tiff_be)
        
        if le_pos != -1:
            self.tiff_offset = le_pos
            self.endian = '<'
            print(f"TIFF header (little-endian) at: 0x{le_pos:X}")
            return True
        elif be_pos != -1:
            self.tiff_offset = be_pos
            self.endian = '>'
            print(f"TIFF header (big-endian) at: 0x{be_pos:X}")
            return True
        else:
            print("No TIFF header found")
            return False
    
    def read_ifd(self, ifd_offset, ifd_name="IFD"):
        abs_offset = self.tiff_offset + ifd_offset
        print(f"\nAnalyzing {ifd_name} at offset 0x{abs_offset:X}")
        
        if abs_offset + 2 >= len(self.data):
            print("Offset beyond file end")
            return None
        
        entry_count = struct.unpack(f'{self.endian}H', 
                                   self.data[abs_offset:abs_offset+2])[0]
        print(f"   Entries: {entry_count}")
        
        current_pos = abs_offset + 2
        samples_per_pixel = None
        compression = None
        subifd_offset = None
        
        for i in range(entry_count):
            if current_pos + 12 > len(self.data):
                break
                
            entry = struct.unpack(f'{self.endian}HHII', 
                                 self.data[current_pos:current_pos+12])
            tag, data_type, count, value = entry
            
            if tag == 0x0115:
                samples_per_pixel = value
                self.vulnerability_data['samples_per_pixel'] = {
                    'value': value,
                    'offset': current_pos + 8,
                    'absolute_offset': current_pos + 8
                }
                print(f"   SamplesPerPixel: {value} (at 0x{current_pos + 8:X})")
                
            elif tag == 0x0103:
                compression_types = {
                    1: "None", 
                    7: "JPEG", 
                    34892: "JPEG Lossless",
                    34712: "JPEG 2000"
                }
                compression = compression_types.get(value, f"Unknown ({value})")
                print(f"   Compression: {compression}")
                
            elif tag == 0x014A:
                subifd_offset = value
                print(f"   SubIFD at: 0x{value:X}")
                
            elif tag == 0x0100:
                print(f"   Width: {value}")
                
            elif tag == 0x0101:
                print(f"   Height: {value}")
            
            current_pos += 12
        
        if current_pos + 4 <= len(self.data):
            next_ifd = struct.unpack(f'{self.endian}I', 
                                   self.data[current_pos:current_pos+4])[0]
            
            if subifd_offset and subifd_offset != 0:
                self.read_ifd(subifd_offset, "SubIFD")
            
            return next_ifd if next_ifd != 0 else None
        
        return None
    
    def find_jpeg_lossless_streams(self):
        print(f"\nSearching for JPEG Lossless streams...")
        sof3_marker = b'\xFF\xC3'
        pos = 0
        found_streams = []
        
        while pos < len(self.data):
            pos = self.data.find(sof3_marker, pos)
            if pos == -1:
                break
            
            print(f"   SOF3 marker at: 0x{pos:X}")
            
            if pos + 10 < len(self.data):
                try:
                    length = struct.unpack('>H', self.data[pos+2:pos+4])[0]
                    precision = self.data[pos+4]
                    height = struct.unpack('>H', self.data[pos+5:pos+7])[0]
                    width = struct.unpack('>H', self.data[pos+7:pos+9])[0]
                    components = self.data[pos+9]
                    
                    stream_info = {
                        'offset': pos,
                        'length': length,
                        'precision': precision,
                        'height': height,
                        'width': width,
                        'components': components,
                        'component_offset': pos + 9
                    }
                    
                    found_streams.append(stream_info)
                    
                    print(f"      Length: {length} bytes")
                    print(f"      Precision: {precision} bits")
                    print(f"      Dimensions: {width} x {height}")
                    print(f"      Components: {components} (byte at 0x{pos+9:X})")
                    
                    if 'jpeg_streams' not in self.vulnerability_data:
                        self.vulnerability_data['jpeg_streams'] = []
                    self.vulnerability_data['jpeg_streams'].append(stream_info)
                    
                except (struct.error, IndexError):
                    print(f"      Error parsing SOF3 at 0x{pos:X}")
            
            pos += 1
        
        if not found_streams:
            print("   No JPEG Lossless streams found")
        
        return found_streams
    
    def analyze_vulnerability(self):
        print(f"\n{'='*60}")
        print("VULNERABILITY ANALYSIS")
        print(f"{'='*60}")
        
        samples_per_pixel = self.vulnerability_data.get('samples_per_pixel')
        jpeg_streams = self.vulnerability_data.get('jpeg_streams', [])
        
        if not samples_per_pixel:
            print("No SamplesPerPixel metadata found")
            return
            
        if not jpeg_streams:
            print("No JPEG Lossless streams found")
            return
        
        print(f"Current state:")
        print(f"   SamplesPerPixel (metadata): {samples_per_pixel['value']}")
        print(f"   Metadata location: 0x{samples_per_pixel['absolute_offset']:X}")
        
        for i, stream in enumerate(jpeg_streams):
            print(f"   JPEG Stream {i+1} components: {stream['components']}")
            print(f"   Component count location: 0x{stream['component_offset']:X}")
        
        vulnerable = False
        for stream in jpeg_streams:
            if samples_per_pixel['value'] != stream['components']:
                print(f"\nPOTENTIAL VULNERABILITY DETECTED!")
                print(f"   Metadata says: {samples_per_pixel['value']} components")
                print(f"   JPEG stream says: {stream['components']} components")
                print(f"   This mismatch can cause buffer allocation/write issues!")
                vulnerable = True
        
        if not vulnerable:
            print(f"\nFile appears consistent (no immediate vulnerability)")
            print(f"   To create test case for research:")
            print(f"   1. Modify byte at 0x{samples_per_pixel['absolute_offset']:X} (SamplesPerPixel)")
            for i, stream in enumerate(jpeg_streams):
                print(f"   2. Modify byte at 0x{stream['component_offset']:X} (Stream {i+1} components)")
    
    def generate_poc_offsets(self):
        print(f"\nPOC GENERATION OFFSETS:")
        print(f"{'='*40}")
        
        samples_per_pixel = self.vulnerability_data.get('samples_per_pixel')
        jpeg_streams = self.vulnerability_data.get('jpeg_streams', [])
        
        if samples_per_pixel:
            current_val = samples_per_pixel['value']
            offset = samples_per_pixel['absolute_offset']
            print(f"SamplesPerPixel modification:")
            print(f"  Offset: 0x{offset:X}")
            print(f"  Current: 0x{current_val:02X}")
            print(f"  Suggested change: 0x{current_val:02X} -> 0x{current_val+1 if current_val < 255 else current_val-1:02X}")
        
        for i, stream in enumerate(jpeg_streams):
            current_val = stream['components']
            offset = stream['component_offset']
            print(f"JPEG Stream {i+1} component count:")
            print(f"  Offset: 0x{offset:X}")
            print(f"  Current: 0x{current_val:02X}")
            print(f"  Suggested change: 0x{current_val:02X} -> 0x{current_val-1 if current_val > 1 else current_val+1:02X}")
    
    def run_analysis(self):
        print(f"Analyzing DNG file: {self.filepath}")
        print(f"File size: {len(self.data):,} bytes ({len(self.data)/1024/1024:.2f} MB)")
        
        if not self.find_tiff_header():
            return
        
        first_ifd_offset = struct.unpack(f'{self.endian}I', 
                                        self.data[self.tiff_offset+4:self.tiff_offset+8])[0]
        
        current_ifd = first_ifd_offset
        ifd_count = 0
        
        while current_ifd and ifd_count < 10:
            next_ifd = self.read_ifd(current_ifd, f"IFD_{ifd_count}")
            if next_ifd and next_ifd != current_ifd:
                current_ifd = next_ifd
                ifd_count += 1
            else:
                break
        
        self.find_jpeg_lossless_streams()
        self.analyze_vulnerability()
        self.generate_poc_offsets()

def main():
    if len(sys.argv) != 2:
        print("Usage: python3 dng_vulnerability_analyzer.py <dng_file>")
        print("\nExample: python3 dng_vulnerability_analyzer.py IMGP0847.DNG")
        sys.exit(1)
    
    try:
        analyzer = DNGVulnAnalyzer(sys.argv[1])
        analyzer.run_analysis()
        
        print(f"\n{'='*60}")
        print("SECURITY REMINDER:")
        print("- Only test modified files on isolated devices you own")
        print("- Never share modified files publicly")
        print("- Share analysis results and diffs only for research")
        print("- Follow responsible disclosure practices")
        print(f"{'='*60}")
        
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()