ayin/src/interpret/interpret.rs
2025-12-20 22:31:46 +02:00

659 lines
24 KiB
Rust

//! Interpreter for Ayin.
use super::types::*;
use std::backtrace::Backtrace;
use std::collections::BTreeMap;
use crate::ast;
pub fn global_env_name() -> ast::EnvName {
ast::EnvName("global".to_string())
}
pub fn setup(program: ast::Program, prim_funcs: PrimitiveFuncs) -> Result<State, Error> {
let mut state = State::new(global_env_name().0.clone(), prim_funcs);
defs_to_env(program.0, &global_env_name(), &mut state)?;
Ok(state)
}
pub fn interpret(
state: &mut State,
func: ast::Name,
args: Vec<ast::Expr>,
) -> Result<ast::Value, Error> {
let main = ast::Expr::Value(state.envs.get(&global_env_name())?.get(&func)?.clone());
let expr = ast::Expr::FunCall {
func: Box::new(main),
args,
};
let env: Env = state.envs.get(&global_env_name())?.clone();
let result = eval_expr(&env, state, &expr)?;
Ok(result)
}
pub fn run(
program: ast::Program,
func: ast::Name,
args: Vec<ast::Expr>,
) -> Result<ast::Expr, Error> {
let mut state = setup(program, PrimitiveFuncs::new(vec![]))?;
let value = interpret(&mut state, func, args)?;
Ok(value_to_stadnalone_expr(&state, value).unwrap())
}
enum StatementResult {
Result(ast::Value),
Return(ast::Value),
}
fn eval_statement(
expr_env: &mut Env,
state: &mut State,
statement: &ast::Statement,
) -> Result<StatementResult, Error> {
match statement {
ast::Statement::Expr(expr) => eval_expr(expr_env, state, expr).map(StatementResult::Result),
ast::Statement::Let(ast::Definition {
mutable,
name,
expr,
}) => {
let mut result = eval_expr(expr_env, state, expr)?;
if *mutable {
let reference = state.variables.insert(result);
result = ast::Value::Ref(reference);
}
expr_env.insert(name.clone(), result.clone());
Ok(StatementResult::Result(result))
}
ast::Statement::Return(expr) => match expr {
Some(expr) => eval_expr(expr_env, state, expr).map(StatementResult::Return),
None => eval_expr(expr_env, state, &ast::UNIT).map(StatementResult::Return),
},
}
}
fn eval_expr(expr_env: &Env, state: &mut State, expr: &ast::Expr) -> Result<ast::Value, Error> {
match expr {
ast::Expr::If {
condition,
then,
r#else,
} => match eval_expr(expr_env, state, condition)? {
ast::Value::Boolean(condition) => {
if condition {
eval_expr(expr_env, state, then)
} else {
eval_expr(expr_env, state, r#else)
}
}
v => Err(Error::NotABoolean(v, Backtrace::capture())),
},
ast::Expr::Func(func) => {
let env_name = state.generate_env(expr_env.clone())?;
Ok(ast::Value::Closure {
expr: Box::new(ast::Expr::Func(func.clone())),
env: env_name,
})
}
ast::Expr::Block(statements) => {
let mut block_env = expr_env.clone();
let mut result = ast::UNIT_VALUE;
for statement in statements {
let r = eval_statement(&mut block_env, state, statement)?;
match r {
StatementResult::Result(ast::Value::Return(r)) => {
result = ast::Value::Return(r);
break;
}
StatementResult::Return(ast::Value::Return(r)) => {
result = ast::Value::Return(r);
break;
}
StatementResult::Result(r) => {
result = r;
}
StatementResult::Return(r) => {
result = ast::Value::Return(Box::new(r));
break;
}
}
}
Ok(result)
}
ast::Expr::Access { expr, field } => match eval_expr(expr_env, state, expr)? {
ast::Value::Record(record) => Ok(state.variables.get(record.get(&field)?).clone()),
v => Err(Error::NotARecord(v, Backtrace::capture())),
},
ast::Expr::Var(var) => {
let value = match expr_env.get(var) {
Ok(value) => Ok(value.clone()),
Err(err) => {
if state.primitive_funcs.get(&var).is_some() {
Ok(ast::Value::PrimitiveFunc(var.clone()))
} else {
Err(err)
}
}
}?;
eval_expr(expr_env, state, &ast::Expr::Value(value))
}
ast::Expr::FunCall { func, args } => match eval_expr(expr_env, state, func)? {
ast::Value::Closure { expr, env } => match *expr {
ast::Expr::Func(func) => {
let mut closure_env: Env = state.envs.get(&env)?.clone();
if func.args.len() != args.len() {
Err(Error::ArgumentsMismatch(Backtrace::capture()))?;
}
for (arg, e) in func.args.into_iter().zip(args.iter()) {
let evalled = eval_expr(&expr_env, state, e)?;
closure_env.insert(arg.name, evalled);
}
match eval_expr(&closure_env, state, &func.body)? {
ast::Value::Return(r) => Ok(*r),
v => Ok(v),
}
}
e => Err(Error::NotAFunction(e, Backtrace::capture())),
},
ast::Value::PrimitiveFunc(func) => {
let mut argsvec = vec![];
for arg in args.iter() {
let evalled = eval_expr(&expr_env, state, arg)?;
let e = value_to_stadnalone_expr(state, evalled)?;
argsvec.push(e);
}
if let Some(func) = state.primitive_funcs.get(&func) {
Ok(func(argsvec))
} else {
Err(Error::NotAFunction(
ast::Expr::Value(ast::Value::PrimitiveFunc(func)),
Backtrace::capture(),
))
}
}
e => Err(Error::NotAFunction(
ast::Expr::Value(e),
Backtrace::capture(),
)),
},
ast::Expr::Not(e) => match eval_expr(&expr_env, state, e)? {
ast::Value::Boolean(b) => Ok(ast::Value::Boolean(!b)),
v => Err(Error::NotABoolean(v, Backtrace::capture())),
},
ast::Expr::Value(v) => match v {
ast::Value::Closure { expr, env } => {
let closure_env: Env = state.envs.get(env)?.clone();
eval_expr(&closure_env, state, expr)
}
ast::Value::Ref(reference) => Ok(state.variables.get(reference).clone()),
_ => Ok(v.clone()),
},
ast::Expr::Op { lhs, rhs, op } => match op {
ast::Op::Assign => match eval_expr_shallow(expr_env, state, lhs)? {
ast::Value::Ref(reference) => {
let rhs = eval_expr(expr_env, state, rhs)?;
state.variables.set(reference, rhs.clone());
Ok(rhs)
}
v => Err(Error::NotAReference(
*lhs.clone(),
v.clone(),
Backtrace::capture(),
)),
},
ast::Op::Calc(cop) => {
let lhs = eval_expr(expr_env, state, lhs)?;
let rhs = eval_expr(expr_env, state, rhs)?;
match (lhs, rhs) {
(ast::Value::Int(a), ast::Value::Int(b)) => {
let cop = match cop {
ast::Calc::Add => |a, b| a + b,
ast::Calc::Sub => |a, b| a - b,
ast::Calc::Mul => |a, b| a * b,
ast::Calc::Div => |a, b| a / b,
ast::Calc::Mod => |a, b| a % b,
ast::Calc::BinAnd => |a, b| a & b,
ast::Calc::BinOr => |a, b| a | b,
};
Ok(ast::Value::Int(cop(a, b)))
}
(ast::Value::Float(a), ast::Value::Int(b)) => {
let cop = match cop {
ast::Calc::Add => |a, b| a + b,
ast::Calc::Sub => |a, b| a - b,
ast::Calc::Mul => |a, b| a * b,
ast::Calc::Div => |a, b| a / b,
ast::Calc::Mod => |a, b| a % b,
_ => todo!(),
};
Ok(ast::Value::Float(cop(a, b as f32)))
}
(ast::Value::Int(a), ast::Value::Float(b)) => {
let cop = match cop {
ast::Calc::Add => |a, b| a + b,
ast::Calc::Sub => |a, b| a - b,
ast::Calc::Mul => |a, b| a * b,
ast::Calc::Div => |a, b| a / b,
ast::Calc::Mod => |a, b| a % b,
_ => todo!(),
};
Ok(ast::Value::Float(cop(a as f32, b)))
}
(ast::Value::Float(a), ast::Value::Float(b)) => {
let cop = match cop {
ast::Calc::Add => |a, b| a + b,
ast::Calc::Sub => |a, b| a - b,
ast::Calc::Mul => |a, b| a * b,
ast::Calc::Div => |a, b| a / b,
ast::Calc::Mod => |a, b| a % b,
_ => todo!(),
};
Ok(ast::Value::Float(cop(a, b)))
}
(lhs, rhs) => Err(Error::OpError(lhs, op.clone(), rhs, Backtrace::capture())),
}
}
ast::Op::Compare(cop) => {
let lhs = eval_expr(expr_env, state, lhs)?;
let rhs = eval_expr(expr_env, state, rhs)?;
match (lhs, rhs) {
(ast::Value::Int(a), ast::Value::Int(b)) => {
let cop = match cop {
ast::Cmp::Eq => |a, b| a == b,
ast::Cmp::NotEq => |a, b| a != b,
ast::Cmp::Gt => |a, b| a > b,
ast::Cmp::Gte => |a, b| a >= b,
ast::Cmp::Lt => |a, b| a < b,
ast::Cmp::Lte => |a, b| a <= b,
};
Ok(ast::Value::Boolean(cop(a, b)))
}
(ast::Value::Int(a), ast::Value::Float(b)) => {
let cop = match cop {
ast::Cmp::Eq => |a, b| a == b,
ast::Cmp::NotEq => |a, b| a != b,
ast::Cmp::Gt => |a, b| a > b,
ast::Cmp::Gte => |a, b| a >= b,
ast::Cmp::Lt => |a, b| a < b,
ast::Cmp::Lte => |a, b| a <= b,
};
Ok(ast::Value::Boolean(cop(a as f32, b)))
}
(ast::Value::Float(a), ast::Value::Int(b)) => {
let cop = match cop {
ast::Cmp::Eq => |a, b| a == b,
ast::Cmp::NotEq => |a, b| a != b,
ast::Cmp::Gt => |a, b| a > b,
ast::Cmp::Gte => |a, b| a >= b,
ast::Cmp::Lt => |a, b| a < b,
ast::Cmp::Lte => |a, b| a <= b,
};
Ok(ast::Value::Boolean(cop(a, b as f32)))
}
(ast::Value::Float(a), ast::Value::Float(b)) => {
let cop = match cop {
ast::Cmp::Eq => |a, b| a == b,
ast::Cmp::NotEq => |a, b| a != b,
ast::Cmp::Gt => |a, b| a > b,
ast::Cmp::Gte => |a, b| a >= b,
ast::Cmp::Lt => |a, b| a < b,
ast::Cmp::Lte => |a, b| a <= b,
};
Ok(ast::Value::Boolean(cop(a, b)))
}
(ast::Value::Boolean(a), ast::Value::Boolean(b)) => {
let cop = match cop {
ast::Cmp::Eq => |a, b| a == b,
ast::Cmp::NotEq => |a, b| a != b,
ast::Cmp::Gt => |a, b| a > b,
ast::Cmp::Gte => |a, b| a >= b,
ast::Cmp::Lt => |a, b| a < b,
ast::Cmp::Lte => |a, b| a <= b,
};
Ok(ast::Value::Boolean(cop(a, b)))
}
(ast::Value::String(a), ast::Value::String(b)) => {
let cop = match cop {
ast::Cmp::Eq => |a, b| a == b,
ast::Cmp::NotEq => |a, b| a != b,
ast::Cmp::Gt => |a, b| a > b,
ast::Cmp::Gte => |a, b| a >= b,
ast::Cmp::Lt => |a, b| a < b,
ast::Cmp::Lte => |a, b| a <= b,
};
Ok(ast::Value::Boolean(cop(a, b)))
}
(lhs, rhs) => Err(Error::OpError(lhs, op.clone(), rhs, Backtrace::capture())),
}
}
ast::Op::Bool(bop) => {
let lhs = eval_expr(expr_env, state, lhs)?;
let rhs = eval_expr(expr_env, state, rhs)?;
match (lhs, rhs) {
(ast::Value::Boolean(a), ast::Value::Boolean(b)) => {
let bop = match bop {
ast::BoolOp::And => |a, b| a && b,
ast::BoolOp::Or => |a, b| a || b,
};
Ok(ast::Value::Boolean(bop(a, b)))
}
(lhs, rhs) => Err(Error::OpError(lhs, op.clone(), rhs, Backtrace::capture())),
}
}
},
ast::Expr::Record(record) => {
let mut map = BTreeMap::new();
for (field, expr) in record {
let value = eval_expr(expr_env, state, expr)?;
let reference = state.variables.insert(value);
map.insert(field.clone(), reference);
}
Ok(ast::Value::Record(ast::Record(map)))
}
ast::Expr::Vector(vector) => {
let mut vec = Vec::with_capacity(vector.len());
for expr in vector {
let value = eval_expr(expr_env, state, expr)?;
let reference = state.variables.insert(value);
vec.push(reference);
}
Ok(ast::Value::Vector(ast::Vector(vec)))
}
}
}
fn eval_expr_shallow(
expr_env: &Env,
state: &mut State,
expr: &ast::Expr,
) -> Result<ast::Value, Error> {
match expr {
ast::Expr::Var(var) => {
let result = expr_env.get(&var)?;
match result {
ast::Value::Ref(_) => Ok(result.clone()),
v => Err(Error::NotAReference(
expr.clone(),
v.clone(),
Backtrace::capture(),
)),
}
}
ast::Expr::Access { expr, field } => match eval_expr(expr_env, state, expr)? {
ast::Value::Record(record) => Ok(ast::Value::Ref(record.get(&field)?.clone())),
v => Err(Error::NotARecord(v, Backtrace::capture())),
},
_ => eval_expr(expr_env, state, expr),
}
}
fn defs_to_env(
defs: Vec<ast::Definition>,
env_name: &ast::EnvName,
state: &mut State,
) -> Result<(), Error> {
let mut env = Env::new(env_name.clone());
for def in defs {
let (mutable, name, closure) = match def {
ast::Definition {
mutable,
expr,
name,
} => (
mutable,
name,
ast::Value::Closure {
expr: Box::new(expr.clone()),
env: env_name.clone(),
},
),
};
if mutable {
let reference = state.variables.insert(closure);
env.insert_nodup(&name, ast::Value::Ref(reference))?;
} else {
env.insert_nodup(&name, closure)?;
}
}
state.envs.0.insert(env_name.clone(), env);
Ok(())
}
pub fn value_to_stadnalone_expr(state: &State, value: ast::Value) -> Result<ast::Expr, Error> {
match value {
ast::Value::Ref(reference) => {
let value = state.variables.get(&reference).clone();
value_to_stadnalone_expr(state, value)
}
ast::Value::Vector(vector) => {
let mut vec = Vec::with_capacity(vector.0.len());
for reference in vector.0 {
let expr = value_to_stadnalone_expr(state, ast::Value::Ref(reference))?;
vec.push(expr);
}
Ok(ast::Expr::Vector(vec))
}
ast::Value::Record(record) => {
let mut map = BTreeMap::new();
for (field, reference) in record.0 {
let expr = value_to_stadnalone_expr(state, ast::Value::Ref(reference))?;
map.insert(field.clone(), expr);
}
Ok(ast::Expr::Record(map))
}
ast::Value::Closure { .. } => Err(Error::MigrationError(
value.clone(),
"Closure migration not supported".into(),
Backtrace::capture(),
)),
_ => Ok(ast::Expr::Value(value)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::helpers;
#[test]
fn main_0() {
let program = vec![helpers::define_expr(
"main",
helpers::func(vec![], 0.into()),
)]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(0.into()));
}
#[test]
fn main_not() {
let program = vec![helpers::define_expr(
"main",
helpers::func(vec![], ast::Expr::Not(Box::new(false.into()))),
)]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(true.into()));
}
#[test]
fn main_or() {
let program = vec![helpers::define_expr(
"main",
helpers::func(
vec![],
helpers::op(ast::Op::Bool(ast::BoolOp::Or), false.into(), true.into()),
),
)]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(true.into()));
}
#[test]
fn main_and() {
let program = vec![helpers::define_expr(
"main",
helpers::func(
vec![],
helpers::op(ast::Op::Bool(ast::BoolOp::And), false.into(), true.into()),
),
)]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(false.into()));
}
#[test]
fn main_add() {
let program = vec![helpers::define_expr(
"main",
helpers::func(
vec![],
helpers::op(ast::Op::Calc(ast::Calc::Add), 7.into(), 8.into()),
),
)]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(15.into()));
}
#[test]
fn var_lookup() {
let program = vec![
helpers::define_expr("main", helpers::func(vec![], "lit".into())),
helpers::define_expr("lit", 0.into()),
]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(0.into()));
}
#[test]
fn var_assign_and_lookup() {
let program = vec![helpers::define_expr(
"main",
helpers::func(
vec![],
vec![
helpers::assign("zero", 0.into()),
helpers::stmt_expr("zero".into()),
]
.into(),
),
)]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(0.into()));
}
#[test]
fn field_access() {
let program = vec![
helpers::define_expr("main", helpers::func(vec![], "record".into())),
helpers::define_expr(
"record",
ast::Expr::from(vec![("my_field", 0.into())]).field("my_field"),
),
]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(0.into()));
}
#[test]
fn fun_call() {
let program = vec![
helpers::define_expr("main", helpers::func(vec![], "zero".into())),
helpers::define_expr("zero", helpers::func(vec![], 0.into()).call(vec![])),
]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(0.into()));
}
#[test]
fn if_then_else() {
let program = vec![helpers::define_expr(
"main",
helpers::func(
vec![],
ast::Expr::If {
condition: Box::new(false.into()),
then: Box::new(0.into()),
r#else: Box::new(ast::Expr::If {
condition: Box::new(true.into()),
then: Box::new(1.into()),
r#else: Box::new(2.into()),
}),
},
),
)]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(1.into()));
}
#[test]
fn fun_call_args() {
let program = vec![
helpers::define_expr("main", helpers::func(vec![], "zero".into())),
helpers::define_expr(
"zero",
helpers::func(
vec![ast::Arg { name: "a".into() }, ast::Arg { name: "b".into() }],
"b".into(),
)
.call(vec![1.into(), 0.into()]),
),
]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(result, Ok(0.into()));
}
// Errors
#[test]
fn duplicate_toplevel_defs() {
let program = vec![
helpers::define_expr("main", helpers::func(vec![], "record".into())),
helpers::define_expr("main", 0.into()),
]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(
result,
Err(Error::DuplicateNames("main".into(), Backtrace::capture()))
);
}
#[test]
fn field_access_not_a_record() {
let program = vec![
helpers::define_expr("main", helpers::func(vec![], "record".into())),
helpers::define_expr("record", ast::Expr::from(0).field("my_field")),
]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(
result,
Err(Error::NotARecord(0.into(), Backtrace::capture()))
);
}
#[test]
fn fun_call_not_a_function() {
let program = vec![
helpers::define_expr("main", helpers::func(vec![], "zero".into())),
helpers::define_expr("zero", ast::Expr::from(0).call(vec![1.into()])),
]
.into();
let result = run(program, "main".into(), vec![]);
assert_eq!(
result,
Err(Error::NotAFunction(0.into(), Backtrace::capture()))
);
}
}