Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 64 additions & 57 deletions pkg/corset/compiler/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,22 @@ import (
// Thus, you should never expect to see duplicate module names in the returned
// array.
func ParseSourceFiles(files []source.File, enforceTypes bool) (ast.Circuit, *source.Maps[ast.Node], []SyntaxError) {
var circuit ast.Circuit
// (for now) at most one error per source file is supported.
var errors []SyntaxError
// Construct an initially empty source map
srcmaps := source.NewSourceMaps[ast.Node]()
// Contents map holds the combined fragments of each module.
contents := make(map[string]ast.Module, 0)
// Names identifies the names of each unique module.
names := make([]string, 0)
var (
circuit ast.Circuit
// (for now) at most one error per source file is supported.
errors []SyntaxError
// Construct an initially empty source map
srcmaps = source.NewSourceMaps[ast.Node]()
// Contents map holds the combined fragments of each module.
contents = make(map[string]ast.Module, 0)
// Names identifies the names of each unique module.
names = make([]string, 0)
// Handles used for detecting duplicate constraint handles
handles = make(map[string]bool)
)
//
for _, file := range files {
c, srcmap, errs := ParseSourceFile(file, enforceTypes)
c, srcmap, errs := parseSourceFile(file, enforceTypes, handles)
// Handle errors
if len(errs) > 0 {
// Report any errors encountered
Expand Down Expand Up @@ -94,10 +98,9 @@ func ParseSourceFiles(files []source.File, enforceTypes bool) (ast.Circuit, *sou
return circuit, srcmaps, nil
}

// ParseSourceFile parses the contents of a single lisp file into one or more
// modules. Observe that every lisp file starts in the "prelude" or "root"
// module, and may declare items for additional modules as necessary.
func ParseSourceFile(srcfile source.File, enforceTypes bool) (ast.Circuit, *source.Map[ast.Node], []SyntaxError) {
func parseSourceFile(srcfile source.File, enforceTypes bool,
handles map[string]bool) (ast.Circuit, *source.Map[ast.Node], []SyntaxError) {
//
//
var (
circuit ast.Circuit
Expand All @@ -111,7 +114,7 @@ func ParseSourceFile(srcfile source.File, enforceTypes bool) (ast.Circuit, *sour
return circuit, nil, []SyntaxError{*err}
}
// Construct parser for corset syntax
p := NewParser(srcfile, srcmap, enforceTypes)
p := NewParser(srcfile, srcmap, enforceTypes, handles)
// Parse whatever is declared at the beginning of the file before the first
// module declaration. These declarations form part of the "prelude".
if circuit.Declarations, terms, errors = p.parseModuleContents(path, terms); len(errors) > 0 {
Expand Down Expand Up @@ -150,6 +153,8 @@ func ParseSourceFile(srcfile source.File, enforceTypes bool) (ast.Circuit, *sour
// ensuring that expressions are well-typed, etc) --- that is left up to the
// compiler.
type Parser struct {
// Handles map used to detect duplicate handles
handles map[string]bool
// Translator used for recursive expressions.
translator *sexp.Translator[ast.Expr]
// Mapping from constructed S-Expressions to their spans in the original text.
Expand All @@ -160,12 +165,14 @@ type Parser struct {

// NewParser constructs a new parser using a given mapping from S-Expressions to
// spans in the underlying source file.
func NewParser(srcfile source.File, srcmap *source.Map[sexp.SExp], enforceTypes bool) *Parser {
func NewParser(srcfile source.File, srcmap *source.Map[sexp.SExp], enforceTypes bool, handles map[string]bool,
) *Parser {
//
p := sexp.NewTranslator[ast.Expr](&srcfile, srcmap)
// Construct (initially empty) node map
nodemap := source.NewSourceMap[ast.Node](srcmap.Source())
// Construct parser
parser := &Parser{p, nodemap, enforceTypes}
parser := &Parser{handles, p, nodemap, enforceTypes}
// Configure expression translator
p.AddSymbolRule(constantParserRule)
p.AddSymbolRule(varAccessParserRule)
Expand Down Expand Up @@ -215,9 +222,8 @@ func (p *Parser) parseModuleContents(path file.Path, terms []sexp.SExp) ([]ast.D
[]SyntaxError) {
//
var (
errors []SyntaxError
handles = make(map[string]bool)
decls = make([]ast.Declaration, 0)
errors []SyntaxError
decls = make([]ast.Declaration, 0)
)
//
for i, s := range terms {
Expand All @@ -228,7 +234,7 @@ func (p *Parser) parseModuleContents(path file.Path, terms []sexp.SExp) ([]ast.D
errors = append(errors, *err)
} else if e.MatchSymbols(2, "module") {
return decls, terms[i:], errors
} else if decl, errs := p.parseDeclaration(path, e, handles); len(errs) > 0 {
} else if decl, errs := p.parseDeclaration(path, e); len(errs) > 0 {
errors = append(errors, errs...)
} else {
// Continue accumulating declarations for this module.
Expand Down Expand Up @@ -268,8 +274,7 @@ func (p *Parser) parseModuleStart(s sexp.SExp) (string, []SyntaxError) {
return name, errors
}

func (p *Parser) parseDeclaration(module file.Path, s *sexp.List,
handles map[string]bool) (ast.Declaration, []SyntaxError) {
func (p *Parser) parseDeclaration(module file.Path, s *sexp.List) (ast.Declaration, []SyntaxError) {
//
var (
decl ast.Declaration
Expand All @@ -285,7 +290,7 @@ func (p *Parser) parseDeclaration(module file.Path, s *sexp.List,
} else if s.Len() > 1 && s.MatchSymbols(1, "defconst") {
decl, errors = p.parseDefConst(module, s.Elements)
} else if s.Len() == 4 && s.MatchSymbols(2, "defconstraint") {
decl, errors = p.parseDefConstraint(module, s.Elements, handles)
decl, errors = p.parseDefConstraint(module, s.Elements)
} else if s.Len() == 3 && s.MatchSymbols(1, "defpurefun") {
decl, errors = p.parseDefFun(module, true, s.Elements)
} else if s.Len() == 3 && s.MatchSymbols(1, "defun") {
Expand All @@ -295,21 +300,21 @@ func (p *Parser) parseDeclaration(module file.Path, s *sexp.List,
} else if s.Len() == 3 && s.MatchSymbols(1, "definterleaved") {
decl, errors = p.parseDefInterleaved(module, s.Elements)
} else if s.Len() == 4 && s.MatchSymbols(1, "deflookup") {
decl, errors = p.parseDefLookup(s.Elements, handles)
decl, errors = p.parseDefLookup(module, s.Elements)
} else if (s.Len() == 5 || s.Len() == 6) && s.MatchSymbols(1, "defclookup") {
decl, errors = p.parseDefConditionalLookup(s.Elements, handles)
decl, errors = p.parseDefConditionalLookup(module, s.Elements)
} else if s.Len() == 4 && s.MatchSymbols(1, "defmlookup") {
decl, errors = p.parseDefMultiLookup(s.Elements, handles)
decl, errors = p.parseDefMultiLookup(module, s.Elements)
} else if s.Len() == 3 && s.MatchSymbols(2, "defpermutation") {
decl, errors = p.parseDefPermutation(module, s.Elements)
} else if s.Len() == 4 && s.MatchSymbols(2, "defperspective") {
decl, errors = p.parseDefPerspective(module, s.Elements)
} else if 3 <= s.Len() && s.Len() <= 4 && s.MatchSymbols(2, "defproperty") {
decl, errors = p.parseDefProperty(s.Elements, handles)
decl, errors = p.parseDefProperty(module, s.Elements)
} else if s.Len() == 3 && s.MatchSymbols(2, "defsorted") {
decl, errors = p.parseDefSorted(false, s.Elements, handles)
decl, errors = p.parseDefSorted(module, false, s.Elements)
} else if 3 <= s.Len() && s.Len() <= 4 && s.MatchSymbols(2, "defstrictsorted") {
decl, errors = p.parseDefSorted(true, s.Elements, handles)
decl, errors = p.parseDefSorted(module, true, s.Elements)
} else if s.Len() == 3 && s.MatchSymbols(1, "defcomputedcolumn") {
decl, errors = p.parseDefComputedColumn(module, s.Elements)
} else {
Expand Down Expand Up @@ -761,23 +766,21 @@ func (p *Parser) parseDefConstHead(head sexp.SExp) (*sexp.Symbol, ast.Type, bool
}

// Parse a vanishing declaration
func (p *Parser) parseDefConstraint(module file.Path, elements []sexp.SExp,
handles map[string]bool) (ast.Declaration, []SyntaxError) {
var (
errors []SyntaxError
handle string
)
func (p *Parser) parseDefConstraint(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
var errors []SyntaxError
// Initial sanity checks
if !isIdentifier(elements[1]) {
return nil, p.translator.SyntaxErrors(elements[1], "invalid constraint handle")
} else {
handle = elements[1].AsSymbol().Value
}
//
handle := elements[1].AsSymbol().Value
// Generate qualified name
qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle)
// Check for duplicate
if _, ok := handles[handle]; ok {
if _, ok := p.handles[qualifiedHandle]; ok {
return nil, p.translator.SyntaxErrors(elements[1], "duplicate handle")
} else {
handles[handle] = true
p.handles[qualifiedHandle] = true
}
// Vanishing constraints do not have global scope, hence qualified column
// accesses are not permitted.
Expand Down Expand Up @@ -869,9 +872,9 @@ func (p *Parser) parseDefInterleavedSourceArray(source *sexp.Array) (ast.TypedSy
}

// Parse a lookup declaration
func (p *Parser) parseDefLookup(elements []sexp.SExp, handles map[string]bool) (ast.Declaration, []SyntaxError) {
func (p *Parser) parseDefLookup(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
// Extract items
handle, checked, errors := p.parseLookupHandle(elements[1], handles)
handle, checked, errors := p.parseLookupHandle(module, elements[1])
targets, tgtErrors := p.parseDefLookupSources("target", elements[2])
sources, srcErrors := p.parseDefLookupSources("source", elements[3])
// Combine any and all errors
Expand All @@ -897,16 +900,15 @@ func (p *Parser) parseDefLookup(elements []sexp.SExp, handles map[string]bool) (
}

// Parse a conditional lookup declaration
func (p *Parser) parseDefConditionalLookup(elements []sexp.SExp,
handles map[string]bool) (ast.Declaration, []SyntaxError) {
func (p *Parser) parseDefConditionalLookup(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
// Extract items
var (
targets, sources []ast.Expr
targetSelector, sourceSelector ast.Expr
errs1, errs2, errs3, errs4 []SyntaxError
)
//
handle, checked, errors := p.parseLookupHandle(elements[1], handles)
handle, checked, errors := p.parseLookupHandle(module, elements[1])
//
if len(elements) == 6 {
targetSelector, errs1 = p.translator.Translate(elements[2])
Expand Down Expand Up @@ -942,9 +944,9 @@ func (p *Parser) parseDefConditionalLookup(elements []sexp.SExp,
[][]ast.Expr{targets}), nil
}

func (p *Parser) parseDefMultiLookup(elements []sexp.SExp, handles map[string]bool) (ast.Declaration, []SyntaxError) {
func (p *Parser) parseDefMultiLookup(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
// Extract items
handle, checked, errors := p.parseLookupHandle(elements[1], handles)
handle, checked, errors := p.parseLookupHandle(module, elements[1])
m, targets, tgtErrors := p.parseDefLookupMultiSources("target", elements[2])
n, sources, srcErrors := p.parseDefLookupMultiSources("source", elements[3])
// Combine any and all errors
Expand All @@ -966,7 +968,7 @@ func (p *Parser) parseDefMultiLookup(elements []sexp.SExp, handles map[string]bo
return ast.NewDefLookup(handle, checked, sourceSelectors, sources, targetSelectors, targets), nil
}

func (p *Parser) parseLookupHandle(element sexp.SExp, handles map[string]bool) (string, bool, []SyntaxError) {
func (p *Parser) parseLookupHandle(module file.Path, element sexp.SExp) (string, bool, []SyntaxError) {
var (
checked = true
errors []SyntaxError
Expand All @@ -981,12 +983,14 @@ func (p *Parser) parseLookupHandle(element sexp.SExp, handles map[string]bool) (
} else {
handle = element.AsSymbol().Value
}
// Generate qualified name
qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle)
// Check for duplicate handle
if _, ok := handles[handle]; ok {
if _, ok := p.handles[qualifiedHandle]; ok {
return "", checked, p.translator.SyntaxErrors(element, "duplicate handle")
} else if len(errors) == 0 {
// Done
handles[handle] = true
p.handles[qualifiedHandle] = true
}
//
return handle, checked, errors
Expand Down Expand Up @@ -1149,8 +1153,7 @@ func (p *Parser) parseDefPermutation(module file.Path, elements []sexp.SExp) (as
}

// Parse a permutation declaration
func (p *Parser) parseDefSorted(strict bool, elements []sexp.SExp,
handles map[string]bool) (ast.Declaration, []SyntaxError) {
func (p *Parser) parseDefSorted(module file.Path, strict bool, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
//
var (
selector util.Option[ast.Expr]
Expand Down Expand Up @@ -1182,12 +1185,14 @@ func (p *Parser) parseDefSorted(strict bool, elements []sexp.SExp,
} else {
handle = elements[1].AsSymbol().Value
}
// Generate qualified name
qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Nil Pointer Panic in Definition Parsing

Potential nil pointer dereference in parseDefSorted. When isIdentifier(elements[1]) returns false, an error is appended but execution continues to line 1185 where elements[1].AsSymbol().Value is called. Since isIdentifier returns false when AsSymbol() returns nil, this will panic. The handle extraction should be inside an else block like in parseDefProperty, or the function should return early like in parseDefConstraint.

Fix in Cursor Fix in Web

// Check for duplicate handle
if _, ok := handles[handle]; ok {
if _, ok := p.handles[qualifiedHandle]; ok {
return nil, p.translator.SyntaxErrors(elements[1], "duplicate handle")
} else if len(errors) == 0 {
// Record handle
handles[handle] = true
p.handles[qualifiedHandle] = true
}
// Check source Expressions
if sexpSources == nil {
Expand Down Expand Up @@ -1327,7 +1332,7 @@ func (p *Parser) parseDefPerspective(module file.Path, elements []sexp.SExp) (as
}

// Parse a property assertion
func (p *Parser) parseDefProperty(elements []sexp.SExp, handles map[string]bool) (ast.Declaration, []SyntaxError) {
func (p *Parser) parseDefProperty(module file.Path, elements []sexp.SExp) (ast.Declaration, []SyntaxError) {
var (
errors []SyntaxError
assertion int = len(elements) - 1
Expand All @@ -1341,12 +1346,14 @@ func (p *Parser) parseDefProperty(elements []sexp.SExp, handles map[string]bool)
// Extract handle
handle = elements[1].AsSymbol().Value
}
// Generate qualified name
qualifiedHandle := fmt.Sprintf("%s:%s", module.String(), handle)
// Check for duplicate handle
if _, ok := handles[handle]; ok {
if _, ok := p.handles[qualifiedHandle]; ok {
return nil, p.translator.SyntaxErrors(elements[1], "duplicate handle")
} else if len(errors) == 0 {
// Record handle
handles[handle] = true
p.handles[qualifiedHandle] = true
}
// Check for any attributes
if len(elements) > 3 {
Expand Down
7 changes: 6 additions & 1 deletion pkg/test/corset_invalid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ func Test_Invalid_Basic_18(t *testing.T) {
func Test_Invalid_Basic_19(t *testing.T) {
checkCorsetInvalid(t, "corset/invalid/basic_invalid_19")
}

func Test_Invalid_Basic_20(t *testing.T) {
checkCorsetInvalid(t, "corset/invalid/basic_invalid_20")
}
func Test_Invalid_Logic_01(t *testing.T) {
checkCorsetInvalid(t, "corset/invalid/logic_invalid_01")
}
Expand Down Expand Up @@ -515,6 +517,9 @@ func Test_Invalid_Lookup_17(t *testing.T) {
func Test_Invalid_Lookup_18(t *testing.T) {
checkCorsetInvalid(t, "corset/invalid/lookup_invalid_18")
}
func Test_Invalid_Lookup_19(t *testing.T) {
checkCorsetInvalid(t, "corset/invalid/lookup_invalid_19")
}

// ===================================================================
// Interleavings
Expand Down
8 changes: 8 additions & 0 deletions testdata/corset/invalid/basic_invalid_20.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
;;error:8:16-20:duplicate handle
(module m1)

(defcolumns (X :i16) (Y :i16))
(defconstraint test () (== X Y))

(module m1)
(defconstraint test () (!= X Y))
7 changes: 7 additions & 0 deletions testdata/corset/invalid/lookup_invalid_19.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
;;error:7:13-17:duplicate handle
(module m1)
(defcolumns (X :i16) (Y :i16))
(defclookup test (Y) 1 (X))

(module m1)
(defclookup test (X) 1 (Y))
Loading