Skip to content

Commit e487520

Browse files
authored
feat: function call declaration (#1329)
This adds support for parsing defcall declarations, along with an appropriate AST node. This adds support for the resolution, processing and typing stages for the DefCall construct. An initial implementation of translation is also included, but this does not yet actually translate the call. This adds support for checking arguments and returns. In particular, check that there are enough arguments and returns, and also that their bitwidths follow the expected subtyping pattern. Condition function calls are supported where a (logical) selector is provided. Function calls are encoded / decoded at the HIR level.
1 parent 58b4d75 commit e487520

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+7039
-73
lines changed

pkg/asm/propagate.go

Lines changed: 159 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ import (
1919

2020
"github.com/consensys/go-corset/pkg/asm/io"
2121
"github.com/consensys/go-corset/pkg/ir"
22+
"github.com/consensys/go-corset/pkg/ir/hir"
2223
sc "github.com/consensys/go-corset/pkg/schema"
2324
"github.com/consensys/go-corset/pkg/schema/module"
2425
"github.com/consensys/go-corset/pkg/schema/register"
26+
"github.com/consensys/go-corset/pkg/trace"
2527
"github.com/consensys/go-corset/pkg/trace/lt"
2628
"github.com/consensys/go-corset/pkg/util/collection/array"
27-
"github.com/consensys/go-corset/pkg/util/field"
2829
"github.com/consensys/go-corset/pkg/util/word"
2930
)
3031

@@ -38,8 +39,8 @@ type RawModule = lt.Module[word.BigEndian]
3839
// Validation?
3940
// Batch size?
4041
// Recursion limit (to prevent infinite loops)
41-
func PropagateAll[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p MixedProgram[F, T, M], ts []lt.TraceFile,
42-
expanding bool) ([]lt.TraceFile, []error) {
42+
func PropagateAll[T io.Instruction[T], M sc.Module[word.BigEndian]](p MixedProgram[word.BigEndian, T, M],
43+
ts []lt.TraceFile, expanding bool) ([]lt.TraceFile, []error) {
4344
//
4445
var (
4546
errors []error
@@ -76,12 +77,12 @@ func PropagateAll[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p Mix
7677
// Validation?
7778
// Batch size?
7879
// Recursion limit (to prevent infinite loops)
79-
func Propagate[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p MixedProgram[F, T, M], trace lt.TraceFile,
80-
expanding bool) (lt.TraceFile, []error) {
80+
func Propagate[T io.Instruction[T], M sc.Module[word.BigEndian]](p MixedProgram[word.BigEndian, T, M],
81+
trace lt.TraceFile, expanding bool) (lt.TraceFile, []error) {
8182
// Construct suitable executior for the given program
8283
var (
8384
errors []error
84-
n = len(p.program.Functions())
85+
n = uint(len(p.program.Functions()))
8586
//
8687
executor = io.NewExecutor(p.program)
8788
// Clone heap in trace file, since will mutate this.
@@ -94,7 +95,7 @@ func Propagate[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p MixedP
9495
return lt.TraceFile{}, errors
9596
}
9697
// Write seed instances
97-
errors = writeInstances(p.program, trace.Modules[:n], executor)
98+
errors = writeInstances(p, n, trace.Modules, executor)
9899
// Read out generated instances
99100
modules := readInstances(&heap, p.program, executor)
100101
// Append external modules (which are unaffected by propagation).
@@ -106,15 +107,24 @@ func Propagate[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p MixedP
106107
// WriteInstances writes all of the instances defined in the given trace columns
107108
// into the executor which, in turn, forces it to execute the relevant
108109
// functions, and functions they call, etc.
109-
func writeInstances[T io.Instruction[T]](p io.Program[T], trace []lt.Module[word.BigEndian],
110-
executor *io.Executor[T]) []error {
110+
func writeInstances[T io.Instruction[T], M sc.Module[word.BigEndian]](p MixedProgram[word.BigEndian, T, M], n uint,
111+
trace []lt.Module[word.BigEndian], executor *io.Executor[T]) []error {
111112
//
112113
var errors []error
113-
//
114-
for i, m := range trace {
115-
errs := writeFunctionInstances(uint(i), p, m, executor)
114+
// Write all from assembly modules
115+
for i, m := range trace[:n] {
116+
errs := writeFunctionInstances(uint(i), p.program, m, executor)
116117
errors = append(errors, errs...)
117118
}
119+
// Write all from non-assembly modules
120+
for i, m := range trace[n:] {
121+
var extern = p.externs[i]
122+
// Write instances from any external calls
123+
for _, call := range extractExternalCalls(extern) {
124+
errs := writeExternCall(call, p.program, m, executor)
125+
errors = append(errors, errs...)
126+
}
127+
}
118128
//
119129
return errors
120130
}
@@ -145,6 +155,68 @@ func writeFunctionInstances[T io.Instruction[T]](fid uint, p io.Program[T], mod
145155
return errors
146156
}
147157

158+
// Extract any external function calls found within the given module, returning
159+
// them as an array.
160+
func extractExternalCalls[M sc.Module[word.BigEndian]](extern M) []hir.FunctionCall {
161+
var calls []hir.FunctionCall
162+
//
163+
for iter := extern.Constraints(); iter.HasNext(); {
164+
c := iter.Next()
165+
// This should always hold
166+
if hc, ok := c.(hir.Constraint); ok {
167+
// Check whether its a call or not
168+
if call, ok := hc.Unwrap().(hir.FunctionCall); ok {
169+
// Yes, so record it
170+
calls = append(calls, call)
171+
}
172+
}
173+
}
174+
//
175+
return calls
176+
}
177+
178+
// Write any function instances arising from the given call.
179+
func writeExternCall[T io.Instruction[T]](call hir.FunctionCall, p io.Program[T], mod RawModule,
180+
executor *io.Executor[T]) []error {
181+
//
182+
var (
183+
trMod = &ltModuleAdaptor{mod}
184+
height = mod.Height()
185+
fn = p.Function(call.Callee)
186+
inputs = make([]big.Int, fn.NumInputs())
187+
outputs = make([]big.Int, fn.NumOutputs())
188+
errors []error
189+
)
190+
//
191+
if call.Selector.HasValue() {
192+
var selector = call.Selector.Unwrap()
193+
// Invoke each user-defined instance in turn
194+
for i := range height {
195+
// execute if selector enabled
196+
if enabled, _, err := selector.TestAt(int(i), trMod, nil); enabled {
197+
// Extract external columns
198+
extractExternColumns(int(i), call, trMod, inputs, outputs)
199+
// Execute function call to produce outputs
200+
errs := executeAndCheck(call.Callee, fn.Name(), inputs, outputs, executor)
201+
errors = append(errors, errs...)
202+
} else if err != nil {
203+
errors = append(errors, err)
204+
}
205+
}
206+
} else {
207+
// Invoke each user-defined instance in turn
208+
for i := range height {
209+
// Extract external columns
210+
extractExternColumns(int(i), call, trMod, inputs, outputs)
211+
// Execute function call to produce outputs
212+
errs := executeAndCheck(call.Callee, fn.Name(), inputs, outputs, executor)
213+
errors = append(errors, errs...)
214+
}
215+
}
216+
//
217+
return errors
218+
}
219+
148220
func executeAndCheck[T io.Instruction[T]](fid uint, name module.Name, inputs, outputs []big.Int,
149221
executor *io.Executor[T]) []error {
150222
var (
@@ -198,6 +270,34 @@ func extractFunctionColumns(row uint, mod RawModule, inputs, outputs []big.Int)
198270
}
199271
}
200272

273+
func extractExternColumns(row int, call hir.FunctionCall, mod trace.Module[word.BigEndian],
274+
inputs, outputs []big.Int) []error {
275+
//
276+
// Extract function arguments
277+
errs1 := extractExternTerms(row, call.Arguments, mod, inputs)
278+
// Extract function returns
279+
errs2 := extractExternTerms(row, call.Returns, mod, outputs)
280+
//
281+
return append(errs1, errs2...)
282+
}
283+
284+
func extractExternTerms(row int, terms []hir.Term, mod trace.Module[word.BigEndian], values []big.Int) []error {
285+
var errors []error
286+
//
287+
for i, arg := range terms {
288+
var (
289+
ith big.Int
290+
val, err = arg.EvalAt(row, mod, nil)
291+
)
292+
ith.SetBytes(val.Bytes())
293+
values[i] = ith
294+
//
295+
errors = append(errors, err)
296+
}
297+
//
298+
return errors
299+
}
300+
201301
func extractFunctionPadding(registers []register.Register, inputs, outputs []big.Int) {
202302
var numInputs = len(inputs)
203303
//
@@ -282,3 +382,50 @@ func toArgumentString(args []big.Int) string {
282382
//
283383
return builder.String()
284384
}
385+
386+
// The purpose of the lt adaptor is to make an lt.TraceFile look like a Trace.
387+
// In general, this is not safe. However, we use this once we already know that
388+
// the trace has been aligned. Also, it is only used in a specific context.
389+
type ltModuleAdaptor struct {
390+
module lt.Module[word.BigEndian]
391+
}
392+
393+
func (p *ltModuleAdaptor) Name() trace.ModuleName {
394+
return p.module.Name
395+
}
396+
397+
func (p *ltModuleAdaptor) Width() uint {
398+
return uint(len(p.module.Columns))
399+
}
400+
401+
func (p *ltModuleAdaptor) Height() uint {
402+
return p.module.Height()
403+
}
404+
405+
func (p *ltModuleAdaptor) Column(cid uint) trace.Column[word.BigEndian] {
406+
return &ltColumnAdaptor{p.module.Columns[cid]}
407+
}
408+
409+
func (p *ltModuleAdaptor) ColumnOf(col string) trace.Column[word.BigEndian] {
410+
panic("unsupported operation")
411+
}
412+
413+
type ltColumnAdaptor struct {
414+
column lt.Column[word.BigEndian]
415+
}
416+
417+
func (p *ltColumnAdaptor) Name() string {
418+
return p.column.Name
419+
}
420+
421+
func (p *ltColumnAdaptor) Get(row int) word.BigEndian {
422+
return p.column.Data.Get(uint(row))
423+
}
424+
425+
func (p *ltColumnAdaptor) Data() array.Array[word.BigEndian] {
426+
return p.column.Data
427+
}
428+
429+
func (p *ltColumnAdaptor) Padding() word.BigEndian {
430+
panic("unsupported operation")
431+
}

pkg/corset/ast/declaration.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,97 @@ func (p *DefAlias) Lisp() sexp.SExp {
147147
return sexp.NewSymbol(p.Name)
148148
}
149149

150+
// ============================================================================
151+
// defcall
152+
// ============================================================================
153+
154+
// DefCall captures a function call between a lisp module and an assembly
155+
// function. A key feature of this is that it triggers trace propagation.
156+
type DefCall struct {
157+
// Returns for the call
158+
Returns []Expr
159+
// Function being called
160+
Function string
161+
// Arguments for the call
162+
Arguments []Expr
163+
// Optional source selector
164+
Selector util.Option[Expr]
165+
// determines whether or not this has been finalised.
166+
finalised bool
167+
}
168+
169+
// NewDefCall creates a new (unfinalised) function call.
170+
func NewDefCall(returns []Expr, fun string, args []Expr, selector util.Option[Expr]) *DefCall {
171+
//
172+
return &DefCall{returns, fun, args, selector, false}
173+
}
174+
175+
// Definitions returns the set of symbols defined by this declaration. Observe
176+
// that these may not yet have been finalised.
177+
func (p *DefCall) Definitions() iter.Iterator[SymbolDefinition] {
178+
return iter.NewArrayIterator[SymbolDefinition](nil)
179+
}
180+
181+
// Dependencies needed to signal declaration.
182+
func (p *DefCall) Dependencies() iter.Iterator[Symbol] {
183+
var deps []Symbol
184+
//
185+
deps = append(deps, DependenciesOfExpressions(p.Arguments)...)
186+
deps = append(deps, DependenciesOfExpressions(p.Returns)...)
187+
// Include selector dependencies (if applicable)
188+
if p.Selector.HasValue() {
189+
deps = append(deps, p.Selector.Unwrap().Dependencies()...)
190+
}
191+
// Combine deps
192+
return iter.NewArrayIterator(deps)
193+
}
194+
195+
// Defines checks whether this declaration defines the given symbol. The symbol
196+
// in question needs to have been resolved already for this to make sense.
197+
func (p *DefCall) Defines(symbol Symbol) bool {
198+
return false
199+
}
200+
201+
// IsFinalised checks whether this declaration has already been finalised. If
202+
// so, then we don't need to finalise it again.
203+
func (p *DefCall) IsFinalised() bool {
204+
return p.finalised
205+
}
206+
207+
// Finalise this declaration, which means that all source and target expressions
208+
// have been resolved.
209+
func (p *DefCall) Finalise() {
210+
p.finalised = true
211+
}
212+
213+
// Lisp converts this node into its lisp representation. This is primarily used
214+
// for debugging purposes.
215+
func (p *DefCall) Lisp() sexp.SExp {
216+
returns := make([]sexp.SExp, len(p.Returns))
217+
args := make([]sexp.SExp, len(p.Arguments))
218+
// Returns
219+
for i, t := range p.Returns {
220+
returns[i] = t.Lisp()
221+
}
222+
// Arguments
223+
for i, t := range p.Arguments {
224+
args[i] = t.Lisp()
225+
}
226+
//
227+
list := sexp.NewList([]sexp.SExp{
228+
sexp.NewSymbol("defcall"),
229+
sexp.NewList(returns),
230+
sexp.NewSymbol(p.Function),
231+
sexp.NewList(args),
232+
})
233+
// Include selector (if applicable)
234+
if p.Selector.HasValue() {
235+
list.Append(p.Selector.Unwrap().Lisp())
236+
}
237+
//
238+
return list
239+
}
240+
150241
// ============================================================================
151242
// defcolumns
152243
// ============================================================================

pkg/corset/ast/expression.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ type For struct {
524524
// NewFor constructs a new for-expression given a variable name, a static index
525525
// range and a body.
526526
func NewFor(name string, start uint, end uint, body Expr) *For {
527-
binding := NewLocalVariableBinding(name, INT_TYPE)
527+
binding := NewLocalVariableBinding(name, UINT_TYPE)
528528
return &For{binding, start, end, body}
529529
}
530530

@@ -714,7 +714,7 @@ func NewLet(bindings []util.Pair[string, Expr], body Expr) *Let {
714714
exprs := make([]Expr, len(bindings))
715715
//
716716
for i, p := range bindings {
717-
vars[i] = NewLocalVariableBinding(p.Left, INT_TYPE)
717+
vars[i] = NewLocalVariableBinding(p.Left, UINT_TYPE)
718718
exprs[i] = p.Right
719719
}
720720
//

pkg/corset/ast/type.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ func (p *AnyType) String() string {
136136
// IntType
137137
// ============================================================================
138138

139-
// INT_TYPE represents the infinite integer range. This cannot be translated
139+
// UINT_TYPE represents the infinite integer range. This cannot be translated
140140
// into a concrete type at the lower level, and therefore can only be used
141141
// internally (e.g. for type checking).
142-
var INT_TYPE = &IntType{math.INFINITY}
142+
var UINT_TYPE = &IntType{math.INFINITY}
143143

144144
// IntType represents a set of signed integer values.
145145
type IntType struct {

pkg/corset/compiler/intrinsics.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func (p *IntrinsicDefinition) Signature() *ast.FunctionSignature {
9191
types := make([]ast.Type, p.arity)
9292
//
9393
for i := 0; i < len(types); i++ {
94-
types[i] = ast.INT_TYPE
94+
types[i] = ast.UINT_TYPE
9595
}
9696
// Allow return type to be inferred.
9797
return ast.NewFunctionSignature(true, types, nil, body)

0 commit comments

Comments
 (0)