diff --git a/go/dpagg/count.go b/go/dpagg/count.go index 0126bfda..3751dc67 100644 --- a/go/dpagg/count.go +++ b/go/dpagg/count.go @@ -305,13 +305,23 @@ func (c *Count) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode Count from bytes") } + n := noise.ToNoise(enc.NoiseKind) + if n == nil { + return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind) + } + if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil { + return fmt.Errorf("GobDecode Count: %v", err) + } + if enc.LInfSensitivity <= 0 { + return fmt.Errorf("GobDecode Count: lInfSensitivity must be positive, got %d", enc.LInfSensitivity) + } *c = Count{ epsilon: enc.Epsilon, delta: enc.Delta, l0Sensitivity: enc.L0Sensitivity, lInfSensitivity: enc.LInfSensitivity, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: n, count: enc.Count, state: defaultState, } diff --git a/go/dpagg/gobdecode_validation_test.go b/go/dpagg/gobdecode_validation_test.go new file mode 100644 index 00000000..a69d21cc --- /dev/null +++ b/go/dpagg/gobdecode_validation_test.go @@ -0,0 +1,153 @@ +// +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package dpagg + +import ( + "testing" + + "github.com/google/differential-privacy/go/v4/noise" +) + +func TestCountGobDecodeRejectsInvalidNoiseKind(t *testing.T) { + enc := encodableCount{ + Epsilon: 1.0, + Delta: 0, + L0Sensitivity: 1, + LInfSensitivity: 1, + NoiseKind: 99, // invalid + Count: 0, + } + data, err := encode(&enc) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var c Count + if err := c.GobDecode(data); err == nil { + t.Error("GobDecode: expected error for invalid NoiseKind, got nil") + } +} + +func TestCountGobDecodeRejectsInvalidL0Sensitivity(t *testing.T) { + enc := encodableCount{ + Epsilon: 1.0, + Delta: 0, + L0Sensitivity: -1, // invalid + LInfSensitivity: 1, + NoiseKind: noise.LaplaceNoise, + Count: 0, + } + data, err := encode(&enc) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var c Count + if err := c.GobDecode(data); err == nil { + t.Error("GobDecode: expected error for negative l0Sensitivity, got nil") + } +} + +func TestBoundedSumFloat64GobDecodeRejectsInvalidNoiseKind(t *testing.T) { + enc := encodableBoundedSumFloat64{ + Epsilon: 1.0, + Delta: 0, + L0Sensitivity: 1, + LInfSensitivity: 1.0, + Lower: 0, + Upper: 10, + NoiseKind: 99, // invalid + Sum: 0, + } + data, err := encode(&enc) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var bs BoundedSumFloat64 + if err := bs.GobDecode(data); err == nil { + t.Error("GobDecode: expected error for invalid NoiseKind, got nil") + } +} + +func TestBoundedSumFloat64GobDecodeRejectsInvalidLInfSensitivity(t *testing.T) { + enc := encodableBoundedSumFloat64{ + Epsilon: 1.0, + Delta: 0, + L0Sensitivity: 1, + LInfSensitivity: -1.0, // invalid: must be positive + Lower: 0, + Upper: 10, + NoiseKind: noise.LaplaceNoise, + Sum: 0, + } + data, err := encode(&enc) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var bs BoundedSumFloat64 + if err := bs.GobDecode(data); err == nil { + t.Error("GobDecode: expected error for negative lInfSensitivity, got nil") + } +} + +func TestBoundedQuantilesGobDecodeRejectsZeroBranchingFactor(t *testing.T) { + enc := encodableBoundedQuantiles{ + Epsilon: 1.0, + Delta: 1e-5, + L0Sensitivity: 1, + LInfSensitivity: 1.0, + TreeHeight: 4, + BranchingFactor: 0, // invalid: causes div-by-zero + Lower: 0, + Upper: 10, + NumLeaves: 0, + LeftmostLeafIndex: 0, + NoiseKind: noise.GaussianNoise, + QuantileTree: make(map[int]int64), + } + data, err := encode(&enc) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var bq BoundedQuantiles + if err := bq.GobDecode(data); err == nil { + t.Error("GobDecode: expected error for zero branchingFactor, got nil") + } +} + +func TestBoundedQuantilesGobDecodeRejectsInvalidNoiseKind(t *testing.T) { + enc := encodableBoundedQuantiles{ + Epsilon: 1.0, + Delta: 1e-5, + L0Sensitivity: 1, + LInfSensitivity: 1.0, + TreeHeight: 4, + BranchingFactor: 16, + Lower: 0, + Upper: 10, + NumLeaves: 0, + LeftmostLeafIndex: 0, + NoiseKind: 99, // invalid + QuantileTree: make(map[int]int64), + } + data, err := encode(&enc) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + var bq BoundedQuantiles + if err := bq.GobDecode(data); err == nil { + t.Error("GobDecode: expected error for invalid NoiseKind, got nil") + } +} diff --git a/go/dpagg/quantiles.go b/go/dpagg/quantiles.go index b25dc6c9..75809559 100644 --- a/go/dpagg/quantiles.go +++ b/go/dpagg/quantiles.go @@ -445,6 +445,25 @@ func (bq *BoundedQuantiles) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedQuantiles from bytes") } + n := noise.ToNoise(enc.NoiseKind) + if n == nil { + return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind) + } + if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil { + return fmt.Errorf("GobDecode BoundedQuantiles: %v", err) + } + if err := checks.CheckLInfSensitivity(enc.LInfSensitivity); err != nil { + return fmt.Errorf("GobDecode BoundedQuantiles: %v", err) + } + if err := checks.CheckBoundsFloat64IgnoreOverflows(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("GobDecode BoundedQuantiles: %v", err) + } + if err := checks.CheckTreeHeight(enc.TreeHeight); err != nil { + return fmt.Errorf("GobDecode BoundedQuantiles: %v", err) + } + if enc.BranchingFactor < 2 { + return fmt.Errorf("GobDecode BoundedQuantiles: branchingFactor must be at least 2, got %d", enc.BranchingFactor) + } *bq = BoundedQuantiles{ epsilon: enc.Epsilon, delta: enc.Delta, @@ -455,7 +474,7 @@ func (bq *BoundedQuantiles) GobDecode(data []byte) error { lower: enc.Lower, upper: enc.Upper, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: n, numLeaves: enc.NumLeaves, leftmostLeafIndex: enc.LeftmostLeafIndex, tree: enc.QuantileTree, diff --git a/go/dpagg/select_partition.go b/go/dpagg/select_partition.go index 88e85c0a..e7d666a6 100644 --- a/go/dpagg/select_partition.go +++ b/go/dpagg/select_partition.go @@ -439,6 +439,9 @@ func (s *PreAggSelectPartition) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode PreAggSelectPartition from bytes") } + if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil { + return fmt.Errorf("GobDecode PreAggSelectPartition: %v", err) + } *s = PreAggSelectPartition{ epsilon: enc.Epsilon, delta: enc.Delta, diff --git a/go/dpagg/sum.go b/go/dpagg/sum.go index f655e453..84f1ebb7 100644 --- a/go/dpagg/sum.go +++ b/go/dpagg/sum.go @@ -376,6 +376,16 @@ func (bs *BoundedSumInt64) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedSumInt64 from bytes") } + n := noise.ToNoise(enc.NoiseKind) + if n == nil { + return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind) + } + if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil { + return fmt.Errorf("GobDecode BoundedSumInt64: %v", err) + } + if err := checks.CheckBoundsInt64IgnoreOverflows(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("GobDecode BoundedSumInt64: %v", err) + } *bs = BoundedSumInt64{ epsilon: enc.Epsilon, delta: enc.Delta, @@ -384,7 +394,7 @@ func (bs *BoundedSumInt64) GobDecode(data []byte) error { lower: enc.Lower, upper: enc.Upper, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: n, sum: enc.Sum, state: defaultState, } @@ -693,6 +703,19 @@ func (bs *BoundedSumFloat64) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedSumFloat64 from bytes") } + n := noise.ToNoise(enc.NoiseKind) + if n == nil { + return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind) + } + if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil { + return fmt.Errorf("GobDecode BoundedSumFloat64: %v", err) + } + if err := checks.CheckLInfSensitivity(enc.LInfSensitivity); err != nil { + return fmt.Errorf("GobDecode BoundedSumFloat64: %v", err) + } + if err := checks.CheckBoundsFloat64IgnoreOverflows(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("GobDecode BoundedSumFloat64: %v", err) + } *bs = BoundedSumFloat64{ epsilon: enc.Epsilon, delta: enc.Delta, @@ -701,7 +724,7 @@ func (bs *BoundedSumFloat64) GobDecode(data []byte) error { lower: enc.Lower, upper: enc.Upper, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: n, sum: enc.Sum, state: defaultState, }