Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion go/dpagg/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,23 @@ func (c *Count) GobDecode(data []byte) error {
if err != nil {
return fmt.Errorf("couldn't decode Count from bytes")
}
n := noise.ToNoise(enc.NoiseKind)
if n == nil {
return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind)
}
if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil {
return fmt.Errorf("GobDecode Count: %v", err)
}
if enc.LInfSensitivity <= 0 {
return fmt.Errorf("GobDecode Count: lInfSensitivity must be positive, got %d", enc.LInfSensitivity)
}
*c = Count{
epsilon: enc.Epsilon,
delta: enc.Delta,
l0Sensitivity: enc.L0Sensitivity,
lInfSensitivity: enc.LInfSensitivity,
noiseKind: enc.NoiseKind,
Noise: noise.ToNoise(enc.NoiseKind),
Noise: n,
count: enc.Count,
state: defaultState,
}
Expand Down
153 changes: 153 additions & 0 deletions go/dpagg/gobdecode_validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

package dpagg

import (
"testing"

"github.com/google/differential-privacy/go/v4/noise"
)

func TestCountGobDecodeRejectsInvalidNoiseKind(t *testing.T) {
enc := encodableCount{
Epsilon: 1.0,
Delta: 0,
L0Sensitivity: 1,
LInfSensitivity: 1,
NoiseKind: 99, // invalid
Count: 0,
}
data, err := encode(&enc)
if err != nil {
t.Fatalf("encode failed: %v", err)
}
var c Count
if err := c.GobDecode(data); err == nil {
t.Error("GobDecode: expected error for invalid NoiseKind, got nil")
}
}

func TestCountGobDecodeRejectsInvalidL0Sensitivity(t *testing.T) {
enc := encodableCount{
Epsilon: 1.0,
Delta: 0,
L0Sensitivity: -1, // invalid
LInfSensitivity: 1,
NoiseKind: noise.LaplaceNoise,
Count: 0,
}
data, err := encode(&enc)
if err != nil {
t.Fatalf("encode failed: %v", err)
}
var c Count
if err := c.GobDecode(data); err == nil {
t.Error("GobDecode: expected error for negative l0Sensitivity, got nil")
}
}

func TestBoundedSumFloat64GobDecodeRejectsInvalidNoiseKind(t *testing.T) {
enc := encodableBoundedSumFloat64{
Epsilon: 1.0,
Delta: 0,
L0Sensitivity: 1,
LInfSensitivity: 1.0,
Lower: 0,
Upper: 10,
NoiseKind: 99, // invalid
Sum: 0,
}
data, err := encode(&enc)
if err != nil {
t.Fatalf("encode failed: %v", err)
}
var bs BoundedSumFloat64
if err := bs.GobDecode(data); err == nil {
t.Error("GobDecode: expected error for invalid NoiseKind, got nil")
}
}

func TestBoundedSumFloat64GobDecodeRejectsInvalidLInfSensitivity(t *testing.T) {
enc := encodableBoundedSumFloat64{
Epsilon: 1.0,
Delta: 0,
L0Sensitivity: 1,
LInfSensitivity: -1.0, // invalid: must be positive
Lower: 0,
Upper: 10,
NoiseKind: noise.LaplaceNoise,
Sum: 0,
}
data, err := encode(&enc)
if err != nil {
t.Fatalf("encode failed: %v", err)
}
var bs BoundedSumFloat64
if err := bs.GobDecode(data); err == nil {
t.Error("GobDecode: expected error for negative lInfSensitivity, got nil")
}
}

func TestBoundedQuantilesGobDecodeRejectsZeroBranchingFactor(t *testing.T) {
enc := encodableBoundedQuantiles{
Epsilon: 1.0,
Delta: 1e-5,
L0Sensitivity: 1,
LInfSensitivity: 1.0,
TreeHeight: 4,
BranchingFactor: 0, // invalid: causes div-by-zero
Lower: 0,
Upper: 10,
NumLeaves: 0,
LeftmostLeafIndex: 0,
NoiseKind: noise.GaussianNoise,
QuantileTree: make(map[int]int64),
}
data, err := encode(&enc)
if err != nil {
t.Fatalf("encode failed: %v", err)
}
var bq BoundedQuantiles
if err := bq.GobDecode(data); err == nil {
t.Error("GobDecode: expected error for zero branchingFactor, got nil")
}
}

