From ddb410df07fa9c718f1cee7d6ebdcb3d85ba12e1 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 11 Apr 2025 19:01:08 -0700 Subject: [PATCH 1/2] mssmt: add tree copy functionality for full and compacted trees This commit introduces a new `Copy` method to both the `FullTree` and `CompactedTree` implementations of the MS-SMT. This method allows copying all key-value pairs from a source tree to a target tree, assuming the target tree is initially empty. The `Copy` method is implemented differently for each tree type: - For `FullTree`, the method recursively traverses the tree, collecting all non-empty leaf nodes along with their keys. It then inserts these leaves into the target tree. - For `CompactedTree`, the method similarly traverses the tree, collecting all non-empty compacted leaf nodes along with their keys. It then inserts these leaves into the target tree. A new test case, `TestTreeCopy`, is added to verify the correctness of the `Copy` method for both tree types, including copying between different tree types (FullTree to CompactedTree and vice versa). The test case generates a set of random leaves, inserts them into a source tree, copies the source tree to a target tree, and then verifies that the target tree contains the same leaves as the source tree. --- mssmt/compacted_tree.go | 130 ++++++++++++++++++++++++++++++++++++++++ mssmt/interface.go | 4 ++ mssmt/tree.go | 111 ++++++++++++++++++++++++++++++++++ mssmt/tree_test.go | 97 ++++++++++++++++++++++++++++++ 4 files changed, 342 insertions(+) diff --git a/mssmt/compacted_tree.go b/mssmt/compacted_tree.go index abd654cc8..51b81ab79 100644 --- a/mssmt/compacted_tree.go +++ b/mssmt/compacted_tree.go @@ -392,3 +392,133 @@ func (t *CompactedTree) MerkleProof(ctx context.Context, key [hashSize]byte) ( return NewProof(proof), nil } + +// collectLeavesRecursive is a recursive helper function that's used to traverse +// down an MS-SMT tree and collect all leaf nodes. It returns a map of leaf +// nodes indexed by their hash. +func collectLeavesRecursive(ctx context.Context, tx TreeStoreViewTx, node Node, + depth int) (map[[hashSize]byte]*LeafNode, error) { + + // Base case: If it's a compacted leaf node. + if compactedLeaf, ok := node.(*CompactedLeafNode); ok { + if compactedLeaf.LeafNode.IsEmpty() { + return make(map[[hashSize]byte]*LeafNode), nil + } + return map[[hashSize]byte]*LeafNode{ + compactedLeaf.Key(): compactedLeaf.LeafNode, + }, nil + } + + // Recursive step: If it's a branch node. + if branchNode, ok := node.(*BranchNode); ok { + // Optimization: if the branch is empty, return early. + if depth < MaxTreeLevels && + IsEqualNode(branchNode, EmptyTree[depth]) { + + return make(map[[hashSize]byte]*LeafNode), nil + } + + // Handle case where depth might exceed EmptyTree bounds if + // logic error exists + if depth >= MaxTreeLevels { + // This shouldn't happen if called correctly, implies a + // leaf. + return nil, fmt.Errorf("invalid depth %d for branch "+ + "node", depth) + } + + left, right, err := tx.GetChildren(depth, branchNode.NodeHash()) + if err != nil { + // If children not found, it might be an empty branch + // implicitly Check if the error indicates "not found" + // or similar Depending on store impl, this might be how + // empty is signaled For now, treat error as fatal. + return nil, fmt.Errorf("error getting children for "+ + "branch %s at depth %d: %w", + branchNode.NodeHash(), depth, err) + } + + leftLeaves, err := collectLeavesRecursive( + ctx, tx, left, depth+1, + ) + if err != nil { + return nil, err + } + + rightLeaves, err := collectLeavesRecursive( + ctx, tx, right, depth+1, + ) + if err != nil { + return nil, err + } + + // Merge the results. + for k, v := range rightLeaves { + // Check for duplicate keys, although this shouldn't + // happen in a valid SMT. + if _, exists := leftLeaves[k]; exists { + return nil, fmt.Errorf("duplicate key %x "+ + "found during leaf collection", k) + } + leftLeaves[k] = v + } + + return leftLeaves, nil + } + + // Handle unexpected node types or implicit empty nodes. If node is nil + // or explicitly an EmptyLeafNode representation + if node == nil || IsEqualNode(node, EmptyLeafNode) { + return make(map[[hashSize]byte]*LeafNode), nil + } + + // Check against EmptyTree branches if possible (requires depth) + if depth < MaxTreeLevels && IsEqualNode(node, EmptyTree[depth]) { + return make(map[[hashSize]byte]*LeafNode), nil + } + + return nil, fmt.Errorf("unexpected node type %T encountered "+ + "during leaf collection at depth %d", node, depth) +} + +// Copy copies all the key-value pairs from the source tree into the target +// tree. +func (t *CompactedTree) Copy(ctx context.Context, targetTree Tree) error { + var leaves map[[hashSize]byte]*LeafNode + err := t.store.View(ctx, func(tx TreeStoreViewTx) error { + root, err := tx.RootNode() + if err != nil { + return fmt.Errorf("error getting root node: %w", err) + } + + // Optimization: If the source tree is empty, there's nothing to + // copy. + if IsEqualNode(root, EmptyTree[0]) { + leaves = make(map[[hashSize]byte]*LeafNode) + return nil + } + + // Start recursive collection from the root at depth 0. + leaves, err = collectLeavesRecursive(ctx, tx, root, 0) + if err != nil { + return fmt.Errorf("error collecting leaves: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + // Insert all found leaves into the target tree. + for key, leaf := range leaves { + // Use the target tree's Insert method. + _, err := targetTree.Insert(ctx, key, leaf) + if err != nil { + return fmt.Errorf("error inserting leaf with key %x "+ + "into target tree: %w", key, err) + } + } + + return nil +} diff --git a/mssmt/interface.go b/mssmt/interface.go index 371fc6b77..a9e3cd315 100644 --- a/mssmt/interface.go +++ b/mssmt/interface.go @@ -30,4 +30,8 @@ type Tree interface { // proof. This is noted by the returned `Proof` containing an empty // leaf. MerkleProof(ctx context.Context, key [hashSize]byte) (*Proof, error) + + // Copy copies all the key-value pairs from the source tree into the + // target tree. + Copy(ctx context.Context, targetTree Tree) error } diff --git a/mssmt/tree.go b/mssmt/tree.go index 764dc72c5..9bb37d34f 100644 --- a/mssmt/tree.go +++ b/mssmt/tree.go @@ -97,6 +97,14 @@ func bitIndex(idx uint8, key *[hashSize]byte) byte { return (byteVal >> (idx % 8)) & 1 } +// setBit returns a copy of the key with the bit at the given depth set to 1. +func setBit(key [hashSize]byte, depth int) [hashSize]byte { + byteIndex := depth / 8 + bitIndex := depth % 8 + key[byteIndex] |= (1 << bitIndex) + return key +} + // iterFunc is a type alias for closures to be invoked at every iteration of // walking through a tree. type iterFunc = func(height int, current, sibling, parent Node) error @@ -333,6 +341,109 @@ func (t *FullTree) MerkleProof(ctx context.Context, key [hashSize]byte) ( return NewProof(proof), nil } +// findLeaves recursively traverses the tree represented by the given node and +// collects all non-empty leaf nodes along with their reconstructed keys. +func findLeaves(ctx context.Context, tx TreeStoreViewTx, node Node, + keyPrefix [hashSize]byte, depth int) (map[[hashSize]byte]*LeafNode, error) { + + // Base case: If it's a leaf node. + if leafNode, ok := node.(*LeafNode); ok { + if leafNode.IsEmpty() { + return make(map[[hashSize]byte]*LeafNode), nil + } + return map[[hashSize]byte]*LeafNode{keyPrefix: leafNode}, nil + } + + // Recursive step: If it's a branch node. + if branchNode, ok := node.(*BranchNode); ok { + // Optimization: if the branch is empty, return early. + if IsEqualNode(branchNode, EmptyTree[depth]) { + return make(map[[hashSize]byte]*LeafNode), nil + } + + left, right, err := tx.GetChildren(depth, branchNode.NodeHash()) + if err != nil { + return nil, fmt.Errorf("error getting children for "+ + "branch %s at depth %d: %w", + branchNode.NodeHash(), depth, err) + } + + // Recursively find leaves in the left subtree. The key prefix + // remains the same as the 0 bit is implicitly handled by the + // initial keyPrefix state. + leftLeaves, err := findLeaves( + ctx, tx, left, keyPrefix, depth+1, + ) + if err != nil { + return nil, err + } + + // Recursively find leaves in the right subtree. Set the bit + // corresponding to the current depth in the key prefix. + rightKeyPrefix := setBit(keyPrefix, depth) + + rightLeaves, err := findLeaves( + ctx, tx, right, rightKeyPrefix, depth+1, + ) + if err != nil { + return nil, err + } + + // Merge the results. + for k, v := range rightLeaves { + leftLeaves[k] = v + } + return leftLeaves, nil + } + + // Handle unexpected node types. + return nil, fmt.Errorf("unexpected node type %T encountered "+ + "during leaf collection", node) +} + +// Copy copies all the key-value pairs from the source tree into the target +// tree. +func (t *FullTree) Copy(ctx context.Context, targetTree Tree) error { + var leaves map[[hashSize]byte]*LeafNode + err := t.store.View(ctx, func(tx TreeStoreViewTx) error { + root, err := tx.RootNode() + if err != nil { + return fmt.Errorf("error getting root node: %w", err) + } + + // Optimization: If the source tree is empty, there's nothing + // to copy. + if IsEqualNode(root, EmptyTree[0]) { + leaves = make(map[[hashSize]byte]*LeafNode) + return nil + } + + leaves, err = findLeaves(ctx, tx, root, [hashSize]byte{}, 0) + if err != nil { + return fmt.Errorf("error finding leaves: %w", err) + } + return nil + }) + if err != nil { + return err + } + + // Insert all found leaves into the target tree. We assume the target + // tree handles batching or individual inserts efficiently. + for key, leaf := range leaves { + // Use the target tree's Insert method. We ignore the returned + // tree as we are modifying the targetTree in place via its + // store. + _, err := targetTree.Insert(ctx, key, leaf) + if err != nil { + return fmt.Errorf("error inserting leaf with key %x "+ + "into target tree: %w", key, err) + } + } + + return nil +} + // VerifyMerkleProof determines whether a merkle proof for the leaf found at the // given key is valid. func VerifyMerkleProof(key [hashSize]byte, leaf *LeafNode, proof *Proof, diff --git a/mssmt/tree_test.go b/mssmt/tree_test.go index 939571504..d1e37de0e 100644 --- a/mssmt/tree_test.go +++ b/mssmt/tree_test.go @@ -822,6 +822,103 @@ func TestBIPTestVectors(t *testing.T) { } } +// TestTreeCopy tests the Copy method for both FullTree and CompactedTree, +// including copying between different tree types. +func TestTreeCopy(t *testing.T) { + t.Parallel() + + leaves := randTree(50) // Use a smaller number for faster testing + + // Prepare source trees (Full and Compacted) + ctx := context.Background() + sourceFullStore := mssmt.NewDefaultStore() + sourceFullTree := mssmt.NewFullTree(sourceFullStore) + sourceCompactedStore := mssmt.NewDefaultStore() + sourceCompactedTree := mssmt.NewCompactedTree(sourceCompactedStore) + + for _, item := range leaves { + _, err := sourceFullTree.Insert(ctx, item.key, item.leaf) + require.NoError(t, err) + _, err = sourceCompactedTree.Insert(ctx, item.key, item.leaf) + require.NoError(t, err) + } + + sourceFullRoot, err := sourceFullTree.Root(ctx) + require.NoError(t, err) + sourceCompactedRoot, err := sourceCompactedTree.Root(ctx) + require.NoError(t, err) + require.True(t, mssmt.IsEqualNode(sourceFullRoot, sourceCompactedRoot)) + + // Define test cases + testCases := []struct { + name string + sourceTree mssmt.Tree + makeTarget func() mssmt.Tree + }{ + { + name: "Full -> Full", + sourceTree: sourceFullTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewFullTree(mssmt.NewDefaultStore()) + }, + }, + { + name: "Full -> Compacted", + sourceTree: sourceFullTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewCompactedTree(mssmt.NewDefaultStore()) + }, + }, + { + name: "Compacted -> Full", + sourceTree: sourceCompactedTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewFullTree(mssmt.NewDefaultStore()) + }, + }, + { + name: "Compacted -> Compacted", + sourceTree: sourceCompactedTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewCompactedTree(mssmt.NewDefaultStore()) + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + targetTree := tc.makeTarget() + + // Perform the copy + err := tc.sourceTree.Copy(ctx, targetTree) + require.NoError(t, err) + + // Verify the target tree root + targetRoot, err := targetTree.Root(ctx) + require.NoError(t, err) + require.True(t, mssmt.IsEqualNode(sourceFullRoot, targetRoot), + "Root mismatch after copy") + + // Verify individual leaves in the target tree + for _, item := range leaves { + targetLeaf, err := targetTree.Get(ctx, item.key) + require.NoError(t, err) + require.Equal(t, item.leaf, targetLeaf, + "Leaf mismatch for key %x", item.key) + } + + // Verify a non-existent key is still empty + emptyLeaf, err := targetTree.Get(ctx, test.RandHash()) + require.NoError(t, err) + require.True(t, emptyLeaf.IsEmpty(), "Non-existent key found") + }) + } +} + + // runBIPTestVector runs the tests in a single BIP test vector file. func runBIPTestVector(t *testing.T, testVectors *mssmt.TestVectors) { for _, validCase := range testVectors.ValidTestCases { From bd32d61a384c1f0db0f271c9457ffb3daaeff26c Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 11 Apr 2025 19:15:49 -0700 Subject: [PATCH 2/2] mssmt: add InsertMany method to full and compacted trees This commit introduces the InsertMany method to both the FullTree and CompactedTree implementations of the MS-SMT. This method allows for the insertion of multiple leaf nodes in a single database transaction, improving efficiency when adding multiple leaves at once. The InsertMany method is added to the Tree interface and implemented in both FullTree and CompactedTree. The implementation includes sum overflow checks before each insertion and updates the root within the transaction for consistency. A new test case, TestInsertMany, is added to verify the functionality of the InsertMany method in both FullTree and CompactedTree. The test inserts a random set of leaves using InsertMany and verifies the resulting root and retrieved leaves. The Copy method in both FullTree and CompactedTree is updated to use InsertMany for efficiency when copying leaves to the target tree. --- mssmt/compacted_tree.go | 72 ++++++++++++++++-- mssmt/interface.go | 5 ++ mssmt/tree.go | 75 ++++++++++++++++--- mssmt/tree_test.go | 159 ++++++++++++++++++++++++++++++++++++---- 4 files changed, 279 insertions(+), 32 deletions(-) diff --git a/mssmt/compacted_tree.go b/mssmt/compacted_tree.go index 51b81ab79..84246e502 100644 --- a/mssmt/compacted_tree.go +++ b/mssmt/compacted_tree.go @@ -510,15 +510,73 @@ func (t *CompactedTree) Copy(ctx context.Context, targetTree Tree) error { return err } - // Insert all found leaves into the target tree. - for key, leaf := range leaves { - // Use the target tree's Insert method. - _, err := targetTree.Insert(ctx, key, leaf) + // Insert all found leaves into the target tree using InsertMany for + // efficiency. + _, err = targetTree.InsertMany(ctx, leaves) + if err != nil { + return fmt.Errorf("error inserting leaves into "+ + "target tree: %w", err) + } + + return nil +} + +// InsertMany inserts multiple leaf nodes provided in the leaves map within a +// single database transaction. +func (t *CompactedTree) InsertMany(ctx context.Context, + leaves map[[hashSize]byte]*LeafNode) (Tree, error) { + + if len(leaves) == 0 { + return t, nil + } + + dbErr := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error { + currentRoot, err := tx.RootNode() if err != nil { - return fmt.Errorf("error inserting leaf with key %x "+ - "into target tree: %w", key, err) + return err } + rootBranch := currentRoot.(*BranchNode) + + for key, leaf := range leaves { + // Check for potential sum overflow before each + // insertion. + sumRoot := rootBranch.NodeSum() + sumLeaf := leaf.NodeSum() + err = CheckSumOverflowUint64(sumRoot, sumLeaf) + if err != nil { + return fmt.Errorf("compact tree leaf insert "+ + "sum overflow, root: %d, leaf: %d; %w", + sumRoot, sumLeaf, err) + } + + // Insert the leaf using the internal helper. + newRoot, err := t.insert( + tx, &key, 0, rootBranch, leaf, + ) + if err != nil { + return fmt.Errorf("error inserting leaf "+ + "with key %x: %w", key, err) + } + rootBranch = newRoot + + // Update the root within the transaction for + // consistency, even though the insert logic passes the + // root explicitly. + err = tx.UpdateRoot(rootBranch) + if err != nil { + return fmt.Errorf("error updating root "+ + "during InsertMany: %w", err) + } + } + + // The root is already updated by the last iteration of the + // loop. No final update needed here, but returning nil error + // signals success. + return nil + }) + if dbErr != nil { + return nil, dbErr } - return nil + return t, nil } diff --git a/mssmt/interface.go b/mssmt/interface.go index a9e3cd315..bf3759f0c 100644 --- a/mssmt/interface.go +++ b/mssmt/interface.go @@ -31,6 +31,11 @@ type Tree interface { // leaf. MerkleProof(ctx context.Context, key [hashSize]byte) (*Proof, error) + // InsertMany inserts multiple leaf nodes provided in the leaves map + // within a single database transaction. + InsertMany(ctx context.Context, leaves map[[hashSize]byte]*LeafNode) ( + Tree, error) + // Copy copies all the key-value pairs from the source tree into the // target tree. Copy(ctx context.Context, targetTree Tree) error diff --git a/mssmt/tree.go b/mssmt/tree.go index 9bb37d34f..577d15e5e 100644 --- a/mssmt/tree.go +++ b/mssmt/tree.go @@ -344,7 +344,8 @@ func (t *FullTree) MerkleProof(ctx context.Context, key [hashSize]byte) ( // findLeaves recursively traverses the tree represented by the given node and // collects all non-empty leaf nodes along with their reconstructed keys. func findLeaves(ctx context.Context, tx TreeStoreViewTx, node Node, - keyPrefix [hashSize]byte, depth int) (map[[hashSize]byte]*LeafNode, error) { + keyPrefix [hashSize]byte, + depth int) (map[[hashSize]byte]*LeafNode, error) { // Base case: If it's a leaf node. if leafNode, ok := node.(*LeafNode); ok { @@ -428,20 +429,72 @@ func (t *FullTree) Copy(ctx context.Context, targetTree Tree) error { return err } - // Insert all found leaves into the target tree. We assume the target - // tree handles batching or individual inserts efficiently. - for key, leaf := range leaves { - // Use the target tree's Insert method. We ignore the returned - // tree as we are modifying the targetTree in place via its - // store. - _, err := targetTree.Insert(ctx, key, leaf) + // Insert all found leaves into the target tree using InsertMany for + // efficiency. + _, err = targetTree.InsertMany(ctx, leaves) + if err != nil { + return fmt.Errorf("error inserting leaves into target "+ + "tree: %w", err) + } + + return nil +} + +// InsertMany inserts multiple leaf nodes provided in the leaves map within a +// single database transaction. +func (t *FullTree) InsertMany(ctx context.Context, + leaves map[[hashSize]byte]*LeafNode) (Tree, error) { + + if len(leaves) == 0 { + return t, nil + } + + err := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error { + currentRoot, err := tx.RootNode() if err != nil { - return fmt.Errorf("error inserting leaf with key %x "+ - "into target tree: %w", key, err) + return err } + rootBranch := currentRoot.(*BranchNode) + + for key, leaf := range leaves { + // Check for potential sum overflow before each + // insertion. + sumRoot := rootBranch.NodeSum() + sumLeaf := leaf.NodeSum() + err = CheckSumOverflowUint64(sumRoot, sumLeaf) + if err != nil { + return fmt.Errorf("full tree leaf insert sum "+ + "overflow, root: %d, leaf: %d; %w", + sumRoot, sumLeaf, err) + } + + // Insert the leaf using the internal helper. + newRoot, err := t.insert(tx, &key, leaf) + if err != nil { + return fmt.Errorf("error inserting leaf "+ + "with key %x: %w", key, err) + } + rootBranch = newRoot + + // Update the root within the transaction so subsequent + // inserts in this batch read the correct state. + err = tx.UpdateRoot(rootBranch) + if err != nil { + return fmt.Errorf("error updating root "+ + "during InsertMany: %w", err) + } + } + + // The root is already updated by the last iteration of the + // loop. No final update needed here, but returning nil error + // signals success. + return nil + }) + if err != nil { + return nil, err } - return nil + return t, nil } // VerifyMerkleProof determines whether a merkle proof for the leaf found at the diff --git a/mssmt/tree_test.go b/mssmt/tree_test.go index d1e37de0e..13c7f8791 100644 --- a/mssmt/tree_test.go +++ b/mssmt/tree_test.go @@ -714,7 +714,9 @@ func testMerkleProof(t *testing.T, tree mssmt.Tree, leaves []treeLeaf) { )) } -func testProofEquality(t *testing.T, tree1, tree2 mssmt.Tree, leaves []treeLeaf) { +func testProofEquality(t *testing.T, tree1, tree2 mssmt.Tree, + leaves []treeLeaf) { + assertEqualProof := func(proof1, proof2 *mssmt.Proof) { t.Helper() @@ -849,6 +851,16 @@ func TestTreeCopy(t *testing.T) { require.NoError(t, err) require.True(t, mssmt.IsEqualNode(sourceFullRoot, sourceCompactedRoot)) + // Define some leaves to pre-populate the target tree. + initialTargetLeaves := []treeLeaf{ + {key: test.RandHash(), leaf: randLeaf()}, + {key: test.RandHash(), leaf: randLeaf()}, + } + initialTargetLeavesMap := make(map[[hashSize]byte]*mssmt.LeafNode) + for _, item := range initialTargetLeaves { + initialTargetLeavesMap[item.key] = item.leaf + } + // Define test cases testCases := []struct { name string @@ -859,28 +871,36 @@ func TestTreeCopy(t *testing.T) { name: "Full -> Full", sourceTree: sourceFullTree, makeTarget: func() mssmt.Tree { - return mssmt.NewFullTree(mssmt.NewDefaultStore()) + return mssmt.NewFullTree( + mssmt.NewDefaultStore(), + ) }, }, { name: "Full -> Compacted", sourceTree: sourceFullTree, makeTarget: func() mssmt.Tree { - return mssmt.NewCompactedTree(mssmt.NewDefaultStore()) + return mssmt.NewCompactedTree( + mssmt.NewDefaultStore(), + ) }, }, { name: "Compacted -> Full", sourceTree: sourceCompactedTree, makeTarget: func() mssmt.Tree { - return mssmt.NewFullTree(mssmt.NewDefaultStore()) + return mssmt.NewFullTree( + mssmt.NewDefaultStore(), + ) }, }, { name: "Compacted -> Compacted", sourceTree: sourceCompactedTree, makeTarget: func() mssmt.Tree { - return mssmt.NewCompactedTree(mssmt.NewDefaultStore()) + return mssmt.NewCompactedTree( + mssmt.NewDefaultStore(), + ) }, }, } @@ -892,32 +912,143 @@ func TestTreeCopy(t *testing.T) { targetTree := tc.makeTarget() - // Perform the copy - err := tc.sourceTree.Copy(ctx, targetTree) + // Pre-populate the target tree. + _, err := targetTree.InsertMany( + ctx, initialTargetLeavesMap, + ) require.NoError(t, err) - // Verify the target tree root + // Calculate the expected root after combining initial + // and source leaves. + expectedStateStore := mssmt.NewDefaultStore() + expectedStateTree := mssmt.NewFullTree( + expectedStateStore, + ) + _, err = expectedStateTree.InsertMany( + ctx, initialTargetLeavesMap, + ) + require.NoError(t, err) + sourceLeavesMap := make( + map[[hashSize]byte]*mssmt.LeafNode, + ) + for _, item := range leaves { + sourceLeavesMap[item.key] = item.leaf + } + _, err = expectedStateTree.InsertMany( + ctx, sourceLeavesMap, + ) + require.NoError(t, err) + expectedRoot, err := expectedStateTree.Root(ctx) + require.NoError(t, err) + + // Actually perform the copy. + err = tc.sourceTree.Copy(ctx, targetTree) + require.NoError(t, err) + + // Verify the target tree root matches the expected + // combined root. targetRoot, err := targetTree.Root(ctx) require.NoError(t, err) - require.True(t, mssmt.IsEqualNode(sourceFullRoot, targetRoot), - "Root mismatch after copy") + require.True(t, + mssmt.IsEqualNode(expectedRoot, targetRoot), + "root mismatch after copy to non-empty target", + ) - // Verify individual leaves in the target tree - for _, item := range leaves { + // Verify individual leaves (both initial and copied) in + // the target tree + allExpectedLeaves := append([]treeLeaf{}, leaves...) + allExpectedLeaves = append( + allExpectedLeaves, initialTargetLeaves..., + ) + for _, item := range allExpectedLeaves { targetLeaf, err := targetTree.Get(ctx, item.key) require.NoError(t, err) require.Equal(t, item.leaf, targetLeaf, - "Leaf mismatch for key %x", item.key) + "leaf mismatch for key %x", item.key) } // Verify a non-existent key is still empty emptyLeaf, err := targetTree.Get(ctx, test.RandHash()) require.NoError(t, err) - require.True(t, emptyLeaf.IsEmpty(), "Non-existent key found") + require.True( + t, emptyLeaf.IsEmpty(), + "non-existent key found", + ) }) } } +// TestInsertMany tests inserting multiple leaves using the InsertMany method. +func TestInsertMany(t *testing.T) { + t.Parallel() + + leavesToInsert := randTree(50) + leavesMap := make(map[[hashSize]byte]*mssmt.LeafNode) + for _, item := range leavesToInsert { + leavesMap[item.key] = item.leaf + } + + // Calculate expected root after individual insertions for comparison. + tempStore := mssmt.NewDefaultStore() + tempTree := mssmt.NewFullTree(tempStore) + ctx := context.Background() + for key, leaf := range leavesMap { + _, err := tempTree.Insert(ctx, key, leaf) + require.NoError(t, err) + } + expectedRoot, err := tempTree.Root(ctx) + require.NoError(t, err) + + runTest := func(t *testing.T, name string, + makeTree func(mssmt.TreeStore) mssmt.Tree, + makeStore makeTestTreeStoreFunc) { + + t.Run(name, func(t *testing.T) { + store, err := makeStore() + require.NoError(t, err) + tree := makeTree(store) + + // Test inserting an empty map (should be a no-op). + _, err = tree.InsertMany( + ctx, make(map[[hashSize]byte]*mssmt.LeafNode), + ) + require.NoError(t, err) + initialRoot, err := tree.Root(ctx) + require.NoError(t, err) + require.True( + t, + mssmt.IsEqualNode( + mssmt.EmptyTree[0], initialRoot, + ), + ) + + // Insert the leaves using InsertMany. + _, err = tree.InsertMany(ctx, leavesMap) + require.NoError(t, err) + + // Verify the root. + finalRoot, err := tree.Root(ctx) + require.NoError(t, err) + require.True( + t, mssmt.IsEqualNode(expectedRoot, finalRoot), + ) + + // Verify each leaf can be retrieved. + for key, expectedLeaf := range leavesMap { + retrievedLeaf, err := tree.Get(ctx, key) + require.NoError(t, err) + require.Equal(t, expectedLeaf, retrievedLeaf) + } + }) + } + + for storeName, makeStore := range genTestStores(t) { + t.Run(storeName, func(t *testing.T) { + runTest(t, "full SMT", makeFullTree, makeStore) + runTest(t, "smol SMT", makeSmolTree, makeStore) + }) + } +} // runBIPTestVector runs the tests in a single BIP test vector file. func runBIPTestVector(t *testing.T, testVectors *mssmt.TestVectors) {