proper lexical scoping now woo

This commit is contained in:
Lobo 2026-01-14 12:10:29 -03:00
parent ce9489b5d2
commit 27df5f8ce0
13 changed files with 349 additions and 45 deletions

View file

@ -4,7 +4,7 @@ root = true
end_of_line = lf
insert_final_newline = true
[*.{c,h}]
[*.{c,h,scm}]
indent_style = space
indent_size = 2

22
examples/fizzbuzz.scm Normal file
View file

@ -0,0 +1,22 @@
(def each-integer-aux
(fn (n i thunk)
(if (= n 0)
'()
(progn
(thunk (- (+ i 1) n))
(each-integer-aux (- n 1) i thunk)))))
(def each-integer
(fn (n thunk)
(each-integer-aux n n thunk)))
(each-integer 30
(fn (x)
(if (or (= 0 (% x 3)) (= 0 (% x 5)))
(progn
(if (= 0 (% x 3))
(write "Fizz"))
(if (= 0 (% x 5))
(write "Buzz")))
(print x))
(write "\n")))

19
examples/map.scm Normal file
View file

@ -0,0 +1,19 @@
(def defn
(mac (name args . body)
(list 'def name (cons 'fn (cons args body)))))
(defn map-aux (f acc l)
(if (nil? l)
(acc f '())
(map-aux
f
(fn (f ys)
(acc f (cons (f (head l)) ys)))
(tail l))))
(defn map (f l)
(map-aux f (fn (f x) x) l))
(println
(map (fn (x) (* x x))
'(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20)))

View file

@ -176,16 +176,20 @@ enum {
OP_GET_LOCAL,
OP_SET_LOCAL,
OP_RESERVE,
OP_DUP,
};
// Local variable info
typedef struct Lv {
O name;
U16 index;
I escapes;
} Lv;
// Compiler context
typedef struct Cm {
typedef struct Cm Cm;
struct Cm {
Cm *parent;
In *in;
U8 *code;
Z count;
@ -202,6 +206,8 @@ typedef struct Cm {
O mac;
O progn;
O def;
O and;
O or;
} specials;
struct {
Lv *data;
@ -209,7 +215,7 @@ typedef struct Cm {
Z capacity;
} locals;
I use_locals;
} Cm;
};
enum {
TOK_EOF = 0,

View file

@ -25,6 +25,17 @@ static void add_local(Cm *co, O sym) {
co->locals.count++;
}
static void mark_escaping(Cm *co, O sym) {
for (Z i = 0; i < co->locals.count; i++) {
if (co->locals.data[i].name == sym) {
co->locals.data[i].escapes = 1;
return;
}
}
}
static I is_local(Cm *co, O sym) { return find_local(co, sym) >= 0; }
static V emit(Cm *co, U8 byte) {
if (co->count >= co->capacity) {
Z newcap = co->capacity == 0 ? 16 : co->capacity * 2;
@ -162,18 +173,88 @@ static V compile_progn(Cm *co, O forms, I tail) {
}
}
static int nested_p(Cm *c, O body) {
while (body != NIL) {
O expr = list_next(c->in, &body);
if (type(expr) == TYPE_PAIR) {
Pa *p = pair_unwrap(c->in, expr);
if (p->head == c->specials.fn)
return 1;
if (nested_p(c, expr))
return 1;
// Find all free variables in an expression (vars used but not defined locally)
static void find_free_vars(Cm *co, O expr, Cm *parent) {
I ty = type(expr);
switch (ty) {
case TYPE_SYM:
// If this symbol isn't local to current function but is in parent scope
if (!is_local(co, expr) && parent && is_local(parent, expr)) {
// Mark it as escaping in the parent
mark_escaping(parent, expr);
}
break;
case TYPE_PAIR: {
Pa *p = pair_unwrap(co->in, expr);
O head = p->head;
// Skip quote - quoted expressions aren't variable references
if (type(head) == TYPE_SYM && head == co->specials.quote) {
return;
}
// Handle nested fn/mac - need to analyze their bodies for free variables
if (type(head) == TYPE_SYM &&
(head == co->specials.fn || head == co->specials.mac)) {
// This is a nested function definition
// Parse it to find what variables it captures
Pa *fn_tail = pair_unwrap(co->in, p->tail);
O nested_args = fn_tail->head;
O nested_body = fn_tail->tail;
// Create a temporary compiler for the nested function
Cm nested;
memset(&nested, 0, sizeof(Cm));
nested.in = co->in;
nested.specials = co->specials;
nested.parent = co;
// Add nested function's parameters as locals
O curr_arg = nested_args;
while (curr_arg != NIL) {
if (type(curr_arg) == TYPE_SYM) {
add_local(&nested, curr_arg);
break;
}
Pa *arg_pair = pair_unwrap(co->in, curr_arg);
add_local(&nested, arg_pair->head);
curr_arg = arg_pair->tail;
}
// Analyze the nested function's body to find free variables
O body_copy = nested_body;
while (body_copy != NIL) {
if (type(body_copy) != TYPE_PAIR) {
find_free_vars(&nested, body_copy, co);
break;
}
Pa *body_pair = pair_unwrap(co->in, body_copy);
find_free_vars(&nested, body_pair->head, co);
body_copy = body_pair->tail;
}
// Free the temporary compiler
if (nested.locals.data)
free(nested.locals.data);
return;
}
// For other expressions, recurse into all subexpressions
O curr = expr;
while (curr != NIL) {
if (type(curr) != TYPE_PAIR) {
break;
}
Pa *pair = pair_unwrap(co->in, curr);
find_free_vars(co, pair->head, parent);
curr = pair->tail;
}
break;
}
default:
break;
}
return 0;
}
// Compile a closure with `args` and `body`.
@ -183,13 +264,14 @@ static V compile_fn(Cm *c, O args, O body, I macro) {
memset(&ic, 0, sizeof(Cm));
ic.in = c->in;
ic.specials = c->specials;
ic.parent = c; // Link to parent scope
O curr = args;
O fixed[256];
int fixed_count = 0;
O rest = NIL;
// Count fixed arguments, and if the function has a rest argument
// Count fixed arguments
while (curr != NIL) {
if (type(curr) == TYPE_SYM) {
rest = curr;
@ -200,18 +282,46 @@ static V compile_fn(Cm *c, O args, O body, I macro) {
curr = p->tail;
}
int nested = nested_p(c, body);
if (rest == NIL && !nested) {
// If the function has no rest argument, and has no nested closures inside
// it, compile using stack locals as an optimization for tail-calls
ic.use_locals = 1;
curr = args;
for (O next = list_next(c->in, &curr); next != NIL;
next = list_next(c->in, &curr))
add_local(&ic, next);
} else {
// Otherwise, fallback to using environment bindings for locals
// Add parameters as locals to inner compiler
curr = args;
while (curr != NIL) {
if (type(curr) == TYPE_SYM) {
add_local(&ic, curr);
break;
}
Pa *p = pair_unwrap(c->in, curr);
add_local(&ic, p->head);
curr = p->tail;
}
// Analyze body to find free variables and mark them as escaping in parent
O body_copy = body;
while (body_copy != NIL) {
O expr = list_next(c->in, &body_copy);
find_free_vars(&ic, expr, c);
}
// Decide whether to use stack locals or environment bindings
// Use stack locals only if:
// 1. No rest parameter (rest parameters require environment bindings)
// 2. None of our own locals escape (are captured by nested functions)
ic.use_locals = 1;
// If we have a rest parameter, we must use environment bindings
if (rest != NIL)
ic.use_locals = 0;
// Check if any of this function's locals escape (are captured by nested
// functions) find_free_vars has already marked escaping variables in
// ic.locals
for (Z i = 0; i < ic.locals.count; i++) {
if (ic.locals.data[i].escapes) {
ic.use_locals = 0;
break;
}
}
// If we need environment bindings, emit BIND instructions
if (!ic.use_locals) {
if (rest != NIL) {
emit(&ic, OP_BIND_REST);
emit16(&ic, add_constant(&ic, rest));
@ -223,24 +333,123 @@ static V compile_fn(Cm *c, O args, O body, I macro) {
}
}
// Compile the function's body (as a `progn` form)
// Compile the function's body
compile_progn(&ic, body, 1);
O code = code_make(c->in, ic.code, ic.count, ic.constants.data,
ic.constants.count);
Z code_idx = add_constant(c, code);
Z args_idx = add_constant(c, args);
// Compile pushing the closure to the stack.
// Compile pushing the closure to the stack
emit(c, macro ? OP_MAC : OP_CLOS);
emit16(c, code_idx);
emit16(c, args_idx);
// Free the inner compiler context.
// Free the inner compiler context
free(ic.code);
if (ic.locals.data)
free(ic.locals.data);
}
// Compile the `(and ...)` special form
static V compile_and(Cm *co, O forms, I tail) {
if (forms == NIL) {
emit(co, OP_CONST);
emit16(co, add_constant(co, co->in->t));
if (tail)
emit(co, OP_RET);
return;
}
Z *jumps = NULL;
Z jump_count = 0;
Z jump_cap = 0;
while (forms != NIL) {
O expr = list_next(co->in, &forms);
if (forms == NIL) {
compile(co, expr, tail);
} else {
compile(co, expr, 0);
emit(co, OP_DUP);
emit(co, OP_JUMP_IF_NIL);
if (jump_count >= jump_cap) {
jump_cap = jump_cap == 0 ? 8 : jump_cap * 2;
jumps = realloc(jumps, jump_cap * sizeof(Z));
}
jumps[jump_count++] = co->count;
emit16(co, 0);
emit(co, OP_POP);
}
}
Z end = co->count;
for (Z i = 0; i < jump_count; i++) {
Z j = jumps[i];
co->code[j] = (U8)(end >> 8);
co->code[j + 1] = (U8)(end & 0xff);
}
if (jumps)
free(jumps);
if (tail)
emit(co, OP_RET);
}
// Compile the `(or ...)` special form
static V compile_or(Cm *co, O forms, I tail) {
if (forms == NIL) {
emit(co, OP_CONST);
emit16(co, add_constant(co, NIL));
if (tail)
emit(co, OP_RET);
return;
}
Z *jumps = NULL;
Z jump_count = 0;
Z jump_cap = 0;
while (forms != NIL) {
O expr = list_next(co->in, &forms);
if (forms == NIL) {
compile(co, expr, tail);
} else {
compile(co, expr, 0);
emit(co, OP_DUP);
emit(co, OP_JUMP_IF_NIL);
Z jump_next = co->count;
emit16(co, 0);
emit(co, OP_JUMP);
if (jump_count >= jump_cap) {
jump_cap = jump_cap == 0 ? 8 : jump_cap * 2;
jumps = realloc(jumps, jump_cap * sizeof(Z));
}
jumps[jump_count++] = co->count;
emit16(co, 0);
Z next = co->count;
co->code[jump_next] = (U8)(next >> 8);
co->code[jump_next + 1] = (U8)(next & 0xff);
emit(co, OP_POP);
}
}
Z end = co->count;
for (Z i = 0; i < jump_count; i++) {
Z j = jumps[i];
co->code[j] = (U8)(end >> 8);
co->code[j + 1] = (U8)(end & 0xff);
}
if (jumps)
free(jumps);
if (tail)
emit(co, OP_RET);
}
// Compile the `(def name value)` special form
static V compile_def(Cm *c, O args, I tail) {
O sym = list_next(c->in, &args);
@ -290,6 +499,12 @@ static V compile_apply(Cm *co, O expr, I tail) {
if (tail)
emit(co, OP_RET);
return;
} else if (head == co->specials.and) {
compile_and(co, p->tail, tail);
return;
} else if (head == co->specials.or) {
compile_or(co, p->tail, tail);
return;
} else if (head == co->specials.def) {
compile_def(co, p->tail, tail);
return;
@ -324,6 +539,10 @@ V compile(Cm *co, O expr, I tail) {
co->specials.def = symbol_make(co->in, "def");
if (co->specials.mac == NIL)
co->specials.mac = symbol_make(co->in, "mac");
if (co->specials.and == NIL)
co->specials.and = symbol_make(co->in, "and");
if (co->specials.or == NIL)
co->specials.or = symbol_make(co->in, "or");
switch (ty) {
case TYPE_NIL:
@ -338,11 +557,13 @@ V compile(Cm *co, O expr, I tail) {
if (co->use_locals) {
int idx = find_local(co, expr);
if (idx >= 0) {
emit(co, OP_GET_LOCAL);
emit16(co, idx);
if (tail)
emit(co, OP_RET);
return;
if (!co->locals.data[idx].escapes) {
emit(co, OP_GET_LOCAL);
emit16(co, idx);
if (tail)
emit(co, OP_RET);
return;
}
}
}
emit(co, OP_GET);

View file

@ -109,6 +109,9 @@ V disassemble(Cm *co) {
ofs += 2;
break;
}
case OP_DUP:
printf("DUP\n");
break;
default:
printf("%02x\n", co->code[ofs]);
}

View file

@ -31,10 +31,13 @@ V interp_init(In *in) {
PRIM("print", prim_print, 1, 1);
PRIM("println", prim_println, 1, 1);
PRIM("write", prim_write, 1, 1);
PRIM("env", prim_env, 0, 0);
PRIM("nil?", prim_nil_p, 1, 1);
PRIM("+", prim_add, 0, -1); // variadic
PRIM("-", prim_sub, 0, -1); // variadic
PRIM("*", prim_mul, 0, -1); // variadic
PRIM("/", prim_div, 0, -1); // variadic
PRIM("%", prim_mod, 0, -1); // variadic
PRIM("<", prim_lt, 2, 2);
PRIM(">", prim_gt, 2, 2);
PRIM("=", prim_equal, 0, -1); // variadic

View file

@ -70,7 +70,8 @@ O prim_write(In *in, O *args, int argc, O env) {
if (argc != 1)
error_throw(in, "write: expected 1 argument, got %d", argc);
if (type(args[0]) != TYPE_STR)
error_throw(in, "write: expected string argument, got %s", typename(type(args[0])));
error_throw(in, "write: expected string argument, got %s",
typename(type(args[0])));
Ss *s = (Ss *)(UNBOX(args[0]) + 1);
printf("%.*s", (int)s->len, s->data);
return NIL;
@ -132,6 +133,23 @@ O prim_div(In *in, O *args, int argc, O env) {
return NUM(result);
}
O prim_mod(In *in, O *args, int argc, O env) {
(void)env;
if (argc == 0)
return NUM(1);
if (!IMM(args[0]))
error_throw(in, "/: non numeric argument at position 0");
I result = ORD(args[0]);
for (int i = 1; i < argc; i++) {
if (!IMM(args[i]))
error_throw(in, "/: non numeric argument at position %d", i);
if (ORD(args[i]) == 0)
error_throw(in, "/: division by zero at position %d", i);
result %= ORD(args[i]);
}
return NUM(result);
}
O prim_equal(In *in, O *args, int argc, O env) {
(void)env;
if (argc < 2)
@ -168,6 +186,20 @@ O prim_gt(In *in, O *args, int argc, O env) {
}
}
O prim_nil_p(In *in, O *args, int argc, O env) {
(void)env;
if (argc != 1)
error_throw(in, "nil?: expected 1 argument, got %d", argc);
return BOOL(args[0] == NIL);
}
O prim_env(In *in, O *args, int argc, O env) {
(void)args;
(void)argc;
(void)env;
return in->env;
}
O prim_gc(In *in, O *args, int argc, O env) {
(void)args;
(void)argc;

View file

@ -11,7 +11,10 @@ O prim_add(In *in, O *args, int argc, O env);
O prim_sub(In *in, O *args, int argc, O env);
O prim_mul(In *in, O *args, int argc, O env);
O prim_div(In *in, O *args, int argc, O env);
O prim_mod(In *in, O *args, int argc, O env);
O prim_equal(In *in, O *args, int argc, O env);
O prim_lt(In *in, O *args, int argc, O env);
O prim_gt(In *in, O *args, int argc, O env);
O prim_nil_p(In *in, O *args, int argc, O env);
O prim_env(In *in, O *args, int argc, O env);
O prim_gc(In *in, O *args, int argc, O env);

View file

@ -15,9 +15,7 @@ static Sy *find(St *tab, const char *str, U32 hash, Z len) {
Sy *s = tab->data[ix];
if (!s)
return NULL;
if (s->hash == hash && s->len == len)
return s;
if (memcmp(s->data, str, len) == 0)
if (s->hash == hash && s->len == len && memcmp(s->data, str, len) == 0)
return s;
ix = (ix + 1) % tab->capacity;
}

View file

@ -228,6 +228,11 @@ static O vm_exec(Cm *co_in, O env_in, int argc_in) {
PUSH(BOX(hdr));
break;
}
case OP_DUP: {
O val = PEEK();
PUSH(val);
break;
}
default:
error_throw(in, "unknown opcode %d", op);
}

View file

@ -1,8 +0,0 @@
(def twice (mac (x)
(list 'progn x x)))
(def when (mac (cond . body)
(list 'if cond (cons 'progn body))))
(when (= 1 1)
(twice (write "ok\n")))