From 47c8353366c5cd7544d182a897bacaa303c08d8e Mon Sep 17 00:00:00 2001 From: Mistivia Date: Sun, 22 Jun 2025 17:23:32 +0800 Subject: math functions --- Makefile | 2 +- src/builtins.c | 76 +++++++++++++++++++++++++++++++++++++++++++-- src/builtins.h | 20 ++++++++++++ src/interp.c | 61 ++++++++++++++++++++++++------------ src/prelude.c | 2 +- src/prelude.lisp | 1 + tests/arithmetic.lisp | 13 ++++++++ tests/math.lisp | 86 +++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test.lisp | 1 + 9 files changed, 238 insertions(+), 24 deletions(-) create mode 100644 tests/math.lisp diff --git a/Makefile b/Makefile index 0ac435b..ac30fc3 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ mode ?= debug cc = gcc includes = -I3rdparty/algds/build/include/ 3rdlibs = 3rdparty/algds/build/lib/libalgds.a -ldflags = -lreadline +ldflags = -lm -lreadline ifeq ($(mode), debug) cflags = $(includes) \ -g \ diff --git a/src/builtins.c b/src/builtins.c index fbf5855..b8040ef 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -4,6 +4,78 @@ #include #include #include +#include + +SExpRef builtin_float(Interp *interp, SExpRef args) { + if (LENGTH(args) != 1) return new_error(interp, "float: expect 1 arg.\n"); + SExpRef x = CAR(args); + if (VALTYPE(x) != kIntegerSExp) return new_error(interp, "float: wrong type.\n"); + return new_real(interp, REF(x)->integer); +} + +SExpRef builtin_abs(Interp *interp, SExpRef args) { + if (LENGTH(args) != 1) return new_error(interp, "abs: expect 1 arg.\n"); + SExpRef x = CAR(args); + if (VALTYPE(x) != kIntegerSExp && VALTYPE(x) != kRealSExp) { + return new_error(interp, "abs: wrong type.\n"); + } + if (VALTYPE(x) == kIntegerSExp) { + int64_t val = REF(x)->integer; + if (val < 0) val = -val; + return new_integer(interp, val); + } else { + double val = REF(x)->real; + if (val < 0) val = -val; + return new_real(interp, val); + } +} + +static double real_value(Interp *interp, SExpRef x) { + if (VALTYPE(x) == kIntegerSExp) { + return REF(x)->integer; + } else { + return REF(x)->real; + } +} + +SExpRef builtin_pow(Interp *interp, SExpRef args) { + if (LENGTH(args) != 2) return new_error(interp, "pow: expect 2 args.\n"); + SExpRef x = CAR(args), y = CADR(args); + if (VALTYPE(x) != kIntegerSExp && VALTYPE(x) != kRealSExp) { + return new_error(interp, "pow: wrong type.\n"); + } + if (VALTYPE(y) != kIntegerSExp && VALTYPE(y) != kRealSExp) { + return new_error(interp, "pow: wrong type.\n"); + } + return new_real(interp, pow(real_value(interp, x), real_value(interp, y))); +} + +#define GEN_MATH_FUNC(name, cfunc) \ +SExpRef builtin_##name(Interp *interp, SExpRef args) { \ + if (LENGTH(args) != 1) return new_error(interp, #name": expect 1 args.\n"); \ + SExpRef x = CAR(args); \ + if (VALTYPE(x) != kIntegerSExp && VALTYPE(x) != kRealSExp) { \ + return new_error(interp, #name": wrong type.\n"); \ + } \ + return new_real(interp, cfunc(real_value(interp, x))); \ +} + +GEN_MATH_FUNC(sqrt, sqrt); +GEN_MATH_FUNC(cbrt, cbrt); +GEN_MATH_FUNC(floor, floor); +GEN_MATH_FUNC(truncate, trunc); +GEN_MATH_FUNC(ceiling, ceil); +GEN_MATH_FUNC(round, round); +GEN_MATH_FUNC(sin, sin); +GEN_MATH_FUNC(cos, cos); +GEN_MATH_FUNC(tan, tan); +GEN_MATH_FUNC(asin, asin); +GEN_MATH_FUNC(acos, acos); +GEN_MATH_FUNC(atan, atan); +GEN_MATH_FUNC(ln, log); +GEN_MATH_FUNC(log10, log10); +GEN_MATH_FUNC(log2, log2); +GEN_MATH_FUNC(exp, exp); SExpRef builtin_min(Interp *interp, SExpRef args) { if (LENGTH(args) < 1) return new_error(interp, "min: wrong arg number.\n"); @@ -294,9 +366,9 @@ static SExp raw_add(SExp a, SExp b) { static SExp raw_mul(SExp a, SExp b) { if (a.type == kRealSExp || b.type == kRealSExp) { double result = 1.0; - if (a.type == kRealSExp) result += a.real; + if (a.type == kRealSExp) result *= a.real; else result *= a.integer; - if (b.type == kRealSExp) result += b.real; + if (b.type == kRealSExp) result *= b.real; else result *= b.integer; return (SExp){ .type = kRealSExp, .real = result }; } else { diff --git a/src/builtins.h b/src/builtins.h index abcd7a0..3c54bdc 100644 --- a/src/builtins.h +++ b/src/builtins.h @@ -3,6 +3,26 @@ #include "interp.h" + +SExpRef builtin_sqrt(Interp *interp, SExpRef sexp); +SExpRef builtin_cbrt(Interp *interp, SExpRef sexp); +SExpRef builtin_float(Interp *interp, SExpRef sexp); +SExpRef builtin_abs(Interp *interp, SExpRef sexp); +SExpRef builtin_pow(Interp *interp, SExpRef sexp); +SExpRef builtin_floor(Interp *interp, SExpRef sexp); +SExpRef builtin_truncate(Interp *interp, SExpRef sexp); +SExpRef builtin_ceiling(Interp *interp, SExpRef sexp); +SExpRef builtin_round(Interp *interp, SExpRef sexp); +SExpRef builtin_sin(Interp *interp, SExpRef sexp); +SExpRef builtin_cos(Interp *interp, SExpRef sexp); +SExpRef builtin_tan(Interp *interp, SExpRef sexp); +SExpRef builtin_asin(Interp *interp, SExpRef sexp); +SExpRef builtin_acos(Interp *interp, SExpRef sexp); +SExpRef builtin_atan(Interp *interp, SExpRef sexp); +SExpRef builtin_ln(Interp *interp, SExpRef sexp); +SExpRef builtin_log10(Interp *interp, SExpRef sexp); +SExpRef builtin_log2(Interp *interp, SExpRef sexp); +SExpRef builtin_exp(Interp *interp, SExpRef sexp); SExpRef builtin_min(Interp *interp, SExpRef sexp); SExpRef builtin_max(Interp *interp, SExpRef sexp); SExpRef builtin_equal(Interp *interp, SExpRef sexp); diff --git a/src/interp.c b/src/interp.c index 8813d37..a8c9ad7 100644 --- a/src/interp.c +++ b/src/interp.c @@ -86,33 +86,54 @@ void Interp_init(Interp *self) { Interp_add_primitive(self, "assert-error", primitive_assert_error); Interp_add_primitive(self, "load", primitive_load); - Interp_add_userfunc(self, "min", builtin_min); - Interp_add_userfunc(self, "max", builtin_max); + Interp_add_userfunc(self, "round", builtin_round); + Interp_add_userfunc(self, "acos", builtin_acos); + Interp_add_userfunc(self, "floor", builtin_floor); + Interp_add_userfunc(self, "asin", builtin_asin); + Interp_add_userfunc(self, "log2", builtin_log2); + Interp_add_userfunc(self, "pow", builtin_pow); + Interp_add_userfunc(self, "float", builtin_float); Interp_add_userfunc(self, "eq", builtin_eq); - Interp_add_userfunc(self, "equal", builtin_equal); - Interp_add_userfunc(self, "format", builtin_format); + Interp_add_userfunc(self, "ln", builtin_ln); + Interp_add_userfunc(self, "=", builtin_num_equal); + Interp_add_userfunc(self, "/=", builtin_num_neq); Interp_add_userfunc(self, "concat", builtin_concat); - Interp_add_userfunc(self, "error", builtin_error); Interp_add_userfunc(self, "print", builtin_print); - Interp_add_userfunc(self, "princ", builtin_princ); - Interp_add_userfunc(self, "car", builtin_car); - Interp_add_userfunc(self, "list", builtin_list); - Interp_add_userfunc(self, "cdr", builtin_cdr); - Interp_add_userfunc(self, "cons", builtin_cons); - Interp_add_userfunc(self, "+", builtin_add); + Interp_add_userfunc(self, "format", builtin_format); + Interp_add_userfunc(self, "truncate", builtin_truncate); + Interp_add_userfunc(self, "mod", builtin_mod); + Interp_add_userfunc(self, "i/", builtin_idiv); Interp_add_userfunc(self, "-", builtin_sub); + Interp_add_userfunc(self, "abs", builtin_abs); Interp_add_userfunc(self, "*", builtin_mul); - Interp_add_userfunc(self, "/", builtin_div); - Interp_add_userfunc(self, "i/", builtin_idiv); - Interp_add_userfunc(self, "mod", builtin_mod); - Interp_add_userfunc(self, "=", builtin_num_equal); - Interp_add_userfunc(self, "/=", builtin_num_neq); - Interp_add_userfunc(self, "<", builtin_lt); + Interp_add_userfunc(self, "tan", builtin_tan); + Interp_add_userfunc(self, "exp", builtin_exp); + Interp_add_userfunc(self, "log10", builtin_log10); + Interp_add_userfunc(self, "list", builtin_list); + Interp_add_userfunc(self, "car", builtin_car); + Interp_add_userfunc(self, "sin", builtin_sin); + Interp_add_userfunc(self, "max", builtin_max); + Interp_add_userfunc(self, "exit", builtin_exit); + Interp_add_userfunc(self, "not", builtin_not); + Interp_add_userfunc(self, "cos", builtin_cos); + Interp_add_userfunc(self, "<=", builtin_le); + Interp_add_userfunc(self, "princ", builtin_princ); Interp_add_userfunc(self, ">", builtin_gt); + Interp_add_userfunc(self, "+", builtin_add); + Interp_add_userfunc(self, "equal", builtin_equal); + Interp_add_userfunc(self, "/", builtin_div); + Interp_add_userfunc(self, "atan", builtin_atan); + Interp_add_userfunc(self, "cons", builtin_cons); + Interp_add_userfunc(self, "cdr", builtin_cdr); + Interp_add_userfunc(self, "ceiling", builtin_ceiling); + Interp_add_userfunc(self, "min", builtin_min); + Interp_add_userfunc(self, "error", builtin_error); Interp_add_userfunc(self, ">=", builtin_ge); - Interp_add_userfunc(self, "<=", builtin_le); - Interp_add_userfunc(self, "not", builtin_not); - Interp_add_userfunc(self, "exit", builtin_exit); + Interp_add_userfunc(self, "<", builtin_lt); + Interp_add_userfunc(self, "sqrt", builtin_sqrt); + Interp_add_userfunc(self, "cbrt", builtin_cbrt); + + Interp_add_userfunc(self, "_gcstat", builtin_gcstat); SExpRef ret = Interp_eval_string(self, bamboo_lisp_prelude); diff --git a/src/prelude.c b/src/prelude.c index 3d1a971..ca9109d 100644 --- a/src/prelude.c +++ b/src/prelude.c @@ -1,6 +1,6 @@ #include "prelude.h" -const char *bamboo_lisp_prelude = "(defvar nil \'())\n\n(defvar pi 3.1415926)\n\n(defmacro incq (i)\n `(setq ,i (+ ,i 1)))\n\n(defmacro decq (i)\n `(setq ,i (- ,i 1)))\n\n(defun zerop (x) (= x 0))\n(defun plusp (x) (> x 0))\n(defun minusp (x) (< x 0))\n\n(defmacro when (pred . body)\n `(if ,pred\n (progn ,@body)\n nil))\n\n(defmacro unless (pred . body)\n `(if ,pred\n nil\n (progn ,@body)))\n"; +const char *bamboo_lisp_prelude = "(defvar nil \'())\n\n(defvar pi 3.1415926)\n(defvar e 2.718281828)\n\n(defmacro incq (i)\n `(setq ,i (+ ,i 1)))\n\n(defmacro decq (i)\n `(setq ,i (- ,i 1)))\n\n(defun zerop (x) (= x 0))\n(defun plusp (x) (> x 0))\n(defun minusp (x) (< x 0))\n\n(defmacro when (pred . body)\n `(if ,pred\n (progn ,@body)\n nil))\n\n(defmacro unless (pred . body)\n `(if ,pred\n nil\n (progn ,@body)))\n"; diff --git a/src/prelude.lisp b/src/prelude.lisp index df85a9b..7e9992b 100644 --- a/src/prelude.lisp +++ b/src/prelude.lisp @@ -1,6 +1,7 @@ (defvar nil '()) (defvar pi 3.1415926) +(defvar e 2.718281828) (defmacro incq (i) `(setq ,i (+ ,i 1))) diff --git a/tests/arithmetic.lisp b/tests/arithmetic.lisp index e8634d7..1942a1d 100644 --- a/tests/arithmetic.lisp +++ b/tests/arithmetic.lisp @@ -5,6 +5,19 @@ (assert (= 2 (i/ 11 5))) (assert (= 1 (mod 11 5))) +(assert (zerop 0)) +(assert (not (zerop 1))) +(assert (not (zerop -1))) + +(assert (plusp 1)) +(assert (plusp 1.0)) +(assert (not (plusp 0))) +(assert (not (plusp -1))) + +(assert (minusp -1)) +(assert (not (minusp 0))) +(assert (not (minusp 1))) + (assert (< 1 2)) (assert (< 1.0 2)) (assert (not (> 1 2))) diff --git a/tests/math.lisp b/tests/math.lisp new file mode 100644 index 0000000..53e4668 --- /dev/null +++ b/tests/math.lisp @@ -0,0 +1,86 @@ +(defun ~~ (a b) + (if (< (abs (- a b)) 0.01) + nil + (error "failed"))) + +(assert (= 1 (abs -1))) +(assert (= 1.1 (abs -1.1))) +(assert (= 1 (abs 1))) +(assert (= 1.1 (abs 1.1))) + +(~~ 3.141 pi) +(assert-error (~~ 3.2 pi)) + +(~~ 2.718 e) + +(assert (= 1.0 (float 1))) +(assert (= -1.0 (float -1))) + +(~~ 8 (pow 2 3)) +(~~ 1.414 (pow 2 0.5)) +(~~ 1.732 (pow 3 0.5)) + +(~~ 2.0 (floor 2.1)) +(~~ 2.0 (floor 2.5)) +(~~ 2.0 (floor 2.7)) +(~~ -2.0 (floor -1.1)) +(~~ -2.0 (floor -1.5)) +(~~ -2.0 (floor -1.7)) + +(~~ 2.0 (truncate 2.1)) +(~~ 2.0 (truncate 2.5)) +(~~ 2.0 (truncate 2.7)) +(~~ -2.0 (truncate -2.1)) +(~~ -2.0 (truncate -2.5)) +(~~ -2.0 (truncate -2.7)) + +(~~ 2.0 (ceiling 1.1)) +(~~ 2.0 (ceiling 1.5)) +(~~ 2.0 (ceiling 1.7)) +(~~ -2.0 (ceiling -2.1)) +(~~ -2.0 (ceiling -2.5)) +(~~ -2.0 (ceiling -2.7)) + +(~~ 2.0 (round 2.1)) +(~~ 2.0 (round 1.5)) +(~~ 2.0 (round 1.7)) +(~~ -2.0 (round -2.1)) +(~~ -2.0 (round -1.5)) +(~~ -2.0 (round -1.7)) + +(~~ 0 (sin 0)) +(~~ 1 (sin (/ pi 2))) +(~~ -1 (sin (- (/ pi 2)))) + +(~~ 1 (cos 0)) +(~~ 0 (cos (/ pi 2))) +(~~ 0 (cos (- (/ pi 2)))) + +(~~ (tan 1.1234) (/ (sin 1.1234) (cos 1.1234))) + +(~~ (asin 0.5) 0.525) +(~~ (acos 0.5) 1.047) +(~~ (atan 0.5) 0.463) + +(~~ 0 (ln 1)) +(~~ 1 (ln e)) +(~~ 2 (ln (* e e))) +(~~ 1.5 (ln (* e (sqrt e)))) +(~~ 1.333 (ln (* e (cbrt e)))) +(~~ 0.667 (ln (/ e (cbrt e)))) + +(~~ 0 (log10 1)) +(~~ 1 (log10 10)) +(~~ 2 (log10 (* 10 10))) +(~~ 1.5 (log10 (* 10 (sqrt 10)))) +(~~ 1.333 (log10 (* 10 (cbrt 10)))) +(~~ 0.667 (log10 (/ 10 (cbrt 10)))) + +(~~ 0 (log2 1)) +(~~ 1 (log2 2)) +(~~ 2 (log2 (* 2 2))) +(~~ 1.5 (log2 (* 2 (sqrt 2)))) +(~~ 1.333 (log2 (* 2 (cbrt 2)))) +(~~ 0.667 (log2 (/ 2 (cbrt 2)))) + +(~~ (pow e 1.5) (exp 1.5)) diff --git a/tests/test.lisp b/tests/test.lisp index 797fc0d..9b0f888 100644 --- a/tests/test.lisp +++ b/tests/test.lisp @@ -4,6 +4,7 @@ (load (format "%s.lisp" ,name)) (princ (format "[PASS] %s\n" ,name)))) +(test-module "math") (test-module "eq") (test-module "arithmetic") (test-module "error") -- cgit v1.0