#include "as_parser.h"
#include "as_tokenizer.h"
#include "utils.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// BNF
// ===
//
// <prog> ::= <stmts>
// <stmts> ::= <stmt> newline | <stmt> newline <stmts>
// <stmt> ::= <label> <instr> | <instr> | <label>
// <instr> ::= <op> | <op> arg | <op> tag
// <label> ::= tag ":"
// <op> ::= "add" | "sub" | "mul" | "div" | "mod" | "eq" | ...

struct result parse_prog(struct allocator * alct, struct token_stream * ts) {
    struct result result;
    struct prog * p = allocate(alct, sizeof(struct prog));
    result = parse_stmts(alct, ts);
    if (result.errmsg != NULL) return result;
    p->stmts = result.value;
    return (struct result){.value = p, .errmsg = NULL};
}

struct result parse_stmts(struct allocator * alct, struct token_stream * ts) {
    struct token *token;
    struct result result;
    const char* errmsg;
    struct stmts * ss = allocate(alct, sizeof(struct stmts));
    ss->stmts = allocate(alct, sizeof(struct stmt *));
    ss->stmts[0] = NULL;
    int capacity = 0;
    int len = 0;

    while (1) {
        result = peek_token(alct, ts);
        if (result.errmsg != NULL) return result;
        token = result.value;
        if (token->type == TK_ENDOFFILE) {
            break;
        }
        
        result = parse_stmt(alct, ts);
        if (result.errmsg != NULL) {
            return result;
        }
        struct stmt * s = result.value;
        if (s == NULL) continue;
        if (len == capacity) {
            size_t new_capacity = capacity * 2 + 1;
            void* new_stmts = allocate(alct, sizeof(struct stmt **) * new_capacity);
            memcpy(new_stmts, ss->stmts, sizeof(struct stmt **) * capacity);
            ss->stmts = new_stmts;
            capacity = new_capacity;
        }
        // expect newline
        result = peek_token(alct, ts);
        if (result.errmsg != NULL) return result;
        token = result.value;
        if (token->type == TK_NEWLINE) {
            result = next_token(alct, ts);
            if (result.errmsg != NULL) return result;
        } else {
            errmsg = safe_sprintf(alct, "%d:%d expect newline.\n", token->line, token->col);
            return (struct result){.value = NULL, .errmsg = errmsg};
        }
        ss->stmts[len] = s;
        len++;
    }
    ss->stmts[len] = NULL;
    return (struct result){.value = ss, .errmsg = NULL};
}

struct result parse_label(struct allocator * alct, struct token_stream * ts) {
    const char *errmsg;
    struct result result;
    struct token * t;
    result = next_token(alct, ts);
    if (result.errmsg != NULL) return result;
    t = result.value;
    if (t->type != TK_TAG) {
        errmsg = safe_sprintf(alct, "%d:%d expect label.\n", t->line, t->col);
        return (struct result){.value = NULL, .errmsg = errmsg};
    }
    struct label * l = allocate(alct, sizeof(struct label *));
    l->name = t->sval;
    result = next_token(alct, ts);
    if (result.errmsg != NULL) return result;
    t = result.value;
    if (t->type != TK_COLON) {
        errmsg = safe_sprintf(alct, "%d:%d expect colon.\n", t->line, t->col);
        return (struct result){.value = NULL, .errmsg = errmsg};
    }
    return (struct result){.value = l, .errmsg = NULL};
}

struct result parse_stmt(struct allocator * alct, struct token_stream * ts) {
    const char *errmsg;
    struct result result;
    struct token * t;
    result = peek_token(alct, ts);
    if (result.errmsg != NULL) return result;
    t = result.value;
    struct stmt * stmt = allocate(alct, sizeof(struct stmt));
    stmt->label = NULL;
    stmt->instr = NULL;
    if (t->type == TK_TAG) {
        result = parse_label(alct, ts);
        if (result.errmsg != NULL) return result;
        stmt->label = result.value;
        result = peek_token(alct, ts);
        if (result.errmsg != NULL) return result;
        t = result.value;
        if (t->type == TK_NEWLINE) {
            return (struct result){.value = stmt, .errmsg = NULL};
        }
        result = peek_token(alct, ts);
        if (result.errmsg != NULL) return result;
        t = result.value;
    }
    if (t->type == TK_OP) {
        result = parse_instr(alct, ts);
        if (result.errmsg != NULL) return result;
        stmt->instr = result.value;

        result = peek_token(alct, ts);
        if (result.errmsg != NULL) return result;
        t = result.value;
        if (t->type == TK_NEWLINE) {
            return (struct result){.value = stmt, .errmsg = NULL};
        }
    }
    if (t->type == TK_NEWLINE) {
        return (struct result){.value = NULL, .errmsg = NULL};
    }
    errmsg = safe_sprintf(alct, "%d:%d expect lable + instruction, lable, or instruction.\n", t->line, t->col);
    return (struct result){.value = NULL, .errmsg = errmsg};
}

struct result parse_op(struct allocator * alct, struct token_stream * ts) {
    const char *errmsg;
    struct result result;
    struct token * t;
    result = next_token(alct, ts);
    if (result.errmsg != NULL) return result;
    t = result.value;
    enum op op;
    if (t->type == TK_OP) {
        op = str2op(t->sval);
        if (op == OP_END) {
            errmsg = safe_sprintf(alct, "%d:%d invalid op.\n", t->line, t->col);
            return (struct result){.value = NULL, .errmsg = errmsg};
        }
    } else {
        errmsg = safe_sprintf(alct, "%d:%d expect op.\n", t->line, t->col);
        return (struct result){.value = NULL, .errmsg = errmsg};
    }
    return (struct result){.value = (void*)op, .errmsg = NULL};
}

struct result parse_instr(struct allocator * alct, struct token_stream * ts) {
    struct result result;
    struct token * t;
    result = peek_token(alct, ts);
    if (result.errmsg != NULL) return result;
    t = result.value;
    struct instr * i = allocate(alct, sizeof(struct instr));
    i->tag_name = NULL;
    i->arg = NULL;
    i->op = OP_END;
    if (t->type == TK_OP) {
        result = parse_op(alct, ts);
        i->op = (enum op)(result.value);
        result = peek_token(alct, ts);
        if (result.errmsg != NULL) return result;
        t = result.value;
        if (t->type == TK_ARG) {
            struct arg * a = allocate(alct, sizeof(struct arg));
            a->ival = t->ival;
            a->fval = t->fval;
            i->arg = a;
            next_token(alct, ts);
        } else if (t->type == TK_TAG) {
            i->tag_name = t->sval;
            next_token(alct, ts);
        }
    }
    return (struct result){.value = i, .errmsg = NULL};
}