
#include <algorithm>
#include <array>
#include <charconv>
#include <chrono>
#include <fstream>
#include <llvm/IR/BasicBlock.h>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <variant>
#include <vector>
// I still hate iostream.
#include <iostream>

#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/DenseSet.h>
#include <llvm/ADT/PostOrderIterator.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringMap.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Support/Allocator.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Transforms/Utils/BasicBlockUtils.h>
#include <llvm/Transforms/Utils/Local.h>

#define UNLIKELY(x) (__builtin_expect((x), 0))

namespace {

struct SrcLoc {
  std::size_t start, end;

  SrcLoc operator|(const SrcLoc &other) const {
    return SrcLoc{std::min(start, other.start), std::max(end, other.end)};
  }
};

struct Error {
  SrcLoc loc;
  std::string desc;
};

struct Token {
  enum Kind {
    // clang-format off
    T_EOF, ERROR,

    // Operators
    COMMA, SEMICOLON, AT, LPAREN, RPAREN, LCURLY, RCURLY, LBRACKET, RBRACKET,
    LANGLE, RANGLE, EQ, EQEQ, LE, GE, SL, SR, INV, XOR, PLUS, PLUSEQ, MINUS,
    MINUSEQ, TIMES, TIMESEQ, DIV, DIVEQ, REM, REMEQ, NOT, NOTEQ, AND, ANDAND,
    OR, OROR,

    // Generic things..
    NUMBER, IDENT,

    // Keywords
    AUTO, REGISTER, IF, ELSE, WHILE, RETURN,

    MAX
    // clang-format on
  };

  Kind kind;
  SrcLoc loc;
};

class Tokenizer {
  const std::string_view prog;
  std::size_t pos;
  std::vector<std::size_t> lineBreaks;

public:
  Tokenizer(std::string_view prog) : prog(prog), pos(0) {}

private:
  enum { W = 1, L, I, N, S, Q, C };
  // Lookup table to classify characters. 0=invalid, W=whitespace, L=line break,
  // I=identifier start, N=number, S=special symbol, Q=quote, C=maybe comment
  static constexpr uint8_t CHAR_LUT[1 << CHAR_BIT] = {
      // clang-format off
      0,0,0,0,0,0,0,0,0,W,L,0,0,W,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
      W,S,Q,S,S,S,S,Q,S,S,S,S,S,S,S,C,N,N,N,N,N,N,N,N,N,N,S,S,S,S,S,S,
      S,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,S,S,S,S,I,
      S,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,I,S,S,S,S,0,
      // clang-format on
  };

  Token tokenizeIdentifier() {
    std::size_t start = pos;
    for (; pos < prog.size(); pos++) {
      auto lv = CHAR_LUT[(unsigned char)prog[pos]];
      if (lv != I && lv != N)
        break;
    }
    std::string_view val = prog.substr(start, pos - start);
    auto kind = Token::IDENT;
    // There are only a few keywords => the inefficient approach is good enough.
    if (val == "auto")
      kind = Token::AUTO;
    else if (val == "register")
      kind = Token::REGISTER;
    else if (val == "if")
      kind = Token::IF;
    else if (val == "else")
      kind = Token::ELSE;
    else if (val == "while")
      kind = Token::WHILE;
    else if (val == "return")
      kind = Token::RETURN;
    return Token{kind, SrcLoc{start, pos}};
  }

  Token tokenizeNumber() {
    std::size_t start = pos;
    while (pos < prog.size() && CHAR_LUT[(unsigned char)prog[pos]] == N)
      pos++;
    return Token{Token::NUMBER, SrcLoc{start, pos}};
  }

  Token tokenizeOperator() {
    auto kind = Token::ERROR;
    auto next = pos + 1 < prog.size() ? prog[pos + 1] : 0;
    std::size_t start = pos;
    std::size_t len = 1;
    switch (prog[pos]) {
    case ',': kind = Token::COMMA; break;
    case ';': kind = Token::SEMICOLON; break;
    case '@': kind = Token::AT; break;
    case '(': kind = Token::LPAREN; break;
    case ')': kind = Token::RPAREN; break;
    case '[': kind = Token::LBRACKET; break;
    case ']': kind = Token::RBRACKET; break;
    case '{': kind = Token::LCURLY; break;
    case '}': kind = Token::RCURLY; break;
    case '<':
      kind = next == '='   ? (len++, Token::LE)
             : next == '<' ? (len++, Token::SL)
                           : Token::LANGLE;
      break;
    case '>':
      kind = next == '='   ? (len++, Token::GE)
             : next == '>' ? (len++, Token::SR)
                           : Token::RANGLE;
      break;
    case '=': kind = next == '=' ? (len++, Token::EQEQ) : Token::EQ; break;
    case '!': kind = next == '=' ? (len++, Token::NOTEQ) : Token::NOT; break;
    case '+': kind = next == '=' ? (len++, Token::PLUSEQ) : Token::PLUS; break;
    case '-':
      kind = next == '=' ? (len++, Token::MINUSEQ) : Token::MINUS;
      break;
    case '*':
      kind = next == '=' ? (len++, Token::TIMESEQ) : Token::TIMES;
      break;
    case '/': kind = next == '=' ? (len++, Token::DIVEQ) : Token::DIV; break;
    case '%': kind = next == '=' ? (len++, Token::REMEQ) : Token::REM; break;
    case '~': kind = Token::INV; break;
    case '^': kind = Token::XOR; break;
    case '&': kind = next == '&' ? (len++, Token::ANDAND) : Token::AND; break;
    case '|': kind = next == '|' ? (len++, Token::OROR) : Token::OR; break;
    }
    pos += len;
    return Token{kind, {start, pos}};
  }

public:
  Token next() {
    while (pos < prog.size()) {
      unsigned char c = prog[pos];
      switch (CHAR_LUT[c]) {
      case W: pos++; continue;
      case L: lineBreaks.push_back(pos++); continue;
      case I: return tokenizeIdentifier();
      case N: return tokenizeNumber();
      case S: return tokenizeOperator();
      case C: // slash, might be a comment
        if (pos + 1 < prog.size() && prog[pos + 1] == '/')
          for (pos += 2; pos < prog.size() && prog[pos] != '\n'; pos++) {
          }
        else
          return tokenizeOperator();
        continue;
      default:
        return Token{Token::ERROR, {pos, pos + 1}};
        // Note: quotation marks ' and " (Q) are caught here, too.
      }
    }
    return Token{Token::T_EOF, {pos, pos}};
  }

