Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 4 additions & 297 deletions pkg/ir/air/gadgets/bitwidth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/consensys/go-corset/pkg/util/collection/hash"
"github.com/consensys/go-corset/pkg/util/field"
"github.com/consensys/go-corset/pkg/util/source/sexp"
"github.com/consensys/go-corset/pkg/util/word"
)

// BitwidthGadget is a general-purpose mechanism for enforcing type constraints
Expand All @@ -41,19 +40,14 @@ type BitwidthGadget[F field.Element[F]] struct {
// translated into AIR range constraints, versus using a horizontal
// bitwidth gadget.
maxRangeConstraint uint
// Enables the use of type proofs which exploit the
// limitless prover. Specifically, modules with a recursive structure are
// created specifically for the purpose of checking types.
limitless bool
// Schema into which constraints are placed.
schema *air.SchemaBuilder[F]
}

// NewBitwidthGadget constructs a new bitwidth gadget.
func NewBitwidthGadget[F field.Element[F]](schema *air.SchemaBuilder[F]) *BitwidthGadget[F] {
return &BitwidthGadget[F]{
maxRangeConstraint: 8,
limitless: false,
maxRangeConstraint: 16,
schema: schema,
}
}
Expand All @@ -64,12 +58,6 @@ func (p *BitwidthGadget[F]) WithMaxRangeConstraint(width uint) *BitwidthGadget[F
return p
}

// WithLimitless enables or disables use of limitless type proofs.
func (p *BitwidthGadget[F]) WithLimitless(flag bool) *BitwidthGadget[F] {
p.limitless = flag
return p
}

// Constrain ensures all values in a given register fit within a given bitwidth.
func (p *BitwidthGadget[F]) Constrain(ref register.Ref, bitwidth uint) {
var (
Expand All @@ -92,12 +80,8 @@ func (p *BitwidthGadget[F]) Constrain(ref register.Ref, bitwidth uint) {
[]*term.RegisterAccess[F, air.Term[F]]{access}, []uint{bitwidth}))
// Done
return
case p.limitless:
p.applyRecursiveBitwidthGadget(ref, bitwidth)
default:
// NOTE: this should be deprecated once the limitless prover is well
// established.
p.applyHorizontalBitwidthGadget(ref, bitwidth)
p.applyRecursiveBitwidthGadget(ref, bitwidth)
}
}

Expand Down Expand Up @@ -139,32 +123,6 @@ func (p *BitwidthGadget[F]) applyBinaryGadget(ref register.Ref) {
air.NewVanishingConstraint(handle, module.Id(), util.None[int](), X_X_m1))
}

// ApplyHorizontalBitwidthGadget ensures all values in a given column fit within
// a given bitwidth. This is implemented using a *horizontal byte
// decomposition* which adds n columns and a vanishing constraint (where n*8 >=
// bitwidth).
func (p *BitwidthGadget[F]) applyHorizontalBitwidthGadget(ref register.Ref, bitwidth uint) {
var (
module = p.schema.Module(ref.Module())
reg = module.Register(ref.Register())
lookupHandle = fmt.Sprintf("%s:u%d", reg.Name, bitwidth)
)
// Allocate computed byte registers in the given module, and add required
// range constraints.
byteRegisters := allocateByteRegisters(reg.Name, bitwidth, module)
// Build up the decomposition sum
sum := buildDecompositionTerm[F](bitwidth, byteRegisters)
// Construct X == (X:0 * 1) + ... + (X:n * 2^n)
X := term.NewRegisterAccess[F, air.Term[F]](ref.Register(), reg.Width, 0)
//
eq := term.Subtract(X, sum)
// Construct column name
module.AddConstraint(
air.NewVanishingConstraint(lookupHandle, module.Id(), util.None[int](), eq))
// Add decomposition assignment
module.AddAssignment(&byteDecomposition[F]{reg.Name, bitwidth, ref, byteRegisters})
}

// ApplyRecursiveBitwidthGadget ensures all values in a given column fit within
// a given bitwidth. This is implemented using a combination of reference tables
// and lookups. Specifically, if the width is below 16bits, then a static
Expand Down Expand Up @@ -357,117 +315,7 @@ func (p *typeDecomposition[F]) Lisp(schema sc.AnySchema[F]) sexp.SExp {
}

