Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions cmd/migration-checker/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package main

import (
"bytes"
"encoding/hex"
"flag"
"fmt"
"os"
"runtime"
"sync"
"sync/atomic"

"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/ethdb/leveldb"
"github.com/scroll-tech/go-ethereum/rlp"
"github.com/scroll-tech/go-ethereum/trie"
)

var accountsDone atomic.Uint64
var trieCheckers = make(chan struct{}, runtime.GOMAXPROCS(0)*4)

type dbs struct {
zkDb *leveldb.Database
mptDb *leveldb.Database
}

func main() {
var (
mptDbPath = flag.String("mpt-db", "", "path to the MPT node DB")
zkDbPath = flag.String("zk-db", "", "path to the ZK node DB")
mptRoot = flag.String("mpt-root", "", "root hash of the MPT node")
zkRoot = flag.String("zk-root", "", "root hash of the ZK node")
paranoid = flag.Bool("paranoid", false, "verifies all node contents against their expected hash")
)
flag.Parse()

zkDb, err := leveldb.New(*zkDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open zk db")
mptDb, err := leveldb.New(*mptDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open mpt db")

zkRootHash := common.HexToHash(*zkRoot)
mptRootHash := common.HexToHash(*mptRoot)

for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
trieCheckers <- struct{}{}
}

checkTrieEquality(&dbs{
zkDb: zkDb,
mptDb: mptDb,
}, zkRootHash, mptRootHash, "", checkAccountEquality, true, *paranoid)

for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
<-trieCheckers
}
}

func panicOnError(err error, label, msg string) {
if err != nil {
panic(fmt.Sprint(label, " error: ", msg, " ", err))
}
}
Comment on lines +73 to +77
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Prefer returning errors instead of calling panicOnError.
This helper function can abruptly terminate the program. For production tools or libraries, returning errors often leads to more flexible handling.


func dup(s []byte) []byte {
return append([]byte{}, s...)
}
func checkTrieEquality(dbs *dbs, zkRoot, mptRoot common.Hash, label string, leafChecker func(string, *dbs, []byte, []byte, bool), top, paranoid bool) {
zkTrie, err := trie.NewZkTrie(zkRoot, trie.NewZktrieDatabaseFromTriedb(trie.NewDatabaseWithConfig(dbs.zkDb, &trie.Config{Preimages: true})))
panicOnError(err, label, "failed to create zk trie")
mptTrie, err := trie.NewSecureNoTracer(mptRoot, trie.NewDatabaseWithConfig(dbs.mptDb, &trie.Config{Preimages: true}))
panicOnError(err, label, "failed to create mpt trie")

mptLeafCh := loadMPT(mptTrie, top)
zkLeafCh := loadZkTrie(zkTrie, top, paranoid)

mptLeafMap := <-mptLeafCh
zkLeafMap := <-zkLeafCh

if len(mptLeafMap) != len(zkLeafMap) {
panic(fmt.Sprintf("%s MPT and ZK trie leaf count mismatch: MPT: %d, ZK: %d", label, len(mptLeafMap), len(zkLeafMap)))
}

for preimageKey, zkValue := range zkLeafMap {
if top {
// ZkTrie pads preimages with 0s to make them 32 bytes.
// So we might need to clear those zeroes here since we need 20 byte addresses at top level (ie state trie)
if len(preimageKey) > 20 {
for _, b := range []byte(preimageKey)[20:] {
if b != 0 {
panic(fmt.Sprintf("%s padded byte is not 0 (preimage %s)", label, hex.EncodeToString([]byte(preimageKey))))
}
}
preimageKey = preimageKey[:20]
}
} else if len(preimageKey) != 32 {
// storage leafs should have 32 byte keys, pad them if needed
zeroes := make([]byte, 32)
copy(zeroes, []byte(preimageKey))
preimageKey = string(zeroes)
}

mptKey := crypto.Keccak256([]byte(preimageKey))
mptVal, ok := mptLeafMap[string(mptKey)]
if !ok {
panic(fmt.Sprintf("%s key %s (preimage %s) not found in mpt", label, hex.EncodeToString(mptKey), hex.EncodeToString([]byte(preimageKey))))
}

leafChecker(fmt.Sprintf("%s key: %s", label, hex.EncodeToString([]byte(preimageKey))), dbs, zkValue, mptVal, paranoid)
}
}

func checkAccountEquality(label string, dbs *dbs, zkAccountBytes, mptAccountBytes []byte, paranoid bool) {
mptAccount := &types.StateAccount{}
panicOnError(rlp.DecodeBytes(mptAccountBytes, mptAccount), label, "failed to decode mpt account")
zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes)
panicOnError(err, label, "failed to decode zk account")

if mptAccount.Nonce != zkAccount.Nonce {
panic(fmt.Sprintf("%s nonce mismatch: zk: %d, mpt: %d", label, zkAccount.Nonce, mptAccount.Nonce))
}