  /// Find line and column of a character offset. Runtime O(log n).
  std::pair<std::size_t, std::size_t> locate(std::size_t pos) const {
    auto lineIt = std::lower_bound(lineBreaks.begin(), lineBreaks.end(), pos);
    std::size_t line = (lineIt - lineBreaks.begin());
    return std::make_pair(line, line ? pos - *(lineIt - 1) - 1 : pos);
  }

  /// Extract a line from the input. Runtime O(1) if tokenized, otw. O(n).
  std::string_view line(std::size_t line) const {
    std::size_t start = line ? lineBreaks[line - 1] + 1 : 0;
    if (line < lineBreaks.size())
      return prog.substr(start, lineBreaks[line] - start);
    // We not necessarily have tokenized the rest of the string.
    std::string_view lineStr = prog.substr(start);
    return lineStr.substr(0, lineStr.find('\n')); // no \n => whole line
  }
};

struct IdentDesc {
  uint64_t value;

  IdentDesc(uint64_t iid, bool addressable = false)
      : value(iid << 1 | addressable) {}
  bool addressable() const { return value & 1; }
  uint64_t id() const { return value >> 1; }
};

class ASTNode {
public:
  enum Kind {
    BLOCK,   // [statements...]
    DECLREG, // [INITIALIZER] ident
    DECLVAR, // [INITIALIZER] ident
    IF,      // [expression, then, else?]
    WHILE,   // [expression, body]
    NUMBER,  // numVal
    IDENT,   // ident
    CALL,    // [arguments...] fnId
    NEG,
    NOT,
    INV,
    ADDROF, // [exp]
    ADD,
    SUB,
    MUL,
    DIV,
    REM,
    SL,
    SR,
    LT,
    GT,
    LE,
    GE,
    EQ,
    NOTEQ,
    AND,
    OR,
    XOR,
    ANDAND,
    OROR,
    ASSIGN,    // [exp, exp],
    SUBSCRIPT, // [exp, exp],
    RETURN,    // [exp?]
    ERROR,     // error node, only for invalid inputs
  };

  const Kind kind;
  const SrcLoc loc;
  ASTNode *child;
  ASTNode *sibling;
  union {
    uint64_t granularity;
    uint64_t numVal;
    uint64_t fnId;
    IdentDesc ident;
  };

  ASTNode(Kind kind, SrcLoc loc, ASTNode *child)
      : kind(kind), loc(loc), child(child), sibling(nullptr) {}

  void printSexpr() const {
    std::string_view name = "???";
    switch (kind) {
    case IDENT: std::cout << "v" << ident.id(); return;
    case NUMBER:
      std::cout << numVal;
      return;

      // regular cases
    case BLOCK: name = "block"; break;
    case DECLREG: name = "register"; break;
    case DECLVAR: name = "auto"; break;
    case IF: name = "if"; break;
    case WHILE: name = "while"; break;
    case CALL: name = "call"; break;
    case NEG: name = "u-"; break;
    case NOT: name = "!"; break;
    case INV: name = "~"; break;
    case ADDROF: name = "&"; break;
    case ADD: name = "+"; break;
    case SUB: name = "-"; break;
    case MUL: name = "*"; break;
    case DIV: name = "/"; break;
    case REM: name = "%"; break;
    case SL: name = "<<"; break;
    case SR: name = ">>"; break;
    case LT: name = "<"; break;
    case GT: name = ">"; break;
    case LE: name = "<="; break;
    case GE: name = ">="; break;
    case EQ: name = "=="; break;
    case NOTEQ: name = "!="; break;
    case AND: name = "&"; break;
    case OR: name = "|"; break;
    case XOR: name = "^"; break;
    case ANDAND: name = "&&"; break;
    case OROR: name = "||"; break;
    case ASSIGN: name = "="; break;
    case SUBSCRIPT: name = "[]"; break;
    case RETURN: name = "return"; break;
    case ERROR: name = "ERROR"; break;
    }
    std::cout << "(" << name;
    if (kind == CALL)
      std::cout << " f" << fnId;
    if (kind == DECLVAR || kind == DECLREG)
      std::cout << " v" << ident.id();
    for (const ASTNode *cur = child; cur; cur = cur->sibling) {
      std::cout << " ";
      cur->printSexpr();
    }
    std::cout << ")";
  }
};

struct Function {
  /// Name
  std::string_view name;
  /// Number of parameters, identifiers 0..<numParams
  const unsigned numParams;
  /// Maximum identifier ID
  unsigned maxIid = 0;

