diff --git a/go/dpagg/count.go b/go/dpagg/count.go index 0126bfda..674a1512 100644 --- a/go/dpagg/count.go +++ b/go/dpagg/count.go @@ -305,13 +305,17 @@ func (c *Count) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode Count from bytes") } + nObj, err := validateDecodedAggregation(enc.Epsilon, enc.Delta, float64(enc.LInfSensitivity), enc.L0Sensitivity, enc.NoiseKind) + if err != nil { + return fmt.Errorf("couldn't decode Count: %v", err) + } *c = Count{ epsilon: enc.Epsilon, delta: enc.Delta, l0Sensitivity: enc.L0Sensitivity, lInfSensitivity: enc.LInfSensitivity, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: nObj, count: enc.Count, state: defaultState, } diff --git a/go/dpagg/gobdecode_checks.go b/go/dpagg/gobdecode_checks.go new file mode 100644 index 00000000..34f7b736 --- /dev/null +++ b/go/dpagg/gobdecode_checks.go @@ -0,0 +1,41 @@ +package dpagg + +import ( + "fmt" + + "github.com/google/differential-privacy/go/v4/checks" + "github.com/google/differential-privacy/go/v4/noise" +) + +// validateDecodedNoise rejects NoiseKind values that don't correspond to a +// concrete Noise implementation; without this gate, Result() panics on a nil +// receiver after deserializing a tampered payload. +func validateDecodedNoise(kind noise.Kind) (noise.Noise, error) { + n := noise.ToNoise(kind) + if n == nil { + return nil, fmt.Errorf("unsupported NoiseKind value %d", kind) + } + return n, nil +} + +// validateDecodedAggregation mirrors the precondition checks that New*() +// constructors already perform, applied to fields recovered from gob bytes. +func validateDecodedAggregation(epsilon, delta, lInfSensitivity float64, l0Sensitivity int64, kind noise.Kind) (noise.Noise, error) { + n, err := validateDecodedNoise(kind) + if err != nil { + return nil, err + } + if err := checks.CheckEpsilonVeryStrict(epsilon); err != nil { + return nil, err + } + if err := checks.CheckDelta(delta); err != nil { + return nil, err + } + if err := checks.CheckL0Sensitivity(l0Sensitivity); err != nil { + return nil, err + } + if err := checks.CheckLInfSensitivity(lInfSensitivity); err != nil { + return nil, err + } + return n, nil +} diff --git a/go/dpagg/mean.go b/go/dpagg/mean.go index 292b69ea..d67b2b7a 100644 --- a/go/dpagg/mean.go +++ b/go/dpagg/mean.go @@ -379,6 +379,9 @@ func (bm *BoundedMean) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedMean from bytes") } + if err := checks.CheckBoundsFloat64(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("couldn't decode BoundedMean: %v", err) + } *bm = BoundedMean{ lower: enc.Lower, upper: enc.Upper, diff --git a/go/dpagg/quantiles.go b/go/dpagg/quantiles.go index b25dc6c9..501793a6 100644 --- a/go/dpagg/quantiles.go +++ b/go/dpagg/quantiles.go @@ -445,6 +445,19 @@ func (bq *BoundedQuantiles) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedQuantiles from bytes") } + nObj, err := validateDecodedAggregation(enc.Epsilon, enc.Delta, enc.LInfSensitivity, enc.L0Sensitivity, enc.NoiseKind) + if err != nil { + return fmt.Errorf("couldn't decode BoundedQuantiles: %v", err) + } + if err := checks.CheckBoundsFloat64(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("couldn't decode BoundedQuantiles: %v", err) + } + if err := checks.CheckTreeHeight(enc.TreeHeight); err != nil { + return fmt.Errorf("couldn't decode BoundedQuantiles: %v", err) + } + if err := checks.CheckBranchingFactor(enc.BranchingFactor); err != nil { + return fmt.Errorf("couldn't decode BoundedQuantiles: %v", err) + } *bq = BoundedQuantiles{ epsilon: enc.Epsilon, delta: enc.Delta, @@ -455,7 +468,7 @@ func (bq *BoundedQuantiles) GobDecode(data []byte) error { lower: enc.Lower, upper: enc.Upper, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: nObj, numLeaves: enc.NumLeaves, leftmostLeafIndex: enc.LeftmostLeafIndex, tree: enc.QuantileTree, diff --git a/go/dpagg/select_partition.go b/go/dpagg/select_partition.go index 88e85c0a..8ef2dc64 100644 --- a/go/dpagg/select_partition.go +++ b/go/dpagg/select_partition.go @@ -439,6 +439,15 @@ func (s *PreAggSelectPartition) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode PreAggSelectPartition from bytes") } + if err := checks.CheckEpsilonVeryStrict(enc.Epsilon); err != nil { + return fmt.Errorf("couldn't decode PreAggSelectPartition: %v", err) + } + if err := checks.CheckDelta(enc.Delta); err != nil { + return fmt.Errorf("couldn't decode PreAggSelectPartition: %v", err) + } + if err := checks.CheckL0Sensitivity(enc.L0Sensitivity); err != nil { + return fmt.Errorf("couldn't decode PreAggSelectPartition: %v", err) + } *s = PreAggSelectPartition{ epsilon: enc.Epsilon, delta: enc.Delta, diff --git a/go/dpagg/sum.go b/go/dpagg/sum.go index f655e453..50672b3f 100644 --- a/go/dpagg/sum.go +++ b/go/dpagg/sum.go @@ -376,6 +376,13 @@ func (bs *BoundedSumInt64) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedSumInt64 from bytes") } + nObj, err := validateDecodedAggregation(enc.Epsilon, enc.Delta, float64(enc.LInfSensitivity), enc.L0Sensitivity, enc.NoiseKind) + if err != nil { + return fmt.Errorf("couldn't decode BoundedSumInt64: %v", err) + } + if err := checks.CheckBoundsInt64(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("couldn't decode BoundedSumInt64: %v", err) + } *bs = BoundedSumInt64{ epsilon: enc.Epsilon, delta: enc.Delta, @@ -384,7 +391,7 @@ func (bs *BoundedSumInt64) GobDecode(data []byte) error { lower: enc.Lower, upper: enc.Upper, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: nObj, sum: enc.Sum, state: defaultState, } @@ -693,6 +700,13 @@ func (bs *BoundedSumFloat64) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedSumFloat64 from bytes") } + nObj2, err := validateDecodedAggregation(enc.Epsilon, enc.Delta, enc.LInfSensitivity, enc.L0Sensitivity, enc.NoiseKind) + if err != nil { + return fmt.Errorf("couldn't decode BoundedSumFloat64: %v", err) + } + if err := checks.CheckBoundsFloat64(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("couldn't decode BoundedSumFloat64: %v", err) + } *bs = BoundedSumFloat64{ epsilon: enc.Epsilon, delta: enc.Delta, @@ -701,7 +715,7 @@ func (bs *BoundedSumFloat64) GobDecode(data []byte) error { lower: enc.Lower, upper: enc.Upper, noiseKind: enc.NoiseKind, - Noise: noise.ToNoise(enc.NoiseKind), + Noise: nObj2, sum: enc.Sum, state: defaultState, } diff --git a/go/dpagg/variance.go b/go/dpagg/variance.go index 3a6352db..0bb6a75d 100644 --- a/go/dpagg/variance.go +++ b/go/dpagg/variance.go @@ -350,6 +350,9 @@ func (bv *BoundedVariance) GobDecode(data []byte) error { if err != nil { return fmt.Errorf("couldn't decode BoundedVariance from bytes") } + if err := checks.CheckBoundsFloat64(enc.Lower, enc.Upper); err != nil { + return fmt.Errorf("couldn't decode BoundedVariance: %v", err) + } *bv = BoundedVariance{ lower: enc.Lower, upper: enc.Upper,