Skip to content

Commit ce0186e

Browse files
authored
feat: add MulNoReduce and Sum methods in field emulation (#1072)
* feat: implement mulnoreduce * test: mulnoreduce test * docs: add method doc * feat: add AddMany * refactor: rename AddMany to Sum * feat: if only single input then return as is * test: non-native sum
1 parent 781de03 commit ce0186e

File tree

3 files changed

+141
-0
lines changed

3 files changed

+141
-0
lines changed

std/math/emulated/element_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package emulated
33
import (
44
"crypto/rand"
55
"fmt"
6+
"math"
67
"math/big"
78
"reflect"
89
"testing"
@@ -970,3 +971,91 @@ func testSqrt[T FieldParams](t *testing.T) {
970971
assert.ProverSucceeded(&SqrtCircuit[T]{}, &SqrtCircuit[T]{X: ValueOf[T](X), Expected: ValueOf[T](exp)}, test.WithCurves(testCurve), test.NoSerializationChecks(), test.WithBackends(backend.GROTH16, backend.PLONK))
971972
}, testName[T]())
972973
}
974+
975+
type MulNoReduceCircuit[T FieldParams] struct {
976+
A, B, C Element[T]
977+
expectedOverflow uint
978+
expectedNbLimbs int
979+
}
980+
981+
func (c *MulNoReduceCircuit[T]) Define(api frontend.API) error {
982+
f, err := NewField[T](api)
983+
if err != nil {
984+
return err
985+
}
986+
res := f.MulNoReduce(&c.A, &c.B)
987+
f.AssertIsEqual(res, &c.C)
988+
if res.overflow != c.expectedOverflow {
989+
return fmt.Errorf("unexpected overflow: got %d, expected %d", res.overflow, c.expectedOverflow)
990+
}
991+
if len(res.Limbs) != c.expectedNbLimbs {
992+
return fmt.Errorf("unexpected number of limbs: got %d, expected %d", len(res.Limbs), c.expectedNbLimbs)
993+
}
994+
return nil
995+
}
996+
997+
func TestMulNoReduce(t *testing.T) {
998+
testMulNoReduce[Goldilocks](t)
999+
testMulNoReduce[Secp256k1Fp](t)
1000+
testMulNoReduce[BN254Fp](t)
1001+
}
1002+
1003+
func testMulNoReduce[T FieldParams](t *testing.T) {
1004+
var fp T
1005+
assert := test.NewAssert(t)
1006+
assert.Run(func(assert *test.Assert) {
1007+
A, _ := rand.Int(rand.Reader, fp.Modulus())
1008+
B, _ := rand.Int(rand.Reader, fp.Modulus())
1009+
C := new(big.Int).Mul(A, B)
1010+
C.Mod(C, fp.Modulus())
1011+
expectedLimbs := 2*fp.NbLimbs() - 1
1012+
expectedOverFlow := math.Ceil(math.Log2(float64(expectedLimbs+1))) + float64(fp.BitsPerLimb())
1013+
circuit := &MulNoReduceCircuit[T]{expectedOverflow: uint(expectedOverFlow), expectedNbLimbs: int(expectedLimbs)}
1014+
assignment := &MulNoReduceCircuit[T]{A: ValueOf[T](A), B: ValueOf[T](B), C: ValueOf[T](C)}
1015+
assert.CheckCircuit(circuit, test.WithValidAssignment(assignment))
1016+
}, testName[T]())
1017+
}
1018+
1019+
type SumCircuit[T FieldParams] struct {
1020+
Inputs []Element[T]
1021+
Expected Element[T]
1022+
}
1023+
1024+
func (c *SumCircuit[T]) Define(api frontend.API) error {
1025+
f, err := NewField[T](api)
1026+
if err != nil {
1027+
return err
1028+
}
1029+
inputs := make([]*Element[T], len(c.Inputs))
1030+
for i := range inputs {
1031+
inputs[i] = &c.Inputs[i]
1032+
}
1033+
res := f.Sum(inputs...)
1034+
f.AssertIsEqual(res, &c.Expected)
1035+
return nil
1036+
}
1037+
1038+
func TestSum(t *testing.T) {
1039+
testSum[Goldilocks](t)
1040+
testSum[Secp256k1Fp](t)
1041+
testSum[BN254Fp](t)
1042+
}
1043+
1044+
func testSum[T FieldParams](t *testing.T) {
1045+
var fp T
1046+
nbInputs := 1024
1047+
assert := test.NewAssert(t)
1048+
assert.Run(func(assert *test.Assert) {
1049+
circuit := &SumCircuit[T]{Inputs: make([]Element[T], nbInputs)}
1050+
inputs := make([]Element[T], nbInputs)
1051+
result := new(big.Int)
1052+
for i := range inputs {
1053+
val, _ := rand.Int(rand.Reader, fp.Modulus())
1054+
result.Add(result, val)
1055+
inputs[i] = ValueOf[T](val)
1056+
}
1057+
result.Mod(result, fp.Modulus())
1058+
witness := &SumCircuit[T]{Inputs: inputs, Expected: ValueOf[T](result)}
1059+
assert.CheckCircuit(circuit, test.WithValidAssignment(witness))
1060+
}, testName[T]())
1061+
}

