From 7d2741c8ebd7a388891530ba1ee803aa29a7ed72 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Wed, 12 Nov 2025 09:59:37 +1300 Subject: [PATCH 1/2] add test cases This adds two (failing) tests for this issue. --- pkg/test/corset_invalid_test.go | 7 ++++++- testdata/corset/invalid/basic_invalid_20.lisp | 8 ++++++++ testdata/corset/invalid/lookup_invalid_19.lisp | 6 ++++++ 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 testdata/corset/invalid/basic_invalid_20.lisp create mode 100644 testdata/corset/invalid/lookup_invalid_19.lisp diff --git a/pkg/test/corset_invalid_test.go b/pkg/test/corset_invalid_test.go index 04ad10c9..871e6d5d 100644 --- a/pkg/test/corset_invalid_test.go +++ b/pkg/test/corset_invalid_test.go @@ -97,7 +97,9 @@ func Test_Invalid_Basic_18(t *testing.T) { func Test_Invalid_Basic_19(t *testing.T) { checkCorsetInvalid(t, "corset/invalid/basic_invalid_19") } - +func Test_Invalid_Basic_20(t *testing.T) { + checkCorsetInvalid(t, "corset/invalid/basic_invalid_20") +} func Test_Invalid_Logic_01(t *testing.T) { checkCorsetInvalid(t, "corset/invalid/logic_invalid_01") } @@ -515,6 +517,9 @@ func Test_Invalid_Lookup_17(t *testing.T) { func Test_Invalid_Lookup_18(t *testing.T) { checkCorsetInvalid(t, "corset/invalid/lookup_invalid_18") } +func Test_Invalid_Lookup_19(t *testing.T) { + checkCorsetInvalid(t, "corset/invalid/lookup_invalid_19") +} // =================================================================== // Interleavings diff --git a/testdata/corset/invalid/basic_invalid_20.lisp b/testdata/corset/invalid/basic_invalid_20.lisp new file mode 100644 index 00000000..d842e526 --- /dev/null +++ b/testdata/corset/invalid/basic_invalid_20.lisp @@ -0,0 +1,8 @@ +;; +(module m1) + +(defcolumns (X :i16) (Y :i16)) +(defconstraint test () (== X Y)) + +(module m1) +(defconstraint test () (!= X Y)) diff --git a/testdata/corset/invalid/lookup_invalid_19.lisp b/testdata/corset/invalid/lookup_invalid_19.lisp new file mode 100644 index 00000000..2b7bc5d0 --- /dev/null +++ b/testdata/corset/invalid/lookup_invalid_19.lisp @@ -0,0 +1,6 @@ +(module m1) +(defcolumns (X :i16) (Y :i16)) +(defclookup test (Y) 1 (X)) + +(module m1) +(defclookup test (X) 1 (Y)) From eaa5f223b1a9e8cdfa05fe37985cd132a6cc48a5 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Wed, 12 Nov 2025 10:53:07 +1300 Subject: [PATCH 2/2] fix: use global handle checking Previously, the algorithm for checking that two constraints do not have the same handle (in the same module) only worked when the constraints were defined within the same module block. However, if they were defined in different blocks (e.g. in different files) the duplicate handle was not reported as an error. To resolve this, a global map of all handles is now maintained which is shared across all source files and module blocks. --- pkg/corset/compiler/parser.go | 121 +++++++++--------- testdata/corset/invalid/basic_invalid_20.lisp | 2 +- .../corset/invalid/lookup_invalid_19.lisp | 1 + 3 files changed, 66 insertions(+), 58 deletions(-) diff --git a/pkg/corset/compiler/parser.go b/pkg/corset/compiler/parser.go index 7a6dce89..192a94da 100644 --- a/pkg/corset/compiler/parser.go +++ b/pkg/corset/compiler/parser.go @@ -42,18 +42,22 @@ import ( // Thus, you should never expect to see duplicate module names in the returned // array. func ParseSourceFiles(files []source.File, enforceTypes bool) (ast.Circuit, *source.Maps[ast.Node], []SyntaxError) { - var circuit ast.Circuit - // (for now) at most one error per source file is supported. - var errors []SyntaxError - // Construct an initially empty source map - srcmaps := source.NewSourceMaps[ast.Node]() - // Contents map holds the combined fragments of each module. - contents := make(map[string]ast.Module, 0) - // Names identifies the names of each unique module. - names := make([]string, 0) + var ( + circuit ast.Circuit + // (for now) at most one error per source file is supported. + errors []SyntaxError + // Construct an initially empty source map + srcmaps = source.NewSourceMaps[ast.Node]() + // Contents map holds the combined fragments of each module. + contents = make(map[string]ast.Module, 0) + // Names identifies the names of each unique module. + names = make([]string, 0) + // Handles used for detecting duplicate constraint handles + handles = make(map[string]bool) + ) // for _, file := range files { - c, srcmap, errs := ParseSourceFile(file, enforceTypes) + c, srcmap, errs := parseSourceFile(file, enforceTypes, handles) // Handle errors if len(errs) > 0 { // Report any errors encountered @@ -94,10 +98,9 @@ func ParseSourceFiles(files []source.File, enforceTypes bool) (ast.Circuit, *sou return circuit, srcmaps, nil } -// ParseSourceFile parses the contents of a single lisp file into one or more -// modules. Observe that every lisp file starts in the "prelude" or "root" -// module, and may declare items for additional modules as necessary. -func ParseSourceFile(srcfile source.File, enforceTypes bool) (ast.Circuit, *source.Map[ast.Node], []SyntaxError) { +func parseSourceFile(srcfile source.File, enforceTypes bool, + handles map[string]bool) (ast.Circuit, *source.Map[ast.Node], []SyntaxError) { + // // var ( circuit ast.Circuit @@ -111,7 +114,7 @@ func ParseSourceFile(srcfile source.File, enforceTypes bool) (ast.Circuit, *sour return circuit, nil, []SyntaxError{*err} } // Construct parser for corset syntax - p := NewParser(srcfile, srcmap, enforceTypes) + p := NewParser(srcfile, srcmap, enforceTypes, handles) // Parse whatever is declared at the beginning of the file before the first // module declaration. These declarations form part of the "prelude". if circuit.Declarations, terms, errors = p.parseModuleContents(path, terms); len(errors) > 0 { @@ -150,6 +153,8 @@ func ParseSourceFile(srcfile source.File, enforceTypes bool) (ast.Circuit, *sour // ensuring that expressions are well-typed, etc) --- that is left up to the // compiler. type Parser struct { + // Handles map used to detect duplicate handles + handles map[string]bool // Translator used for recursive expressions. translator *sexp.Translator[ast.Expr] // Mapping from constructed S-Expressions to their spans in the original text. @@ -160,12 +165,14 @@ type Parser struct { // NewParser constructs a new parser using a given mapping from S-Expressions to // spans in the underlying source file. -func NewParser(srcfile source.File, srcmap *source.Map[sexp.SExp], enforceTypes bool) *Parser { +func NewParser(srcfile source.File, srcmap *source.Map[sexp.SExp], enforceTypes bool, handles map[string]bool, +) *Parser { + // p := sexp.NewTranslator[ast.Expr](&srcfile, srcmap) // Construct (initially empty) node map nodemap := source.NewSourceMap[ast.Node](srcmap.Source()) // Construct parser - parser := &Parser{p, nodemap, enforceTypes} + parser := &Parser{handles, p, nodemap, enforceTypes} // Configure expression translator p.AddSymbolRule(constantParserRule) p.AddSymbolRule(varAccessParserRule) @@ -215,9 +222,8 @@ func (p *Parser) parseModuleContents(path file.Path, terms []sexp.SExp) ([]ast.D []SyntaxError) { // var ( - errors []SyntaxError - handles = make(map[string]bool) - decls = make([]ast.Declaration, 0) + errors []SyntaxError + decls = make([]ast.Declaration, 0) ) // for i, s := range terms { @@ -228,7 +234,7 @@ func (p *Parser) parseModuleContents(path file.Path, terms []sexp.SExp) ([]ast.D errors = append(errors, *err) } else if e.MatchSymbols(2, "module") { return decls, terms[i:], errors - } else if decl, errs := p.parseDeclaration(path, e, handles); len(errs) > 0 { + } else if decl, errs := p.parseDeclaration(path, e); len(errs) > 0 { errors = append(errors, errs...) } else { // Continue accumulating declarations for this module. @@ -268,8 +274,7 @@ func (p *Parser) parseModuleStart(s sexp.SExp) (string, []SyntaxError) { return name, errors } -func (p *Parser) parseDeclaration(module file.Path, s *sexp.List, - handles map[string]bool) (ast.Declaration, []SyntaxError) { +func (p *Parser) parseDeclaration(module file.Path, s *sexp.List) (ast.Declaration, []SyntaxError) { // var ( decl ast.Declaration @@ -285,7 +290,7 @@ func (p *Parser) parseDeclaration(module file.Path, s *sexp.List, } else if s.Len() > 1 && s.MatchSymbols(1, "defconst") { decl, errors = p.parseDefConst(module, s.Elements) } else if s.Len() == 4 && s.MatchSymbols(2, "defconstraint") { - decl, errors = p.parseDefConstraint(module, s.Elements, handles) + decl, errors = p.parseDefConstraint(module, s.Elements) } else if s.Len() == 3 && s.MatchSymbols(1, "defpurefun") { decl, errors = p.parseDefFun(module, true, s.Elements) } else if s.Len() == 3 && s.MatchSymbols(1, "defun") { @@ -295,21 +300,21 @@ func (p *Parser) parseDeclaration(module file.Path, s *sexp.List, } else if s.Len() == 3 && s.MatchSymbols(1, "definterleaved") { decl, errors = p.parseDefInterleaved(module, s.Elements) } else if s.Len() == 4 && s.MatchSymbols(1, "deflookup") { - decl, errors = p.parseDefLookup(s.Elements, handles) + decl, errors = p.parseDefLookup(module, s.Elements) } else if (s.Len() == 5 || s.Len() == 6) && s.MatchSymbols(1, "defclookup") { - decl, errors = p.parseDefConditionalLookup(s.Elements, handles) + decl, errors = p.parseDefConditionalLookup(module, s.Elements) } else if s.Len() == 4 && s.MatchSymbols(1, "defmlookup") { - decl, errors = p.parseDefMultiLookup(s.Elements, handles) + decl, errors = p.parseDefMultiLookup(module, s.Elements) } else if s.Len() == 3 && s.MatchSymbols(2, "defpermutation") { decl, errors = p.parseDefPermutation(module, s.Elements) } else if s.Len() == 4 && s.MatchSymbols(2, "defperspective") { decl, errors = p.parseDefPerspective(module, s.Elements) } else if 3 <= s.Len() && s.Len() <= 4 && s.MatchSymbols(2, "defproperty") { - decl, errors = p.parseDefProperty(s.Elements, handles) + decl, errors = p.parseDefProperty(module, s.Elements) } else if s.Len() == 3 && s.MatchSymbols(2, "defsorted") { - decl, errors = p.parseDefSorted(false, s.Elements, handles) + decl, errors = p.parseDefSorted(module, false, s.Elements) } else if 3 <= s.Len() && s.Len() <= 4 && s.MatchSymbols(2, "defstrictsorted") { - decl, errors = p.parseDefSorted(true, s.Elements, handles) + decl, errors = p.parseDefSorted(module, true, s.Elements) } else if s.Len() == 3 && s.MatchSymbols(1, "defcomputedcolumn") { decl, errors = p.parseDefComputedColumn(module, s.Elements) } else { @@ -761,23 +766,21 @@ func (p *Parser) parseDefConstHead(head sexp.SExp) (*sexp.Symbol, ast.Type, bool } // Parse a vanishing declaration -func (p *Parser) parseDefConstraint(module file.Path, elements []sexp.SExp, - handles map[string]bool) (ast.Declaration, []SyntaxError) { - var ( - errors []SyntaxError - handle string - ) +func (p *Parser) parseDefConstraint(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) { + var errors []SyntaxError // Initial sanity checks if !isIdentifier(elements[1]) { return nil, p.translator.SyntaxErrors(elements[1], "invalid constraint handle") - } else { - handle = elements[1].AsSymbol().Value } + // + handle := elements[1].AsSymbol().Value + // Generate qualified name + qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle) // Check for duplicate - if _, ok := handles[handle]; ok { + if _, ok := p.handles[qualifiedHandle]; ok { return nil, p.translator.SyntaxErrors(elements[1], "duplicate handle") } else { - handles[handle] = true + p.handles[qualifiedHandle] = true } // Vanishing constraints do not have global scope, hence qualified column // accesses are not permitted. @@ -869,9 +872,9 @@ func (p *Parser) parseDefInterleavedSourceArray(source *sexp.Array) (ast.TypedSy } // Parse a lookup declaration -func (p *Parser) parseDefLookup(elements []sexp.SExp, handles map[string]bool) (ast.Declaration, []SyntaxError) { +func (p *Parser) parseDefLookup(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) { // Extract items - handle, checked, errors := p.parseLookupHandle(elements[1], handles) + handle, checked, errors := p.parseLookupHandle(module, elements[1]) targets, tgtErrors := p.parseDefLookupSources("target", elements[2]) sources, srcErrors := p.parseDefLookupSources("source", elements[3]) // Combine any and all errors @@ -897,8 +900,7 @@ func (p *Parser) parseDefLookup(elements []sexp.SExp, handles map[string]bool) ( } // Parse a conditional lookup declaration -func (p *Parser) parseDefConditionalLookup(elements []sexp.SExp, - handles map[string]bool) (ast.Declaration, []SyntaxError) { +func (p *Parser) parseDefConditionalLookup(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) { // Extract items var ( targets, sources []ast.Expr @@ -906,7 +908,7 @@ func (p *Parser) parseDefConditionalLookup(elements []sexp.SExp, errs1, errs2, errs3, errs4 []SyntaxError ) // - handle, checked, errors := p.parseLookupHandle(elements[1], handles) + handle, checked, errors := p.parseLookupHandle(module, elements[1]) // if len(elements) == 6 { targetSelector, errs1 = p.translator.Translate(elements[2]) @@ -942,9 +944,9 @@ func (p *Parser) parseDefConditionalLookup(elements []sexp.SExp, [][]ast.Expr{targets}), nil } -func (p *Parser) parseDefMultiLookup(elements []sexp.SExp, handles map[string]bool) (ast.Declaration, []SyntaxError) { +func (p *Parser) parseDefMultiLookup(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) { // Extract items - handle, checked, errors := p.parseLookupHandle(elements[1], handles) + handle, checked, errors := p.parseLookupHandle(module, elements[1]) m, targets, tgtErrors := p.parseDefLookupMultiSources("target", elements[2]) n, sources, srcErrors := p.parseDefLookupMultiSources("source", elements[3]) // Combine any and all errors @@ -966,7 +968,7 @@ func (p *Parser) parseDefMultiLookup(elements []sexp.SExp, handles map[string]bo return ast.NewDefLookup(handle, checked, sourceSelectors, sources, targetSelectors, targets), nil } -func (p *Parser) parseLookupHandle(element sexp.SExp, handles map[string]bool) (string, bool, []SyntaxError) { +func (p *Parser) parseLookupHandle(module file.Path, element sexp.SExp) (string, bool, []SyntaxError) { var ( checked = true errors []SyntaxError @@ -981,12 +983,14 @@ func (p *Parser) parseLookupHandle(element sexp.SExp, handles map[string]bool) ( } else { handle = element.AsSymbol().Value } + // Generate qualified name + qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle) // Check for duplicate handle - if _, ok := handles[handle]; ok { + if _, ok := p.handles[qualifiedHandle]; ok { return "", checked, p.translator.SyntaxErrors(element, "duplicate handle") } else if len(errors) == 0 { // Done - handles[handle] = true + p.handles[qualifiedHandle] = true } // return handle, checked, errors @@ -1149,8 +1153,7 @@ func (p *Parser) parseDefPermutation(module file.Path, elements []sexp.SExp) (as } // Parse a permutation declaration -func (p *Parser) parseDefSorted(strict bool, elements []sexp.SExp, - handles map[string]bool) (ast.Declaration, []SyntaxError) { +func (p *Parser) parseDefSorted(module file.Path, strict bool, elements []sexp.SExp) (ast.Declaration, []SyntaxError) { // var ( selector util.Option[ast.Expr] @@ -1182,12 +1185,14 @@ func (p *Parser) parseDefSorted(strict bool, elements []sexp.SExp, } else { handle = elements[1].AsSymbol().Value } + // Generate qualified name + qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle) // Check for duplicate handle - if _, ok := handles[handle]; ok { + if _, ok := p.handles[qualifiedHandle]; ok { return nil, p.translator.SyntaxErrors(elements[1], "duplicate handle") } else if len(errors) == 0 { // Record handle - handles[handle] = true + p.handles[qualifiedHandle] = true } // Check source Expressions if sexpSources == nil { @@ -1327,7 +1332,7 @@ func (p *Parser) parseDefPerspective(module file.Path, elements []sexp.SExp) (as } // Parse a property assertion -func (p *Parser) parseDefProperty(elements []sexp.SExp, handles map[string]bool) (ast.Declaration, []SyntaxError) { +func (p *Parser) parseDefProperty(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) { var ( errors []SyntaxError assertion int = len(elements) - 1 @@ -1341,12 +1346,14 @@ func (p *Parser) parseDefProperty(elements []sexp.SExp, handles map[string]bool) // Extract handle handle = elements[1].AsSymbol().Value } + // Generate qualified name + qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle) // Check for duplicate handle - if _, ok := handles[handle]; ok { + if _, ok := p.handles[qualifiedHandle]; ok { return nil, p.translator.SyntaxErrors(elements[1], "duplicate handle") } else if len(errors) == 0 { // Record handle - handles[handle] = true + p.handles[qualifiedHandle] = true } // Check for any attributes if len(elements) > 3 { diff --git a/testdata/corset/invalid/basic_invalid_20.lisp b/testdata/corset/invalid/basic_invalid_20.lisp index d842e526..37311ba8 100644 --- a/testdata/corset/invalid/basic_invalid_20.lisp +++ b/testdata/corset/invalid/basic_invalid_20.lisp @@ -1,4 +1,4 @@ -;; +;;error:8:16-20:duplicate handle (module m1) (defcolumns (X :i16) (Y :i16)) diff --git a/testdata/corset/invalid/lookup_invalid_19.lisp b/testdata/corset/invalid/lookup_invalid_19.lisp index 2b7bc5d0..27185324 100644 --- a/testdata/corset/invalid/lookup_invalid_19.lisp +++ b/testdata/corset/invalid/lookup_invalid_19.lisp @@ -1,3 +1,4 @@ +;;error:7:13-17:duplicate handle (module m1) (defcolumns (X :i16) (Y :i16)) (defclookup test (Y) 1 (X))