func TestBoundedQuantilesGobDecodeRejectsInvalidNoiseKind(t *testing.T) {
enc := encodableBoundedQuantiles{
Epsilon: 1.0,
Delta: 1e-5,
L0Sensitivity: 1,
LInfSensitivity: 1.0,
TreeHeight: 4,
BranchingFactor: 16,
Lower: 0,
Upper: 10,
NumLeaves: 0,
LeftmostLeafIndex: 0,
NoiseKind: 99, // invalid
QuantileTree: make(map[int]int64),
}
data, err := encode(&enc)
if err != nil {
t.Fatalf("encode failed: %v", err)
}
var bq BoundedQuantiles
if err := bq.GobDecode(data); err == nil {
t.Error("GobDecode: expected error for invalid NoiseKind, got nil")
}
}
21 changes: 20 additions & 1 deletion go/dpagg/quantiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,25 @@ func (bq *BoundedQuantiles) GobDecode(data []byte) error {
if err != nil {
return fmt.Errorf("couldn't decode BoundedQuantiles from bytes")
}
n := noise.ToNoise(enc.NoiseKind)
if n == nil {
return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind)
}
if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil {
return fmt.Errorf("GobDecode BoundedQuantiles: %v", err)
}
if err := checks.CheckLInfSensitivity(enc.LInfSensitivity); err != nil {
return fmt.Errorf("GobDecode BoundedQuantiles: %v", err)
}
if err := checks.CheckBoundsFloat64IgnoreOverflows(enc.Lower, enc.Upper); err != nil {
return fmt.Errorf("GobDecode BoundedQuantiles: %v", err)
}
if err := checks.CheckTreeHeight(enc.TreeHeight); err != nil {
return fmt.Errorf("GobDecode BoundedQuantiles: %v", err)
}
if enc.BranchingFactor < 2 {
return fmt.Errorf("GobDecode BoundedQuantiles: branchingFactor must be at least 2, got %d", enc.BranchingFactor)
}
*bq = BoundedQuantiles{
epsilon: enc.Epsilon,
delta: enc.Delta,
Expand All @@ -455,7 +474,7 @@ func (bq *BoundedQuantiles) GobDecode(data []byte) error {
lower: enc.Lower,
upper: enc.Upper,
noiseKind: enc.NoiseKind,
Noise: noise.ToNoise(enc.NoiseKind),
Noise: n,
numLeaves: enc.NumLeaves,
leftmostLeafIndex: enc.LeftmostLeafIndex,
tree: enc.QuantileTree,
Expand Down
3 changes: 3 additions & 0 deletions go/dpagg/select_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,9 @@ func (s *PreAggSelectPartition) GobDecode(data []byte) error {
if err != nil {
return fmt.Errorf("couldn't decode PreAggSelectPartition from bytes")
}
if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil {
return fmt.Errorf("GobDecode PreAggSelectPartition: %v", err)
}
*s = PreAggSelectPartition{
epsilon: enc.Epsilon,
delta: enc.Delta,
Expand Down
27 changes: 25 additions & 2 deletions go/dpagg/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ func (bs *BoundedSumInt64) GobDecode(data []byte) error {
if err != nil {
return fmt.Errorf("couldn't decode BoundedSumInt64 from bytes")
}
n := noise.ToNoise(enc.NoiseKind)
if n == nil {
return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind)
}
if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil {
return fmt.Errorf("GobDecode BoundedSumInt64: %v", err)
}
if err := checks.CheckBoundsInt64IgnoreOverflows(enc.Lower, enc.Upper); err != nil {
return fmt.Errorf("GobDecode BoundedSumInt64: %v", err)
}
*bs = BoundedSumInt64{
epsilon: enc.Epsilon,
delta: enc.Delta,
Expand All @@ -384,7 +394,7 @@ func (bs *BoundedSumInt64) GobDecode(data []byte) error {
lower: enc.Lower,
upper: enc.Upper,
noiseKind: enc.NoiseKind,
Noise: noise.ToNoise(enc.NoiseKind),
Noise: n,
sum: enc.Sum,
state: defaultState,
}
Expand Down Expand Up @@ -693,6 +703,19 @@ func (bs *BoundedSumFloat64) GobDecode(data []byte) error {
if err != nil {
return fmt.Errorf("couldn't decode BoundedSumFloat64 from bytes")
}
n := noise.ToNoise(enc.NoiseKind)
if n == nil {
return fmt.Errorf("GobDecode: invalid NoiseKind %d", enc.NoiseKind)
}
if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil {
return fmt.Errorf("GobDecode BoundedSumFloat64: %v", err)
}
if err := checks.CheckLInfSensitivity(enc.LInfSensitivity); err != nil {
return fmt.Errorf("GobDecode BoundedSumFloat64: %v", err)
}
if err := checks.CheckBoundsFloat64IgnoreOverflows(enc.Lower, enc.Upper); err != nil {
return fmt.Errorf("GobDecode BoundedSumFloat64: %v", err)
}
*bs = BoundedSumFloat64{
epsilon: enc.Epsilon,
delta: enc.Delta,
Expand All @@ -701,7 +724,7 @@ func (bs *BoundedSumFloat64) GobDecode(data []byte) error {
lower: enc.Lower,
upper: enc.Upper,
noiseKind: enc.NoiseKind,
Noise: noise.ToNoise(enc.NoiseKind),
Noise: n,
sum: enc.Sum,
state: defaultState,
}
Expand Down