  /// AST, or null if the function is just declared
  const ASTNode *ast = nullptr;

  Function(std::string_view name, unsigned numParams)
      : name(name), numParams(numParams) {}
};

struct Program {
  /// ASTNode allocator. Use a bump ptr allocator to avoid calling malloc/free
  /// for every created node.
  llvm::SpecificBumpPtrAllocator<ASTNode> nodeAllocator;

  /// All declared and defined functions
  std::vector<Function> funcs;
};

template <class T> class Scopes {
  struct IdentEntry {
    unsigned nest;
    unsigned gen;
    T payload;
  };

  unsigned currentNest = 0;
  // LLVM's data structures are much better than libstdc++ (or libc++).
  llvm::SmallVector<unsigned> gens;
  llvm::StringMap<llvm::SmallVector<IdentEntry>> map;

public:
  Scopes() : gens(1), map() {}

  void nest() {
    currentNest++;
    if (gens.size() <= currentNest)
      gens.push_back(0);
  }
  void unnest() { gens[currentNest--]++; }
  bool tryDeclare(std::string_view name, T payload) {
    auto &entry = lookupImpl(name);
    if (!entry.empty() && entry.back().nest == currentNest)
      return false;
    entry.emplace_back(currentNest, gens[currentNest], payload);
    return true;
  }
  T *lookup(std::string_view name) {
    auto &entry = lookupImpl(name);
    return entry.empty() ? nullptr : &entry.back().payload;
  }

private:
  decltype(map)::mapped_type &lookupImpl(std::string_view name) {
    auto &entry = map[name];
    while (!entry.empty() && gens[entry.back().nest] != entry.back().gen)
      entry.pop_back();
    return entry;
  }
};

class Parser {
  Program program;

  std::string_view prog;
  Tokenizer t;
  std::optional<Token> peekedToken;
  std::vector<Error> errors;

  uint64_t nextIid = 0;
  Scopes<IdentDesc> scope;

  // Map of function name to function id
  llvm::StringMap<std::size_t> funcMap;

public:
  Parser(std::string_view prog) : prog(prog), t(prog) {}

private:
  Token next() {
    if (peekedToken) {
      Token res = peekedToken.value();
      peekedToken = std::nullopt;
      return res;
    }

    return t.next();
  }
  Token peek() {
    if (!peekedToken)
      peekedToken = next();
    return peekedToken.value();
  }
  bool eof() { return peek().kind == Token::T_EOF; }
  Token expectNext(Token::Kind kind) {
    Token nextToken = next();
    if (UNLIKELY(nextToken.kind != kind)) {
      errors.emplace_back(nextToken.loc, "unexpected token");
      do {
        nextToken = next();
      } while (nextToken.kind != kind && nextToken.kind != Token::T_EOF);
    }
    return nextToken;
  }

  std::string_view tokenValue(Token tok) {
    return prog.substr(tok.loc.start, tok.loc.end - tok.loc.start);
  }

  ASTNode *createNode(ASTNode::Kind kind, SrcLoc loc,
                      ASTNode *children = nullptr) {
    return new (program.nodeAllocator.Allocate()) ASTNode(kind, loc, children);
  }

  /// Declare function with number of parameters and return function id.
  unsigned declareFunc(SrcLoc loc, std::string_view name, std::size_t params) {
    auto [it, inserted] = funcMap.try_emplace(name, program.funcs.size());
    if (inserted)
      program.funcs.emplace_back(Function(name, params));
    else if (program.funcs[it->second].numParams != params)
      errors.emplace_back(loc, "function redeclared with parameter mismatch");
    return it->second;
  }

  ASTNode *parsePrimaryExpression() {
    Token tok = next();
    switch (tok.kind) {
      ASTNode::Kind nodeKind;

    case Token::LPAREN: {
      auto exp = parseExpression();
      (void)expectNext(Token::RPAREN);
      return exp;
    }
    case Token::IDENT: {
      std::string_view identName = tokenValue(tok);
      if (peek().kind != Token::LPAREN) {
        auto node = createNode(ASTNode::IDENT, tok.loc);
        if (IdentDesc *var = scope.lookup(identName))
          node->ident = *var;
        else
          errors.emplace_back(tok.loc, "undeclared variable");
        return node;
      }

      // function call
      (void)next();
      ASTNode *args = nullptr;
      ASTNode **lastArgPtr = &args;
      std::size_t argCount = 0;
      while (!eof() && peek().kind != Token::RPAREN) {
        if (args)
          expectNext(Token::COMMA);
        *lastArgPtr = parseExpression();
        lastArgPtr = &(*lastArgPtr)->sibling;
        argCount++;
      }
      Token rparen = expectNext(Token::RPAREN);

      unsigned fnId = declareFunc(tok.loc | rparen.loc, identName, argCount);
      auto node = createNode(ASTNode::CALL, tok.loc | rparen.loc, args);
      node->fnId = fnId;
      return node;
    }
    case Token::NUMBER: {
      auto node = createNode(ASTNode::NUMBER, tok.loc);
      auto str = tokenValue(tok);
      auto convres =
          std::from_chars(str.data(), str.data() + str.size(), node->numVal);
      (void)convres; // TODO: overflow handling
      return node;
    }
    case Token::MINUS: nodeKind = ASTNode::NEG; goto unaryCommon;
    case Token::NOT: nodeKind = ASTNode::NOT; goto unaryCommon;
    case Token::INV: nodeKind = ASTNode::INV; goto unaryCommon;
    case Token::AND:
      nodeKind = ASTNode::ADDROF;
      goto unaryCommon;
    unaryCommon: {
      auto exp = parseExpression(13);
      auto loc = tok.loc | exp->loc;
      if (nodeKind == ASTNode::ADDROF) {
        if (exp->kind != ASTNode::IDENT && exp->kind != ASTNode::SUBSCRIPT)
          errors.emplace_back(loc, "cannot take address of non-lvalue");
        if (exp->kind == ASTNode::IDENT && !exp->ident.addressable())
          errors.emplace_back(loc, "cannot take address of register");
      }
      return createNode(nodeKind, loc, exp);
    }

    default:
      errors.emplace_back(tok.loc, "invalid token for expression");
      // Just try to keep going...
      return createNode(ASTNode::ERROR, tok.loc);
    }
  }

