From 27df5f8ce0f3804d34a440d949a74d8a3a482c89 Mon Sep 17 00:00:00 2001 From: "Javier B. Torres" Date: Wed, 14 Jan 2026 12:10:29 -0300 Subject: [PATCH] proper lexical scoping now woo --- .editorconfig | 2 +- examples/fizzbuzz.scm | 22 +++ examples/map.scm | 19 ++ test.lisp => examples/tailcalls.scm | 0 include/wolflisp.h | 10 +- src/core/compile.c | 281 +++++++++++++++++++++++++--- src/core/disasm.c | 3 + src/core/interp.c | 3 + src/core/prim.c | 34 +++- src/core/prim.h | 3 + src/core/symbol.c | 4 +- src/core/vm.c | 5 + test_macro.lisp | 8 - 13 files changed, 349 insertions(+), 45 deletions(-) create mode 100644 examples/fizzbuzz.scm create mode 100644 examples/map.scm rename test.lisp => examples/tailcalls.scm (100%) delete mode 100644 test_macro.lisp diff --git a/.editorconfig b/.editorconfig index 1c51663..0ec63b3 100644 --- a/.editorconfig +++ b/.editorconfig @@ -4,7 +4,7 @@ root = true end_of_line = lf insert_final_newline = true -[*.{c,h}] +[*.{c,h,scm}] indent_style = space indent_size = 2 diff --git a/examples/fizzbuzz.scm b/examples/fizzbuzz.scm new file mode 100644 index 0000000..25381d3 --- /dev/null +++ b/examples/fizzbuzz.scm @@ -0,0 +1,22 @@ +(def each-integer-aux + (fn (n i thunk) + (if (= n 0) + '() + (progn + (thunk (- (+ i 1) n)) + (each-integer-aux (- n 1) i thunk))))) + +(def each-integer + (fn (n thunk) + (each-integer-aux n n thunk))) + +(each-integer 30 + (fn (x) + (if (or (= 0 (% x 3)) (= 0 (% x 5))) + (progn + (if (= 0 (% x 3)) + (write "Fizz")) + (if (= 0 (% x 5)) + (write "Buzz"))) + (print x)) + (write "\n"))) diff --git a/examples/map.scm b/examples/map.scm new file mode 100644 index 0000000..aaa40d6 --- /dev/null +++ b/examples/map.scm @@ -0,0 +1,19 @@ +(def defn + (mac (name args . body) + (list 'def name (cons 'fn (cons args body))))) + +(defn map-aux (f acc l) + (if (nil? l) + (acc f '()) + (map-aux + f + (fn (f ys) + (acc f (cons (f (head l)) ys))) + (tail l)))) + +(defn map (f l) + (map-aux f (fn (f x) x) l)) + +(println + (map (fn (x) (* x x)) + '(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20))) diff --git a/test.lisp b/examples/tailcalls.scm similarity index 100% rename from test.lisp rename to examples/tailcalls.scm diff --git a/include/wolflisp.h b/include/wolflisp.h index 3889a43..14eb097 100644 --- a/include/wolflisp.h +++ b/include/wolflisp.h @@ -176,16 +176,20 @@ enum { OP_GET_LOCAL, OP_SET_LOCAL, OP_RESERVE, + OP_DUP, }; // Local variable info typedef struct Lv { O name; U16 index; + I escapes; } Lv; // Compiler context -typedef struct Cm { +typedef struct Cm Cm; +struct Cm { + Cm *parent; In *in; U8 *code; Z count; @@ -202,6 +206,8 @@ typedef struct Cm { O mac; O progn; O def; + O and; + O or; } specials; struct { Lv *data; @@ -209,7 +215,7 @@ typedef struct Cm { Z capacity; } locals; I use_locals; -} Cm; +}; enum { TOK_EOF = 0, diff --git a/src/core/compile.c b/src/core/compile.c index bb5a6c9..0f32ba7 100644 --- a/src/core/compile.c +++ b/src/core/compile.c @@ -25,6 +25,17 @@ static void add_local(Cm *co, O sym) { co->locals.count++; } +static void mark_escaping(Cm *co, O sym) { + for (Z i = 0; i < co->locals.count; i++) { + if (co->locals.data[i].name == sym) { + co->locals.data[i].escapes = 1; + return; + } + } +} + +static I is_local(Cm *co, O sym) { return find_local(co, sym) >= 0; } + static V emit(Cm *co, U8 byte) { if (co->count >= co->capacity) { Z newcap = co->capacity == 0 ? 16 : co->capacity * 2; @@ -162,18 +173,88 @@ static V compile_progn(Cm *co, O forms, I tail) { } } -static int nested_p(Cm *c, O body) { - while (body != NIL) { - O expr = list_next(c->in, &body); - if (type(expr) == TYPE_PAIR) { - Pa *p = pair_unwrap(c->in, expr); - if (p->head == c->specials.fn) - return 1; - if (nested_p(c, expr)) - return 1; +// Find all free variables in an expression (vars used but not defined locally) +static void find_free_vars(Cm *co, O expr, Cm *parent) { + I ty = type(expr); + switch (ty) { + case TYPE_SYM: + // If this symbol isn't local to current function but is in parent scope + if (!is_local(co, expr) && parent && is_local(parent, expr)) { + // Mark it as escaping in the parent + mark_escaping(parent, expr); } + break; + case TYPE_PAIR: { + Pa *p = pair_unwrap(co->in, expr); + O head = p->head; + + // Skip quote - quoted expressions aren't variable references + if (type(head) == TYPE_SYM && head == co->specials.quote) { + return; + } + + // Handle nested fn/mac - need to analyze their bodies for free variables + if (type(head) == TYPE_SYM && + (head == co->specials.fn || head == co->specials.mac)) { + // This is a nested function definition + // Parse it to find what variables it captures + Pa *fn_tail = pair_unwrap(co->in, p->tail); + O nested_args = fn_tail->head; + O nested_body = fn_tail->tail; + + // Create a temporary compiler for the nested function + Cm nested; + memset(&nested, 0, sizeof(Cm)); + nested.in = co->in; + nested.specials = co->specials; + nested.parent = co; + + // Add nested function's parameters as locals + O curr_arg = nested_args; + while (curr_arg != NIL) { + if (type(curr_arg) == TYPE_SYM) { + add_local(&nested, curr_arg); + break; + } + Pa *arg_pair = pair_unwrap(co->in, curr_arg); + add_local(&nested, arg_pair->head); + curr_arg = arg_pair->tail; + } + + // Analyze the nested function's body to find free variables + O body_copy = nested_body; + while (body_copy != NIL) { + if (type(body_copy) != TYPE_PAIR) { + find_free_vars(&nested, body_copy, co); + break; + } + Pa *body_pair = pair_unwrap(co->in, body_copy); + find_free_vars(&nested, body_pair->head, co); + body_copy = body_pair->tail; + } + + // Free the temporary compiler + if (nested.locals.data) + free(nested.locals.data); + + return; + } + + // For other expressions, recurse into all subexpressions + O curr = expr; + while (curr != NIL) { + if (type(curr) != TYPE_PAIR) { + break; + } + Pa *pair = pair_unwrap(co->in, curr); + find_free_vars(co, pair->head, parent); + curr = pair->tail; + } + break; + } + default: + break; } - return 0; } // Compile a closure with `args` and `body`. @@ -183,13 +264,14 @@ static V compile_fn(Cm *c, O args, O body, I macro) { memset(&ic, 0, sizeof(Cm)); ic.in = c->in; ic.specials = c->specials; + ic.parent = c; // Link to parent scope O curr = args; O fixed[256]; int fixed_count = 0; O rest = NIL; - // Count fixed arguments, and if the function has a rest argument + // Count fixed arguments while (curr != NIL) { if (type(curr) == TYPE_SYM) { rest = curr; @@ -200,18 +282,46 @@ static V compile_fn(Cm *c, O args, O body, I macro) { curr = p->tail; } - int nested = nested_p(c, body); - if (rest == NIL && !nested) { - // If the function has no rest argument, and has no nested closures inside - // it, compile using stack locals as an optimization for tail-calls - ic.use_locals = 1; - curr = args; - for (O next = list_next(c->in, &curr); next != NIL; - next = list_next(c->in, &curr)) - add_local(&ic, next); - } else { - // Otherwise, fallback to using environment bindings for locals + // Add parameters as locals to inner compiler + curr = args; + while (curr != NIL) { + if (type(curr) == TYPE_SYM) { + add_local(&ic, curr); + break; + } + Pa *p = pair_unwrap(c->in, curr); + add_local(&ic, p->head); + curr = p->tail; + } + + // Analyze body to find free variables and mark them as escaping in parent + O body_copy = body; + while (body_copy != NIL) { + O expr = list_next(c->in, &body_copy); + find_free_vars(&ic, expr, c); + } + + // Decide whether to use stack locals or environment bindings + // Use stack locals only if: + // 1. No rest parameter (rest parameters require environment bindings) + // 2. None of our own locals escape (are captured by nested functions) + ic.use_locals = 1; + + // If we have a rest parameter, we must use environment bindings + if (rest != NIL) ic.use_locals = 0; + + // Check if any of this function's locals escape (are captured by nested + // functions) find_free_vars has already marked escaping variables in + // ic.locals + for (Z i = 0; i < ic.locals.count; i++) { + if (ic.locals.data[i].escapes) { + ic.use_locals = 0; + break; + } + } + // If we need environment bindings, emit BIND instructions + if (!ic.use_locals) { if (rest != NIL) { emit(&ic, OP_BIND_REST); emit16(&ic, add_constant(&ic, rest)); @@ -223,24 +333,123 @@ static V compile_fn(Cm *c, O args, O body, I macro) { } } - // Compile the function's body (as a `progn` form) + // Compile the function's body compile_progn(&ic, body, 1); + O code = code_make(c->in, ic.code, ic.count, ic.constants.data, ic.constants.count); Z code_idx = add_constant(c, code); Z args_idx = add_constant(c, args); - // Compile pushing the closure to the stack. + // Compile pushing the closure to the stack emit(c, macro ? OP_MAC : OP_CLOS); emit16(c, code_idx); emit16(c, args_idx); - // Free the inner compiler context. + // Free the inner compiler context free(ic.code); if (ic.locals.data) free(ic.locals.data); } +// Compile the `(and ...)` special form +static V compile_and(Cm *co, O forms, I tail) { + if (forms == NIL) { + emit(co, OP_CONST); + emit16(co, add_constant(co, co->in->t)); + if (tail) + emit(co, OP_RET); + return; + } + + Z *jumps = NULL; + Z jump_count = 0; + Z jump_cap = 0; + + while (forms != NIL) { + O expr = list_next(co->in, &forms); + if (forms == NIL) { + compile(co, expr, tail); + } else { + compile(co, expr, 0); + emit(co, OP_DUP); + emit(co, OP_JUMP_IF_NIL); + if (jump_count >= jump_cap) { + jump_cap = jump_cap == 0 ? 8 : jump_cap * 2; + jumps = realloc(jumps, jump_cap * sizeof(Z)); + } + jumps[jump_count++] = co->count; + emit16(co, 0); + emit(co, OP_POP); + } + } + + Z end = co->count; + for (Z i = 0; i < jump_count; i++) { + Z j = jumps[i]; + co->code[j] = (U8)(end >> 8); + co->code[j + 1] = (U8)(end & 0xff); + } + if (jumps) + free(jumps); + + if (tail) + emit(co, OP_RET); +} + +// Compile the `(or ...)` special form +static V compile_or(Cm *co, O forms, I tail) { + if (forms == NIL) { + emit(co, OP_CONST); + emit16(co, add_constant(co, NIL)); + if (tail) + emit(co, OP_RET); + return; + } + + Z *jumps = NULL; + Z jump_count = 0; + Z jump_cap = 0; + + while (forms != NIL) { + O expr = list_next(co->in, &forms); + if (forms == NIL) { + compile(co, expr, tail); + } else { + compile(co, expr, 0); + emit(co, OP_DUP); + emit(co, OP_JUMP_IF_NIL); + Z jump_next = co->count; + emit16(co, 0); + + emit(co, OP_JUMP); + if (jump_count >= jump_cap) { + jump_cap = jump_cap == 0 ? 8 : jump_cap * 2; + jumps = realloc(jumps, jump_cap * sizeof(Z)); + } + jumps[jump_count++] = co->count; + emit16(co, 0); + + Z next = co->count; + co->code[jump_next] = (U8)(next >> 8); + co->code[jump_next + 1] = (U8)(next & 0xff); + emit(co, OP_POP); + } + } + + Z end = co->count; + for (Z i = 0; i < jump_count; i++) { + Z j = jumps[i]; + co->code[j] = (U8)(end >> 8); + co->code[j + 1] = (U8)(end & 0xff); + } + if (jumps) + free(jumps); + + if (tail) + emit(co, OP_RET); +} + // Compile the `(def name value)` special form static V compile_def(Cm *c, O args, I tail) { O sym = list_next(c->in, &args); @@ -290,6 +499,12 @@ static V compile_apply(Cm *co, O expr, I tail) { if (tail) emit(co, OP_RET); return; + } else if (head == co->specials.and) { + compile_and(co, p->tail, tail); + return; + } else if (head == co->specials.or) { + compile_or(co, p->tail, tail); + return; } else if (head == co->specials.def) { compile_def(co, p->tail, tail); return; @@ -324,6 +539,10 @@ V compile(Cm *co, O expr, I tail) { co->specials.def = symbol_make(co->in, "def"); if (co->specials.mac == NIL) co->specials.mac = symbol_make(co->in, "mac"); + if (co->specials.and == NIL) + co->specials.and = symbol_make(co->in, "and"); + if (co->specials.or == NIL) + co->specials.or = symbol_make(co->in, "or"); switch (ty) { case TYPE_NIL: @@ -338,11 +557,13 @@ V compile(Cm *co, O expr, I tail) { if (co->use_locals) { int idx = find_local(co, expr); if (idx >= 0) { - emit(co, OP_GET_LOCAL); - emit16(co, idx); - if (tail) - emit(co, OP_RET); - return; + if (!co->locals.data[idx].escapes) { + emit(co, OP_GET_LOCAL); + emit16(co, idx); + if (tail) + emit(co, OP_RET); + return; + } } } emit(co, OP_GET); diff --git a/src/core/disasm.c b/src/core/disasm.c index 6174f30..7c62600 100644 --- a/src/core/disasm.c +++ b/src/core/disasm.c @@ -109,6 +109,9 @@ V disassemble(Cm *co) { ofs += 2; break; } + case OP_DUP: + printf("DUP\n"); + break; default: printf("%02x\n", co->code[ofs]); } diff --git a/src/core/interp.c b/src/core/interp.c index dbac969..d3f2c46 100644 --- a/src/core/interp.c +++ b/src/core/interp.c @@ -31,10 +31,13 @@ V interp_init(In *in) { PRIM("print", prim_print, 1, 1); PRIM("println", prim_println, 1, 1); PRIM("write", prim_write, 1, 1); + PRIM("env", prim_env, 0, 0); + PRIM("nil?", prim_nil_p, 1, 1); PRIM("+", prim_add, 0, -1); // variadic PRIM("-", prim_sub, 0, -1); // variadic PRIM("*", prim_mul, 0, -1); // variadic PRIM("/", prim_div, 0, -1); // variadic + PRIM("%", prim_mod, 0, -1); // variadic PRIM("<", prim_lt, 2, 2); PRIM(">", prim_gt, 2, 2); PRIM("=", prim_equal, 0, -1); // variadic diff --git a/src/core/prim.c b/src/core/prim.c index 1e08534..8e75dcb 100644 --- a/src/core/prim.c +++ b/src/core/prim.c @@ -70,7 +70,8 @@ O prim_write(In *in, O *args, int argc, O env) { if (argc != 1) error_throw(in, "write: expected 1 argument, got %d", argc); if (type(args[0]) != TYPE_STR) - error_throw(in, "write: expected string argument, got %s", typename(type(args[0]))); + error_throw(in, "write: expected string argument, got %s", + typename(type(args[0]))); Ss *s = (Ss *)(UNBOX(args[0]) + 1); printf("%.*s", (int)s->len, s->data); return NIL; @@ -132,6 +133,23 @@ O prim_div(In *in, O *args, int argc, O env) { return NUM(result); } +O prim_mod(In *in, O *args, int argc, O env) { + (void)env; + if (argc == 0) + return NUM(1); + if (!IMM(args[0])) + error_throw(in, "/: non numeric argument at position 0"); + I result = ORD(args[0]); + for (int i = 1; i < argc; i++) { + if (!IMM(args[i])) + error_throw(in, "/: non numeric argument at position %d", i); + if (ORD(args[i]) == 0) + error_throw(in, "/: division by zero at position %d", i); + result %= ORD(args[i]); + } + return NUM(result); +} + O prim_equal(In *in, O *args, int argc, O env) { (void)env; if (argc < 2) @@ -168,6 +186,20 @@ O prim_gt(In *in, O *args, int argc, O env) { } } +O prim_nil_p(In *in, O *args, int argc, O env) { + (void)env; + if (argc != 1) + error_throw(in, "nil?: expected 1 argument, got %d", argc); + return BOOL(args[0] == NIL); +} + +O prim_env(In *in, O *args, int argc, O env) { + (void)args; + (void)argc; + (void)env; + return in->env; +} + O prim_gc(In *in, O *args, int argc, O env) { (void)args; (void)argc; diff --git a/src/core/prim.h b/src/core/prim.h index 1bde768..5d419aa 100644 --- a/src/core/prim.h +++ b/src/core/prim.h @@ -11,7 +11,10 @@ O prim_add(In *in, O *args, int argc, O env); O prim_sub(In *in, O *args, int argc, O env); O prim_mul(In *in, O *args, int argc, O env); O prim_div(In *in, O *args, int argc, O env); +O prim_mod(In *in, O *args, int argc, O env); O prim_equal(In *in, O *args, int argc, O env); O prim_lt(In *in, O *args, int argc, O env); O prim_gt(In *in, O *args, int argc, O env); +O prim_nil_p(In *in, O *args, int argc, O env); +O prim_env(In *in, O *args, int argc, O env); O prim_gc(In *in, O *args, int argc, O env); diff --git a/src/core/symbol.c b/src/core/symbol.c index ab28c00..897662f 100644 --- a/src/core/symbol.c +++ b/src/core/symbol.c @@ -15,9 +15,7 @@ static Sy *find(St *tab, const char *str, U32 hash, Z len) { Sy *s = tab->data[ix]; if (!s) return NULL; - if (s->hash == hash && s->len == len) - return s; - if (memcmp(s->data, str, len) == 0) + if (s->hash == hash && s->len == len && memcmp(s->data, str, len) == 0) return s; ix = (ix + 1) % tab->capacity; } diff --git a/src/core/vm.c b/src/core/vm.c index b5918a8..0c67f4f 100644 --- a/src/core/vm.c +++ b/src/core/vm.c @@ -228,6 +228,11 @@ static O vm_exec(Cm *co_in, O env_in, int argc_in) { PUSH(BOX(hdr)); break; } + case OP_DUP: { + O val = PEEK(); + PUSH(val); + break; + } default: error_throw(in, "unknown opcode %d", op); } diff --git a/test_macro.lisp b/test_macro.lisp deleted file mode 100644 index 67c93a6..0000000 --- a/test_macro.lisp +++ /dev/null @@ -1,8 +0,0 @@ -(def twice (mac (x) - (list 'progn x x))) - -(def when (mac (cond . body) - (list 'if cond (cons 'progn body)))) - -(when (= 1 1) - (twice (write "ok\n")))