@@ -19,12 +19,13 @@ import (
1919
2020 "github.com/consensys/go-corset/pkg/asm/io"
2121 "github.com/consensys/go-corset/pkg/ir"
22+ "github.com/consensys/go-corset/pkg/ir/hir"
2223 sc "github.com/consensys/go-corset/pkg/schema"
2324 "github.com/consensys/go-corset/pkg/schema/module"
2425 "github.com/consensys/go-corset/pkg/schema/register"
26+ "github.com/consensys/go-corset/pkg/trace"
2527 "github.com/consensys/go-corset/pkg/trace/lt"
2628 "github.com/consensys/go-corset/pkg/util/collection/array"
27- "github.com/consensys/go-corset/pkg/util/field"
2829 "github.com/consensys/go-corset/pkg/util/word"
2930)
3031
@@ -38,8 +39,8 @@ type RawModule = lt.Module[word.BigEndian]
3839// Validation?
3940// Batch size?
4041// Recursion limit (to prevent infinite loops)
41- func PropagateAll [F field. Element [ F ], T io.Instruction [T ], M sc.Module [F ]](p MixedProgram [F , T , M ], ts []lt. TraceFile ,
42- expanding bool ) ([]lt.TraceFile , []error ) {
42+ func PropagateAll [T io.Instruction [T ], M sc.Module [word. BigEndian ]](p MixedProgram [word. BigEndian , T , M ],
43+ ts []lt. TraceFile , expanding bool ) ([]lt.TraceFile , []error ) {
4344 //
4445 var (
4546 errors []error
@@ -76,12 +77,12 @@ func PropagateAll[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p Mix
7677// Validation?
7778// Batch size?
7879// Recursion limit (to prevent infinite loops)
79- func Propagate [F field. Element [ F ], T io.Instruction [T ], M sc.Module [F ]](p MixedProgram [F , T , M ], trace lt. TraceFile ,
80- expanding bool ) (lt.TraceFile , []error ) {
80+ func Propagate [T io.Instruction [T ], M sc.Module [word. BigEndian ]](p MixedProgram [word. BigEndian , T , M ],
81+ trace lt. TraceFile , expanding bool ) (lt.TraceFile , []error ) {
8182 // Construct suitable executior for the given program
8283 var (
8384 errors []error
84- n = len (p .program .Functions ())
85+ n = uint ( len (p .program .Functions () ))
8586 //
8687 executor = io .NewExecutor (p .program )
8788 // Clone heap in trace file, since will mutate this.
@@ -94,7 +95,7 @@ func Propagate[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p MixedP
9495 return lt.TraceFile {}, errors
9596 }
9697 // Write seed instances
97- errors = writeInstances (p . program , trace .Modules [: n ] , executor )
98+ errors = writeInstances (p , n , trace .Modules , executor )
9899 // Read out generated instances
99100 modules := readInstances (& heap , p .program , executor )
100101 // Append external modules (which are unaffected by propagation).
@@ -106,15 +107,24 @@ func Propagate[F field.Element[F], T io.Instruction[T], M sc.Module[F]](p MixedP
106107// WriteInstances writes all of the instances defined in the given trace columns
107108// into the executor which, in turn, forces it to execute the relevant
108109// functions, and functions they call, etc.
109- func writeInstances [T io.Instruction [T ]]( p io. Program [ T ], trace []lt .Module [word.BigEndian ],
110- executor * io.Executor [T ]) []error {
110+ func writeInstances [T io.Instruction [T ], M sc .Module [word.BigEndian ]]( p MixedProgram [word. BigEndian , T , M ], n uint ,
111+ trace []lt. Module [word. BigEndian ], executor * io.Executor [T ]) []error {
111112 //
112113 var errors []error
113- //
114- for i , m := range trace {
115- errs := writeFunctionInstances (uint (i ), p , m , executor )
114+ // Write all from assembly modules
115+ for i , m := range trace [: n ] {
116+ errs := writeFunctionInstances (uint (i ), p . program , m , executor )
116117 errors = append (errors , errs ... )
117118 }
119+ // Write all from non-assembly modules
120+ for i , m := range trace [n :] {
121+ var extern = p .externs [i ]
122+ // Write instances from any external calls
123+ for _ , call := range extractExternalCalls (extern ) {
124+ errs := writeExternCall (call , p .program , m , executor )
125+ errors = append (errors , errs ... )
126+ }
127+ }
118128 //
119129 return errors
120130}
@@ -145,6 +155,68 @@ func writeFunctionInstances[T io.Instruction[T]](fid uint, p io.Program[T], mod
145155 return errors
146156}
147157
158+ // Extract any external function calls found within the given module, returning
159+ // them as an array.
160+ func extractExternalCalls [M sc.Module [word.BigEndian ]](extern M ) []hir.FunctionCall {
161+ var calls []hir.FunctionCall
162+ //
163+ for iter := extern .Constraints (); iter .HasNext (); {
164+ c := iter .Next ()
165+ // This should always hold
166+ if hc , ok := c .(hir.Constraint ); ok {
167+ // Check whether its a call or not
168+ if call , ok := hc .Unwrap ().(hir.FunctionCall ); ok {
169+ // Yes, so record it
170+ calls = append (calls , call )
171+ }
172+ }
173+ }
174+ //
175+ return calls
176+ }
177+
178+ // Write any function instances arising from the given call.
179+ func writeExternCall [T io.Instruction [T ]](call hir.FunctionCall , p io.Program [T ], mod RawModule ,
180+ executor * io.Executor [T ]) []error {
181+ //
182+ var (
183+ trMod = & ltModuleAdaptor {mod }
184+ height = mod .Height ()
185+ fn = p .Function (call .Callee )
186+ inputs = make ([]big.Int , fn .NumInputs ())
187+ outputs = make ([]big.Int , fn .NumOutputs ())
188+ errors []error
189+ )
190+ //
191+ if call .Selector .HasValue () {
192+ var selector = call .Selector .Unwrap ()
193+ // Invoke each user-defined instance in turn
194+ for i := range height {
195+ // execute if selector enabled
196+ if enabled , _ , err := selector .TestAt (int (i ), trMod , nil ); enabled {
197+ // Extract external columns
198+ extractExternColumns (int (i ), call , trMod , inputs , outputs )
199+ // Execute function call to produce outputs
200+ errs := executeAndCheck (call .Callee , fn .Name (), inputs , outputs , executor )
201+ errors = append (errors , errs ... )
202+ } else if err != nil {
203+ errors = append (errors , err )
204+ }
205+ }
206+ } else {
207+ // Invoke each user-defined instance in turn
208+ for i := range height {
209+ // Extract external columns
210+ extractExternColumns (int (i ), call , trMod , inputs , outputs )
211+ // Execute function call to produce outputs
212+ errs := executeAndCheck (call .Callee , fn .Name (), inputs , outputs , executor )
213+ errors = append (errors , errs ... )
214+ }
215+ }
216+ //
217+ return errors
218+ }
219+
148220func executeAndCheck [T io.Instruction [T ]](fid uint , name module.Name , inputs , outputs []big.Int ,
149221 executor * io.Executor [T ]) []error {
150222 var (
@@ -198,6 +270,34 @@ func extractFunctionColumns(row uint, mod RawModule, inputs, outputs []big.Int)
198270 }
199271}
200272
273+ func extractExternColumns (row int , call hir.FunctionCall , mod trace.Module [word.BigEndian ],
274+ inputs , outputs []big.Int ) []error {
275+ //
276+ // Extract function arguments
277+ errs1 := extractExternTerms (row , call .Arguments , mod , inputs )
278+ // Extract function returns
279+ errs2 := extractExternTerms (row , call .Returns , mod , outputs )
280+ //
281+ return append (errs1 , errs2 ... )
282+ }
283+
284+ func extractExternTerms (row int , terms []hir.Term , mod trace.Module [word.BigEndian ], values []big.Int ) []error {
285+ var errors []error
286+ //
287+ for i , arg := range terms {
288+ var (
289+ ith big.Int
290+ val , err = arg .EvalAt (row , mod , nil )
291+ )
292+ ith .SetBytes (val .Bytes ())
293+ values [i ] = ith
294+ //
295+ errors = append (errors , err )
296+ }
297+ //
298+ return errors
299+ }
300+
201301func extractFunctionPadding (registers []register.Register , inputs , outputs []big.Int ) {
202302 var numInputs = len (inputs )
203303 //
@@ -282,3 +382,50 @@ func toArgumentString(args []big.Int) string {
282382 //
283383 return builder .String ()
284384}
385+
386+ // The purpose of the lt adaptor is to make an lt.TraceFile look like a Trace.
387+ // In general, this is not safe. However, we use this once we already know that
388+ // the trace has been aligned. Also, it is only used in a specific context.
389+ type ltModuleAdaptor struct {
390+ module lt.Module [word.BigEndian ]
391+ }
392+
393+ func (p * ltModuleAdaptor ) Name () trace.ModuleName {
394+ return p .module .Name
395+ }
396+
397+ func (p * ltModuleAdaptor ) Width () uint {
398+ return uint (len (p .module .Columns ))
399+ }
400+
401+ func (p * ltModuleAdaptor ) Height () uint {
402+ return p .module .Height ()
403+ }
404+
405+ func (p * ltModuleAdaptor ) Column (cid uint ) trace.Column [word.BigEndian ] {
406+ return & ltColumnAdaptor {p .module .Columns [cid ]}
407+ }
408+
409+ func (p * ltModuleAdaptor ) ColumnOf (col string ) trace.Column [word.BigEndian ] {
410+ panic ("unsupported operation" )
411+ }
412+
413+ type ltColumnAdaptor struct {
414+ column lt.Column [word.BigEndian ]
415+ }
416+
417+ func (p * ltColumnAdaptor ) Name () string {
418+ return p .column .Name
419+ }
420+
421+ func (p * ltColumnAdaptor ) Get (row int ) word.BigEndian {
422+ return p .column .Data .Get (uint (row ))
423+ }
424+
425+ func (p * ltColumnAdaptor ) Data () array.Array [word.BigEndian ] {
426+ return p .column .Data
427+ }
428+
429+ func (p * ltColumnAdaptor ) Padding () word.BigEndian {
430+ panic ("unsupported operation" )
431+ }
0 commit comments