  static constexpr std::array<std::tuple<ASTNode::Kind, uint8_t, bool>,
                              Token::MAX>
  getOperators() {
    std::array<std::tuple<ASTNode::Kind, uint8_t, bool>, Token::MAX> res;
    res[Token::LBRACKET] = {ASTNode::SUBSCRIPT, 14, false};
    res[Token::TIMES] = {ASTNode::MUL, 12, false};
    res[Token::DIV] = {ASTNode::DIV, 12, false};
    res[Token::REM] = {ASTNode::REM, 12, false};
    res[Token::PLUS] = {ASTNode::ADD, 11, false};
    res[Token::MINUS] = {ASTNode::SUB, 11, false};
    res[Token::SL] = {ASTNode::SL, 10, false};
    res[Token::SR] = {ASTNode::SR, 10, false};
    res[Token::LANGLE] = {ASTNode::LT, 9, false};
    res[Token::RANGLE] = {ASTNode::GT, 9, false};
    res[Token::LE] = {ASTNode::LE, 9, false};
    res[Token::GE] = {ASTNode::GE, 9, false};
    res[Token::EQEQ] = {ASTNode::EQ, 8, false};
    res[Token::NOTEQ] = {ASTNode::NOTEQ, 8, false};
    res[Token::AND] = {ASTNode::AND, 7, false};
    res[Token::XOR] = {ASTNode::XOR, 6, false};
    res[Token::OR] = {ASTNode::OR, 5, false};
    res[Token::ANDAND] = {ASTNode::ANDAND, 4, false};
    res[Token::OROR] = {ASTNode::OROR, 3, false};
    res[Token::EQ] = {ASTNode::ASSIGN, 2, true};
    return res;
  }

  ASTNode *parseExpression(int minPrec = 0) {
    ASTNode *lhs = parsePrimaryExpression();
    while (true) {
      static constexpr auto OPERATORS = getOperators();

      Token op = peek();
      auto [nodeKind, prec, rassoc] = OPERATORS[op.kind];
      if (!prec || prec < minPrec)
        break;
      (void)next();

      if (nodeKind == ASTNode::SUBSCRIPT) {
        lhs->sibling = parseExpression();
        int granularity = 0;
        if (peek().kind == Token::AT) {
          (void)next();
          Token granTok = expectNext(Token::NUMBER);
          std::string_view granStr = tokenValue(granTok);
          if (granStr == "1")
            granularity = 1;
          else if (granStr == "2")
            granularity = 2;
          else if (granStr == "4")
            granularity = 4;
          else if (granStr == "8")
            granularity = 8;
          else
            errors.emplace_back(granTok.loc, "invalid granularity");
        }
        Token rbracket = expectNext(Token::RBRACKET);
        auto loc = lhs->loc | rbracket.loc;
        lhs = createNode(nodeKind, loc, lhs);
        lhs->granularity = granularity;
        continue;
      }

      lhs->sibling = parseExpression(prec + !rassoc);
      if (nodeKind == ASTNode::ASSIGN) {
        if (lhs->kind != ASTNode::IDENT && lhs->kind != ASTNode::SUBSCRIPT)
          errors.emplace_back(lhs->loc, "lhs of assignment must be lvalue");
      }

      auto loc = lhs->loc | lhs->sibling->loc;
      lhs = createNode(nodeKind, loc, lhs);
    }
    return lhs;
  }

  ASTNode *parseStatement() {
    switch (peek().kind) {
    case Token::LCURLY: return parseBlock();
    case Token::WHILE: {
      Token start = next();
      (void)expectNext(Token::LPAREN);
      auto exp = parseExpression();
      (void)expectNext(Token::RPAREN);
      exp->sibling = parseStatement();
      SrcLoc loc = start.loc | exp->sibling->loc;
      return createNode(ASTNode::WHILE, loc, exp);
    }
    case Token::IF: {
      SrcLoc startLoc = next().loc;
      (void)expectNext(Token::LPAREN);
      ASTNode *cond = parseExpression();
      (void)expectNext(Token::RPAREN);
      cond->sibling = parseStatement();
      SrcLoc endLoc = cond->sibling->loc;
      if (peek().kind == Token::ELSE) {
        next();
        cond->sibling->sibling = parseStatement();
        endLoc = cond->sibling->sibling->loc;
      }
      return createNode(ASTNode::IF, startLoc | endLoc, cond);
    }
    case Token::RETURN: {
      SrcLoc startLoc = next().loc;
      ASTNode *child = nullptr;
      if (peek().kind != Token::SEMICOLON)
        child = parseExpression();
      SrcLoc semiLoc = expectNext(Token::SEMICOLON).loc;
      return createNode(ASTNode::RETURN, startLoc | semiLoc, child);
    }
    default: {
      ASTNode *node = parseExpression();
      (void)expectNext(Token::SEMICOLON);
      return node;
    }
    }
  }

