From 04de9bfb259c64fc9fa2b223f040b459e009f966 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?coffee=20=E2=98=95=EF=B8=8F?= Date: Sun, 27 Jul 2025 21:32:51 -0400 Subject: [PATCH 1/2] feat: add LibAlias (implementation of AJ Walker Alias method) for random probability calculation --- .gitmodules | 3 + lib/solady | 1 + src/utils/LibAlias.sol | 141 ++++++++++++++++++++++++ test/utils/LibAlias.t.sol | 226 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 371 insertions(+) create mode 160000 lib/solady create mode 100644 src/utils/LibAlias.sol create mode 100644 test/utils/LibAlias.t.sol diff --git a/.gitmodules b/.gitmodules index 8efacd6..2fc5b55 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "lib/zksync-contracts"] path = lib/zksync-contracts url = https://github.com/matter-labs/zksync-contracts +[submodule "lib/solady"] + path = lib/solady + url = https://github.com/vectorized/solady diff --git a/lib/solady b/lib/solady new file mode 160000 index 0000000..a5bb996 --- /dev/null +++ b/lib/solady @@ -0,0 +1 @@ +Subproject commit a5bb996e91aae5b0c068087af7594d92068b12f1 diff --git a/src/utils/LibAlias.sol b/src/utils/LibAlias.sol new file mode 100644 index 0000000..faacf5f --- /dev/null +++ b/src/utils/LibAlias.sol @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {LibMap} from "solady/utils/LibMap.sol"; + +/// @title LibAlias +/// @notice A library implementing the alias method for efficient weighted random sampling. +/// @dev The alias method allows O(1) sampling from a discrete probability distribution after O(n) preprocessing. +/// @dev This is particularly useful for applications requiring frequent random sampling from a fixed distribution, +/// such as randomized algorithms, simulations, and games. +/// @dev Implementation uses two arrays: probabilities and aliases, enabling constant-time sampling. +/// @dev Limited size of alias table for gas efficiency and minimized storage writes; no individual weight +/// may be greater than (2^32 - 1) / size, and the total set of probabilities must not exceed 2^32 - 1 items. +library LibAlias { + using LibMap for LibMap.Uint32Map; + + /// @notice Thrown when input arrays have mismatched lengths. + error ArrayLengthMismatch(); + + /// @notice Main data structure for the alias method. + /// @dev Stores the preprocessed alias table for efficient sampling. + /// @param size Number of elements in the distribution. + /// @param totalWeight Sum of all weights in the original distribution. + /// @param probs Scaled probabilities for each index (probability * size) + /// @param aliases Alias indices for rejection sampling + struct Alias { + uint32 size; + uint32 totalWeight; + LibMap.Uint32Map probs; + LibMap.Uint32Map aliases; + } + + /// @notice Temporary state used during alias table construction + /// @dev Used internally by the fill function for organizing probabilities + /// @param index Original index in the weight array + /// @param scaledProb Probability scaled by the distribution size + struct WorkingState { + uint32 index; + uint32 scaledProb; + } + + /// @notice Select a random index from the alias table + /// @dev Uses the alias method for O(1) sampling. Splits the seed into two parts: + /// @dev - Lower 128 bits for column selection + /// @dev - Upper 128 bits for probability comparison within the column + /// @param self The alias table to sample from + /// @param seed Random seed for selection (256-bit value) + /// @return The selected index (0 to size-1) + function select(Alias storage self, uint256 seed) internal view returns (uint32) { + uint32 cachedSize = self.size; + uint32 colIndex = uint32(uint128(seed) % cachedSize); + uint256 prob = ((seed >> 128) % self.totalWeight) * cachedSize; + + uint32 col = self.probs.get(colIndex); + if (prob < col) { + return colIndex; + } else { + return self.aliases.get(colIndex); + } + } + + /// @notice Set the alias table from preprocessed data + /// @dev Use this when you have pre-calculated probabilities and aliases + /// @dev The probs array should contain scaled probabilities (original_prob * size) + /// @dev The aliases array should contain the alias indices for each column + /// @param self The alias table to populate + /// @param totalWeight Sum of all original weights + /// @param probs Array of scaled probabilities + /// @param aliases Array of alias indices + function setRaw(Alias storage self, uint32 totalWeight, uint256[] calldata probs, uint256[] calldata aliases) + internal + { + uint256 length = probs.length; + if (length != aliases.length) { + revert ArrayLengthMismatch(); + } + for (uint32 i; i < length; ++i) { + self.probs.map[i] = probs[i]; + self.aliases.map[i] = aliases[i]; + } + self.totalWeight = totalWeight; + self.size = uint32(length); + } + + /// @notice Construct alias table from weight distribution + /// @dev Implements the alias method construction algorithm with O(n) complexity + /// @dev Automatically handles the probability scaling and alias assignment + /// @dev Uses the "small" and "large" probability redistribution technique + /// @param self The alias table to populate + /// @param weights Array of weights for each outcome (must be non-zero length) + function fill(Alias storage self, uint16[] memory weights) internal { + uint32 size = uint32(weights.length); + WorkingState[] memory smallProbs = new WorkingState[](size); + WorkingState[] memory largeProbs = new WorkingState[](size); + uint256 smallCount; + uint256 largeCount; + uint32 totalWeight; + + for (uint32 i; i < size; ++i) { + totalWeight += weights[i]; + } + + for (uint32 i; i < size; ++i) { + uint32 scaledProb = weights[i] * size; + self.probs.set(i, scaledProb); + self.aliases.set(i, i); + + WorkingState memory workingState = WorkingState({scaledProb: scaledProb, index: i}); + if (scaledProb < totalWeight) { + smallProbs[smallCount++] = workingState; + } else { + largeProbs[largeCount++] = workingState; + } + } + + while (smallCount > 0 && largeCount > 0) { + WorkingState memory small = smallProbs[--smallCount]; + WorkingState memory large = largeProbs[--largeCount]; + + self.aliases.set(small.index, large.index); + self.probs.set(small.index, small.scaledProb); + + large.scaledProb -= (totalWeight - small.scaledProb); + + if (large.scaledProb < totalWeight) { + smallProbs[smallCount++] = large; + } else { + largeCount++; + } + } + + while (largeCount > 0) { + WorkingState memory large = largeProbs[--largeCount]; + self.probs.set(large.index, totalWeight); + self.aliases.set(large.index, large.index); + } + + self.size = size; + self.totalWeight = totalWeight; + } +} diff --git a/test/utils/LibAlias.t.sol b/test/utils/LibAlias.t.sol new file mode 100644 index 0000000..8d4369b --- /dev/null +++ b/test/utils/LibAlias.t.sol @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.27; + +import {Test} from "forge-std/Test.sol"; +import {LibAlias} from "../../src/utils/LibAlias.sol"; +import {LibMap} from "solady/utils/LibMap.sol"; +import {console} from "forge-std/console.sol"; + +contract LibAliasTest is Test { + using LibAlias for LibAlias.Alias; + using LibMap for LibMap.Uint32Map; + + LibAlias.Alias _alias; + + uint32[] public fixtureSize = [10, 100, 1000]; + + function test_fill_4_items() public { + vm.pauseGasMetering(); + uint16[] memory probs = new uint16[](4); + probs[0] = 270; // 2.7% + probs[1] = 950; // 9.5% + probs[2] = 2280; // 22.8% + probs[3] = 6500; // 65% + + vm.resumeGasMetering(); + + _alias.fill(probs); + + vm.pauseGasMetering(); + + assertEq(_alias.size, 4); + assertEq(_alias.probs.get(0), 1080); + assertEq(_alias.probs.get(1), 3800); + assertEq(_alias.probs.get(2), 9120); + assertEq(_alias.probs.get(3), 10000); + assertEq(_alias.aliases.get(0), 3); + assertEq(_alias.aliases.get(1), 3); + assertEq(_alias.aliases.get(2), 3); + assertEq(_alias.aliases.get(3), 3); + + vm.resumeGasMetering(); + } + + function test_select_4_items() public { + uint16[] memory probs = new uint16[](4); + probs[0] = 270; // 2.7% + probs[1] = 950; // 9.5% + probs[2] = 2280; // 22.8% + probs[3] = 6500; // 65% + + _alias.fill(probs); + + // Test direct selections + assertEq(_alias.select(0x8700000000000000000000000000000000), 0); + assertEq(_alias.select(0x1db00000000000000000000000000000001), 1); + assertEq(_alias.select(0x47400000000000000000000000000000002), 2); + assertEq(_alias.select(0x4e200000000000000000000000000000003), 3); + + // Test alias selections + assertEq(_alias.select(0x4f600000000000000000000000000000000), 3); + assertEq(_alias.select(0x79e00000000000000000000000000000001), 3); + assertEq(_alias.select(0xcd000000000000000000000000000000002), 3); + assertEq(_alias.select(0xdac00000000000000000000000000000003), 3); + } + + function test_fill_4_items_notOneSummed() public { + vm.pauseGasMetering(); + uint16[] memory probs = new uint16[](4); + probs[0] = 540; // 5.4% + probs[1] = 1900; // 19% + probs[2] = 4560; // 45.6% + probs[3] = 13000; // 130% + + vm.resumeGasMetering(); + + _alias.fill(probs); + + vm.pauseGasMetering(); + + assertEq(_alias.totalWeight, 20000); + + assertEq(_alias.size, 4); + assertEq(_alias.probs.get(0), 2160); + assertEq(_alias.probs.get(1), 7600); + assertEq(_alias.probs.get(2), 18240); + assertEq(_alias.probs.get(3), 20000); + assertEq(_alias.aliases.get(0), 3); + assertEq(_alias.aliases.get(1), 3); + assertEq(_alias.aliases.get(2), 3); + assertEq(_alias.aliases.get(3), 3); + + vm.resumeGasMetering(); + } + + function test_fill_5_items_unbalanced() public { + uint16[] memory probs = new uint16[](5); + probs[0] = 1000; + probs[1] = 2000; + probs[2] = 2000; + probs[3] = 4500; + probs[4] = 5500; + + _alias.fill(probs); + } + + function test_fill_20_items() public { + uint16[] memory probs = new uint16[](20); + probs[0] = 50; // 0.5% + probs[1] = 100; // 1% + probs[2] = 150; // 1.5% + probs[3] = 200; // 2% + probs[4] = 250; // 2.5% + probs[5] = 300; // 3% + probs[6] = 350; // 3.5% + probs[7] = 400; // 4% + probs[8] = 450; // 4.5% + probs[9] = 500; // 5% + probs[10] = 550; // 5.5% + probs[11] = 600; // 6% + probs[12] = 650; // 6.5% + probs[13] = 700; // 7% + probs[14] = 750; // 7.5% + probs[15] = 800; // 8% + probs[16] = 850; // 8.5% + probs[17] = 900; // 9% + probs[18] = 950; // 9.5% + probs[19] = 500; // 5% + + _alias.fill(probs); + + assertEq(_alias.size, 20); + assertEq(_alias.probs.get(0), 1000); + assertEq(_alias.probs.get(1), 2000); + assertEq(_alias.probs.get(2), 3000); + assertEq(_alias.probs.get(3), 4000); + assertEq(_alias.probs.get(4), 5000); + assertEq(_alias.probs.get(5), 6000); + assertEq(_alias.probs.get(6), 7000); + assertEq(_alias.probs.get(7), 8000); + assertEq(_alias.probs.get(8), 9000); + assertEq(_alias.probs.get(9), 10000); + assertEq(_alias.probs.get(10), 10000); + assertEq(_alias.probs.get(11), 9000); + assertEq(_alias.probs.get(12), 7000); + assertEq(_alias.probs.get(13), 4000); + assertEq(_alias.probs.get(14), 9000); + assertEq(_alias.probs.get(15), 4000); + assertEq(_alias.probs.get(16), 6000); + assertEq(_alias.probs.get(17), 6000); + assertEq(_alias.probs.get(18), 9000); + assertEq(_alias.probs.get(19), 9000); + assertEq(_alias.aliases.get(0), 13); + assertEq(_alias.aliases.get(1), 15); + assertEq(_alias.aliases.get(2), 16); + assertEq(_alias.aliases.get(3), 17); + assertEq(_alias.aliases.get(4), 17); + assertEq(_alias.aliases.get(5), 18); + assertEq(_alias.aliases.get(6), 18); + assertEq(_alias.aliases.get(7), 18); + assertEq(_alias.aliases.get(8), 19); + assertEq(_alias.aliases.get(9), 9); + assertEq(_alias.aliases.get(10), 10); + assertEq(_alias.aliases.get(11), 10); + assertEq(_alias.aliases.get(12), 11); + assertEq(_alias.aliases.get(13), 12); + assertEq(_alias.aliases.get(14), 13); + assertEq(_alias.aliases.get(15), 14); + assertEq(_alias.aliases.get(16), 15); + assertEq(_alias.aliases.get(17), 16); + assertEq(_alias.aliases.get(18), 17); + assertEq(_alias.aliases.get(19), 18); + } + + function test_fill_edge_case() public { + uint16[] memory weights = new uint16[](20); + weights[0] = 1; + weights[1] = 1; + weights[2] = 1; + weights[3] = 1; + weights[4] = 1; + weights[5] = 1; + weights[6] = 1; + weights[7] = 1; + weights[8] = 1; + weights[9] = 1; + weights[10] = 1; + weights[11] = 1; + weights[12] = 1; + weights[13] = 1; + weights[14] = 1; + weights[15] = 1; + weights[16] = 1; + weights[17] = 1; + weights[18] = 1; + weights[19] = 65535; + + _alias.fill(weights); + + assertEq(_alias.size, 20); + } + + function testFuzz_fill_items(uint16[] memory weights) public { + _alias.fill(weights); + + uint32 totalWeight = 0; + for (uint32 i; i < weights.length; ++i) { + totalWeight += weights[i]; + } + + assertEq(_alias.size, weights.length); + assertEq(_alias.totalWeight, totalWeight); + } + + function testFuzz_fill_items(uint32 size) public { + vm.assume(size < 20000); + + uint16[] memory weights = new uint16[](size); + for (uint32 i; i < size; ++i) { + weights[i] = uint16(i); + } + + _alias.fill(weights); + + assertEq(_alias.size, size); + } +} From 6a97d799f2ff59e404a0ed30b3cb7d0db2c2850a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?coffee=20=E2=98=95=EF=B8=8F?= Date: Sun, 27 Jul 2025 21:33:36 -0400 Subject: [PATCH 2/2] adjust pragma --- src/utils/LibAlias.sol | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/LibAlias.sol b/src/utils/LibAlias.sol index faacf5f..03a76b4 100644 --- a/src/utils/LibAlias.sol +++ b/src/utils/LibAlias.sol @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -pragma solidity ^0.8.24; +pragma solidity ^0.8.4; import {LibMap} from "solady/utils/LibMap.sol";