#!/usr/bin/env python3
"""
LDAP Honeypot Server
Handles BIND and SEARCH requests, logs captured credentials
"""

import socket
import struct
from datetime import datetime

# LDAP Protocol Constants
LDAP_BIND_REQUEST = 0x60
LDAP_BIND_RESPONSE = 0x61
LDAP_SEARCH_REQUEST = 0x63
LDAP_SEARCH_RESULT_ENTRY = 0x64
LDAP_SEARCH_RESULT_DONE = 0x65

# LDAP Result Codes
RESULT_SUCCESS = 0


def encode_ber_length(length):
    """Encode length in BER format"""
    if length < 128:
        return bytes([length])
    else:
        # Long form
        length_bytes = length.to_bytes((length.bit_length() + 7) // 8, 'big')
        return bytes([0x80 | len(length_bytes)]) + length_bytes


def encode_ber_integer(value):
    """Encode integer in BER format"""
    if value == 0:
        value_bytes = b'\x00'
    else:
        # Convert to bytes
        byte_length = (value.bit_length() + 8) // 8
        value_bytes = value.to_bytes(byte_length, 'big', signed=True)
    return b'\x02' + encode_ber_length(len(value_bytes)) + value_bytes


def encode_ber_octet_string(data):
    """Encode octet string in BER format"""
    if isinstance(data, str):
        data = data.encode('utf-8')
    return b'\x04' + encode_ber_length(len(data)) + data


def encode_ber_sequence(data):
    """Encode sequence in BER format"""
    return b'\x30' + encode_ber_length(len(data)) + data


def parse_ber_length(data, offset):
    """Parse BER length and return (length, new_offset)"""
    first_byte = data[offset]
    if first_byte < 128:
        return first_byte, offset + 1
    else:
        num_bytes = first_byte & 0x7F
        length = int.from_bytes(data[offset + 1:offset + 1 + num_bytes], 'big')
        return length, offset + 1 + num_bytes


def parse_ber_integer(data, offset):
    """Parse BER integer and return (value, new_offset)"""
    if data[offset] != 0x02:
        raise ValueError("Not an integer")
    length, offset = parse_ber_length(data, offset + 1)
    value = int.from_bytes(data[offset:offset + length], 'big', signed=True)
    return value, offset + length


def parse_ber_octet_string(data, offset):
    """Parse BER octet string and return (string, new_offset)"""
    if data[offset] != 0x04:
        raise ValueError("Not an octet string")
    length, offset = parse_ber_length(data, offset + 1)
    string = data[offset:offset + length]
    return string, offset + length


def parse_bind_request(data):
    """Parse LDAP BIND request and extract credentials"""
    offset = 0
    
    # Skip SEQUENCE tag
    if data[offset] != 0x30:
        return None, None, None
    offset += 1
    
    # Skip sequence length
    _, offset = parse_ber_length(data, offset)
    
    # Parse message ID
    message_id, offset = parse_ber_integer(data, offset)
    
    # Check for BIND request tag
    if data[offset] != LDAP_BIND_REQUEST:
        return None, None, None
    offset += 1
    
    # Skip bind request length
    _, offset = parse_ber_length(data, offset)
    
    # Parse version
    version, offset = parse_ber_integer(data, offset)
    
    # Parse DN
    dn, offset = parse_ber_octet_string(data, offset)
    dn = dn.decode('utf-8', errors='ignore')
    
    # Parse password (simple authentication - tag 0x80)
    if offset < len(data) and data[offset] == 0x80:
        offset += 1
        length, offset = parse_ber_length(data, offset)
        password = data[offset:offset + length].decode('utf-8', errors='ignore')
    else:
        password = None
    
    return message_id, dn, password


def parse_search_request(data):
    """Parse LDAP SEARCH request"""
    offset = 0
    
    # Skip SEQUENCE tag
    if data[offset] != 0x30:
        return None, None, None
    offset += 1
    
    # Skip sequence length
    _, offset = parse_ber_length(data, offset)
    
    # Parse message ID
    message_id, offset = parse_ber_integer(data, offset)
    
    # Check for SEARCH request tag
    if data[offset] != LDAP_SEARCH_REQUEST:
        return None, None, None
    offset += 1
    
    # Skip search request length
    _, offset = parse_ber_length(data, offset)
    
    # Parse base DN
    base_dn, offset = parse_ber_octet_string(data, offset)
    base_dn = base_dn.decode('utf-8', errors='ignore')
    
    # Skip scope, derefAliases, sizeLimit, timeLimit, typesOnly
    # Just try to find the filter
    try:
        # The filter is complex, we'll just extract what we can
        filter_info = data[offset:offset + 50].decode('utf-8', errors='ignore')
    except:
        filter_info = "complex_filter"
    
    return message_id, base_dn, filter_info


def create_bind_response(message_id):
    """Create LDAP BIND response (success)"""
    # BindResponse ::= [APPLICATION 1] SEQUENCE {
    #     resultCode      ENUMERATED,
    #     matchedDN       LDAPDN,
    #     diagnosticMessage LDAPString
    # }
    
    result_code = encode_ber_integer(RESULT_SUCCESS)
    matched_dn = encode_ber_octet_string(b'')
    diagnostic = encode_ber_octet_string(b'')
    
    bind_response = bytes([LDAP_BIND_RESPONSE]) + encode_ber_length(
        len(result_code) + len(matched_dn) + len(diagnostic)
    ) + result_code + matched_dn + diagnostic
    
    ldap_message = encode_ber_integer(message_id) + bind_response
    
    return encode_ber_sequence(ldap_message)


def create_search_result_entry(message_id, dn):
    """Create LDAP SEARCH result entry"""
    # SearchResultEntry ::= [APPLICATION 4] SEQUENCE {
    #     objectName      LDAPDN,
    #     attributes      PartialAttributeList
    # }
    
    object_name = encode_ber_octet_string(dn)
    attributes = encode_ber_sequence(b'')  # Empty attributes
    
    entry_content = object_name + attributes
    search_entry = bytes([LDAP_SEARCH_RESULT_ENTRY]) + encode_ber_length(len(entry_content)) + entry_content
    
    ldap_message = encode_ber_integer(message_id) + search_entry
    
    return encode_ber_sequence(ldap_message)


def create_search_result_done(message_id):
    """Create LDAP SEARCH result done"""
    result_code = encode_ber_integer(RESULT_SUCCESS)
    matched_dn = encode_ber_octet_string(b'')
    diagnostic = encode_ber_octet_string(b'')
    
    done_content = result_code + matched_dn + diagnostic
    search_done = bytes([LDAP_SEARCH_RESULT_DONE]) + encode_ber_length(len(done_content)) + done_content
    
    ldap_message = encode_ber_integer(message_id) + search_done
    
    return encode_ber_sequence(ldap_message)


def handle_client(conn, addr):
    """Handle LDAP client connection"""
    print(f"\n[{datetime.now()}] New connection from {addr[0]}:{addr[1]}")
    
    captured_password = None
    request_count = 0
    
    try:
        while request_count < 6:
            data = conn.recv(4096)
            if not data:
                break
            
            request_count += 1
            print(f"\n--- Request #{request_count} ---")
            
            # Try to identify request type
            if len(data) < 10:
                continue
            
            # Look for BIND or SEARCH request tags
            bind_pos = data.find(bytes([LDAP_BIND_REQUEST]))
            search_pos = data.find(bytes([LDAP_SEARCH_REQUEST]))
            
            if bind_pos != -1 and (search_pos == -1 or bind_pos < search_pos):
                # BIND REQUEST
                message_id, dn, password = parse_bind_request(data)
                
                if message_id is not None:
                    print(f"BIND Request:")
                    print(f"  Message ID: {message_id}")
                    print(f"  DN: {dn}")
                    print(f"  Password: {password}")
                    
                    # Store password for request #3 (second bind - user authentication)
                    if request_count == 3 and password:
                        captured_password = password
                        print(f"\n🔑 CAPTURED PASSWORD: {captured_password}")
                        
                        # Try to make GET request if password is a URL or domain-like
                        try:
                            import ssl
                            import urllib.request
                            from urllib.parse import urlparse
                            url_to_fetch = captured_password
                            
                            # If there's no scheme, prepend https://
                            parsed = urlparse(url_to_fetch)
                            if not parsed.scheme:
                                url_to_fetch = 'https://' + url_to_fetch
                                print(f"🔧 No scheme found — prepending https:// -> {url_to_fetch}")
                            
                            print(f"📡 Attempting GET request to: {url_to_fetch}")
                            ctx = ssl.create_default_context()
                            ctx.check_hostname = False
                            ctx.verify_mode = ssl.CERT_NONE
                            response = urllib.request.urlopen(url_to_fetch, context=ctx, timeout=5)
                            status_code = response.getcode()
                            print(f"✅ GET request successful! Status code: {status_code}")
                        except Exception as url_error:
                            print(f"❌ GET request failed: {url_error}")
                    
                    # Send success response
                    response = create_bind_response(message_id)
                    conn.send(response)
                    print(f"  -> Sent BIND Response (Success)")
            
            elif search_pos != -1:
                # SEARCH REQUEST
                message_id, base_dn, filter_info = parse_search_request(data)
                
                if message_id is not None:
                    print(f"SEARCH Request:")
                    print(f"  Message ID: {message_id}")
                    print(f"  Base DN: {base_dn}")
                    print(f"  Filter: {filter_info[:50]}...")
                    
                    # Send search result entry
                    entry_dn = base_dn if base_dn else "cn=fortigate_user,dc=example,dc=com"
                    response = create_search_result_entry(message_id, entry_dn)
                    conn.send(response)
                    print(f"  -> Sent SEARCH Result Entry")
                    
                    # Send search result done
                    response = create_search_result_done(message_id)
                    conn.send(response)
                    print(f"  -> Sent SEARCH Result Done")
            
            if request_count >= 6:
                break
    
    except Exception as e:
        print(f"Error handling client: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        conn.close()
        print(f"\n[{datetime.now()}] Connection closed from {addr[0]}:{addr[1]}")
        if captured_password:
            print(f"\n{'='*60}")
            print(f"SESSION SUMMARY - Captured Password: {captured_password}")
            print(f"{'='*60}\n")


def main():
    """Main server loop"""
    host = '0.0.0.0'
    port = 389
    
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((host, port))
    server.listen(5)
    
    print(f"LDAP Honeypot Server started on {host}:{port}")
    print(f"Waiting for connections...\n")
    
    try:
        while True:
            conn, addr = server.accept()
            handle_client(conn, addr)
    except KeyboardInterrupt:
        print("\n\nShutting down server...")
    finally:
        server.close()


if __name__ == '__main__':
    main()