std/math/emulated/field_mul.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,23 @@ func (f *Field[T]) mulPreCond(a, b *Element[T]) (nextOverflow uint, err error) {
443443
}
444444
return
445445
}
446+
447+
// MulNoReduce computes a*b and returns the result without reducing it modulo
448+
// the field order. The number of limbs of the returned element depends on the
449+
// number of limbs of the inputs.
450+
func (f *Field[T]) MulNoReduce(a, b *Element[T]) *Element[T] {
451+
return f.reduceAndOp(f.mulNoReduce, f.mulPreCond, a, b)
452+
}
453+
454+
func (f *Field[T]) mulNoReduce(a, b *Element[T], nextoverflow uint) *Element[T] {
455+
resLimbs := make([]frontend.Variable, nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))
456+
for i := range resLimbs {
457+
resLimbs[i] = 0
458+
}
459+
for i := range a.Limbs {
460+
for j := range b.Limbs {
461+
resLimbs[i+j] = f.api.MulAcc(resLimbs[i+j], a.Limbs[i], b.Limbs[j])
462+
}
463+
}
464+
return f.newInternalElement(resLimbs, nextoverflow)
465+
}

std/math/emulated/field_ops.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package emulated
33
import (
44
"errors"
55
"fmt"
6+
"math/bits"
67

78
"github.com/consensys/gnark/frontend"
89
"github.com/consensys/gnark/std/selector"
@@ -132,6 +133,37 @@ func (f *Field[T]) add(a, b *Element[T], nextOverflow uint) *Element[T] {
132133
return f.newInternalElement(limbs, nextOverflow)
133134
}
134135

136+
func (f *Field[T]) Sum(inputs ...*Element[T]) *Element[T] {
137+
if len(inputs) == 0 {
138+
return f.Zero()
139+
}
140+
if len(inputs) == 1 {
141+
return inputs[0]
142+
}
143+
overflow := uint(0)
144+
nbLimbs := 0
145+
for i := range inputs {
146+
f.enforceWidthConditional(inputs[i])
147+
if inputs[i].overflow > overflow {
148+
overflow = inputs[i].overflow
149+
}
150+
if len(inputs[i].Limbs) > nbLimbs {
151+
nbLimbs = len(inputs[i].Limbs)
152+
}
153+
}
154+
addOverflow := bits.Len(uint(len(inputs)))
155+
limbs := make([]frontend.Variable, nbLimbs)
156+
for i := range limbs {
157+
limbs[i] = 0
158+
}
159+
for i := range inputs {
160+
for j := range inputs[i].Limbs {
161+
limbs[j] = f.api.Add(limbs[j], inputs[i].Limbs[j])
162+
}
163+
}
164+
return f.newInternalElement(limbs, overflow+uint(addOverflow))
165+
}
166+
135167
// Reduce reduces a modulo the field order and returns it.
136168
func (f *Field[T]) Reduce(a *Element[T]) *Element[T] {
137169
f.enforceWidthConditional(a)

0 commit comments

Comments
 (0)