  ASTNode *parseDeclStatement() {
    switch (peek().kind) {
    case Token::AUTO:
    case Token::REGISTER: {
      Token storage = next();
      Token nameTok = expectNext(Token::IDENT);
      (void)expectNext(Token::EQ);
      ASTNode *exp = parseExpression();
      Token semi = expectNext(Token::SEMICOLON);

      std::string_view name = tokenValue(nameTok);
      bool addressable = storage.kind != Token::REGISTER;

      IdentDesc desc{nextIid++, addressable};
      if (!scope.tryDeclare(name, desc))
        errors.emplace_back(storage.loc | nameTok.loc,
                            "redundant variable declaration");

      auto nodeKind = addressable ? ASTNode::DECLVAR : ASTNode::DECLREG;
      auto node = createNode(nodeKind, storage.loc | semi.loc, exp);
      node->ident = desc;
      return node;
    }
    default: return parseStatement();
    }
  }

  ASTNode *parseBlock(bool nestScope = true) {
    if (nestScope)
      scope.nest();
    SrcLoc lcurlyLoc = expectNext(Token::LCURLY).loc;
    ASTNode *stmts{nullptr};
    ASTNode **lastStmtPtr = &stmts;
    while (!eof() && peek().kind != Token::RCURLY) {
      *lastStmtPtr = parseDeclStatement();
      lastStmtPtr = &(*lastStmtPtr)->sibling;
    }
    SrcLoc rcurlyLoc = expectNext(Token::RCURLY).loc;
    if (nestScope)
      scope.unnest();
    return createNode(ASTNode::BLOCK, lcurlyLoc | rcurlyLoc, stmts);
  }

  void parseFunction() {
    scope.nest();

    // identifier ids are unique within a function.
    nextIid = 0;

    Token nameTok = expectNext(Token::IDENT);
    std::string_view name = tokenValue(nameTok);

    (void)expectNext(Token::LPAREN);
    while (!eof() && peek().kind != Token::RPAREN) {
      if (nextIid)
        (void)expectNext(Token::COMMA);

      Token ident = expectNext(Token::IDENT);
      IdentDesc desc{nextIid++, /*addressable=*/false};
      if (!scope.tryDeclare(tokenValue(ident), desc))
        errors.emplace_back(ident.loc, "redundant parameter declaration");
    }
    Token rparen = expectNext(Token::RPAREN);

    unsigned fnid = declareFunc(nameTok.loc | rparen.loc, name, nextIid);

    // Don't nest scope for the block, parameters belong to the innermost scope.
    program.funcs[fnid].ast = parseBlock(/*nestScope=*/false);
    program.funcs[fnid].maxIid = nextIid;

    scope.unnest();
  }

public:
  std::optional<Program> parseProgram() {
    while (!eof())
      parseFunction();
    if (errors.empty())
      return std::move(program);
    for (const Error &error : errors) {
      auto [lineStart, colStart] = t.locate(error.loc.start);
      std::string_view line = t.line(lineStart);
      // Only show a single line, if lineEnd is different, assume line.size().
      auto [lineEnd, colEnd] = t.locate(error.loc.end);
      colEnd = lineStart == lineEnd ? colEnd : line.size();

      std::size_t markLen = colEnd != colStart ? colEnd - colStart - 1 : 0;
      std::cerr << "input:" << (lineStart + 1) << ":" << (colStart + 1) << ": "
                << error.desc << "\n";
      std::cerr << line << "\n";
      std::cerr << std::string(colStart, ' ') << "^"
                << std::string(markLen, '~') << "\n";
    }
    return std::nullopt;
  }
};

std::string readFile(std::string_view path) {
  auto stream = std::ifstream(path.data(), std::ios::in);
  stream.seekg(0, std::ios::end);
  auto size = stream.tellg();
  stream.seekg(0, std::ios::beg);
  stream.clear();
  if (size != -1) {
    std::vector<char> data(size);
    stream.read(&data[0], size);
    return std::string(&data[0], size);
  }

  std::string res;
  std::vector<char> buf(0x2000);
  do {
    stream.read(&buf[0], buf.size());
    res.append(&buf[0], 0, stream.gcount());
  } while (stream.gcount());
  return res;
}

class LLVMIRGen {
  llvm::LLVMContext &ctx;
  llvm::Module *mod;
  llvm::ArrayRef<llvm::Function *> fns;
  llvm::Function *fn;
  llvm::IRBuilder<> irb;

  using VarBlockMap = llvm::DenseMap<llvm::BasicBlock *, llvm::Value *>;
  using VarDesc = std::variant<std::monostate, VarBlockMap, llvm::AllocaInst *>;
  llvm::SmallVector<VarDesc, 0> varMap;