if mptAccount.Balance.Cmp(zkAccount.Balance) != 0 {
panic(fmt.Sprintf("%s balance mismatch: zk: %s, mpt: %s", label, zkAccount.Balance.String(), mptAccount.Balance.String()))
}

if !bytes.Equal(mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash) {
panic(fmt.Sprintf("%s code hash mismatch: zk: %s, mpt: %s", label, hex.EncodeToString(zkAccount.KeccakCodeHash), hex.EncodeToString(mptAccount.KeccakCodeHash)))
}

if (zkAccount.Root == common.Hash{}) != (mptAccount.Root == types.EmptyRootHash) {
panic(fmt.Sprintf("%s empty account root mismatch", label))
} else if zkAccount.Root != (common.Hash{}) {
zkRoot := common.BytesToHash(zkAccount.Root[:])
mptRoot := common.BytesToHash(mptAccount.Root[:])
<-trieCheckers
go func() {
defer func() {
if p := recover(); p != nil {
fmt.Println(p)
os.Exit(1)
}
}()

checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false, paranoid)
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
trieCheckers <- struct{}{}
}()
} else {
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
}
}

func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte, _ bool) {
zkValue := common.BytesToHash(zkStorageBytes)
_, content, _, err := rlp.Split(mptStorageBytes)
panicOnError(err, label, "failed to decode mpt storage")
mptValue := common.BytesToHash(content)
if !bytes.Equal(zkValue[:], mptValue[:]) {
panic(fmt.Sprintf("%s storage mismatch: zk: %s, mpt: %s", label, zkValue.Hex(), mptValue.Hex()))
}
}

func loadMPT(mptTrie *trie.SecureTrie, parallel bool) chan map[string][]byte {
startKey := make([]byte, 32)
workers := 1 << 5
if !parallel {
workers = 1
}
step := byte(0xFF) / byte(workers)

mptLeafMap := make(map[string][]byte, 1000)
var mptLeafMutex sync.Mutex

var mptWg sync.WaitGroup
for i := 0; i < workers; i++ {
startKey[0] = byte(i) * step
trieIt := trie.NewIterator(mptTrie.NodeIterator(startKey))

mptWg.Add(1)
go func() {
defer mptWg.Done()
for trieIt.Next() {
if parallel {
mptLeafMutex.Lock()
}

if _, ok := mptLeafMap[string(trieIt.Key)]; ok {
mptLeafMutex.Unlock()
break
}

mptLeafMap[string(dup(trieIt.Key))] = dup(trieIt.Value)

if parallel {
mptLeafMutex.Unlock()
}

if parallel && len(mptLeafMap)%10000 == 0 {
fmt.Println("MPT Accounts Loaded:", len(mptLeafMap))
}
}
}()
}

respChan := make(chan map[string][]byte)
go func() {
mptWg.Wait()
respChan <- mptLeafMap
}()
return respChan
}

func loadZkTrie(zkTrie *trie.ZkTrie, parallel, paranoid bool) chan map[string][]byte {
zkLeafMap := make(map[string][]byte, 1000)
var zkLeafMutex sync.Mutex
zkDone := make(chan map[string][]byte)
go func() {
zkTrie.CountLeaves(func(key, value []byte) {
preimageKey := zkTrie.GetKey(key)
if len(preimageKey) == 0 {
panic(fmt.Sprintf("preimage not found zk trie %s", hex.EncodeToString(key)))
}

if parallel {
zkLeafMutex.Lock()
}

zkLeafMap[string(dup(preimageKey))] = value

if parallel {
zkLeafMutex.Unlock()
}

if parallel && len(zkLeafMap)%10000 == 0 {
fmt.Println("ZK Accounts Loaded:", len(zkLeafMap))
}
}, parallel, paranoid)
zkDone <- zkLeafMap
}()
return zkDone
}
10 changes: 10 additions & 0 deletions trie/secure_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
return &SecureTrie{trie: *trie, preimages: db.preimages}, nil
}

func NewSecureNoTracer(root common.Hash, db *Database) (*SecureTrie, error) {
t, err := NewSecure(root, db)
if err != nil {
return nil, err
}

t.trie.tracer = nil
return t, nil
}

// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *SecureTrie) Get(key []byte) []byte {
Expand Down
24 changes: 24 additions & 0 deletions trie/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ func newTracer() *tracer {
// blob internally. Don't change the value outside of function since
// it's not deep-copied.
func (t *tracer) onRead(path []byte, val []byte) {
if t == nil {
return
}

t.accessList[string(path)] = val
}

// onInsert tracks the newly inserted trie node. If it's already
// in the deletion set (resurrected node), then just wipe it from
// the deletion set as it's "untouched".
func (t *tracer) onInsert(path []byte) {
if t == nil {
return
}

if _, present := t.deletes[string(path)]; present {
delete(t.deletes, string(path))
return
Expand All @@ -78,6 +86,10 @@ func (t *tracer) onInsert(path []byte) {
// in the addition set, then just wipe it from the addition set
// as it's untouched.
func (t *tracer) onDelete(path []byte) {
if t == nil {
return
}

if _, present := t.inserts[string(path)]; present {
delete(t.inserts, string(path))
return
Expand All @@ -87,13 +99,21 @@ func (t *tracer) onDelete(path []byte) {

// reset clears the content tracked by tracer.
func (t *tracer) reset() {
if t == nil {
return
}

t.inserts = make(map[string]struct{})
t.deletes = make(map[string]struct{})
t.accessList = make(map[string][]byte)
}

// copy returns a deep copied tracer instance.
func (t *tracer) copy() *tracer {
if t == nil {
return nil
}

accessList := make(map[string][]byte, len(t.accessList))
for path, blob := range t.accessList {
accessList[path] = common.CopyBytes(blob)
Expand All @@ -107,6 +127,10 @@ func (t *tracer) copy() *tracer {

// deletedNodes returns a list of node paths which are deleted from the trie.
func (t *tracer) deletedNodes() []string {
if t == nil {
return nil
}

var paths []string
for path := range t.deletes {
// It's possible a few deleted nodes were embedded
Expand Down
49 changes: 49 additions & 0 deletions trie/zk_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,52 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead
func (t *ZkTrie) Witness() map[string]struct{} {
panic("not implemented")
}

func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel, verifyNodeHashes bool) uint64 {
root, err := t.ZkTrie.Tree().Root()
if err != nil {
panic("CountLeaves cannot get root")
}
return t.countLeaves(root, cb, 0, parallel, verifyNodeHashes)
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace panic with proper error handling.

Using panic in a library method can lead to unexpected application crashes. Return errors instead so callers can handle them appropriately.

-func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel, verifyNodeHashes bool) uint64 {
+// CountLeaves counts the number of leaf nodes in the trie.
+// It accepts a callback function that is called for each leaf node with its key and value,
+// a parallel flag to enable concurrent counting, and a verification flag to verify node hashes.
+// Returns the total number of leaves and any error encountered.
+func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel, verifyNodeHashes bool) (uint64, error) {
+	if cb == nil {
+		return 0, fmt.Errorf("callback function cannot be nil")
+	}
 	root, err := t.ZkTrie.Tree().Root()
 	if err != nil {
-		panic("CountLeaves cannot get root")
+		return 0, fmt.Errorf("failed to get root: %w", err)
 	}
-	return t.countLeaves(root, cb, 0, parallel, verifyNodeHashes)
+	return t.countLeaves(root, cb, 0, parallel, verifyNodeHashes)
}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel, verifyNodeHashes bool) uint64 {
root, err := t.ZkTrie.Tree().Root()
if err != nil {
panic("CountLeaves cannot get root")
}
return t.countLeaves(root, cb, 0, parallel, verifyNodeHashes)
}
// CountLeaves counts the number of leaf nodes in the trie.
// It accepts a callback function that is called for each leaf node with its key and value,
// a parallel flag to enable concurrent counting, and a verification flag to verify node hashes.
// Returns the total number of leaves and any error encountered.
func (t *ZkTrie) CountLeaves(cb func(key, value []byte), parallel, verifyNodeHashes bool) (uint64, error) {
if cb == nil {
return 0, fmt.Errorf("callback function cannot be nil")
}
root, err := t.ZkTrie.Tree().Root()
if err != nil {
return 0, fmt.Errorf("failed to get root: %w", err)
}
return t.countLeaves(root, cb, 0, parallel, verifyNodeHashes)
}


func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth int, parallel, verifyNodeHashes bool) uint64 {
if root == nil {
return 0
}

rootNode, err := t.ZkTrie.Tree().GetNode(root)
if err != nil {
panic("countLeaves cannot get rootNode")
}

if rootNode.Type == zktrie.NodeTypeLeaf_New {
if verifyNodeHashes {
calculatedNodeHash, err := rootNode.NodeHash()
if err != nil {
panic("countLeaves cannot get calculatedNodeHash")
}
if *calculatedNodeHash != *root {
panic("countLeaves node hash mismatch")
}
}

cb(append([]byte{}, rootNode.NodeKey.Bytes()...), append([]byte{}, rootNode.Data()...))
return 1
} else {
if parallel && depth < 5 {
count := make(chan uint64)
leftT := t.Copy()
rightT := t.Copy()
go func() {
count <- leftT.countLeaves(rootNode.ChildL, cb, depth+1, parallel, verifyNodeHashes)
}()
go func() {
count <- rightT.countLeaves(rootNode.ChildR, cb, depth+1, parallel, verifyNodeHashes)
}()
return <-count + <-count
} else {
return t.countLeaves(rootNode.ChildL, cb, depth+1, parallel, verifyNodeHashes) + t.countLeaves(rootNode.ChildR, cb, depth+1, parallel, verifyNodeHashes)
}
}
}
Loading