proper lexical scoping now woo

This commit is contained in:
Lobo 2026-01-14 12:10:29 -03:00
parent ce9489b5d2
commit 27df5f8ce0
13 changed files with 349 additions and 45 deletions

View file

@ -4,7 +4,7 @@ root = true
end_of_line = lf end_of_line = lf
insert_final_newline = true insert_final_newline = true
[*.{c,h}] [*.{c,h,scm}]
indent_style = space indent_style = space
indent_size = 2 indent_size = 2

22
examples/fizzbuzz.scm Normal file
View file

@ -0,0 +1,22 @@
(def each-integer-aux
(fn (n i thunk)
(if (= n 0)
'()
(progn
(thunk (- (+ i 1) n))
(each-integer-aux (- n 1) i thunk)))))
(def each-integer
(fn (n thunk)
(each-integer-aux n n thunk)))
(each-integer 30
(fn (x)
(if (or (= 0 (% x 3)) (= 0 (% x 5)))
(progn
(if (= 0 (% x 3))
(write "Fizz"))
(if (= 0 (% x 5))
(write "Buzz")))
(print x))
(write "\n")))

19
examples/map.scm Normal file
View file

@ -0,0 +1,19 @@
(def defn
(mac (name args . body)
(list 'def name (cons 'fn (cons args body)))))
(defn map-aux (f acc l)
(if (nil? l)
(acc f '())
(map-aux
f
(fn (f ys)
(acc f (cons (f (head l)) ys)))
(tail l))))
(defn map (f l)
(map-aux f (fn (f x) x) l))
(println
(map (fn (x) (* x x))
'(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20)))

View file

@ -176,16 +176,20 @@ enum {
OP_GET_LOCAL, OP_GET_LOCAL,
OP_SET_LOCAL, OP_SET_LOCAL,
OP_RESERVE, OP_RESERVE,
OP_DUP,
}; };
// Local variable info // Local variable info
typedef struct Lv { typedef struct Lv {
O name; O name;
U16 index; U16 index;
I escapes;
} Lv; } Lv;
// Compiler context // Compiler context
typedef struct Cm { typedef struct Cm Cm;
struct Cm {
Cm *parent;
In *in; In *in;
U8 *code; U8 *code;
Z count; Z count;
@ -202,6 +206,8 @@ typedef struct Cm {
O mac; O mac;
O progn; O progn;
O def; O def;
O and;
O or;
} specials; } specials;
struct { struct {
Lv *data; Lv *data;
@ -209,7 +215,7 @@ typedef struct Cm {
Z capacity; Z capacity;
} locals; } locals;
I use_locals; I use_locals;
} Cm; };
enum { enum {
TOK_EOF = 0, TOK_EOF = 0,

View file

@ -25,6 +25,17 @@ static void add_local(Cm *co, O sym) {
co->locals.count++; co->locals.count++;
} }
static void mark_escaping(Cm *co, O sym) {
for (Z i = 0; i < co->locals.count; i++) {
if (co->locals.data[i].name == sym) {
co->locals.data[i].escapes = 1;
return;
}
}
}
static I is_local(Cm *co, O sym) { return find_local(co, sym) >= 0; }
static V emit(Cm *co, U8 byte) { static V emit(Cm *co, U8 byte) {
if (co->count >= co->capacity) { if (co->count >= co->capacity) {
Z newcap = co->capacity == 0 ? 16 : co->capacity * 2; Z newcap = co->capacity == 0 ? 16 : co->capacity * 2;
@ -162,18 +173,88 @@ static V compile_progn(Cm *co, O forms, I tail) {
} }
} }
static int nested_p(Cm *c, O body) { // Find all free variables in an expression (vars used but not defined locally)
while (body != NIL) { static void find_free_vars(Cm *co, O expr, Cm *parent) {
O expr = list_next(c->in, &body); I ty = type(expr);
if (type(expr) == TYPE_PAIR) { switch (ty) {
Pa *p = pair_unwrap(c->in, expr); case TYPE_SYM:
if (p->head == c->specials.fn) // If this symbol isn't local to current function but is in parent scope
return 1; if (!is_local(co, expr) && parent && is_local(parent, expr)) {
if (nested_p(c, expr)) // Mark it as escaping in the parent
return 1; mark_escaping(parent, expr);
} }
break;
case TYPE_PAIR: {
Pa *p = pair_unwrap(co->in, expr);
O head = p->head;
// Skip quote - quoted expressions aren't variable references
if (type(head) == TYPE_SYM && head == co->specials.quote) {
return;
}
// Handle nested fn/mac - need to analyze their bodies for free variables
if (type(head) == TYPE_SYM &&
(head == co->specials.fn || head == co->specials.mac)) {
// This is a nested function definition
// Parse it to find what variables it captures
Pa *fn_tail = pair_unwrap(co->in, p->tail);
O nested_args = fn_tail->head;
O nested_body = fn_tail->tail;
// Create a temporary compiler for the nested function
Cm nested;
memset(&nested, 0, sizeof(Cm));
nested.in = co->in;
nested.specials = co->specials;
nested.parent = co;
// Add nested function's parameters as locals
O curr_arg = nested_args;
while (curr_arg != NIL) {
if (type(curr_arg) == TYPE_SYM) {
add_local(&nested, curr_arg);
break;
}
Pa *arg_pair = pair_unwrap(co->in, curr_arg);
add_local(&nested, arg_pair->head);
curr_arg = arg_pair->tail;
}
// Analyze the nested function's body to find free variables
O body_copy = nested_body;
while (body_copy != NIL) {
if (type(body_copy) != TYPE_PAIR) {
find_free_vars(&nested, body_copy, co);
break;
}
Pa *body_pair = pair_unwrap(co->in, body_copy);
find_free_vars(&nested, body_pair->head, co);
body_copy = body_pair->tail;
}
// Free the temporary compiler
if (nested.locals.data)
free(nested.locals.data);
return;
}
// For other expressions, recurse into all subexpressions
O curr = expr;
while (curr != NIL) {
if (type(curr) != TYPE_PAIR) {
break;
}
Pa *pair = pair_unwrap(co->in, curr);
find_free_vars(co, pair->head, parent);
curr = pair->tail;
}
break;
}
default:
break;
} }
return 0;
} }
// Compile a closure with `args` and `body`. // Compile a closure with `args` and `body`.
@ -183,13 +264,14 @@ static V compile_fn(Cm *c, O args, O body, I macro) {
memset(&ic, 0, sizeof(Cm)); memset(&ic, 0, sizeof(Cm));
ic.in = c->in; ic.in = c->in;
ic.specials = c->specials; ic.specials = c->specials;
ic.parent = c; // Link to parent scope
O curr = args; O curr = args;
O fixed[256]; O fixed[256];
int fixed_count = 0; int fixed_count = 0;
O rest = NIL; O rest = NIL;
// Count fixed arguments, and if the function has a rest argument // Count fixed arguments
while (curr != NIL) { while (curr != NIL) {
if (type(curr) == TYPE_SYM) { if (type(curr) == TYPE_SYM) {
rest = curr; rest = curr;
@ -200,18 +282,46 @@ static V compile_fn(Cm *c, O args, O body, I macro) {
curr = p->tail; curr = p->tail;
} }
int nested = nested_p(c, body); // Add parameters as locals to inner compiler
if (rest == NIL && !nested) { curr = args;
// If the function has no rest argument, and has no nested closures inside while (curr != NIL) {
// it, compile using stack locals as an optimization for tail-calls if (type(curr) == TYPE_SYM) {
ic.use_locals = 1; add_local(&ic, curr);
curr = args; break;
for (O next = list_next(c->in, &curr); next != NIL; }
next = list_next(c->in, &curr)) Pa *p = pair_unwrap(c->in, curr);
add_local(&ic, next); add_local(&ic, p->head);
} else { curr = p->tail;
// Otherwise, fallback to using environment bindings for locals }
// Analyze body to find free variables and mark them as escaping in parent
O body_copy = body;
while (body_copy != NIL) {
O expr = list_next(c->in, &body_copy);
find_free_vars(&ic, expr, c);
}
// Decide whether to use stack locals or environment bindings
// Use stack locals only if:
// 1. No rest parameter (rest parameters require environment bindings)
// 2. None of our own locals escape (are captured by nested functions)
ic.use_locals = 1;
// If we have a rest parameter, we must use environment bindings
if (rest != NIL)
ic.use_locals = 0; ic.use_locals = 0;
// Check if any of this function's locals escape (are captured by nested
// functions) find_free_vars has already marked escaping variables in
// ic.locals
for (Z i = 0; i < ic.locals.count; i++) {
if (ic.locals.data[i].escapes) {
ic.use_locals = 0;
break;
}
}
// If we need environment bindings, emit BIND instructions
if (!ic.use_locals) {
if (rest != NIL) { if (rest != NIL) {
emit(&ic, OP_BIND_REST); emit(&ic, OP_BIND_REST);
emit16(&ic, add_constant(&ic, rest)); emit16(&ic, add_constant(&ic, rest));
@ -223,24 +333,123 @@ static V compile_fn(Cm *c, O args, O body, I macro) {
} }
} }
// Compile the function's body (as a `progn` form) // Compile the function's body
compile_progn(&ic, body, 1); compile_progn(&ic, body, 1);
O code = code_make(c->in, ic.code, ic.count, ic.constants.data, O code = code_make(c->in, ic.code, ic.count, ic.constants.data,
ic.constants.count); ic.constants.count);
Z code_idx = add_constant(c, code); Z code_idx = add_constant(c, code);
Z args_idx = add_constant(c, args); Z args_idx = add_constant(c, args);
// Compile pushing the closure to the stack. // Compile pushing the closure to the stack
emit(c, macro ? OP_MAC : OP_CLOS); emit(c, macro ? OP_MAC : OP_CLOS);
emit16(c, code_idx); emit16(c, code_idx);
emit16(c, args_idx); emit16(c, args_idx);
// Free the inner compiler context. // Free the inner compiler context
free(ic.code); free(ic.code);
if (ic.locals.data) if (ic.locals.data)
free(ic.locals.data); free(ic.locals.data);
} }
// Compile the `(and ...)` special form
static V compile_and(Cm *co, O forms, I tail) {
if (forms == NIL) {
emit(co, OP_CONST);
emit16(co, add_constant(co, co->in->t));
if (tail)
emit(co, OP_RET);
return;
}
Z *jumps = NULL;
Z jump_count = 0;
Z jump_cap = 0;
while (forms != NIL) {
O expr = list_next(co->in, &forms);
if (forms == NIL) {
compile(co, expr, tail);
} else {
compile(co, expr, 0);
emit(co, OP_DUP);
emit(co, OP_JUMP_IF_NIL);
if (jump_count >= jump_cap) {
jump_cap = jump_cap == 0 ? 8 : jump_cap * 2;
jumps = realloc(jumps, jump_cap * sizeof(Z));
}
jumps[jump_count++] = co->count;
emit16(co, 0);
emit(co, OP_POP);
}
}
Z end = co->count;
for (Z i = 0; i < jump_count; i++) {
Z j = jumps[i];
co->code[j] = (U8)(end >> 8);
co->code[j + 1] = (U8)(end & 0xff);
}
if (jumps)
free(jumps);
if (tail)
emit(co, OP_RET);
}
// Compile the `(or ...)` special form
static V compile_or(Cm *co, O forms, I tail) {
if (forms == NIL) {
emit(co, OP_CONST);
emit16(co, add_constant(co, NIL));
if (tail)
emit(co, OP_RET);
return;
}
Z *jumps = NULL;
Z jump_count = 0;
Z jump_cap = 0;
while (forms != NIL) {
O expr = list_next(co->in, &forms);
if (forms == NIL) {
compile(co, expr, tail);
} else {
compile(co, expr, 0);
emit(co, OP_DUP);
emit(co, OP_JUMP_IF_NIL);
Z jump_next = co->count;
emit16(co, 0);
emit(co, OP_JUMP);
if (jump_count >= jump_cap) {
jump_cap = jump_cap == 0 ? 8 : jump_cap * 2;
jumps = realloc(jumps, jump_cap * sizeof(Z));
}
jumps[jump_count++] = co->count;
emit16(co, 0);
Z next = co->count;
co->code[jump_next] = (U8)(next >> 8);
co->code[jump_next + 1] = (U8)(next & 0xff);
emit(co, OP_POP);
}
}
Z end = co->count;
for (Z i = 0; i < jump_count; i++) {
Z j = jumps[i];
co->code[j] = (U8)(end >> 8);
co->code[j + 1] = (U8)(end & 0xff);
}
if (jumps)
free(jumps);
if (tail)
emit(co, OP_RET);
}
// Compile the `(def name value)` special form // Compile the `(def name value)` special form
static V compile_def(Cm *c, O args, I tail) { static V compile_def(Cm *c, O args, I tail) {
O sym = list_next(c->in, &args); O sym = list_next(c->in, &args);
@ -290,6 +499,12 @@ static V compile_apply(Cm *co, O expr, I tail) {
if (tail) if (tail)
emit(co, OP_RET); emit(co, OP_RET);
return; return;
} else if (head == co->specials.and) {
compile_and(co, p->tail, tail);
return;
} else if (head == co->specials.or) {
compile_or(co, p->tail, tail);
return;
} else if (head == co->specials.def) { } else if (head == co->specials.def) {
compile_def(co, p->tail, tail); compile_def(co, p->tail, tail);
return; return;
@ -324,6 +539,10 @@ V compile(Cm *co, O expr, I tail) {
co->specials.def = symbol_make(co->in, "def"); co->specials.def = symbol_make(co->in, "def");
if (co->specials.mac == NIL) if (co->specials.mac == NIL)
co->specials.mac = symbol_make(co->in, "mac"); co->specials.mac = symbol_make(co->in, "mac");
if (co->specials.and == NIL)
co->specials.and = symbol_make(co->in, "and");
if (co->specials.or == NIL)
co->specials.or = symbol_make(co->in, "or");
switch (ty) { switch (ty) {
case TYPE_NIL: case TYPE_NIL:
@ -338,11 +557,13 @@ V compile(Cm *co, O expr, I tail) {
if (co->use_locals) { if (co->use_locals) {
int idx = find_local(co, expr); int idx = find_local(co, expr);
if (idx >= 0) { if (idx >= 0) {
emit(co, OP_GET_LOCAL); if (!co->locals.data[idx].escapes) {
emit16(co, idx); emit(co, OP_GET_LOCAL);
if (tail) emit16(co, idx);
emit(co, OP_RET); if (tail)
return; emit(co, OP_RET);
return;
}
} }
} }
emit(co, OP_GET); emit(co, OP_GET);

View file

@ -109,6 +109,9 @@ V disassemble(Cm *co) {
ofs += 2; ofs += 2;
break; break;
} }
case OP_DUP:
printf("DUP\n");
break;
default: default:
printf("%02x\n", co->code[ofs]); printf("%02x\n", co->code[ofs]);
} }

View file

@ -31,10 +31,13 @@ V interp_init(In *in) {
PRIM("print", prim_print, 1, 1); PRIM("print", prim_print, 1, 1);
PRIM("println", prim_println, 1, 1); PRIM("println", prim_println, 1, 1);
PRIM("write", prim_write, 1, 1); PRIM("write", prim_write, 1, 1);
PRIM("env", prim_env, 0, 0);
PRIM("nil?", prim_nil_p, 1, 1);
PRIM("+", prim_add, 0, -1); // variadic PRIM("+", prim_add, 0, -1); // variadic
PRIM("-", prim_sub, 0, -1); // variadic PRIM("-", prim_sub, 0, -1); // variadic
PRIM("*", prim_mul, 0, -1); // variadic PRIM("*", prim_mul, 0, -1); // variadic
PRIM("/", prim_div, 0, -1); // variadic PRIM("/", prim_div, 0, -1); // variadic
PRIM("%", prim_mod, 0, -1); // variadic
PRIM("<", prim_lt, 2, 2); PRIM("<", prim_lt, 2, 2);
PRIM(">", prim_gt, 2, 2); PRIM(">", prim_gt, 2, 2);
PRIM("=", prim_equal, 0, -1); // variadic PRIM("=", prim_equal, 0, -1); // variadic

View file

@ -70,7 +70,8 @@ O prim_write(In *in, O *args, int argc, O env) {
if (argc != 1) if (argc != 1)
error_throw(in, "write: expected 1 argument, got %d", argc); error_throw(in, "write: expected 1 argument, got %d", argc);
if (type(args[0]) != TYPE_STR) if (type(args[0]) != TYPE_STR)
error_throw(in, "write: expected string argument, got %s", typename(type(args[0]))); error_throw(in, "write: expected string argument, got %s",
typename(type(args[0])));
Ss *s = (Ss *)(UNBOX(args[0]) + 1); Ss *s = (Ss *)(UNBOX(args[0]) + 1);
printf("%.*s", (int)s->len, s->data); printf("%.*s", (int)s->len, s->data);
return NIL; return NIL;
@ -132,6 +133,23 @@ O prim_div(In *in, O *args, int argc, O env) {
return NUM(result); return NUM(result);
} }
O prim_mod(In *in, O *args, int argc, O env) {
(void)env;
if (argc == 0)
return NUM(1);
if (!IMM(args[0]))
error_throw(in, "/: non numeric argument at position 0");
I result = ORD(args[0]);
for (int i = 1; i < argc; i++) {
if (!IMM(args[i]))
error_throw(in, "/: non numeric argument at position %d", i);
if (ORD(args[i]) == 0)
error_throw(in, "/: division by zero at position %d", i);
result %= ORD(args[i]);
}
return NUM(result);
}
O prim_equal(In *in, O *args, int argc, O env) { O prim_equal(In *in, O *args, int argc, O env) {
(void)env; (void)env;
if (argc < 2) if (argc < 2)
@ -168,6 +186,20 @@ O prim_gt(In *in, O *args, int argc, O env) {
} }
} }
O prim_nil_p(In *in, O *args, int argc, O env) {
(void)env;
if (argc != 1)
error_throw(in, "nil?: expected 1 argument, got %d", argc);
return BOOL(args[0] == NIL);
}
O prim_env(In *in, O *args, int argc, O env) {
(void)args;
(void)argc;
(void)env;
return in->env;
}
O prim_gc(In *in, O *args, int argc, O env) { O prim_gc(In *in, O *args, int argc, O env) {
(void)args; (void)args;
(void)argc; (void)argc;

View file

@ -11,7 +11,10 @@ O prim_add(In *in, O *args, int argc, O env);
O prim_sub(In *in, O *args, int argc, O env); O prim_sub(In *in, O *args, int argc, O env);
O prim_mul(In *in, O *args, int argc, O env); O prim_mul(In *in, O *args, int argc, O env);
O prim_div(In *in, O *args, int argc, O env); O prim_div(In *in, O *args, int argc, O env);
O prim_mod(In *in, O *args, int argc, O env);
O prim_equal(In *in, O *args, int argc, O env); O prim_equal(In *in, O *args, int argc, O env);
O prim_lt(In *in, O *args, int argc, O env); O prim_lt(In *in, O *args, int argc, O env);
O prim_gt(In *in, O *args, int argc, O env); O prim_gt(In *in, O *args, int argc, O env);
O prim_nil_p(In *in, O *args, int argc, O env);
O prim_env(In *in, O *args, int argc, O env);
O prim_gc(In *in, O *args, int argc, O env); O prim_gc(In *in, O *args, int argc, O env);

View file

@ -15,9 +15,7 @@ static Sy *find(St *tab, const char *str, U32 hash, Z len) {
Sy *s = tab->data[ix]; Sy *s = tab->data[ix];
if (!s) if (!s)
return NULL; return NULL;
if (s->hash == hash && s->len == len) if (s->hash == hash && s->len == len && memcmp(s->data, str, len) == 0)
return s;
if (memcmp(s->data, str, len) == 0)
return s; return s;
ix = (ix + 1) % tab->capacity; ix = (ix + 1) % tab->capacity;
} }

View file

@ -228,6 +228,11 @@ static O vm_exec(Cm *co_in, O env_in, int argc_in) {
PUSH(BOX(hdr)); PUSH(BOX(hdr));
break; break;
} }
case OP_DUP: {
O val = PEEK();
PUSH(val);
break;
}
default: default:
error_throw(in, "unknown opcode %d", op); error_throw(in, "unknown opcode %d", op);
} }

View file

@ -1,8 +0,0 @@
(def twice (mac (x)
(list 'progn x x)))
(def when (mac (cond . body)
(list 'if cond (cons 'progn body))))
(when (= 1 1)
(twice (write "ok\n")))