  llvm::DenseSet<llvm::BasicBlock *> unsealedBlocks;
  using IncompletePhi = std::pair<uint64_t, llvm::PHINode *>;
  llvm::DenseMap<llvm::BasicBlock *, llvm::SmallVector<IncompletePhi>>
      incompletePhis;

  LLVMIRGen(llvm::ArrayRef<llvm::Function *> fns, llvm::Function *fn)
      : ctx(fn->getContext()), mod(fn->getParent()), fns(fns), fn(fn),
        irb(ctx) {}

  void addPhiOperands(uint64_t iid, llvm::PHINode *phi) {
    for (llvm::BasicBlock *pred : llvm::predecessors(phi->getParent()))
      phi->addIncoming(readVar(iid, pred), pred);
  }

  void writeVar(uint64_t iid, llvm::BasicBlock *block, llvm::Value *val) {
    if (auto aip = std::get_if<llvm::AllocaInst *>(&varMap[iid])) {
      irb.CreateStore(val, *aip);
    } else {
      auto &bm = std::get<VarBlockMap>(varMap[iid]);
      bm[block] = val;
    }
  }

  llvm::Value *readVar(uint64_t iid, llvm::BasicBlock *block) {
    if (auto aip = std::get_if<llvm::AllocaInst *>(&varMap[iid]))
      return irb.CreateLoad(irb.getInt64Ty(), *aip);

    auto &bm = std::get<VarBlockMap>(varMap[iid]);
    auto bit = bm.find(block);
    if (bit != bm.end())
      return bit->second;

    bool sealed = !unsealedBlocks.contains(block);
    auto *pred = block->getUniquePredecessor();
    if (pred && sealed) {
      llvm::Value *predVal = readVar(iid, pred);
      bm[block] = predVal;
      return predVal;
    }

    llvm::IRBuilder<> phiIrb(block, block->begin());
    llvm::PHINode *phi = phiIrb.CreatePHI(irb.getInt64Ty(), 2);
    bm[block] = phi;
    if (sealed)
      addPhiOperands(iid, phi);
    else
      incompletePhis[block].emplace_back(iid, phi);
    return phi;
  }

  std::pair<llvm::Type *, llvm::Value *> getSubscriptAddr(const ASTNode &node) {
    llvm::Type *ty =
        irb.getIntNTy(node.granularity ? node.granularity * 8 : 64);
    llvm::Value *base = genValue(*node.child);
    llvm::Value *idx = genValue(*node.child->sibling);
    base = irb.CreateIntToPtr(base, llvm::PointerType::get(ctx, 0));
    if (auto *cstIdx = llvm::dyn_cast<llvm::ConstantInt>(idx);
        cstIdx && cstIdx->isZero())
      return {ty, base};
    return {ty, irb.CreateGEP(ty, base, {idx})};
  }

  void changeBlock(llvm::BasicBlock *bb) {
    // Reorder blocks so that the LLVM-IR follows the program order. This is not
    // required (or even beneficial other than manual inspection of the code).
    bb->moveAfter(irb.GetInsertBlock());
    irb.SetInsertPoint(bb);
  }

  /// Generate a conditional branch based on the expression in the ASTNode.
  void genCondBr(const ASTNode &node, llvm::BasicBlock *thenBB,
                 llvm::BasicBlock *elseBB) {
    if (node.kind == ASTNode::ANDAND) {
      auto secondBB = llvm::BasicBlock::Create(ctx, "", fn);
      genCondBr(*node.child, secondBB, elseBB);
      changeBlock(secondBB);
      genCondBr(*node.child->sibling, thenBB, elseBB);
    } else if (node.kind == ASTNode::OROR) {
      auto secondBB = llvm::BasicBlock::Create(ctx, "", fn);
      genCondBr(*node.child, thenBB, secondBB);
      changeBlock(secondBB);
      genCondBr(*node.child->sibling, thenBB, elseBB);
    } else {
      llvm::Value *cond = genValueAny(node);
      if (cond->getType() != irb.getInt1Ty())
        cond = irb.CreateIsNotNull(cond);
      irb.CreateCondBr(cond, thenBB, elseBB);
    }
  }

