diff --git a/ecc/bls12-377/fr/fft/domain.go b/ecc/bls12-377/fr/fft/domain.go index 1d44a171ac..3e6c9ad8c7 100644 --- a/ecc/bls12-377/fr/fft/domain.go +++ b/ecc/bls12-377/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls12-377/fr/fft/domain_test.go b/ecc/bls12-377/fr/fft/domain_test.go index 7049120e6a..8b539cd505 100644 --- a/ecc/bls12-377/fr/fft/domain_test.go +++ b/ecc/bls12-377/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls12-377/fr/fft/options.go b/ecc/bls12-377/fr/fft/options.go index a562b0ae72..e4ed53672b 100644 --- a/ecc/bls12-377/fr/fft/options.go +++ b/ecc/bls12-377/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bls12-377/fr/mimc/mimc.go b/ecc/bls12-377/fr/mimc/mimc.go index 4c4aec6c6c..7b72977738 100644 --- a/ecc/bls12-377/fr/mimc/mimc.go +++ b/ecc/bls12-377/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_BLS12_377, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls12-377/fr/mimc/mimc_test.go b/ecc/bls12-377/fr/mimc/mimc_test.go index 6f076df566..d5972e58b4 100644 --- a/ecc/bls12-377/fr/mimc/mimc_test.go +++ b/ecc/bls12-377/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/ecc/bls12-381/fr/fft/domain.go b/ecc/bls12-381/fr/fft/domain.go index 03f1fbf491..65e73d8e05 100644 --- a/ecc/bls12-381/fr/fft/domain.go +++ b/ecc/bls12-381/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls12-381/fr/fft/domain_test.go b/ecc/bls12-381/fr/fft/domain_test.go index 7049120e6a..9f21b00e23 100644 --- a/ecc/bls12-381/fr/fft/domain_test.go +++ b/ecc/bls12-381/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls12-381/fr/fft/options.go b/ecc/bls12-381/fr/fft/options.go index e705081cda..a775c3f480 100644 --- a/ecc/bls12-381/fr/fft/options.go +++ b/ecc/bls12-381/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bls12-381/fr/mimc/mimc.go b/ecc/bls12-381/fr/mimc/mimc.go index bedaf2d9e9..d6912e0fb8 100644 --- a/ecc/bls12-381/fr/mimc/mimc.go +++ b/ecc/bls12-381/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_BLS12_381, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls12-381/fr/mimc/mimc_test.go b/ecc/bls12-381/fr/mimc/mimc_test.go index 6f497c36e8..808f0cb2d4 100644 --- a/ecc/bls12-381/fr/mimc/mimc_test.go +++ b/ecc/bls12-381/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/ecc/bls24-315/fr/fft/domain.go b/ecc/bls24-315/fr/fft/domain.go index 1b8860e7d8..fd9f1c8d29 100644 --- a/ecc/bls24-315/fr/fft/domain.go +++ b/ecc/bls24-315/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls24-315/fr/fft/domain_test.go b/ecc/bls24-315/fr/fft/domain_test.go index 7049120e6a..35e1bce786 100644 --- a/ecc/bls24-315/fr/fft/domain_test.go +++ b/ecc/bls24-315/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls24-315/fr/fft/options.go b/ecc/bls24-315/fr/fft/options.go index 8538f4cdaa..c621412df8 100644 --- a/ecc/bls24-315/fr/fft/options.go +++ b/ecc/bls24-315/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bls24-315/fr/mimc/mimc.go b/ecc/bls24-315/fr/mimc/mimc.go index 35b38a584b..e8fa5d9a78 100644 --- a/ecc/bls24-315/fr/mimc/mimc.go +++ b/ecc/bls24-315/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_BLS24_315, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls24-315/fr/mimc/mimc_test.go b/ecc/bls24-315/fr/mimc/mimc_test.go index 2f901b07e8..d6f36855f4 100644 --- a/ecc/bls24-315/fr/mimc/mimc_test.go +++ b/ecc/bls24-315/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/ecc/bls24-317/fr/fft/domain.go b/ecc/bls24-317/fr/fft/domain.go index c3745da086..a2fee91212 100644 --- a/ecc/bls24-317/fr/fft/domain.go +++ b/ecc/bls24-317/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls24-317/fr/fft/domain_test.go b/ecc/bls24-317/fr/fft/domain_test.go index 7049120e6a..346befd5af 100644 --- a/ecc/bls24-317/fr/fft/domain_test.go +++ b/ecc/bls24-317/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls24-317/fr/fft/options.go b/ecc/bls24-317/fr/fft/options.go index 9a73619358..6f629f63f9 100644 --- a/ecc/bls24-317/fr/fft/options.go +++ b/ecc/bls24-317/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bls24-317/fr/mimc/mimc.go b/ecc/bls24-317/fr/mimc/mimc.go index 1807b6726b..b2329216b6 100644 --- a/ecc/bls24-317/fr/mimc/mimc.go +++ b/ecc/bls24-317/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_BLS24_317, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls24-317/fr/mimc/mimc_test.go b/ecc/bls24-317/fr/mimc/mimc_test.go index cc838f60c6..ed41c62bc3 100644 --- a/ecc/bls24-317/fr/mimc/mimc_test.go +++ b/ecc/bls24-317/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/ecc/bn254/fr/fft/domain.go b/ecc/bn254/fr/fft/domain.go index 5c9d3e545c..fea6e366b3 100644 --- a/ecc/bn254/fr/fft/domain.go +++ b/ecc/bn254/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bn254/fr/fft/domain_test.go b/ecc/bn254/fr/fft/domain_test.go index 7049120e6a..f274817c05 100644 --- a/ecc/bn254/fr/fft/domain_test.go +++ b/ecc/bn254/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bn254/fr/fft/options.go b/ecc/bn254/fr/fft/options.go index 87e5bae69a..54ff79010f 100644 --- a/ecc/bn254/fr/fft/options.go +++ b/ecc/bn254/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bn254/fr/mimc/mimc.go b/ecc/bn254/fr/mimc/mimc.go index 91f260e3cd..73e9f741d7 100644 --- a/ecc/bn254/fr/mimc/mimc.go +++ b/ecc/bn254/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_BN254, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bn254/fr/mimc/mimc_test.go b/ecc/bn254/fr/mimc/mimc_test.go index aa9e21c105..c4737fda4b 100644 --- a/ecc/bn254/fr/mimc/mimc_test.go +++ b/ecc/bn254/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/ecc/bw6-633/fr/fft/domain.go b/ecc/bw6-633/fr/fft/domain.go index ab6cc41a90..6f83b4ea28 100644 --- a/ecc/bw6-633/fr/fft/domain.go +++ b/ecc/bw6-633/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bw6-633/fr/fft/domain_test.go b/ecc/bw6-633/fr/fft/domain_test.go index 7049120e6a..8cd438b873 100644 --- a/ecc/bw6-633/fr/fft/domain_test.go +++ b/ecc/bw6-633/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bw6-633/fr/fft/options.go b/ecc/bw6-633/fr/fft/options.go index 3b9f572b47..2f2ee7e39a 100644 --- a/ecc/bw6-633/fr/fft/options.go +++ b/ecc/bw6-633/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bw6-633/fr/mimc/mimc.go b/ecc/bw6-633/fr/mimc/mimc.go index 1245152b83..7f3cbe465f 100644 --- a/ecc/bw6-633/fr/mimc/mimc.go +++ b/ecc/bw6-633/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_BW6_633, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bw6-633/fr/mimc/mimc_test.go b/ecc/bw6-633/fr/mimc/mimc_test.go index 105bf45c29..13694fe588 100644 --- a/ecc/bw6-633/fr/mimc/mimc_test.go +++ b/ecc/bw6-633/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/ecc/bw6-761/fr/fft/domain.go b/ecc/bw6-761/fr/fft/domain.go index 079b9ada80..bc2569e6ab 100644 --- a/ecc/bw6-761/fr/fft/domain.go +++ b/ecc/bw6-761/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bw6-761/fr/fft/domain_test.go b/ecc/bw6-761/fr/fft/domain_test.go index 7049120e6a..6752274e55 100644 --- a/ecc/bw6-761/fr/fft/domain_test.go +++ b/ecc/bw6-761/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bw6-761/fr/fft/options.go b/ecc/bw6-761/fr/fft/options.go index 276471bd1c..6bd158b940 100644 --- a/ecc/bw6-761/fr/fft/options.go +++ b/ecc/bw6-761/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bw6-761/fr/mimc/mimc.go b/ecc/bw6-761/fr/mimc/mimc.go index d341995043..33770aec78 100644 --- a/ecc/bw6-761/fr/mimc/mimc.go +++ b/ecc/bw6-761/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_BW6_761, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bw6-761/fr/mimc/mimc_test.go b/ecc/bw6-761/fr/mimc/mimc_test.go index afb5d5e6fe..2d44034b02 100644 --- a/ecc/bw6-761/fr/mimc/mimc_test.go +++ b/ecc/bw6-761/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/ecc/grumpkin/fr/mimc/mimc.go b/ecc/grumpkin/fr/mimc/mimc.go index a1fe21bbad..a47790cab5 100644 --- a/ecc/grumpkin/fr/mimc/mimc.go +++ b/ecc/grumpkin/fr/mimc/mimc.go @@ -17,6 +17,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_GRUMPKIN, func() stdhash.Hash { return NewMiMC() @@ -53,8 +78,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +108,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +167,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -144,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/grumpkin/fr/mimc/mimc_test.go b/ecc/grumpkin/fr/mimc/mimc_test.go index 88f802a09c..b26236c1b5 100644 --- a/ecc/grumpkin/fr/mimc/mimc_test.go +++ b/ecc/grumpkin/fr/mimc/mimc_test.go @@ -116,3 +116,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} diff --git a/field/babybear/extensions/e2_test.go b/field/babybear/extensions/e2_test.go index 68178cd56e..e95d627137 100644 --- a/field/babybear/extensions/e2_test.go +++ b/field/babybear/extensions/e2_test.go @@ -513,3 +513,36 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + elmt[i] = uint32(w) + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/babybear/fft/domain.go b/field/babybear/fft/domain.go index 49a6ab8afb..cdb69ae2be 100644 --- a/field/babybear/fft/domain.go +++ b/field/babybear/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/field/babybear" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() babybear.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen babybear.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/field/babybear/fft/domain_test.go b/field/babybear/fft/domain_test.go index 7049120e6a..b9e8f1a387 100644 --- a/field/babybear/fft/domain_test.go +++ b/field/babybear/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/field/babybear" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := babybear.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/babybear/fft/options.go b/field/babybear/fft/options.go index d6fcb9c156..a586629a72 100644 --- a/field/babybear/fft/options.go +++ b/field/babybear/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *babybear.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/field/eisenstein/eisenstein.go b/field/eisenstein/eisenstein.go index e809154133..033ed902b7 100644 --- a/field/eisenstein/eisenstein.go +++ b/field/eisenstein/eisenstein.go @@ -9,6 +9,27 @@ type ComplexNumber struct { A0, A1 *big.Int } +// ────────────────────────────────────────────────────────────────────────────── +// helpers – hex-lattice geometry & symmetric rounding +// ────────────────────────────────────────────────────────────────────────────── + +// six axial directions of the hexagonal lattice +var neighbours = [][2]int64{ + {1, 0}, {0, 1}, {-1, 1}, {-1, 0}, {0, -1}, {1, -1}, +} + +// roundNearest returns ⌊(z + d/2) / d⌋ for *any* sign of z, d>0 +func roundNearest(z, d *big.Int) *big.Int { + half := new(big.Int).Rsh(d, 1) // d / 2 + if z.Sign() >= 0 { + return new(big.Int).Div(new(big.Int).Add(z, half), d) + } + tmp := new(big.Int).Neg(z) + tmp.Add(tmp, half) + tmp.Div(tmp, d) + return tmp.Neg(tmp) +} + func (z *ComplexNumber) init() { if z.A0 == nil { z.A0 = new(big.Int) @@ -124,19 +145,55 @@ func (z *ComplexNumber) Norm() *big.Int { return norm } -// QuoRem sets z to the quotient of x and y, r to the remainder, and returns z and r. +// QuoRem sets z to the Euclidean quotient of x / y, r to the remainder, +// and guarantees ‖r‖ < ‖y‖ (true Euclidean division in ℤ[ω]). func (z *ComplexNumber) QuoRem(x, y, r *ComplexNumber) (*ComplexNumber, *ComplexNumber) { - norm := y.Norm() - if norm.Cmp(big.NewInt(0)) == 0 { + + norm := y.Norm() // > 0 (Eisenstein norm is always non-neg) + if norm.Sign() == 0 { panic("division by zero") } - z.Conjugate(y) - z.Mul(x, z) - z.A0.Div(z.A0, norm) - z.A1.Div(z.A1, norm) + + // num = x * ȳ (ȳ computed in a fresh variable → y unchanged) + var yConj, num ComplexNumber + yConj.Conjugate(y) + num.Mul(x, &yConj) + + // first guess by *symmetric* rounding of both coordinates + q0 := roundNearest(num.A0, norm) + q1 := roundNearest(num.A1, norm) + z.A0, z.A1 = q0, q1 + + // r = x – q*y r.Mul(y, z) r.Sub(x, r) + // If Euclidean inequality already holds we're done. + // Otherwise walk ≤2 unit steps in the hex lattice until N(r) < N(y). + if r.Norm().Cmp(norm) >= 0 { + bestQ0, bestQ1 := new(big.Int).Set(z.A0), new(big.Int).Set(z.A1) + bestR := new(ComplexNumber).Set(r) + bestN2 := bestR.Norm() + + for _, dir := range neighbours { + candQ0 := new(big.Int).Add(z.A0, big.NewInt(dir[0])) + candQ1 := new(big.Int).Add(z.A1, big.NewInt(dir[1])) + var candQ ComplexNumber + candQ.A0, candQ.A1 = candQ0, candQ1 + + var candR ComplexNumber + candR.Mul(y, &candQ) + candR.Sub(x, &candR) + + if candR.Norm().Cmp(bestN2) < 0 { + bestQ0, bestQ1 = candQ0, candQ1 + bestR.Set(&candR) + bestN2 = bestR.Norm() + } + } + z.A0, z.A1 = bestQ0, bestQ1 + r.Set(bestR) // update remainder and retry; Euclidean property ⇒ ≤ 2 loops + } return z, r } diff --git a/field/eisenstein/eisenstein_test.go b/field/eisenstein/eisenstein_test.go index 6aff795f92..0d46204448 100644 --- a/field/eisenstein/eisenstein_test.go +++ b/field/eisenstein/eisenstein_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "math/big" "testing" + "time" "github.com/leanovate/gopter" "github.com/leanovate/gopter/prop" @@ -240,6 +241,66 @@ func TestEisensteinHalfGCD(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestEisensteinQuoRem(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + genE := GenComplexNumber(boundSize) + + properties.Property("QuoRem should be correct", prop.ForAll( + func(a, b *ComplexNumber) bool { + var z, rem ComplexNumber + z.QuoRem(a, b, &rem) + var res ComplexNumber + res.Mul(b, &z) + res.Add(&res, &rem) + return res.Equal(a) + }, + genE, + genE, + )) + + properties.Property("QuoRem remainder should be smaller than divisor", prop.ForAll( + func(a, b *ComplexNumber) bool { + var z, rem ComplexNumber + z.QuoRem(a, b, &rem) + return rem.Norm().Cmp(b.Norm()) == -1 + }, + genE, + genE, + )) +} + +func TestRegressionHalfGCD1483(t *testing.T) { + // This test is a regression test for issue #1483 in gnark + a0, _ := new(big.Int).SetString("64502973549206556628585045361533709077", 10) + a1, _ := new(big.Int).SetString("-303414439467246543595250775667605759171", 10) + c0, _ := new(big.Int).SetString("-432420386565659656852420866390673177323", 10) + c1, _ := new(big.Int).SetString("238911465918039986966665730306072050094", 10) + a := ComplexNumber{A0: a0, A1: a1} + c := ComplexNumber{A0: c0, A1: c1} + + ticker := time.NewTimer(time.Second * 3) + doneCh := make(chan struct{}) + go func() { + HalfGCD(&a, &c) + close(doneCh) + }() + + select { + case <-ticker.C: + t.Error("HalfGCD took too long to compute") + case <-doneCh: + // Test passed + } +} + // GenNumber generates a random integer func GenNumber(boundSize int64) gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult { diff --git a/field/generator/internal/templates/extensions/e2_test.go.tmpl b/field/generator/internal/templates/extensions/e2_test.go.tmpl index 502d21bd4d..5be2f5a94d 100644 --- a/field/generator/internal/templates/extensions/e2_test.go.tmpl +++ b/field/generator/internal/templates/extensions/e2_test.go.tmpl @@ -512,3 +512,43 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + + + + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + {{- if or (eq .FF "babybear") (eq .FF "koalabear")}} + elmt[i] = uint32(w) + {{- else}} + elmt[i] = uint64(w) + {{- end}} + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/generator/internal/templates/fft/domain.go.tmpl b/field/generator/internal/templates/fft/domain.go.tmpl index 7e1d2e0827..88d1fe93f8 100644 --- a/field/generator/internal/templates/fft/domain.go.tmpl +++ b/field/generator/internal/templates/fft/domain.go.tmpl @@ -4,6 +4,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "errors" "encoding/binary" @@ -54,17 +55,122 @@ func GeneratorFullMultiplicativeGroup() {{ .FF }}.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. + +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen {{ .FF }}.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) domain.FrMultiplicativeGen = GeneratorFullMultiplicativeGroup() - if opt.shift != nil{ + if opt.shift != nil { domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) diff --git a/field/generator/internal/templates/fft/options.go.tmpl b/field/generator/internal/templates/fft/options.go.tmpl index 87fe89ff4a..9213fb8a69 100644 --- a/field/generator/internal/templates/fft/options.go.tmpl +++ b/field/generator/internal/templates/fft/options.go.tmpl @@ -57,6 +57,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *{{ .FF }}.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -75,11 +76,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/field/generator/internal/templates/fft/tests/domain.go.tmpl b/field/generator/internal/templates/fft/tests/domain.go.tmpl index ff775ad362..71d9b4cc0e 100644 --- a/field/generator/internal/templates/fft/tests/domain.go.tmpl +++ b/field/generator/internal/templates/fft/tests/domain.go.tmpl @@ -1,8 +1,13 @@ import ( + "bytes" "reflect" + "runtime" "testing" - "bytes" + + "{{ .FieldPackagePath }}" + "github.com/stretchr/testify/require" + ) func TestDomainSerialization(t *testing.T) { @@ -27,4 +32,99 @@ func TestDomainSerialization(t *testing.T) { if !reflect.DeepEqual(domain, &reconstructed) { t.Fatal("Domain.SetBytes(Bytes()) failed") } -} \ No newline at end of file +} + + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := {{ .FF }}.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/goldilocks/extensions/e2_test.go b/field/goldilocks/extensions/e2_test.go index c50e416a7f..abb4dc21fb 100644 --- a/field/goldilocks/extensions/e2_test.go +++ b/field/goldilocks/extensions/e2_test.go @@ -513,3 +513,36 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + elmt[i] = uint64(w) + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/goldilocks/fft/domain.go b/field/goldilocks/fft/domain.go index 80e9902fe2..4a35e1adca 100644 --- a/field/goldilocks/fft/domain.go +++ b/field/goldilocks/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/field/goldilocks" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() goldilocks.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen goldilocks.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/field/goldilocks/fft/domain_test.go b/field/goldilocks/fft/domain_test.go index 7049120e6a..7761c83cee 100644 --- a/field/goldilocks/fft/domain_test.go +++ b/field/goldilocks/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/field/goldilocks" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := goldilocks.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/goldilocks/fft/options.go b/field/goldilocks/fft/options.go index db171a3724..1789d5a3c8 100644 --- a/field/goldilocks/fft/options.go +++ b/field/goldilocks/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *goldilocks.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/field/koalabear/extensions/e2_test.go b/field/koalabear/extensions/e2_test.go index d838cb5bb6..acfc6c853e 100644 --- a/field/koalabear/extensions/e2_test.go +++ b/field/koalabear/extensions/e2_test.go @@ -513,3 +513,36 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + elmt[i] = uint32(w) + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/koalabear/fft/domain.go b/field/koalabear/fft/domain.go index 6037f1d6c6..dcaf90e8ae 100644 --- a/field/koalabear/fft/domain.go +++ b/field/koalabear/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/field/koalabear" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() koalabear.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen koalabear.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/field/koalabear/fft/domain_test.go b/field/koalabear/fft/domain_test.go index 7049120e6a..44fcc4a629 100644 --- a/field/koalabear/fft/domain_test.go +++ b/field/koalabear/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/field/koalabear" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := koalabear.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/koalabear/fft/options.go b/field/koalabear/fft/options.go index 7c4ba3ffb9..b3d8a95c15 100644 --- a/field/koalabear/fft/options.go +++ b/field/koalabear/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *koalabear.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl index b1130e2094..8cdd8f2ef3 100644 --- a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl @@ -10,6 +10,31 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements +type FieldHasher interface { + hash.StateStorer + + // WriteElement adds a field element to the running hash. + WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. + SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} + func init() { hash.RegisterHash(hash.MIMC_{{ .EnumID }}, func() stdhash.Hash { return NewMiMC() @@ -62,8 +87,19 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -81,12 +117,20 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -133,6 +177,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form @@ -153,6 +202,30 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} {{ if eq .Name "bls12-377" }} // plain execution of a mimc run diff --git a/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl b/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl index 23c4484374..4a70661ed6 100644 --- a/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl @@ -84,7 +84,6 @@ func TestSetState(t *testing.T) { storedStates := make([][]byte, len(randInputs)) - for i := range randInputs { storedStates[i] = h1.State() @@ -110,3 +109,37 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + +} \ No newline at end of file