diff --git a/tests/IceRpc.Ice.Codec.Tests/DictionaryDecodingTests.cs b/tests/IceRpc.Ice.Codec.Tests/DictionaryDecodingTests.cs index 171693c61..80ffdf71e 100644 --- a/tests/IceRpc.Ice.Codec.Tests/DictionaryDecodingTests.cs +++ b/tests/IceRpc.Ice.Codec.Tests/DictionaryDecodingTests.cs @@ -1,6 +1,7 @@ // Copyright (c) ZeroC, Inc. using NUnit.Framework; +using System.Runtime.CompilerServices; using ZeroC.Tests.Common; namespace IceRpc.Ice.Codec.Tests; @@ -34,4 +35,62 @@ public void Decode_dictionary() Assert.That(decoded, Is.EqualTo(expected)); Assert.That(decoder.Consumed, Is.EqualTo(buffer.WrittenMemory.Length)); } + + [TestCase(10)] + [TestCase(50)] + [TestCase(100)] + public void Decode_dictionary_exceeds_max_collection_allocation(int count) + { + // Arrange + var buffer = new MemoryBufferWriter(new byte[count * (Unsafe.SizeOf() + Unsafe.SizeOf()) + 256]); + var encoder = new IceEncoder(buffer); + var dict = Enumerable.Range(0, count).ToDictionary(k => k, v => (long)v); + encoder.EncodeDictionary( + dict, + (ref IceEncoder encoder, int key) => encoder.EncodeInt(key), + (ref IceEncoder encoder, long value) => encoder.EncodeLong(value)); + + int allocationLimit = (count - 1) * (Unsafe.SizeOf() + Unsafe.SizeOf()); + + // Act/Assert + Assert.That( + () => + { + var sut = new IceDecoder(buffer.WrittenMemory, maxCollectionAllocation: allocationLimit); + _ = sut.DecodeDictionary( + count => new Dictionary(count), + (ref IceDecoder decoder) => decoder.DecodeInt(), + (ref IceDecoder decoder) => decoder.DecodeLong()); + }, + Throws.InstanceOf()); + } + + [TestCase(10)] + [TestCase(50)] + [TestCase(100)] + public void Decode_dictionary_within_max_collection_allocation(int count) + { + // Arrange + var buffer = new MemoryBufferWriter(new byte[count * (Unsafe.SizeOf() + Unsafe.SizeOf()) + 256]); + var encoder = new IceEncoder(buffer); + var dict = Enumerable.Range(0, count).ToDictionary(k => k, v => (long)v); + encoder.EncodeDictionary( + dict, + (ref IceEncoder encoder, int key) => encoder.EncodeInt(key), + (ref IceEncoder encoder, long value) => encoder.EncodeLong(value)); + + int allocationLimit = count * (Unsafe.SizeOf() + Unsafe.SizeOf()); + + // Act/Assert + Assert.That( + () => + { + var sut = new IceDecoder(buffer.WrittenMemory, maxCollectionAllocation: allocationLimit); + _ = sut.DecodeDictionary( + count => new Dictionary(count), + (ref IceDecoder decoder) => decoder.DecodeInt(), + (ref IceDecoder decoder) => decoder.DecodeLong()); + }, + Throws.Nothing); + } } diff --git a/tests/IceRpc.Ice.Codec.Tests/SequenceDecodingTests.cs b/tests/IceRpc.Ice.Codec.Tests/SequenceDecodingTests.cs index c9db93009..143bb24f1 100644 --- a/tests/IceRpc.Ice.Codec.Tests/SequenceDecodingTests.cs +++ b/tests/IceRpc.Ice.Codec.Tests/SequenceDecodingTests.cs @@ -1,6 +1,7 @@ // Copyright (c) ZeroC, Inc. using NUnit.Framework; +using System.Runtime.CompilerServices; using ZeroC.Tests.Common; namespace IceRpc.Ice.Codec.Tests; @@ -79,4 +80,52 @@ public void Decode_sequence_with_element_action() Assert.That(decoded, Is.EqualTo(expected)); Assert.That(checkedValues, Is.EqualTo(expected)); } + + [TestCase(10)] + [TestCase(50)] + [TestCase(100)] + public void Decode_sequence_exceeds_max_collection_allocation(int count) + { + // Arrange + var buffer = new MemoryBufferWriter(new byte[count * Unsafe.SizeOf() + 256]); + var encoder = new IceEncoder(buffer); + encoder.EncodeSequence( + Enumerable.Range(0, count), + (ref IceEncoder encoder, int value) => encoder.EncodeInt(value)); + + int allocationLimit = (count - 1) * Unsafe.SizeOf(); + + // Act/Assert + Assert.That( + () => + { + var sut = new IceDecoder(buffer.WrittenMemory, maxCollectionAllocation: allocationLimit); + _ = sut.DecodeSequence((ref IceDecoder decoder) => decoder.DecodeInt()); + }, + Throws.InstanceOf()); + } + + [TestCase(10)] + [TestCase(50)] + [TestCase(100)] + public void Decode_sequence_within_max_collection_allocation(int count) + { + // Arrange + var buffer = new MemoryBufferWriter(new byte[count * Unsafe.SizeOf() + 256]); + var encoder = new IceEncoder(buffer); + encoder.EncodeSequence( + Enumerable.Range(0, count), + (ref IceEncoder encoder, int value) => encoder.EncodeInt(value)); + + int allocationLimit = count * Unsafe.SizeOf(); + + // Act/Assert + Assert.That( + () => + { + var sut = new IceDecoder(buffer.WrittenMemory, maxCollectionAllocation: allocationLimit); + _ = sut.DecodeSequence((ref IceDecoder decoder) => decoder.DecodeInt()); + }, + Throws.Nothing); + } } diff --git a/tests/IceRpc.Ice.Codec.Tests/StringDecodingTests.cs b/tests/IceRpc.Ice.Codec.Tests/StringDecodingTests.cs index 3d63debd3..46b9d80c1 100644 --- a/tests/IceRpc.Ice.Codec.Tests/StringDecodingTests.cs +++ b/tests/IceRpc.Ice.Codec.Tests/StringDecodingTests.cs @@ -1,6 +1,7 @@ // Copyright (c) ZeroC, Inc. using NUnit.Framework; +using System.Runtime.CompilerServices; using ZeroC.Tests.Common; namespace IceRpc.Ice.Codec.Tests; @@ -42,4 +43,50 @@ public void Decode_non_utf8_string_fails() _ = sut.DecodeString(); }, Throws.InstanceOf()); } + + [TestCase(10)] + [TestCase(50)] + [TestCase(100)] + public void Decode_string_exceeds_max_collection_allocation(int length) + { + // Arrange + string testString = new('a', length); + var buffer = new MemoryBufferWriter(new byte[length + 256]); + var encoder = new IceEncoder(buffer); + encoder.EncodeString(testString); + + int allocationLimit = (length - 1) * Unsafe.SizeOf(); + + // Act/Assert + Assert.That( + () => + { + var sut = new IceDecoder(buffer.WrittenMemory, maxCollectionAllocation: allocationLimit); + _ = sut.DecodeString(); + }, + Throws.InstanceOf()); + } + + [TestCase(10)] + [TestCase(50)] + [TestCase(100)] + public void Decode_string_within_max_collection_allocation(int length) + { + // Arrange + string testString = new('a', length); + var buffer = new MemoryBufferWriter(new byte[length + 256]); + var encoder = new IceEncoder(buffer); + encoder.EncodeString(testString); + + int allocationLimit = length * Unsafe.SizeOf(); + + // Act/Assert + Assert.That( + () => + { + var sut = new IceDecoder(buffer.WrittenMemory, maxCollectionAllocation: allocationLimit); + _ = sut.DecodeString(); + }, + Throws.Nothing); + } }