@@ -18,7 +18,7 @@ package trie
18
18
19
19
import (
20
20
"bytes"
21
-
21
+ "container/heap"
22
22
"github.com/ethereum/go-ethereum/common"
23
23
)
24
24
@@ -268,6 +268,26 @@ outer:
268
268
return nil
269
269
}
270
270
271
+ func compareNodes (a , b NodeIterator ) int {
272
+ cmp := bytes .Compare (a .Path (), b .Path ())
273
+ if cmp != 0 {
274
+ return cmp
275
+ }
276
+
277
+ if a .Leaf () && ! b .Leaf () {
278
+ return - 1
279
+ } else if b .Leaf () && ! a .Leaf () {
280
+ return 1
281
+ }
282
+
283
+ cmp = bytes .Compare (a .Hash ().Bytes (), b .Hash ().Bytes ())
284
+ if cmp != 0 {
285
+ return cmp
286
+ }
287
+
288
+ return bytes .Compare (a .LeafBlob (), b .LeafBlob ())
289
+ }
290
+
271
291
type differenceIterator struct {
272
292
a , b NodeIterator // Nodes returned are those in b - a.
273
293
eof bool // Indicates a has run out of elements
@@ -321,8 +341,7 @@ func (it *differenceIterator) Next(bool) bool {
321
341
}
322
342
323
343
for {
324
- apath , bpath := it .a .Path (), it .b .Path ()
325
- switch bytes .Compare (apath , bpath ) {
344
+ switch compareNodes (it .a , it .b ) {
326
345
case - 1 :
327
346
// b jumped past a; advance a
328
347
if ! it .a .Next (true ) {
@@ -334,15 +353,6 @@ func (it *differenceIterator) Next(bool) bool {
334
353
// b is before a
335
354
return true
336
355
case 0 :
337
- if it .a .Hash () != it .b .Hash () || it .a .Leaf () != it .b .Leaf () {
338
- // Keys are identical, but hashes or leaf status differs
339
- return true
340
- }
341
- if it .a .Leaf () && it .b .Leaf () && ! bytes .Equal (it .a .LeafBlob (), it .b .LeafBlob ()) {
342
- // Both are leaf nodes, but with different values
343
- return true
344
- }
345
-
346
356
// a and b are identical; skip this whole subtree if the nodes have hashes
347
357
hasHash := it .a .Hash () == common.Hash {}
348
358
if ! it .b .Next (hasHash ) {
@@ -364,3 +374,107 @@ func (it *differenceIterator) Error() error {
364
374
}
365
375
return it .b .Error ()
366
376
}
377
+
378
+ type nodeIteratorHeap []NodeIterator
379
+
380
+ func (h nodeIteratorHeap ) Len () int { return len (h ) }
381
+ func (h nodeIteratorHeap ) Less (i , j int ) bool { return compareNodes (h [i ], h [j ]) < 0 }
382
+ func (h nodeIteratorHeap ) Swap (i , j int ) { h [i ], h [j ] = h [j ], h [i ] }
383
+ func (h * nodeIteratorHeap ) Push (x interface {}) { * h = append (* h , x .(NodeIterator )) }
384
+ func (h * nodeIteratorHeap ) Pop () interface {} {
385
+ n := len (* h )
386
+ x := (* h )[n - 1 ]
387
+ * h = (* h )[0 : n - 1 ]
388
+ return x
389
+ }
390
+
391
+ type unionIterator struct {
392
+ items * nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
393
+ count int // Number of nodes scanned across all tries
394
+ err error // The error, if one has been encountered
395
+ }
396
+
397
+ // NewUnionIterator constructs a NodeIterator that iterates over elements in the union
398
+ // of the provided NodeIterators. Returns the iterator, and a pointer to an integer
399
+ // recording the number of nodes visited.
400
+ func NewUnionIterator (iters []NodeIterator ) (NodeIterator , * int ) {
401
+ h := make (nodeIteratorHeap , len (iters ))
402
+ copy (h , iters )
403
+ heap .Init (& h )
404
+
405
+ ui := & unionIterator {
406
+ items : & h ,
407
+ }
408
+ return ui , & ui .count
409
+ }
410
+
411
+ func (it * unionIterator ) Hash () common.Hash {
412
+ return (* it .items )[0 ].Hash ()
413
+ }
414
+
415
+ func (it * unionIterator ) Parent () common.Hash {
416
+ return (* it .items )[0 ].Parent ()
417
+ }
418
+
419
+ func (it * unionIterator ) Leaf () bool {
420
+ return (* it .items )[0 ].Leaf ()
421
+ }
422
+
423
+ func (it * unionIterator ) LeafBlob () []byte {
424
+ return (* it .items )[0 ].LeafBlob ()
425
+ }
426
+
427
+ func (it * unionIterator ) Path () []byte {
428
+ return (* it .items )[0 ].Path ()
429
+ }
430
+
431
+ // Next returns the next node in the union of tries being iterated over.
432
+ //
433
+ // It does this by maintaining a heap of iterators, sorted by the iteration
434
+ // order of their next elements, with one entry for each source trie. Each
435
+ // time Next() is called, it takes the least element from the heap to return,
436
+ // advancing any other iterators that also point to that same element. These
437
+ // iterators are called with descend=false, since we know that any nodes under
438
+ // these nodes will also be duplicates, found in the currently selected iterator.
439
+ // Whenever an iterator is advanced, it is pushed back into the heap if it still
440
+ // has elements remaining.
441
+ //
442
+ // In the case that descend=false - eg, we're asked to ignore all subnodes of the
443
+ // current node - we also advance any iterators in the heap that have the current
444
+ // path as a prefix.
445
+ func (it * unionIterator ) Next (descend bool ) bool {
446
+ if len (* it .items ) == 0 {
447
+ return false
448
+ }
449
+
450
+ // Get the next key from the union
451
+ least := heap .Pop (it .items ).(NodeIterator )
452
+
453
+ // Skip over other nodes as long as they're identical, or, if we're not descending, as
454
+ // long as they have the same prefix as the current node.
455
+ for len (* it .items ) > 0 && ((! descend && bytes .HasPrefix ((* it .items )[0 ].Path (), least .Path ())) || compareNodes (least , (* it .items )[0 ]) == 0 ) {
456
+ skipped := heap .Pop (it .items ).(NodeIterator )
457
+ // Skip the whole subtree if the nodes have hashes; otherwise just skip this node
458
+ if skipped .Next (skipped .Hash () == common.Hash {}) {
459
+ it .count += 1
460
+ // If there are more elements, push the iterator back on the heap
461
+ heap .Push (it .items , skipped )
462
+ }
463
+ }
464
+
465
+ if least .Next (descend ) {
466
+ it .count += 1
467
+ heap .Push (it .items , least )
468
+ }
469
+
470
+ return len (* it .items ) > 0
471
+ }
472
+
473
+ func (it * unionIterator ) Error () error {
474
+ for i := 0 ; i < len (* it .items ); i ++ {
475
+ if err := (* it .items )[i ].Error (); err != nil {
476
+ return err
477
+ }
478
+ }
479
+ return nil
480
+ }
0 commit comments