diff --git a/analyser/module_analyser.c2 b/analyser/module_analyser.c2 index 0ce6e9384..09b99f2f3 100644 --- a/analyser/module_analyser.c2 +++ b/analyser/module_analyser.c2 @@ -28,6 +28,7 @@ import label_vector local; import name_vector local; import src_loc local; import scope; +import string_buffer; import string_pool; import warning_flags; import struct_func_list as sf_list; @@ -507,7 +508,24 @@ fn void Analyser.handleStaticAssert(void* arg, StaticAssert* d) { Value val2 = ast.evalExpr(rhs); if (!val1.is_equal(&val2)) { - ma.errorRange(rhs.getStartLoc(), rhs.getRange(), "static_assert failed, expected %s, got %s", val1.str(), val2.str()); + char[256] tmp; + string_buffer.Buf buf.init(tmp, elemsof(tmp), false, false, 0); + lhs.printLiteral(&buf); + buf.add(" == "); + rhs.printLiteral(&buf); + if (lhs.isLiteral()) { + if (rhs.isLiteral()) { + ma.errorRange(rhs.getStartLoc(), rhs.getRange(), "static_assert failed: %s", buf.data()); + } else { + ma.errorRange(rhs.getStartLoc(), rhs.getRange(), "static_assert failed: %s, got %s", buf.data(), val2.str()); + } + } else { + if (rhs.isLiteral()) { + ma.errorRange(lhs.getStartLoc(), lhs.getRange(), "static_assert failed: %s, got %s", buf.data(), val1.str()); + } else { + ma.errorRange(lhs.getStartLoc(), lhs.getRange(), "static_assert failed: %s, got %s and %s", buf.data(), val1.str(), val2.str()); + } + } } } diff --git a/ast/expr.c2 b/ast/expr.c2 index 6fb570837..57cebb216 100644 --- a/ast/expr.c2 +++ b/ast/expr.c2 @@ -200,6 +200,28 @@ public fn bool Expr.isStringLiteral(const Expr* e) { return e.getKind() == ExprKind.StringLiteral; } +public fn bool Expr.isLiteral(const Expr* e) { + switch (e.getKind()) { + case IntegerLiteral: + case FloatLiteral: + case BooleanLiteral: + case CharLiteral: + case StringLiteral: + case Nil: + return true; + default: + break; + } + return false; +} + +public fn bool Expr.isComparison(const Expr* e) { + if (e.getKind() != BinaryOperator) return false; + BinaryOperator* binop = (BinaryOperator*)e; + BinaryOpcode opcode = binop.getOpcode(); + return opcode.isComparison(); +} + public fn bool Expr.isNil(const Expr* e) { return e.getKind() == ExprKind.Nil; } diff --git a/generator/c/c_generator.c2 b/generator/c/c_generator.c2 index d943f679e..e9f1c8222 100644 --- a/generator/c/c_generator.c2 +++ b/generator/c/c_generator.c2 @@ -1433,10 +1433,18 @@ const char[] C_defines = const char[] C2_assert = ```c + #define va_list __builtin_va_list + #define va_start __builtin_va_start + #define va_end __builtin_va_end int dprintf(int fd, const char *format, ...); + int vdprintf(int fd, const char *format, va_list args); void abort(void); - static int c2_assert(const char* filename, int line, const char* funcname, const char* condstr) { - dprintf(2, "%s:%d: function %s: Assertion failed: %s\n", filename, line, funcname, condstr); + static int c2_assert(const char* filename, int line, const char* funcname, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + dprintf(2, "%s:%d: function %s: assertion failed: ", filename, line, funcname); + vdprintf(2, fmt, args); + dprintf(2, "\n"); abort(); return 0; } diff --git a/generator/c/c_generator_stmt.c2 b/generator/c/c_generator_stmt.c2 index f23f01fb7..0894d96cf 100644 --- a/generator/c/c_generator_stmt.c2 +++ b/generator/c/c_generator_stmt.c2 @@ -237,26 +237,114 @@ fn void Generator.emitStmt(Generator* gen, Stmt* s, u32 indent, bool newline) { gen.emitAsmStmt((AsmStmt*)s, indent); break; case Assert: - if (!gen.enable_asserts) out.print(";//assert"); AssertStmt* a = (AssertStmt*)s; - out.add1('('); - Expr* inner = a.getInner(); - gen.emitExpr(out, inner); - out.add1(')'); - if (gen.enable_asserts) { - source_mgr.Location loc = gen.sm.locate(s.getLoc()); - const char* funcname = gen.cur_function.asDecl().getFullName(); - out.print(" || c2_assert(\"%s\", %d, \"%s\", \"", loc.filename, loc.line, funcname); - // encode expression as a string - string_buffer.Buf* str = string_buffer.create(128, false, 0); - inner.printLiteral(str); - out.encodeBytes(str.data(), str.size(), '"'); - str.free(); - out.add("\")"); + if (!gen.enable_asserts) { + out.print(";//assert"); + out.add1('('); + Expr* inner = a.getInner(); + gen.emitExpr(out, inner); + out.add1(')'); + break; } - out.add(";\n"); + gen.emitAssertStmt(a, indent); + break; + } +} + +fn const char* get_type_format(Expr* e) { + QualType qt = e.getType(); + QualType canon = qt.getCanonicalType(); + if (canon.isPointer() || canon.isFunction()) return "p"; + const Type* t = canon.getTypeOrNil(); + if (canon.isEnum()) { + // output numeric value + EnumType* et = (EnumType*)t; + canon = et.getImplType(); + t = canon.getTypeOrNil(); + } + const BuiltinType* bi = (BuiltinType*)t; + switch (bi.getBaseKind()) { + case Char: + case Int8: + case Int16: + case Int32: + return "d"; + case Int64: + return "ld"; + case UInt8: + case UInt16: + case UInt32: + return "u"; + case UInt64: + return "lu"; + case Float32: + case Float64: + return "g"; + case Bool: + return "d"; + case ISize: // not a base kind + case USize: // not a base kind + case Void: break; } + return nil; +} + +// encode expression as a string +fn void encode_expression(string_buffer.Buf* out, Expr* e) { + char[128] tmp; + string_buffer.Buf buf.init(tmp, elemsof(tmp), false, false, 0); + e.printLiteral(&buf); + const char* s = buf.data(); + for (const char* p = s;; p++) { + if (!*p || *p == '%') { + out.encodeBytes(s, (u32)(p - s), '"'); + if (!*p) break; + if (*p == '%') out.add("%%"); + } + } +} + +fn void Generator.emitAssertStmt(Generator* gen, AssertStmt* a, u32 indent) { + string_buffer.Buf* out = gen.out; + source_mgr.Location loc = gen.sm.locate(((Stmt*)a).getLoc()); + const char* funcname = gen.cur_function.asDecl().getFullName(); + Expr* inner = a.getInner(); + out.add1('('); + gen.emitExpr(out, inner); + out.print(") || c2_assert(\"%s\", %d, \"%s\", \"", loc.filename, loc.line, funcname); + encode_expression(out, inner); + if (inner.isComparison()) { + BinaryOperator* b = (BinaryOperator*)inner; + Expr* lhs = b.getLHS(); + Expr* rhs = b.getRHS(); + const char* fmt1 = lhs.isLiteral() ? nil : get_type_format(lhs); + const char* fmt2 = rhs.isLiteral() ? nil : get_type_format(rhs); + if (fmt1) { + out.print(", "); + encode_expression(out, lhs); + out.print(": %%%s", fmt1); + } + if (fmt2) { + out.print(", "); + encode_expression(out, rhs); + out.print(": %%%s", fmt2); + } + out.add1('"'); + if (fmt1) { + out.add(", "); + if (*fmt1 == 'p') out.add("(void*)"); + gen.emitExpr(out, lhs); + } + if (fmt2) { + out.add(", "); + if (*fmt2 == 'p') out.add("(void*)"); + gen.emitExpr(out, rhs); + } + } else { + out.add1('"'); + } + out.add(");\n"); } fn void emitAsmPart(string_buffer.Buf* out, bool multi_line, u32 indent) { diff --git a/test/c_generator/stmts/assert.c2t b/test/c_generator/stmts/assert.c2t index b5c284aef..2c1b17d07 100644 --- a/test/c_generator/stmts/assert.c2t +++ b/test/c_generator/stmts/assert.c2t @@ -6,11 +6,16 @@ module test; i32 a = 10; +i32 b = 20; void* p = nil; public fn i32 main() { - assert(a); assert(p); + assert(a); + assert(a == 10); + assert(a != 10); + assert(a > 10); + assert(a == b); return 0; } @@ -19,8 +24,12 @@ int32_t main(void); int32_t main(void) { - (test_a) || c2_assert("file1.c2", 7, "test.main", "a"); (test_p) || c2_assert("file1.c2", 8, "test.main", "p"); + (test_a) || c2_assert("file1.c2", 9, "test.main", "a"); + (test_a == 10) || c2_assert("file1.c2", 10, "test.main", "a == 10, a: %d", test_a); + (test_a != 10) || c2_assert("file1.c2", 11, "test.main", "a != 10, a: %d", test_a); + (test_a > 10) || c2_assert("file1.c2", 12, "test.main", "a > 10, a: %d", test_a); + (test_a == test_b) || c2_assert("file1.c2", 13, "test.main", "a == b, a: %d, b: %d", test_a, test_b); return 0; } diff --git a/test/globals/static_asserts/static_assert_int.c2 b/test/globals/static_asserts/static_assert_int.c2 index 300613ce9..56e83127f 100644 --- a/test/globals/static_asserts/static_assert_int.c2 +++ b/test/globals/static_asserts/static_assert_int.c2 @@ -1,5 +1,17 @@ // @warnings{no-unused} module test; -static_assert(3, 4); // @error{static_assert failed, expected 3, got 4} +static_assert(3, 4); // @error{static_assert failed: 3 == 4} +static_assert(0xff, 0b11111110); // @error{static_assert failed: 0xff == 0b11111110} +type S struct { + i32 x, y; +} + +type T struct { + u32 x; +} + +static_assert(sizeof(S), 4); // @error{static_assert failed: sizeof(S) == 4, got 8} +static_assert(4, sizeof(S)); // @error{static_assert failed: 4 == sizeof(S), got 8} +static_assert(sizeof(T), sizeof(S)); // @error{static_assert failed: sizeof(T) == sizeof(S), got 4 and 8}