  /// Generate LLVM-IR code for an ASTNode and return the result value, or null.
  /// Result is either i64 for integer operations or i1 for logical operations.
  llvm::Value *genValueAny(const ASTNode &node) {
    if (irb.GetInsertBlock()->getTerminator())
      return llvm::UndefValue::get(irb.getInt64Ty());

    switch (node.kind) {
      llvm::Value *val;
      llvm::Instruction::BinaryOps binOp;
      llvm::CmpInst::Predicate cmpPred;

    case ASTNode::IDENT: return readVar(node.ident.id(), irb.GetInsertBlock());
    case ASTNode::NUMBER: return irb.getInt64(node.numVal);
    case ASTNode::ASSIGN:
      val = genValue(*node.child->sibling);
      if (node.child->kind == ASTNode::IDENT) {
        writeVar(node.child->ident.id(), irb.GetInsertBlock(), val);
      } else if (node.child->kind == ASTNode::SUBSCRIPT) {
        auto [ty, ptr] = getSubscriptAddr(*node.child);
        irb.CreateStore(irb.CreateSExtOrTrunc(val, ty), ptr);
      }
      return val;
    case ASTNode::SUBSCRIPT: {
      auto [ty, ptr] = getSubscriptAddr(node);
      return irb.CreateSExtOrTrunc(irb.CreateLoad(ty, ptr), irb.getInt64Ty());
    }
    case ASTNode::ADDROF: {
      if (node.child->kind == ASTNode::IDENT) {
        auto ptr = std::get<llvm::AllocaInst *>(varMap[node.child->ident.id()]);
        return irb.CreatePtrToInt(ptr, irb.getInt64Ty());
      } else if (node.child->kind == ASTNode::SUBSCRIPT) {
        auto [_, ptr] = getSubscriptAddr(*node.child);
        return irb.CreatePtrToInt(ptr, irb.getInt64Ty());
      }
      break;
    }
    case ASTNode::NEG: return irb.CreateNeg(genValue(*node.child));
    case ASTNode::INV: return irb.CreateNot(genValue(*node.child));
    case ASTNode::NOT: return irb.CreateIsNull(genValue(*node.child)); // i1

    case ASTNode::ADD: binOp = llvm::Instruction::Add; goto binOpCommon;
    case ASTNode::SUB: binOp = llvm::Instruction::Sub; goto binOpCommon;
    case ASTNode::MUL: binOp = llvm::Instruction::Mul; goto binOpCommon;
    case ASTNode::DIV: binOp = llvm::Instruction::SDiv; goto binOpCommon;
    case ASTNode::REM: binOp = llvm::Instruction::SRem; goto binOpCommon;
    case ASTNode::SL: binOp = llvm::Instruction::Shl; goto binOpCommon;
    case ASTNode::SR: binOp = llvm::Instruction::AShr; goto binOpCommon;
    case ASTNode::AND: binOp = llvm::Instruction::And; goto binOpCommon;
    case ASTNode::OR: binOp = llvm::Instruction::Or; goto binOpCommon;
    case ASTNode::XOR:
      binOp = llvm::Instruction::Xor;
      goto binOpCommon;
    binOpCommon:
      return irb.CreateBinOp(binOp, genValue(*node.child),
                             genValue(*node.child->sibling));

    case ASTNode::EQ: cmpPred = llvm::CmpInst::ICMP_EQ; goto cmpCommon;
    case ASTNode::NOTEQ: cmpPred = llvm::CmpInst::ICMP_NE; goto cmpCommon;
    case ASTNode::LT: cmpPred = llvm::CmpInst::ICMP_SLT; goto cmpCommon;
    case ASTNode::GT: cmpPred = llvm::CmpInst::ICMP_SGT; goto cmpCommon;
    case ASTNode::LE: cmpPred = llvm::CmpInst::ICMP_SLE; goto cmpCommon;
    case ASTNode::GE:
      cmpPred = llvm::CmpInst::ICMP_SGE;
      goto cmpCommon;
    cmpCommon:
      return irb.CreateICmp(cmpPred, genValue(*node.child),
                            genValue(*node.child->sibling)); // i1

    case ASTNode::OROR:
    case ASTNode::ANDAND: {
      auto thenBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto elseBB = llvm::BasicBlock::Create(ctx, "", fn);
      genCondBr(node, thenBB, elseBB);

      auto contBB = llvm::BasicBlock::Create(ctx, "", fn);
      changeBlock(thenBB);
      irb.CreateBr(contBB);
      changeBlock(elseBB);
      irb.CreateBr(contBB);
      changeBlock(contBB);
      llvm::PHINode *phi = irb.CreatePHI(irb.getInt64Ty(), 2);
      phi->addIncoming(irb.getInt64(1), thenBB);
      phi->addIncoming(irb.getInt64(0), elseBB);
      return phi;
    }
    case ASTNode::CALL: {
      llvm::SmallVector<llvm::Value *, 4> args;
      for (const ASTNode *arg = node.child; arg; arg = arg->sibling)
        args.push_back(genValue(*arg));
      return irb.CreateCall(fns[node.fnId]->getFunctionType(), fns[node.fnId],
                            args);
    }

    // Statements
    case ASTNode::BLOCK:
      for (const ASTNode *child = node.child; child; child = child->sibling)
        (void)genValueAny(*child);
      return nullptr;
    case ASTNode::DECLREG:
      val = genValue(*node.child);
      varMap[node.ident.id()] = VarBlockMap{{irb.GetInsertBlock(), val}};
      return nullptr;
    case ASTNode::DECLVAR: {
      llvm::BasicBlock *entryBB = &fn->getEntryBlock();
      llvm::IRBuilder<> allocaIrb(entryBB, entryBB->begin());
      llvm::AllocaInst *alloca = allocaIrb.CreateAlloca(irb.getInt64Ty());
      varMap[node.ident.id()] = alloca;
      irb.CreateStore(genValue(*node.child), alloca);
      return nullptr;
    }
    case ASTNode::IF: {
      auto thenBB = llvm::BasicBlock::Create(ctx, "", fn);
      llvm::BasicBlock *elseBB = nullptr;
      if (node.child->sibling->sibling)
        elseBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto contBB = llvm::BasicBlock::Create(ctx, "", fn);

      genCondBr(*node.child, thenBB, elseBB ? elseBB : contBB);

      changeBlock(thenBB);
      (void)genValueAny(*node.child->sibling);
      if (!irb.GetInsertBlock()->getTerminator())
        irb.CreateBr(contBB);
      if (elseBB) {
        changeBlock(elseBB);
        (void)genValueAny(*node.child->sibling->sibling);
        if (!irb.GetInsertBlock()->getTerminator())
          irb.CreateBr(contBB);
      }
      // In case both branches interrupt control flow...
      if (contBB->hasNPredecessors(0))
        contBB->eraseFromParent();
      else
        changeBlock(contBB);
      return nullptr;
    }
    case ASTNode::WHILE: {
      auto headerBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto bodyBB = llvm::BasicBlock::Create(ctx, "", fn);
      auto contBB = llvm::BasicBlock::Create(ctx, "", fn);

      // Header is missing the predecessor from the loop body.
      unsealedBlocks.insert(headerBB);

      irb.CreateBr(headerBB);
      changeBlock(headerBB);
      genCondBr(*node.child, bodyBB, contBB);
      changeBlock(bodyBB);
      (void)genValueAny(*node.child->sibling);
      if (!irb.GetInsertBlock()->getTerminator())
        irb.CreateBr(headerBB);

      // Seal headerBB
      auto ipit = incompletePhis.find(headerBB);
      if (ipit != incompletePhis.end()) {
        for (const auto &[iid, phi] : ipit->second)
          addPhiOperands(iid, phi);
        incompletePhis.erase(ipit);
      }
      unsealedBlocks.erase(headerBB);

      changeBlock(contBB);
      return nullptr;
    }
    case ASTNode::RETURN:
      if (node.child)
        irb.CreateRet(genValue(*node.child));
      else
        irb.CreateRet(llvm::UndefValue::get(irb.getInt64Ty()));
      return nullptr;
    case ASTNode::ERROR:
      assert(false && "error node in valid AST!");
      return nullptr;
    }
    return nullptr;
  }

