#include #include #include static int find_local(Cm *co, O sym) { for (Z i = 0; i < co->locals.count; i++) { if (co->locals.data[i].name == sym) { return (int)i; } } return -1; } static void add_local(Cm *co, O sym) { if (co->locals.count >= co->locals.capacity) { Z newcap = co->locals.capacity == 0 ? 16 : co->locals.capacity * 2; Lv *newdata = realloc(co->locals.data, newcap * sizeof(Lv)); if (!newdata) abort(); co->locals.capacity = newcap; co->locals.data = newdata; } co->locals.data[co->locals.count].name = sym; co->locals.data[co->locals.count].index = (U16)co->locals.count; co->locals.count++; } static V emit(Cm *co, U8 byte) { if (co->count >= co->capacity) { Z newcap = co->capacity == 0 ? 16 : co->capacity * 2; U8 *newdata = realloc(co->code, newcap); if (!newdata) abort(); co->capacity = newcap; co->code = newdata; } co->code[co->count++] = byte; } static V emit16(Cm *co, U16 word) { emit(co, word >> 8); emit(co, word & 0xff); } O code_make(In *in, const U8 *code, Z len, O *constants, Z clen) { Z size = sizeof(Gh) + sizeof(Bc); Gh *hdr = gc_alloc(&in->gc, size); hdr->type = TYPE_CODE; Bc *s = (Bc *)(hdr + 1); s->len = len; s->data = malloc(len + 1); if (!s->data) abort(); s->constants = constants; s->constant_count = clen; memcpy(s->data, code, len); s->data[len] = 0; return BOX(hdr); } static Z add_constant(Cm *co, O obj) { for (Z i = 0; i < co->constants.count; i++) { if (co->constants.data[i] == obj) return i; } if (co->constants.count >= co->constants.capacity) { Z newcap = co->constants.capacity == 0 ? 16 : co->constants.capacity * 2; O *newdata = realloc(co->constants.data, newcap * sizeof(O *)); if (!newdata) abort(); co->constants.capacity = newcap; co->constants.data = newdata; } co->constants.data[co->constants.count++] = obj; return co->constants.count - 1; } V compile(Cm *co, O expr, I tail); // Compile a (potentially tail-)call to `fn` with `args`. static V compile_call(Cm *co, O fn, O args, I tail) { I argc = 0; // Compile each argument expression for the function. for (O next = list_next(co->in, &args); next != NIL; next = list_next(co->in, &args)) { compile(co, next, 0); argc++; } // Compile the function reference itself compile(co, fn, 0); // Compile the call (opcode followed by number of arguments as a byte) emit(co, tail ? OP_TAIL_CALL : OP_CALL); emit(co, (U8)argc); } // Compile the `(if cond then else?)` special form. static V compile_if(Cm *c, O form, I tail) { O cond_expr = list_next(c->in, &form); O then_expr = list_next(c->in, &form); O else_expr = list_next(c->in, &form); if (cond_expr == NIL || then_expr == NIL) error_throw(c->in, "expected at least two arguments for if"); // Compile the condition expression compile(c, cond_expr, 0); // Prepare the jump to the else-expression emit(c, OP_JUMP_IF_NIL); Z jump_else = c->count; emit16(c, 0); // Compile the then-expression compile(c, then_expr, tail); Z jump_then = 0; if (!tail) { // If the expression is not on a tail-position, compile a jump to the code // following the else-expression emit(c, OP_JUMP); jump_then = c->count; emit16(c, 0); } // Patch the first jump (to the else-expression) Z else_offset = c->count; c->code[jump_else] = (U8)(else_offset >> 8); c->code[jump_else + 1] = (U8)(else_offset & 0xff); // Compile the else-expression compile(c, else_expr, tail); Z end = c->count; if (!tail) { // Patch the second jump (to the end of the else-expression) if we're not // on a tail-position c->code[jump_then] = (U8)(end >> 8); c->code[jump_then + 1] = (U8)(end & 0xff); } } // Compile the `(progn expr...)` special form. static V compile_progn(Cm *co, O forms, I tail) { // If there are no forms to compile, simply compile NIL. if (forms == NIL) { emit(co, OP_CONST); emit16(co, add_constant(co, NIL)); if (tail) emit(co, OP_RET); return; } // Compile all forms, discarding intermediate results. while (forms != NIL) { O expr = list_next(co->in, &forms); compile(co, expr, forms == NIL ? tail : 0); if (forms != NIL) emit(co, OP_POP); } } static int nested_p(Cm *c, O body) { while (body != NIL) { O expr = list_next(c->in, &body); if (type(expr) == TYPE_PAIR) { Pa *p = pair_unwrap(c->in, expr); if (p->head == c->specials.fn) return 1; if (nested_p(c, expr)) return 1; } } return 0; } // Compile a closure with `args` and `body`. static V compile_fn(Cm *c, O args, O body, I macro) { // Create an inner compiler context for compiling the closure's body. Cm ic; memset(&ic, 0, sizeof(Cm)); ic.in = c->in; ic.specials = c->specials; O curr = args; O fixed[256]; int fixed_count = 0; O rest = NIL; // Count fixed arguments, and if the function has a rest argument while (curr != NIL) { if (type(curr) == TYPE_SYM) { rest = curr; break; } Pa *p = pair_unwrap(c->in, curr); fixed[fixed_count++] = p->head; curr = p->tail; } int nested = nested_p(c, body); if (rest == NIL && !nested) { // If the function has no rest argument, and has no nested closures inside // it, compile using stack locals as an optimization for tail-calls ic.use_locals = 1; curr = args; for (O next = list_next(c->in, &curr); next != NIL; next = list_next(c->in, &curr)) add_local(&ic, next); } else { // Otherwise, fallback to using environment bindings for locals ic.use_locals = 0; if (rest != NIL) { emit(&ic, OP_BIND_REST); emit16(&ic, add_constant(&ic, rest)); emit16(&ic, fixed_count); } for (int i = fixed_count - 1; i >= 0; i--) { emit(&ic, OP_BIND); emit16(&ic, add_constant(&ic, fixed[i])); } } // Compile the function's body (as a `progn` form) compile_progn(&ic, body, 1); O code = code_make(c->in, ic.code, ic.count, ic.constants.data, ic.constants.count); Z code_idx = add_constant(c, code); Z args_idx = add_constant(c, args); // Compile pushing the closure to the stack. emit(c, macro ? OP_MAC : OP_CLOS); emit16(c, code_idx); emit16(c, args_idx); // Free the inner compiler context. free(ic.code); if (ic.locals.data) free(ic.locals.data); } // Compile the `(def name value)` special form static V compile_def(Cm *c, O args, I tail) { O sym = list_next(c->in, &args); if (type(sym) != TYPE_SYM) error_throw(c->in, "def: expected symbol"); O val = list_next(c->in, &args); compile(c, val, 0); emit(c, OP_SET); emit16(c, add_constant(c, sym)); emit(c, OP_CONST); emit16(c, add_constant(c, sym)); if (tail) emit(c, OP_RET); } O vm_apply(In *in, O macro, O args); // Compile a function application/special form static V compile_apply(Cm *co, O expr, I tail) { Pa *p = pair_unwrap(co->in, expr); O head = p->head; // Compile special forms if (type(head) == TYPE_SYM) { if (head == co->specials.quote) { Pa *args = pair_unwrap(co->in, p->tail); emit(co, OP_CONST); emit16(co, add_constant(co, args->head)); if (tail) emit(co, OP_RET); return; } else if (head == co->specials.iff) { compile_if(co, p->tail, tail); return; } else if (head == co->specials.progn) { compile_progn(co, p->tail, tail); return; } else if (head == co->specials.fn) { Pa *args = pair_unwrap(co->in, p->tail); compile_fn(co, args->head, args->tail, 0); if (tail) emit(co, OP_RET); return; } else if (head == co->specials.mac) { Pa *args = pair_unwrap(co->in, p->tail); compile_fn(co, args->head, args->tail, 1); if (tail) emit(co, OP_RET); return; } else if (head == co->specials.def) { compile_def(co, p->tail, tail); return; } if (find_local(co, head) == -1) { O obj = list_assoc(co->in, head, co->in->env); if (obj != NIL) { obj = pair_unwrap(co->in, obj)->tail; if (type(obj) == TYPE_MAC) { O exp = vm_apply(co->in, obj, p->tail); compile(co, exp, tail); return; } } } } compile_call(co, head, p->tail, tail); } V compile(Cm *co, O expr, I tail) { I ty = type(expr); if (co->specials.quote == NIL) co->specials.quote = symbol_make(co->in, "quote"); if (co->specials.iff == NIL) co->specials.iff = symbol_make(co->in, "if"); if (co->specials.progn == NIL) co->specials.progn = symbol_make(co->in, "progn"); if (co->specials.fn == NIL) co->specials.fn = symbol_make(co->in, "fn"); if (co->specials.def == NIL) co->specials.def = symbol_make(co->in, "def"); if (co->specials.mac == NIL) co->specials.mac = symbol_make(co->in, "mac"); switch (ty) { case TYPE_NIL: case TYPE_NUM: case TYPE_STR: emit(co, OP_CONST); emit16(co, add_constant(co, expr)); if (tail) emit(co, OP_RET); break; case TYPE_SYM: if (co->use_locals) { int idx = find_local(co, expr); if (idx >= 0) { emit(co, OP_GET_LOCAL); emit16(co, idx); if (tail) emit(co, OP_RET); return; } } emit(co, OP_GET); emit16(co, add_constant(co, expr)); if (tail) emit(co, OP_RET); break; case TYPE_PAIR: compile_apply(co, expr, tail); break; default: error_throw(co->in, "compile: cannot compile type %s", typename(ty)); } }