// ============================================================================
// Byte Decomposition Assignment
// ============================================================================

// byteDecomposition is part of a range constraint for wide columns (e.g. u32)
// implemented using a byte decomposition.
type byteDecomposition[F field.Element[F]] struct {
// Handle for identifying this assignment
handle string
// Width of decomposition.
bitwidth uint
// The source register being decomposed
source register.Ref
// Target registers holding the decomposition
targets []register.Ref
}

// Compute computes the values of columns defined by this assignment.
// This requires computing the value of each byte column in the decomposition.
func (p *byteDecomposition[F]) Compute(tr trace.Trace[F], schema sc.AnySchema[F],
) ([]array.MutArray[F], error) {
var n = uint(len(p.targets))
// Read inputs
sources := assignment.ReadRegistersRef(tr, p.source)
// Apply native function
data := byteDecompositionNativeFunction(n, sources, tr.Builder())
//
return data, nil
}

// Bounds determines the well-definedness bounds for this assignment for both
// the negative (left) or positive (right) directions. For example, consider an
// expression such as "(shift X -1)". This is technically undefined for the
// first row of any trace and, by association, any constraint evaluating this
// expression on that first row is also undefined (and hence must pass).
func (p *byteDecomposition[F]) Bounds(_ sc.ModuleId) util.Bounds {
return util.EMPTY_BOUND
}

// Consistent performs some simple checks that the given schema is consistent.
// This provides a double check of certain key properties, such as that
// registers used for assignments are large enough, etc.
func (p *byteDecomposition[F]) Consistent(schema sc.AnySchema[F]) []error {
var (
bitwidth = schema.Register(p.source).Width
total = uint(0)
errors []error
)
//
for _, ref := range p.targets {
reg := schema.Module(ref.Module()).Register(ref.Register())
total += reg.Width
}
//
if total != bitwidth {
err := fmt.Errorf("inconsistent byte decomposition (decomposed %d bits, but expected %d)", total, bitwidth)
errors = append(errors, err)
}
//
return errors
}

// RegistersExpanded identifies registers expanded by this assignment.
func (p *byteDecomposition[F]) RegistersExpanded() []register.Ref {
return nil
}

// RegistersRead returns the set of columns that this assignment depends upon.
// That can include both input columns, as well as other computed columns.
func (p *byteDecomposition[F]) RegistersRead() []register.Ref {
return []register.Ref{p.source}
}

// RegistersWritten identifies registers assigned by this assignment.
func (p *byteDecomposition[F]) RegistersWritten() []register.Ref {
return p.targets
}