  /// Generate value for ASTNode as 64-bit integer
  llvm::Value *genValue(const ASTNode &node) {
    llvm::Value *v = genValueAny(node);
    if (v->getType() == irb.getInt64Ty())
      return v;
    return irb.CreateZExt(v, irb.getInt64Ty());
  }

  void genFunction(const Function &func) {
    llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(ctx, "", fn);
    irb.SetInsertPoint(entryBB);
    varMap.resize(func.maxIid);
    for (unsigned i = 0; i < func.numParams; i++)
      varMap[i] = VarBlockMap{{entryBB, fn->getArg(i)}};

    (void)genValueAny(*func.ast);

    if (!irb.GetInsertBlock()->getTerminator())
      irb.CreateRet(llvm::UndefValue::get(irb.getInt64Ty()));

    varMap.clear();
  }

public:
  static std::unique_ptr<llvm::Module> genIR(llvm::LLVMContext &ctx,
                                             const Program &prog) {
    auto modUP = std::make_unique<llvm::Module>("mod", ctx);
    llvm::Module *mod = modUP.get();

    llvm::Type *i64 = llvm::Type::getInt64Ty(ctx);
    auto linkage = llvm::GlobalValue::ExternalLinkage;

    llvm::SmallVector<llvm::Function *> irFuncs;
    irFuncs.reserve(prog.funcs.size());
    llvm::SmallVector<llvm::Type *, 4> argTys;
    for (const Function &func : prog.funcs) {
      argTys.resize(func.numParams, i64);
      auto fnTy = llvm::FunctionType::get(i64, argTys, false);
      irFuncs.push_back(llvm::Function::Create(fnTy, linkage, func.name, mod));
    }
    for (unsigned i = 0; i < prog.funcs.size(); i++)
      if (prog.funcs[i].ast)
        LLVMIRGen{irFuncs, irFuncs[i]}.genFunction(prog.funcs[i]);

    return modUP;
  }
};

} // end anonymous namespace

int main(int argc, char **argv) {
  enum class Mode {
    CHECK,
    AST,
    LLVM,
  } mode = Mode::AST;
  bool printTimes = false;

  int c;
  while ((c = getopt(argc, argv, "aclirSt")) != -1) {
    switch (c) {
    case 'c': mode = Mode::CHECK; break;
    case 'a': mode = Mode::AST; break;
    case 'l': mode = Mode::LLVM; break;
    case 't': printTimes = true; break;
    default:
      fprintf(stderr, "usage: %s <bfprog>\n", argv[0]);
      return EXIT_FAILURE;
    }
  }

  if (optind >= argc)
    return 1;
  auto fc = readFile(argv[optind]);
  auto time_argparse_end = std::chrono::steady_clock::now();

  auto p = Parser{fc}.parseProgram();
  auto time_parse_end = std::chrono::steady_clock::now();
  if (printTimes)
    std::cerr << "parsing: "
              << std::chrono::duration_cast<std::chrono::milliseconds>(
                     time_parse_end - time_argparse_end)
                     .count()
              << "ms\n";

  if (!p)
    return 1;
  if (mode == Mode::CHECK)
    return 0;
  if (mode == Mode::AST) {
    for (const Function &fn : p->funcs) {
      if (fn.ast) {
        fn.ast->printSexpr();
        std::cout << "\n";
      }
    }
    return 0;
  }

  llvm::LLVMContext ctx;
  auto mod = LLVMIRGen::genIR(ctx, *p);
  auto time_irgen_end = std::chrono::steady_clock::now();
  if (printTimes)
    std::cerr << "irgen: "
              << std::chrono::duration_cast<std::chrono::milliseconds>(
                     time_irgen_end - time_parse_end)
                     .count()
              << "ms\n";

  if (mode == Mode::LLVM) {
    if (llvm::verifyModule(*mod, &llvm::errs()))
      return 1;
    mod->print(llvm::outs(), nullptr);
    return 0;
  }

  return 0;
}
