diff --git a/mssmt/compacted_tree.go b/mssmt/compacted_tree.go index abd654cc8..84246e502 100644 --- a/mssmt/compacted_tree.go +++ b/mssmt/compacted_tree.go @@ -392,3 +392,191 @@ func (t *CompactedTree) MerkleProof(ctx context.Context, key [hashSize]byte) ( return NewProof(proof), nil } + +// collectLeavesRecursive is a recursive helper function that's used to traverse +// down an MS-SMT tree and collect all leaf nodes. It returns a map of leaf +// nodes indexed by their hash. +func collectLeavesRecursive(ctx context.Context, tx TreeStoreViewTx, node Node, + depth int) (map[[hashSize]byte]*LeafNode, error) { + + // Base case: If it's a compacted leaf node. + if compactedLeaf, ok := node.(*CompactedLeafNode); ok { + if compactedLeaf.LeafNode.IsEmpty() { + return make(map[[hashSize]byte]*LeafNode), nil + } + return map[[hashSize]byte]*LeafNode{ + compactedLeaf.Key(): compactedLeaf.LeafNode, + }, nil + } + + // Recursive step: If it's a branch node. + if branchNode, ok := node.(*BranchNode); ok { + // Optimization: if the branch is empty, return early. + if depth < MaxTreeLevels && + IsEqualNode(branchNode, EmptyTree[depth]) { + + return make(map[[hashSize]byte]*LeafNode), nil + } + + // Handle case where depth might exceed EmptyTree bounds if + // logic error exists + if depth >= MaxTreeLevels { + // This shouldn't happen if called correctly, implies a + // leaf. + return nil, fmt.Errorf("invalid depth %d for branch "+ + "node", depth) + } + + left, right, err := tx.GetChildren(depth, branchNode.NodeHash()) + if err != nil { + // If children not found, it might be an empty branch + // implicitly Check if the error indicates "not found" + // or similar Depending on store impl, this might be how + // empty is signaled For now, treat error as fatal. + return nil, fmt.Errorf("error getting children for "+ + "branch %s at depth %d: %w", + branchNode.NodeHash(), depth, err) + } + + leftLeaves, err := collectLeavesRecursive( + ctx, tx, left, depth+1, + ) + if err != nil { + return nil, err + } + + rightLeaves, err := collectLeavesRecursive( + ctx, tx, right, depth+1, + ) + if err != nil { + return nil, err + } + + // Merge the results. + for k, v := range rightLeaves { + // Check for duplicate keys, although this shouldn't + // happen in a valid SMT. + if _, exists := leftLeaves[k]; exists { + return nil, fmt.Errorf("duplicate key %x "+ + "found during leaf collection", k) + } + leftLeaves[k] = v + } + + return leftLeaves, nil + } + + // Handle unexpected node types or implicit empty nodes. If node is nil + // or explicitly an EmptyLeafNode representation + if node == nil || IsEqualNode(node, EmptyLeafNode) { + return make(map[[hashSize]byte]*LeafNode), nil + } + + // Check against EmptyTree branches if possible (requires depth) + if depth < MaxTreeLevels && IsEqualNode(node, EmptyTree[depth]) { + return make(map[[hashSize]byte]*LeafNode), nil + } + + return nil, fmt.Errorf("unexpected node type %T encountered "+ + "during leaf collection at depth %d", node, depth) +} + +// Copy copies all the key-value pairs from the source tree into the target +// tree. +func (t *CompactedTree) Copy(ctx context.Context, targetTree Tree) error { + var leaves map[[hashSize]byte]*LeafNode + err := t.store.View(ctx, func(tx TreeStoreViewTx) error { + root, err := tx.RootNode() + if err != nil { + return fmt.Errorf("error getting root node: %w", err) + } + + // Optimization: If the source tree is empty, there's nothing to + // copy. + if IsEqualNode(root, EmptyTree[0]) { + leaves = make(map[[hashSize]byte]*LeafNode) + return nil + } + + // Start recursive collection from the root at depth 0. + leaves, err = collectLeavesRecursive(ctx, tx, root, 0) + if err != nil { + return fmt.Errorf("error collecting leaves: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + // Insert all found leaves into the target tree using InsertMany for + // efficiency. + _, err = targetTree.InsertMany(ctx, leaves) + if err != nil { + return fmt.Errorf("error inserting leaves into "+ + "target tree: %w", err) + } + + return nil +} + +// InsertMany inserts multiple leaf nodes provided in the leaves map within a +// single database transaction. +func (t *CompactedTree) InsertMany(ctx context.Context, + leaves map[[hashSize]byte]*LeafNode) (Tree, error) { + + if len(leaves) == 0 { + return t, nil + } + + dbErr := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error { + currentRoot, err := tx.RootNode() + if err != nil { + return err + } + rootBranch := currentRoot.(*BranchNode) + + for key, leaf := range leaves { + // Check for potential sum overflow before each + // insertion. + sumRoot := rootBranch.NodeSum() + sumLeaf := leaf.NodeSum() + err = CheckSumOverflowUint64(sumRoot, sumLeaf) + if err != nil { + return fmt.Errorf("compact tree leaf insert "+ + "sum overflow, root: %d, leaf: %d; %w", + sumRoot, sumLeaf, err) + } + + // Insert the leaf using the internal helper. + newRoot, err := t.insert( + tx, &key, 0, rootBranch, leaf, + ) + if err != nil { + return fmt.Errorf("error inserting leaf "+ + "with key %x: %w", key, err) + } + rootBranch = newRoot + + // Update the root within the transaction for + // consistency, even though the insert logic passes the + // root explicitly. + err = tx.UpdateRoot(rootBranch) + if err != nil { + return fmt.Errorf("error updating root "+ + "during InsertMany: %w", err) + } + } + + // The root is already updated by the last iteration of the + // loop. No final update needed here, but returning nil error + // signals success. + return nil + }) + if dbErr != nil { + return nil, dbErr + } + + return t, nil +} diff --git a/mssmt/interface.go b/mssmt/interface.go index 371fc6b77..bf3759f0c 100644 --- a/mssmt/interface.go +++ b/mssmt/interface.go @@ -30,4 +30,13 @@ type Tree interface { // proof. This is noted by the returned `Proof` containing an empty // leaf. MerkleProof(ctx context.Context, key [hashSize]byte) (*Proof, error) + + // InsertMany inserts multiple leaf nodes provided in the leaves map + // within a single database transaction. + InsertMany(ctx context.Context, leaves map[[hashSize]byte]*LeafNode) ( + Tree, error) + + // Copy copies all the key-value pairs from the source tree into the + // target tree. + Copy(ctx context.Context, targetTree Tree) error } diff --git a/mssmt/tree.go b/mssmt/tree.go index 764dc72c5..577d15e5e 100644 --- a/mssmt/tree.go +++ b/mssmt/tree.go @@ -97,6 +97,14 @@ func bitIndex(idx uint8, key *[hashSize]byte) byte { return (byteVal >> (idx % 8)) & 1 } +// setBit returns a copy of the key with the bit at the given depth set to 1. +func setBit(key [hashSize]byte, depth int) [hashSize]byte { + byteIndex := depth / 8 + bitIndex := depth % 8 + key[byteIndex] |= (1 << bitIndex) + return key +} + // iterFunc is a type alias for closures to be invoked at every iteration of // walking through a tree. type iterFunc = func(height int, current, sibling, parent Node) error @@ -333,6 +341,162 @@ func (t *FullTree) MerkleProof(ctx context.Context, key [hashSize]byte) ( return NewProof(proof), nil } +// findLeaves recursively traverses the tree represented by the given node and +// collects all non-empty leaf nodes along with their reconstructed keys. +func findLeaves(ctx context.Context, tx TreeStoreViewTx, node Node, + keyPrefix [hashSize]byte, + depth int) (map[[hashSize]byte]*LeafNode, error) { + + // Base case: If it's a leaf node. + if leafNode, ok := node.(*LeafNode); ok { + if leafNode.IsEmpty() { + return make(map[[hashSize]byte]*LeafNode), nil + } + return map[[hashSize]byte]*LeafNode{keyPrefix: leafNode}, nil + } + + // Recursive step: If it's a branch node. + if branchNode, ok := node.(*BranchNode); ok { + // Optimization: if the branch is empty, return early. + if IsEqualNode(branchNode, EmptyTree[depth]) { + return make(map[[hashSize]byte]*LeafNode), nil + } + + left, right, err := tx.GetChildren(depth, branchNode.NodeHash()) + if err != nil { + return nil, fmt.Errorf("error getting children for "+ + "branch %s at depth %d: %w", + branchNode.NodeHash(), depth, err) + } + + // Recursively find leaves in the left subtree. The key prefix + // remains the same as the 0 bit is implicitly handled by the + // initial keyPrefix state. + leftLeaves, err := findLeaves( + ctx, tx, left, keyPrefix, depth+1, + ) + if err != nil { + return nil, err + } + + // Recursively find leaves in the right subtree. Set the bit + // corresponding to the current depth in the key prefix. + rightKeyPrefix := setBit(keyPrefix, depth) + + rightLeaves, err := findLeaves( + ctx, tx, right, rightKeyPrefix, depth+1, + ) + if err != nil { + return nil, err + } + + // Merge the results. + for k, v := range rightLeaves { + leftLeaves[k] = v + } + return leftLeaves, nil + } + + // Handle unexpected node types. + return nil, fmt.Errorf("unexpected node type %T encountered "+ + "during leaf collection", node) +} + +// Copy copies all the key-value pairs from the source tree into the target +// tree. +func (t *FullTree) Copy(ctx context.Context, targetTree Tree) error { + var leaves map[[hashSize]byte]*LeafNode + err := t.store.View(ctx, func(tx TreeStoreViewTx) error { + root, err := tx.RootNode() + if err != nil { + return fmt.Errorf("error getting root node: %w", err) + } + + // Optimization: If the source tree is empty, there's nothing + // to copy. + if IsEqualNode(root, EmptyTree[0]) { + leaves = make(map[[hashSize]byte]*LeafNode) + return nil + } + + leaves, err = findLeaves(ctx, tx, root, [hashSize]byte{}, 0) + if err != nil { + return fmt.Errorf("error finding leaves: %w", err) + } + return nil + }) + if err != nil { + return err + } + + // Insert all found leaves into the target tree using InsertMany for + // efficiency. + _, err = targetTree.InsertMany(ctx, leaves) + if err != nil { + return fmt.Errorf("error inserting leaves into target "+ + "tree: %w", err) + } + + return nil +} + +// InsertMany inserts multiple leaf nodes provided in the leaves map within a +// single database transaction. +func (t *FullTree) InsertMany(ctx context.Context, + leaves map[[hashSize]byte]*LeafNode) (Tree, error) { + + if len(leaves) == 0 { + return t, nil + } + + err := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error { + currentRoot, err := tx.RootNode() + if err != nil { + return err + } + rootBranch := currentRoot.(*BranchNode) + + for key, leaf := range leaves { + // Check for potential sum overflow before each + // insertion. + sumRoot := rootBranch.NodeSum() + sumLeaf := leaf.NodeSum() + err = CheckSumOverflowUint64(sumRoot, sumLeaf) + if err != nil { + return fmt.Errorf("full tree leaf insert sum "+ + "overflow, root: %d, leaf: %d; %w", + sumRoot, sumLeaf, err) + } + + // Insert the leaf using the internal helper. + newRoot, err := t.insert(tx, &key, leaf) + if err != nil { + return fmt.Errorf("error inserting leaf "+ + "with key %x: %w", key, err) + } + rootBranch = newRoot + + // Update the root within the transaction so subsequent + // inserts in this batch read the correct state. + err = tx.UpdateRoot(rootBranch) + if err != nil { + return fmt.Errorf("error updating root "+ + "during InsertMany: %w", err) + } + } + + // The root is already updated by the last iteration of the + // loop. No final update needed here, but returning nil error + // signals success. + return nil + }) + if err != nil { + return nil, err + } + + return t, nil +} + // VerifyMerkleProof determines whether a merkle proof for the leaf found at the // given key is valid. func VerifyMerkleProof(key [hashSize]byte, leaf *LeafNode, proof *Proof, diff --git a/mssmt/tree_test.go b/mssmt/tree_test.go index 939571504..13c7f8791 100644 --- a/mssmt/tree_test.go +++ b/mssmt/tree_test.go @@ -714,7 +714,9 @@ func testMerkleProof(t *testing.T, tree mssmt.Tree, leaves []treeLeaf) { )) } -func testProofEquality(t *testing.T, tree1, tree2 mssmt.Tree, leaves []treeLeaf) { +func testProofEquality(t *testing.T, tree1, tree2 mssmt.Tree, + leaves []treeLeaf) { + assertEqualProof := func(proof1, proof2 *mssmt.Proof) { t.Helper() @@ -822,6 +824,232 @@ func TestBIPTestVectors(t *testing.T) { } } +// TestTreeCopy tests the Copy method for both FullTree and CompactedTree, +// including copying between different tree types. +func TestTreeCopy(t *testing.T) { + t.Parallel() + + leaves := randTree(50) // Use a smaller number for faster testing + + // Prepare source trees (Full and Compacted) + ctx := context.Background() + sourceFullStore := mssmt.NewDefaultStore() + sourceFullTree := mssmt.NewFullTree(sourceFullStore) + sourceCompactedStore := mssmt.NewDefaultStore() + sourceCompactedTree := mssmt.NewCompactedTree(sourceCompactedStore) + + for _, item := range leaves { + _, err := sourceFullTree.Insert(ctx, item.key, item.leaf) + require.NoError(t, err) + _, err = sourceCompactedTree.Insert(ctx, item.key, item.leaf) + require.NoError(t, err) + } + + sourceFullRoot, err := sourceFullTree.Root(ctx) + require.NoError(t, err) + sourceCompactedRoot, err := sourceCompactedTree.Root(ctx) + require.NoError(t, err) + require.True(t, mssmt.IsEqualNode(sourceFullRoot, sourceCompactedRoot)) + + // Define some leaves to pre-populate the target tree. + initialTargetLeaves := []treeLeaf{ + {key: test.RandHash(), leaf: randLeaf()}, + {key: test.RandHash(), leaf: randLeaf()}, + } + initialTargetLeavesMap := make(map[[hashSize]byte]*mssmt.LeafNode) + for _, item := range initialTargetLeaves { + initialTargetLeavesMap[item.key] = item.leaf + } + + // Define test cases + testCases := []struct { + name string + sourceTree mssmt.Tree + makeTarget func() mssmt.Tree + }{ + { + name: "Full -> Full", + sourceTree: sourceFullTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewFullTree( + mssmt.NewDefaultStore(), + ) + }, + }, + { + name: "Full -> Compacted", + sourceTree: sourceFullTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewCompactedTree( + mssmt.NewDefaultStore(), + ) + }, + }, + { + name: "Compacted -> Full", + sourceTree: sourceCompactedTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewFullTree( + mssmt.NewDefaultStore(), + ) + }, + }, + { + name: "Compacted -> Compacted", + sourceTree: sourceCompactedTree, + makeTarget: func() mssmt.Tree { + return mssmt.NewCompactedTree( + mssmt.NewDefaultStore(), + ) + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + targetTree := tc.makeTarget() + + // Pre-populate the target tree. + _, err := targetTree.InsertMany( + ctx, initialTargetLeavesMap, + ) + require.NoError(t, err) + + // Calculate the expected root after combining initial + // and source leaves. + expectedStateStore := mssmt.NewDefaultStore() + expectedStateTree := mssmt.NewFullTree( + expectedStateStore, + ) + _, err = expectedStateTree.InsertMany( + ctx, initialTargetLeavesMap, + ) + require.NoError(t, err) + sourceLeavesMap := make( + map[[hashSize]byte]*mssmt.LeafNode, + ) + for _, item := range leaves { + sourceLeavesMap[item.key] = item.leaf + } + _, err = expectedStateTree.InsertMany( + ctx, sourceLeavesMap, + ) + require.NoError(t, err) + expectedRoot, err := expectedStateTree.Root(ctx) + require.NoError(t, err) + + // Actually perform the copy. + err = tc.sourceTree.Copy(ctx, targetTree) + require.NoError(t, err) + + // Verify the target tree root matches the expected + // combined root. + targetRoot, err := targetTree.Root(ctx) + require.NoError(t, err) + require.True(t, + mssmt.IsEqualNode(expectedRoot, targetRoot), + "root mismatch after copy to non-empty target", + ) + + // Verify individual leaves (both initial and copied) in + // the target tree + allExpectedLeaves := append([]treeLeaf{}, leaves...) + allExpectedLeaves = append( + allExpectedLeaves, initialTargetLeaves..., + ) + for _, item := range allExpectedLeaves { + targetLeaf, err := targetTree.Get(ctx, item.key) + require.NoError(t, err) + require.Equal(t, item.leaf, targetLeaf, + "leaf mismatch for key %x", item.key) + } + + // Verify a non-existent key is still empty + emptyLeaf, err := targetTree.Get(ctx, test.RandHash()) + require.NoError(t, err) + require.True( + t, emptyLeaf.IsEmpty(), + "non-existent key found", + ) + }) + } +} + +// TestInsertMany tests inserting multiple leaves using the InsertMany method. +func TestInsertMany(t *testing.T) { + t.Parallel() + + leavesToInsert := randTree(50) + leavesMap := make(map[[hashSize]byte]*mssmt.LeafNode) + for _, item := range leavesToInsert { + leavesMap[item.key] = item.leaf + } + + // Calculate expected root after individual insertions for comparison. + tempStore := mssmt.NewDefaultStore() + tempTree := mssmt.NewFullTree(tempStore) + ctx := context.Background() + for key, leaf := range leavesMap { + _, err := tempTree.Insert(ctx, key, leaf) + require.NoError(t, err) + } + expectedRoot, err := tempTree.Root(ctx) + require.NoError(t, err) + + runTest := func(t *testing.T, name string, + makeTree func(mssmt.TreeStore) mssmt.Tree, + makeStore makeTestTreeStoreFunc) { + + t.Run(name, func(t *testing.T) { + store, err := makeStore() + require.NoError(t, err) + tree := makeTree(store) + + // Test inserting an empty map (should be a no-op). + _, err = tree.InsertMany( + ctx, make(map[[hashSize]byte]*mssmt.LeafNode), + ) + require.NoError(t, err) + initialRoot, err := tree.Root(ctx) + require.NoError(t, err) + require.True( + t, + mssmt.IsEqualNode( + mssmt.EmptyTree[0], initialRoot, + ), + ) + + // Insert the leaves using InsertMany. + _, err = tree.InsertMany(ctx, leavesMap) + require.NoError(t, err) + + // Verify the root. + finalRoot, err := tree.Root(ctx) + require.NoError(t, err) + require.True( + t, mssmt.IsEqualNode(expectedRoot, finalRoot), + ) + + // Verify each leaf can be retrieved. + for key, expectedLeaf := range leavesMap { + retrievedLeaf, err := tree.Get(ctx, key) + require.NoError(t, err) + require.Equal(t, expectedLeaf, retrievedLeaf) + } + }) + } + + for storeName, makeStore := range genTestStores(t) { + t.Run(storeName, func(t *testing.T) { + runTest(t, "full SMT", makeFullTree, makeStore) + runTest(t, "smol SMT", makeSmolTree, makeStore) + }) + } +} + // runBIPTestVector runs the tests in a single BIP test vector file. func runBIPTestVector(t *testing.T, testVectors *mssmt.TestVectors) { for _, validCase := range testVectors.ValidTestCases {