diff --git a/frontend/api.go b/frontend/api.go index 40a6fdfaba..57696f1290 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -99,7 +99,7 @@ type API interface { // // If the absolute difference between the variables i1 and i2 is known, then // it is more efficient to use the bounded methods in package - // [github.com/consensys/gnark/std/math/bits]. + // [https://github.com/Consensys/gnark/blob/master/std/math/cmp]. Cmp(i1, i2 Variable) Variable // --------------------------------------------------------------------------------------------- @@ -121,7 +121,7 @@ type API interface { // // If the absolute difference between the variables b and bound is known, then // it is more efficient to use the bounded methods in package - // [github.com/consensys/gnark/std/math/bits]. + // [https://github.com/Consensys/gnark/blob/master/std/math/cmp]. AssertIsLessOrEqual(v Variable, bound Variable) // Println behaves like fmt.Println but accepts cd.Variable as parameter diff --git a/std/math/cmp/bounded.go b/std/math/cmp/bounded.go index 14d5a433a5..7b60220f8e 100644 --- a/std/math/cmp/bounded.go +++ b/std/math/cmp/bounded.go @@ -2,10 +2,11 @@ package cmp import ( "fmt" + "math/big" + "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/bits" - "math/big" ) func init() { @@ -151,6 +152,10 @@ func (bc BoundedComparator) AssertIsLess(a, b frontend.Variable) { } // IsLess returns 1 if a < b, and returns 0 if a >= b. +// When |a - b| >= 2^absDiffUpp.BitLen(), a panic is occurred, +// then the method has no return value, and a proof can not be generated. +// It is recommended to use the IsLess method to get a valid return value +// in https://github.com/Consensys/gnark/blob/master/std/math/cmp/generic.go func (bc BoundedComparator) IsLess(a, b frontend.Variable) frontend.Variable { res, err := bc.api.Compiler().NewHint(isLessOutputHint, 1, a, b) if err != nil { @@ -164,6 +169,10 @@ func (bc BoundedComparator) IsLess(a, b frontend.Variable) frontend.Variable { } // IsLessEq returns 1 if a <= b, and returns 0 if a > b. +// When |a - b| > 2^absDiffUpp.BitLen(), a panic is occurred, +// then the method has no return value, and a proof can not be generated. +// It is recommended to use the IsLessOrEqual method to get a valid return value +// in https://github.com/Consensys/gnark/blob/master/std/math/cmp/generic.go func (bc BoundedComparator) IsLessEq(a, b frontend.Variable) frontend.Variable { // a <= b <==> a < b + 1 return bc.IsLess(a, bc.api.Add(b, 1)) diff --git a/std/math/cmp/bounded_test.go b/std/math/cmp/bounded_test.go index fcdfc0677d..12cee7bd53 100644 --- a/std/math/cmp/bounded_test.go +++ b/std/math/cmp/bounded_test.go @@ -1,11 +1,14 @@ package cmp_test import ( + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/cmp" "github.com/consensys/gnark/test" - "math/big" - "testing" ) func TestAssertIsLessEq(t *testing.T) { @@ -143,3 +146,72 @@ func (c *minCircuit) Define(api frontend.API) error { return nil } + +type boundedComparatorCircuit struct { + A frontend.Variable + + WantIsLess int + WantIsLessEq int + Bound int +} + +func (c *boundedComparatorCircuit) Define(api frontend.API) error { + comparator := cmp.NewBoundedComparator(api, big.NewInt(int64(c.Bound)), true) + if c.WantIsLess == 1 { + comparator.AssertIsLess(c.A, c.Bound) + } + if c.WantIsLessEq == 1 { + comparator.AssertIsLessEq(c.A, c.Bound) + } + + api.AssertIsEqual(c.WantIsLess, comparator.IsLess(c.A, c.Bound)) + api.AssertIsEqual(c.WantIsLessEq, comparator.IsLessEq(c.A, c.Bound)) + + return nil +} + +type boundedComparatorTestCase struct { + A int + + WantIsLess int + WantIsLessEq int + Bound int + + expectedSuccess bool +} + +func TestBoundedComparator(t *testing.T) { + assert := test.NewAssert(t) + + var testCases []boundedComparatorTestCase + for bound := 2; bound <= 15; bound++ { + c := 1 << (big.NewInt(int64(bound)).BitLen()) + for i := 0; i <= bound+5; i++ { + testCase := boundedComparatorTestCase{ + A: i, Bound: bound, WantIsLess: 1, WantIsLessEq: 1, expectedSuccess: true} + if i >= bound { + testCase.WantIsLess = 0 + if i > bound { + testCase.WantIsLessEq = 0 + } + } + if i-bound >= c { + testCase.expectedSuccess = false + } + testCases = append(testCases, testCase) + } + } + + for _, tc := range testCases { + assert.Run(func(assert *test.Assert) { + circuit := &boundedComparatorCircuit{Bound: tc.Bound, WantIsLess: tc.WantIsLess, WantIsLessEq: tc.WantIsLessEq} + assignment := &boundedComparatorCircuit{A: tc.A} + err := test.IsSolved(circuit, assignment, ecc.BN254.ScalarField()) + if tc.expectedSuccess { + assert.NoError(err) + } else { + assert.Error(err) + } + }, fmt.Sprintf("bound=%d a=%d", tc.Bound, tc.A)) + } +}