diff --git a/pkg/ir/air/gadgets/bitwidth.go b/pkg/ir/air/gadgets/bitwidth.go index a4b15b04..a2270d33 100644 --- a/pkg/ir/air/gadgets/bitwidth.go +++ b/pkg/ir/air/gadgets/bitwidth.go @@ -29,7 +29,6 @@ import ( "github.com/consensys/go-corset/pkg/util/collection/hash" "github.com/consensys/go-corset/pkg/util/field" "github.com/consensys/go-corset/pkg/util/source/sexp" - "github.com/consensys/go-corset/pkg/util/word" ) // BitwidthGadget is a general-purpose mechanism for enforcing type constraints @@ -41,10 +40,6 @@ type BitwidthGadget[F field.Element[F]] struct { // translated into AIR range constraints, versus using a horizontal // bitwidth gadget. maxRangeConstraint uint - // Enables the use of type proofs which exploit the - // limitless prover. Specifically, modules with a recursive structure are - // created specifically for the purpose of checking types. - limitless bool // Schema into which constraints are placed. schema *air.SchemaBuilder[F] } @@ -52,8 +47,7 @@ type BitwidthGadget[F field.Element[F]] struct { // NewBitwidthGadget constructs a new bitwidth gadget. func NewBitwidthGadget[F field.Element[F]](schema *air.SchemaBuilder[F]) *BitwidthGadget[F] { return &BitwidthGadget[F]{ - maxRangeConstraint: 8, - limitless: false, + maxRangeConstraint: 16, schema: schema, } } @@ -64,12 +58,6 @@ func (p *BitwidthGadget[F]) WithMaxRangeConstraint(width uint) *BitwidthGadget[F return p } -// WithLimitless enables or disables use of limitless type proofs. -func (p *BitwidthGadget[F]) WithLimitless(flag bool) *BitwidthGadget[F] { - p.limitless = flag - return p -} - // Constrain ensures all values in a given register fit within a given bitwidth. func (p *BitwidthGadget[F]) Constrain(ref register.Ref, bitwidth uint) { var ( @@ -92,12 +80,8 @@ func (p *BitwidthGadget[F]) Constrain(ref register.Ref, bitwidth uint) { []*term.RegisterAccess[F, air.Term[F]]{access}, []uint{bitwidth})) // Done return - case p.limitless: - p.applyRecursiveBitwidthGadget(ref, bitwidth) default: - // NOTE: this should be deprecated once the limitless prover is well - // established. - p.applyHorizontalBitwidthGadget(ref, bitwidth) + p.applyRecursiveBitwidthGadget(ref, bitwidth) } } @@ -139,32 +123,6 @@ func (p *BitwidthGadget[F]) applyBinaryGadget(ref register.Ref) { air.NewVanishingConstraint(handle, module.Id(), util.None[int](), X_X_m1)) } -// ApplyHorizontalBitwidthGadget ensures all values in a given column fit within -// a given bitwidth. This is implemented using a *horizontal byte -// decomposition* which adds n columns and a vanishing constraint (where n*8 >= -// bitwidth). -func (p *BitwidthGadget[F]) applyHorizontalBitwidthGadget(ref register.Ref, bitwidth uint) { - var ( - module = p.schema.Module(ref.Module()) - reg = module.Register(ref.Register()) - lookupHandle = fmt.Sprintf("%s:u%d", reg.Name, bitwidth) - ) - // Allocate computed byte registers in the given module, and add required - // range constraints. - byteRegisters := allocateByteRegisters(reg.Name, bitwidth, module) - // Build up the decomposition sum - sum := buildDecompositionTerm[F](bitwidth, byteRegisters) - // Construct X == (X:0 * 1) + ... + (X:n * 2^n) - X := term.NewRegisterAccess[F, air.Term[F]](ref.Register(), reg.Width, 0) - // - eq := term.Subtract(X, sum) - // Construct column name - module.AddConstraint( - air.NewVanishingConstraint(lookupHandle, module.Id(), util.None[int](), eq)) - // Add decomposition assignment - module.AddAssignment(&byteDecomposition[F]{reg.Name, bitwidth, ref, byteRegisters}) -} - // ApplyRecursiveBitwidthGadget ensures all values in a given column fit within // a given bitwidth. This is implemented using a combination of reference tables // and lookups. Specifically, if the width is below 16bits, then a static @@ -357,117 +315,7 @@ func (p *typeDecomposition[F]) Lisp(schema sc.AnySchema[F]) sexp.SExp { } // ============================================================================ -// Byte Decomposition Assignment -// ============================================================================ - -// byteDecomposition is part of a range constraint for wide columns (e.g. u32) -// implemented using a byte decomposition. -type byteDecomposition[F field.Element[F]] struct { - // Handle for identifying this assignment - handle string - // Width of decomposition. - bitwidth uint - // The source register being decomposed - source register.Ref - // Target registers holding the decomposition - targets []register.Ref -} - -// Compute computes the values of columns defined by this assignment. -// This requires computing the value of each byte column in the decomposition. -func (p *byteDecomposition[F]) Compute(tr trace.Trace[F], schema sc.AnySchema[F], -) ([]array.MutArray[F], error) { - var n = uint(len(p.targets)) - // Read inputs - sources := assignment.ReadRegistersRef(tr, p.source) - // Apply native function - data := byteDecompositionNativeFunction(n, sources, tr.Builder()) - // - return data, nil -} - -// Bounds determines the well-definedness bounds for this assignment for both -// the negative (left) or positive (right) directions. For example, consider an -// expression such as "(shift X -1)". This is technically undefined for the -// first row of any trace and, by association, any constraint evaluating this -// expression on that first row is also undefined (and hence must pass). -func (p *byteDecomposition[F]) Bounds(_ sc.ModuleId) util.Bounds { - return util.EMPTY_BOUND -} - -// Consistent performs some simple checks that the given schema is consistent. -// This provides a double check of certain key properties, such as that -// registers used for assignments are large enough, etc. -func (p *byteDecomposition[F]) Consistent(schema sc.AnySchema[F]) []error { - var ( - bitwidth = schema.Register(p.source).Width - total = uint(0) - errors []error - ) - // - for _, ref := range p.targets { - reg := schema.Module(ref.Module()).Register(ref.Register()) - total += reg.Width - } - // - if total != bitwidth { - err := fmt.Errorf("inconsistent byte decomposition (decomposed %d bits, but expected %d)", total, bitwidth) - errors = append(errors, err) - } - // - return errors -} - -// RegistersExpanded identifies registers expanded by this assignment. -func (p *byteDecomposition[F]) RegistersExpanded() []register.Ref { - return nil -} - -// RegistersRead returns the set of columns that this assignment depends upon. -// That can include both input columns, as well as other computed columns. -func (p *byteDecomposition[F]) RegistersRead() []register.Ref { - return []register.Ref{p.source} -} - -// RegistersWritten identifies registers assigned by this assignment. -func (p *byteDecomposition[F]) RegistersWritten() []register.Ref { - return p.targets -} - -// Substitute any matchined labelled constants within this assignment -func (p *byteDecomposition[F]) Substitute(mapping map[string]F) { - // Nothing to do here. -} - -// Lisp converts this schema element into a simple S-Expression, for example -// so it can be printed. -func (p *byteDecomposition[F]) Lisp(schema sc.AnySchema[F]) sexp.SExp { - var ( - srcModule = schema.Module(p.source.Module()) - source = srcModule.Register(p.source.Register()) - targets = sexp.EmptyList() - ) - // - for _, t := range p.targets { - tgtModule := schema.Module(t.Module()) - reg := tgtModule.Register(t.Register()) - targets.Append(sexp.NewList([]sexp.SExp{ - // name - sexp.NewSymbol(reg.QualifiedName(tgtModule)), - // type - sexp.NewSymbol(fmt.Sprintf("u%d", reg.Width)), - })) - } - - return sexp.NewList( - []sexp.SExp{sexp.NewSymbol("decompose"), - targets, - sexp.NewSymbol(source.QualifiedName(srcModule)), - }) -} - -// ============================================================================ -// Helpers (for recursive) +// Helpers // ============================================================================ // Determine the split of limbs for the given bitwidth. For example, 33bits @@ -564,7 +412,7 @@ func decompose[F field.Element[F]](loWidth uint, ith F) (F, F) { ) // Sanity check assumption if loWidth%8 != 0 { - panic("unreachable") + panic(fmt.Sprintf("unreachable (u%d)", loWidth)) } // if loByteWidth >= n { @@ -580,144 +428,3 @@ func decompose[F field.Element[F]](loWidth uint, ith F) (F, F) { // return loFr, hiFr } - -// ============================================================================ -// Helpers (for horizontal) -// ============================================================================ - -// Allocate n byte registers, each of which requires a suitable range -// constraint. -func allocateByteRegisters[F field.Element[F]](prefix string, bitwidth uint, module *air.ModuleBuilder[F], -) []register.Ref { - var ( - n = bitwidth / 8 - zero big.Int - ) - // - if bitwidth == 0 { - panic("zero byte decomposition encountered") - } - // Account for asymetric case - if bitwidth%8 != 0 { - n++ - } - // Allocate target register ids - targets := make([]register.Ref, n) - // Allocate byte registers - for i := uint(0); i < n; i++ { - name := fmt.Sprintf("%s:%d", prefix, i) - byteRegister := register.NewComputed(name, min(8, bitwidth), zero) - // Allocate byte register - rid := module.NewRegister(byteRegister) - targets[i] = register.NewRef(module.Id(), rid) - // Add suitable range constraint - ith_access := term.RawRegisterAccess[F, air.Term[F]](rid, byteRegister.Width, 0) - // - module.AddConstraint( - air.NewRangeConstraint(name, module.Id(), - []*term.RegisterAccess[F, air.Term[F]]{ith_access}, - []uint{byteRegister.Width})) - // - bitwidth -= 8 - } - // - return targets -} - -func buildDecompositionTerm[F field.Element[F]](bitwidth uint, byteRegisters []register.Ref) air.Term[F] { - var ( - // Determine ranges required for the give bitwidth - ranges = splitColumnRanges[F](bitwidth) - // Initialise array of terms - terms = make([]air.Term[F], len(byteRegisters)) - // Initialise coefficient - coefficient F = field.One[F]() - ) - // Construct Columns - for i, ref := range byteRegisters { - // Create Column + Constraint - reg := term.NewRegisterAccess[F, air.Term[F]](ref.Register(), 8, 0) - terms[i] = term.Product(reg, term.Const[F, air.Term[F]](coefficient)) - // Update coefficient - coefficient = coefficient.Mul(ranges[i]) - } - // Construct (X:0 * 1) + ... + (X:n * 2^n) - return term.Sum(terms...) -} - -func splitColumnRanges[F field.Element[F]](nbits uint) []F { - var ( - n = nbits / 8 - m = nbits % 8 - ranges []F - // FIXME: following fails for very small fields like GF251! - two8 F = field.Uint64[F](256) - ) - // - if m == 0 { - ranges = make([]F, n) - } else { - // Most significant column has smaller range. - ranges = make([]F, n+1) - // Determine final range - ranges[n] = field.TwoPowN[F](m) - } - // - for i := range n { - ranges[i] = two8 - } - // - return ranges -} - -func byteDecompositionNativeFunction[F field.Element[F]](n uint, sources []array.Array[F], - builder array.Builder[F]) []array.MutArray[F] { - // - var ( - source = sources[0] - targets = make([]array.MutArray[F], n) - height = source.Len() - ) - // Sanity check - if len(sources) != 1 { - panic("too many source columns for byte decomposition") - } - // Initialise columns - for i := range n { - // Construct a byte array for ith byte - targets[i] = builder.NewArray(height, 8) - } - // Decompose each row of each column - for i := range height { - ith := decomposeIntoBytes(source.Get(i), n) - for j := uint(0); j < n; j++ { - targets[j].Set(i, ith[j]) - } - } - // - return targets -} - -// Decompose a given element into n bytes in little endian form. For example, -// decomposing 41b into 2 bytes gives [0x1b,0x04]. -func decomposeIntoBytes[W word.Word[W]](val W, n uint) []W { - // Construct return array - elements := make([]W, n) - // Determine bytes of this value (in big endian form). - bytes := val.Bytes() - // - l := uint(len(bytes)) - // - m := min(n, l) - // Convert each byte into a field element - for i := range m { - var ( - ith W - j = l - i - 1 - ) - // - elements[i] = ith.SetUint64(uint64(bytes[j])) - } - // Done - return elements -} diff --git a/pkg/ir/air/gadgets/lexicographic_sort.go b/pkg/ir/air/gadgets/lexicographic_sort.go index 78150816..8d3f4615 100644 --- a/pkg/ir/air/gadgets/lexicographic_sort.go +++ b/pkg/ir/air/gadgets/lexicographic_sort.go @@ -62,10 +62,6 @@ type LexicographicSortingGadget[F field.Element[F]] struct { // translated into AIR range constraints, versus using a horizontal // bitwidth gadget. maxRangeConstraint uint - // Enables the use of type proofs which exploit the - // limitless prover. Specifically, modules with a recursive structure are - // created specifically for the purpose of checking types. - limitless bool } // NewLexicographicSortingGadget constructs a default sorting gadget which can @@ -81,7 +77,7 @@ func NewLexicographicSortingGadget[F field.Element[F]](prefix string, columns [] } // return &LexicographicSortingGadget[F]{prefix, columns, signs, bitwidth, false, - term.Const64[F, air.Term[F]](1), 8, false} + term.Const64[F, air.Term[F]](1), 8} } // WithSigns configures the directions for all columns being sorted. @@ -113,12 +109,6 @@ func (p *LexicographicSortingGadget[F]) WithMaxRangeConstraint(width uint) *Lexi return p } -// WithLimitless enables or disables use of limitless type proofs. -func (p *LexicographicSortingGadget[F]) WithLimitless(flag bool) *LexicographicSortingGadget[F] { - p.limitless = flag - return p -} - // Apply this lexicographic sorting gadget to a given schema. func (p *LexicographicSortingGadget[F]) Apply(mid sc.ModuleId, schema *air.SchemaBuilder[F]) { var ( @@ -159,7 +149,6 @@ func (p *LexicographicSortingGadget[F]) Apply(mid sc.ModuleId, schema *air.Schem ref := register.NewRef(mid, deltaIndex) // Constrict gadget gadget := NewBitwidthGadget(schema). - WithLimitless(p.limitless). WithMaxRangeConstraint(p.maxRangeConstraint) // Apply bitwidth constraint gadget.Constrain(ref, p.bitwidth) @@ -194,7 +183,9 @@ func (p *LexicographicSortingGadget[F]) addLexicographicSelectorBits(deltaIndex for i := uint(0); i < ncols; i++ { ref := register.NewRef(mid, register.NewId(bitIndex+i)) // Add binarity constraints (i.e. to enforce that this column is a bit). - NewBitwidthGadget(schema).Constrain(ref, 1) + NewBitwidthGadget(schema). + WithMaxRangeConstraint(p.maxRangeConstraint). + Constrain(ref, 1) } // Apply constraints to ensure at most one is set. terms := make([]air.Term[F], ncols) diff --git a/pkg/ir/mir/lower.go b/pkg/ir/mir/lower.go index b9cd25e6..11186ff6 100644 --- a/pkg/ir/mir/lower.go +++ b/pkg/ir/mir/lower.go @@ -197,7 +197,6 @@ func (p *AirLowering[F]) lowerRangeConstraintToAir(v RangeConstraint[F], airModu ref := register.NewRef(airModule.Id(), e.Register()) // Construct gadget gadget := air_gadgets.NewBitwidthGadget(&p.airSchema). - WithLimitless(p.config.LimitlessTypeProofs). WithMaxRangeConstraint(p.config.MaxRangeConstraint) // gadget.Constrain(ref, v.Bitwidths[i]) @@ -302,7 +301,6 @@ func (p *AirLowering[F]) lowerSortedConstraintToAir(c SortedConstraint[F], airMo gadget := air_gadgets.NewLexicographicSortingGadget[F](c.Handle, sources, c.BitWidth). WithSigns(c.Signs...). WithStrictness(c.Strict). - WithLimitless(p.config.LimitlessTypeProofs). WithMaxRangeConstraint(p.config.MaxRangeConstraint) // Add (optional) selector if c.Selector.HasValue() { diff --git a/pkg/ir/mir/optimiser.go b/pkg/ir/mir/optimiser.go index 719e36ba..5a56c16a 100644 --- a/pkg/ir/mir/optimiser.go +++ b/pkg/ir/mir/optimiser.go @@ -29,10 +29,6 @@ type OptimisationConfig struct { // ShiftNormalisation is an optimisation for inverse columns involving // shifts. ShiftNormalisation bool - // LimitlessTypeProofs enables the use of type proofs which exploit the - // limitless prover. Specifically, modules with a recursive structure are - // created specifically for the purpose of checking types. - LimitlessTypeProofs bool } // OPTIMISATION_LEVELS provides a set of precanned optimisation configurations. @@ -41,9 +37,9 @@ type OptimisationConfig struct { // always improve performance). var OPTIMISATION_LEVELS = []OptimisationConfig{ // Level 0 == nothing enabled - {0, 8, false, false}, + {0, 16, false}, // Level 1 == minimal optimisations applied. - {1, 16, true, true}, + {1, 16, true}, } // DEFAULT_OPTIMISATION_INDEX gives the index of the default optimisation level diff --git a/pkg/test/assembly_bench_test.go b/pkg/test/assembly_bench_test.go index a16a6073..eab75993 100644 --- a/pkg/test/assembly_bench_test.go +++ b/pkg/test/assembly_bench_test.go @@ -32,11 +32,11 @@ func Test_AsmBench_Gas(t *testing.T) { } func Test_AsmBench_Shf(t *testing.T) { - util.Check(t, false, "asm/bench/shf") + util.CheckWithFields(t, false, "asm/bench/shf", util.ASM_MAX_PADDING, field.BLS12_377) } func Test_AsmBench_Stp(t *testing.T) { - util.Check(t, false, "asm/bench/stp") + util.CheckWithFields(t, false, "asm/bench/stp", util.ASM_MAX_PADDING, field.BLS12_377) } func Test_AsmBench_Trm(t *testing.T) {