// Substitute any matchined labelled constants within this assignment
func (p *byteDecomposition[F]) Substitute(mapping map[string]F) {
// Nothing to do here.
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (p *byteDecomposition[F]) Lisp(schema sc.AnySchema[F]) sexp.SExp {
var (
srcModule = schema.Module(p.source.Module())
source = srcModule.Register(p.source.Register())
targets = sexp.EmptyList()
)
//
for _, t := range p.targets {
tgtModule := schema.Module(t.Module())
reg := tgtModule.Register(t.Register())
targets.Append(sexp.NewList([]sexp.SExp{
// name
sexp.NewSymbol(reg.QualifiedName(tgtModule)),
// type
sexp.NewSymbol(fmt.Sprintf("u%d", reg.Width)),
}))
}

return sexp.NewList(
[]sexp.SExp{sexp.NewSymbol("decompose"),
targets,
sexp.NewSymbol(source.QualifiedName(srcModule)),
})
}

// ============================================================================
// Helpers (for recursive)
// Helpers
// ============================================================================

// Determine the split of limbs for the given bitwidth. For example, 33bits
Expand Down Expand Up @@ -564,7 +412,7 @@ func decompose[F field.Element[F]](loWidth uint, ith F) (F, F) {
)
// Sanity check assumption
if loWidth%8 != 0 {
panic("unreachable")
panic(fmt.Sprintf("unreachable (u%d)", loWidth))
}
//
if loByteWidth >= n {
Expand All @@ -580,144 +428,3 @@ func decompose[F field.Element[F]](loWidth uint, ith F) (F, F) {
//
return loFr, hiFr
}

// ============================================================================
// Helpers (for horizontal)
// ============================================================================

// Allocate n byte registers, each of which requires a suitable range
// constraint.
func allocateByteRegisters[F field.Element[F]](prefix string, bitwidth uint, module *air.ModuleBuilder[F],
) []register.Ref {
var (
n = bitwidth / 8
zero big.Int
)
//
if bitwidth == 0 {
panic("zero byte decomposition encountered")
}
// Account for asymetric case
if bitwidth%8 != 0 {
n++
}
// Allocate target register ids
targets := make([]register.Ref, n)
// Allocate byte registers
for i := uint(0); i < n; i++ {
name := fmt.Sprintf("%s:%d", prefix, i)
byteRegister := register.NewComputed(name, min(8, bitwidth), zero)
// Allocate byte register
rid := module.NewRegister(byteRegister)
targets[i] = register.NewRef(module.Id(), rid)
// Add suitable range constraint
ith_access := term.RawRegisterAccess[F, air.Term[F]](rid, byteRegister.Width, 0)
//
module.AddConstraint(
air.NewRangeConstraint(name, module.Id(),
[]*term.RegisterAccess[F, air.Term[F]]{ith_access},
[]uint{byteRegister.Width}))
//
bitwidth -= 8
}
//
return targets
}

func buildDecompositionTerm[F field.Element[F]](bitwidth uint, byteRegisters []register.Ref) air.Term[F] {
var (
// Determine ranges required for the give bitwidth
ranges = splitColumnRanges[F](bitwidth)
// Initialise array of terms
terms = make([]air.Term[F], len(byteRegisters))
// Initialise coefficient
coefficient F = field.One[F]()
)
// Construct Columns
for i, ref := range byteRegisters {
// Create Column + Constraint
reg := term.NewRegisterAccess[F, air.Term[F]](ref.Register(), 8, 0)
terms[i] = term.Product(reg, term.Const[F, air.Term[F]](coefficient))
// Update coefficient
coefficient = coefficient.Mul(ranges[i])
}
// Construct (X:0 * 1) + ... + (X:n * 2^n)
return term.Sum(terms...)
}

func splitColumnRanges[F field.Element[F]](nbits uint) []F {
var (
n = nbits / 8
m = nbits % 8
ranges []F
// FIXME: following fails for very small fields like GF251!
two8 F = field.Uint64[F](256)
)
//
if m == 0 {
ranges = make([]F, n)
} else {
// Most significant column has smaller range.
ranges = make([]F, n+1)
// Determine final range
ranges[n] = field.TwoPowN[F](m)
}
//
for i := range n {
ranges[i] = two8
}
//
return ranges
}

func byteDecompositionNativeFunction[F field.Element[F]](n uint, sources []array.Array[F],
builder array.Builder[F]) []array.MutArray[F] {
//
var (
source = sources[0]
targets = make([]array.MutArray[F], n)
height = source.Len()
)
// Sanity check
if len(sources) != 1 {
panic("too many source columns for byte decomposition")
}
// Initialise columns
for i := range n {
// Construct a byte array for ith byte
targets[i] = builder.NewArray(height, 8)
}
// Decompose each row of each column
for i := range height {
ith := decomposeIntoBytes(source.Get(i), n)
for j := uint(0); j < n; j++ {
targets[j].Set(i, ith[j])
}
}
//
return targets
}

// Decompose a given element into n bytes in little endian form. For example,
// decomposing 41b into 2 bytes gives [0x1b,0x04].
func decomposeIntoBytes[W word.Word[W]](val W, n uint) []W {
// Construct return array
elements := make([]W, n)
// Determine bytes of this value (in big endian form).
bytes := val.Bytes()
//
l := uint(len(bytes))
//
m := min(n, l)
// Convert each byte into a field element
for i := range m {
var (
ith W
j = l - i - 1
)
//
elements[i] = ith.SetUint64(uint64(bytes[j]))
}
// Done
return elements
}
Loading
Loading