diff --git a/circuit/circuit.go b/circuit/circuit.go index 4c8bc4a..047d727 100644 --- a/circuit/circuit.go +++ b/circuit/circuit.go @@ -11,6 +11,8 @@ import ( "fmt" "io" "math" + "sync" + "sync/atomic" "github.com/markkurossi/mpc/compiler/utils" "github.com/markkurossi/tabulate" @@ -122,6 +124,10 @@ type Circuit struct { Outputs IO Gates []Gate Stats Stats + + // garblePool holds reusable scratch for Garble, lazily created on first + // use and scoped to this circuit so it is collected with it. + garblePool atomic.Pointer[sync.Pool] } func (c *Circuit) String() string { diff --git a/circuit/garble.go b/circuit/garble.go index 4103d0d..e622915 100644 --- a/circuit/garble.go +++ b/circuit/garble.go @@ -12,6 +12,7 @@ import ( "encoding/binary" "fmt" "io" + "sync" "github.com/markkurossi/mpc/ot" ) @@ -155,11 +156,72 @@ func makeLabels(rand io.Reader, r ot.Label) (ot.Wire, error) { }, nil } -// Garbled contains garbled circuit information. +// Garbled holds garbled circuit information. When produced by Circuit.Garble, +// Wires and Gates point into reusable scratch; call Release once the garbled +// tables have been consumed (e.g. serialized) to return it for reuse. type Garbled struct { R ot.Label Wires []ot.Wire Gates [][]ot.Label + + scratch *garbledScratch + pool *sync.Pool +} + +// Release returns the scratch buffers to the circuit's pool. It is optional: +// skipping it just forgoes reuse. Idempotent; the Garbled must not be used +// afterwards. +func (g *Garbled) Release() { + if g == nil || g.pool == nil { + return + } + g.pool.Put(g.scratch) + g.scratch = nil + g.pool = nil + g.Wires = nil + g.Gates = nil +} + +// garbledScratch holds the heap buffers reused across Garble calls. slab is +// the backing array the per-gate table slices are carved from. +type garbledScratch struct { + wires []ot.Wire + slab []ot.Label + gates [][]ot.Label +} + +// garbleScratchPool returns this circuit's scratch pool, building it on first +// use. The pool is stored on the circuit so it lives and dies with it. +func (c *Circuit) garbleScratchPool() *sync.Pool { + if p := c.garblePool.Load(); p != nil { + return p + } + var slabSize int + for i := range c.Gates { + switch c.Gates[i].Op { + case AND: + slabSize += 2 + case OR: + slabSize += 3 + case INV: + slabSize += 1 + case XOR, XNOR: + // Free XOR: no garbled rows. + } + } + p := &sync.Pool{ + New: func() any { + return &garbledScratch{ + wires: make([]ot.Wire, c.NumWires), + slab: make([]ot.Label, slabSize), + gates: make([][]ot.Label, c.NumGates), + } + }, + } + if c.garblePool.CompareAndSwap(nil, p) { + return p + } + return c.garblePool.Load() } // Lambda returns the lambda value of the wire. @@ -181,62 +243,76 @@ func (g *Garbled) SetLambda(wire Wire, val uint) { g.Wires[wire] = w } -// Garble garbles the circuit. +// Garble garbles the circuit. The returned Garbled is backed by reusable +// scratch; call Release once its tables are consumed to return it for reuse. func (c *Circuit) Garble(rand io.Reader, key []byte) (*Garbled, error) { + pool := c.garbleScratchPool() + scratch := pool.Get().(*garbledScratch) + // Create R. r, err := ot.NewLabel(rand) if err != nil { + pool.Put(scratch) return nil, err } r.SetS(true) - garbled := make([][]ot.Label, c.NumGates) - alg, err := aes.NewCipher(key) if err != nil { + pool.Put(scratch) return nil, err } - // Wire labels. - wires := make([]ot.Wire, c.NumWires) + wires := scratch.wires + slab := scratch.slab + gates := scratch.gates - // Assing all input wires. + // Assign all input wires. for i := 0; i < c.Inputs.Size(); i++ { w, err := makeLabels(rand, r) if err != nil { + pool.Put(scratch) return nil, err } wires[i] = w } - // Garble gates. + // Each gate writes labels into a stack table; we copy into the slab. var data ot.LabelData var id uint32 + slabOff := 0 + var table [4]ot.Label for i := 0; i < len(c.Gates); i++ { gate := &c.Gates[i] - data, err := gate.garble(wires, alg, r, &id, &data) + start, count, err := gate.garbleInto(wires, alg, r, &id, &data, &table) if err != nil { + pool.Put(scratch) return nil, err } - garbled[i] = data + if count == 0 { + gates[i] = nil + continue + } + copy(slab[slabOff:slabOff+count], table[start:start+count]) + gates[i] = slab[slabOff : slabOff+count : slabOff+count] + slabOff += count } return &Garbled{ - R: r, - Wires: wires, - Gates: garbled, + R: r, + Wires: wires, + Gates: gates, + scratch: scratch, + pool: pool, }, nil } -// Garble garbles the gate and returns it labels. -func (g *Gate) garble(wires []ot.Wire, enc cipher.Block, r ot.Label, - idp *uint32, data *ot.LabelData) ([]ot.Label, error) { +// Writes the gate's output table into the caller's buffer; returns the slice [start, start+count). +func (g *Gate) garbleInto(wires []ot.Wire, enc cipher.Block, r ot.Label, + idp *uint32, data *ot.LabelData, table *[4]ot.Label) (start, count int, err error) { var a, b, c ot.Wire - var table [4]ot.Label - var start, count int - // Inputs. switch g.Op { case XOR, XNOR, AND, OR: @@ -247,7 +323,7 @@ func (g *Gate) garble(wires []ot.Wire, enc cipher.Block, r ot.Label, a = wires[g.Input0] default: - return nil, fmt.Errorf("invalid gate type %s", g.Op) + return 0, 0, fmt.Errorf("invalid gate type %s", g.Op) } // Output. @@ -398,9 +474,9 @@ func (g *Gate) garble(wires []ot.Wire, enc cipher.Block, r ot.Label, count = 1 default: - return nil, fmt.Errorf("invalid operand %s", g.Op) + return 0, 0, fmt.Errorf("invalid operand %s", g.Op) } wires[g.Output] = c - return table[start : start+count], nil + return start, count, nil } diff --git a/circuit/garble_bench_test.go b/circuit/garble_bench_test.go new file mode 100644 index 0000000..f1e93cf --- /dev/null +++ b/circuit/garble_bench_test.go @@ -0,0 +1,64 @@ +// +// Copyright (c) 2019-2026 Markku Rossi +// +// All rights reserved. +// + +package circuit + +import ( + "crypto/rand" + "fmt" + "strings" + "testing" +) + +// buildANDChain returns a circuit of n chained AND gates over two input bits: +// gate i computes AND(wire i, wire i+1) -> wire i+2. It exercises the per-gate +// garbled-table allocation path without depending on the compiler. +func buildANDChain(n int) *Circuit { + var b strings.Builder + fmt.Fprintf(&b, "%d %d\n", n, n+2) + b.WriteString("2 1 1\n") + b.WriteString("1 1\n\n") + for i := 0; i < n; i++ { + fmt.Fprintf(&b, "2 1 %d %d %d AND\n", i, i+1, i+2) + } + c, err := ParseBristol(strings.NewReader(b.String())) + if err != nil { + panic(err) + } + return c +} + +var benchKey = []byte("0123456789abcdef") + +// BenchmarkGarble measures a single Garble with no reuse. Runnable on any +// version, so it is the apples-to-apples baseline for the slab change. +func BenchmarkGarble(b *testing.B) { + c := buildANDChain(10000) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + g, err := c.Garble(rand.Reader, benchKey) + if err != nil { + b.Fatal(err) + } + _ = g + } +} + +// BenchmarkGarbleReuse measures Garble with Release, the intended steady-state +// usage where scratch is recycled across calls. +func BenchmarkGarbleReuse(b *testing.B) { + c := buildANDChain(10000) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + g, err := c.Garble(rand.Reader, benchKey) + if err != nil { + b.Fatal(err) + } + g.Release() + } +}