Skip to content

Commit d0e56b4

Browse files
authored
增加treemap的Iterate (#278)
* 增加treemap的Iterate function * 修复go lint * update
1 parent 40e0aa7 commit d0e56b4

File tree

5 files changed

+177
-4
lines changed

5 files changed

+177
-4
lines changed

internal/tree/red_black_tree.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,24 @@ func (rb *RBTree[K, V]) KeyValues() ([]K, []V) {
118118
if rb.root == nil {
119119
return keys, values
120120
}
121-
rb.inOrderTraversal(func(node *rbNode[K, V]) {
121+
rb.inOrderTraversal(func(node *rbNode[K, V]) bool {
122122
keys = append(keys, node.key)
123123
values = append(values, node.value)
124+
return true
124125
})
125126
return keys, values
126127
}
127128

129+
// Iterate 按照key的顺序遍历并执行cb,如果cb返回值为false则结束遍历,否则继续遍历
130+
func (rb *RBTree[K, V]) Iterate(cb func(key K, value V) bool) {
131+
rb.inOrderTraversal(func(node *rbNode[K, V]) bool {
132+
return cb(node.key, node.value)
133+
})
134+
}
135+
128136
// inOrderTraversal 中序遍历
129-
func (rb *RBTree[K, V]) inOrderTraversal(visit func(node *rbNode[K, V])) {
130-
stack := make([]*rbNode[K, V], 0, rb.size)
137+
func (rb *RBTree[K, V]) inOrderTraversal(visit func(node *rbNode[K, V]) bool) {
138+
stack := make([]*rbNode[K, V], 0)
131139
curr := rb.root
132140
for curr != nil || len(stack) > 0 {
133141
for curr != nil {
@@ -136,7 +144,9 @@ func (rb *RBTree[K, V]) inOrderTraversal(visit func(node *rbNode[K, V])) {
136144
}
137145
curr = stack[len(stack)-1]
138146
stack = stack[:len(stack)-1]
139-
visit(curr)
147+
if !visit(curr) {
148+
break
149+
}
140150
curr = curr.right
141151
}
142152
}

internal/tree/red_black_tree_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,73 @@ func TestRBTree_KeyValues(t *testing.T) {
13971397
}
13981398
}
13991399

1400+
func TestRBTree_Iterate(t *testing.T) {
1401+
for _, testCase := range []struct {
1402+
name string
1403+
expectedLen int
1404+
inputStart int
1405+
inputEnd int
1406+
// 如果为true则遍历结束
1407+
endConditionFunc func(key int) bool
1408+
}{
1409+
{
1410+
name: "treeMap为空",
1411+
expectedLen: 0,
1412+
inputStart: 1,
1413+
inputEnd: 0,
1414+
endConditionFunc: func(key int) bool {
1415+
return false
1416+
},
1417+
},
1418+
{
1419+
name: "treeMap 有10000个元素,遍历所有小于等于8000的元素",
1420+
expectedLen: 8000,
1421+
inputStart: 1,
1422+
inputEnd: 10000,
1423+
endConditionFunc: func(key int) bool {
1424+
return key > 8000
1425+
},
1426+
},
1427+
{
1428+
name: "treeMap 有10000个元素,遍历所有元素",
1429+
expectedLen: 10000,
1430+
inputStart: 1,
1431+
inputEnd: 10000,
1432+
endConditionFunc: func(key int) bool {
1433+
return false
1434+
},
1435+
},
1436+
{
1437+
name: "treeMap 有10个元素,由于第一个就不符合条件所以遍历立刻中断",
1438+
expectedLen: 0,
1439+
inputStart: 1,
1440+
inputEnd: 10,
1441+
endConditionFunc: func(key int) bool {
1442+
return key < 5
1443+
},
1444+
},
1445+
} {
1446+
t.Run(testCase.name, func(t *testing.T) {
1447+
rbTree := NewRBTree[int, int](compare())
1448+
for i := testCase.inputStart; i <= testCase.inputEnd; i++ {
1449+
assert.Nil(t, rbTree.Add(i, i))
1450+
}
1451+
arr := make([]int, 0)
1452+
rbTree.Iterate(func(key, value int) bool {
1453+
if testCase.endConditionFunc(key) {
1454+
return false
1455+
}
1456+
arr = append(arr, value)
1457+
return true
1458+
})
1459+
assert.Equal(t, testCase.expectedLen, len(arr))
1460+
for i := 0; i < testCase.expectedLen; i++ {
1461+
assert.Equal(t, testCase.inputStart+i, arr[i])
1462+
}
1463+
})
1464+
}
1465+
}
1466+
14001467
// IsRedBlackTree 检测是否满足红黑树
14011468
func IsRedBlackTree[K any, V any](root *rbNode[K, V]) bool {
14021469
// 检测节点是否黑色

mapx/treemap.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,9 @@ func (treeMap *TreeMap[T, V]) Len() int64 {
101101
return int64(treeMap.tree.Size())
102102
}
103103

104+
// Iterate 按照key的顺序遍历并执行cb,如果cb返回值为false则结束遍历,否则继续遍历
105+
func (treeMap *TreeMap[K, V]) Iterate(cb func(key K, value V) bool) {
106+
treeMap.tree.Iterate(cb)
107+
}
108+
104109
var _ mapi[any, any] = (*TreeMap[any, any])(nil)

mapx/treemap_example_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,26 @@ func ExampleNewTreeMap() {
2929
// Output:
3030
// 11
3131
}
32+
33+
func ExampleTreeMap_Iterate() {
34+
m, _ := mapx.NewTreeMap[int, int](ekit.ComparatorRealNumber[int])
35+
_ = m.Put(1, 11)
36+
_ = m.Put(-1, 12)
37+
_ = m.Put(100, 13)
38+
_ = m.Put(-100, 14)
39+
_ = m.Put(-101, 15)
40+
41+
m.Iterate(func(key, value int) bool {
42+
if key > 1 {
43+
return false
44+
}
45+
fmt.Println(key, value)
46+
return true
47+
})
48+
49+
// Output:
50+
// -101 15
51+
// -100 14
52+
// -1 12
53+
// 1 11
54+
}

mapx/treemap_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,74 @@ func TestTreeMap_Len(t *testing.T) {
445445
}
446446
}
447447

448+
func TestRBTree_Iterate(t *testing.T) {
449+
for _, testCase := range []struct {
450+
name string
451+
expectedLen int
452+
inputStart int
453+
inputEnd int
454+
// 如果为true则遍历结束
455+
endConditionFunc func(key int) bool
456+
}{
457+
{
458+
name: "treeMap为空",
459+
expectedLen: 0,
460+
inputStart: 1,
461+
inputEnd: 0,
462+
endConditionFunc: func(key int) bool {
463+
return false
464+
},
465+
},
466+
{
467+
name: "treeMap 有10000个元素,遍历所有小于等于8000的元素",
468+
expectedLen: 8000,
469+
inputStart: 1,
470+
inputEnd: 10000,
471+
endConditionFunc: func(key int) bool {
472+
return key > 8000
473+
},
474+
},
475+
{
476+
name: "treeMap 有10000个元素,遍历所有元素",
477+
expectedLen: 10000,
478+
inputStart: 1,
479+
inputEnd: 10000,
480+
endConditionFunc: func(key int) bool {
481+
return false
482+
},
483+
},
484+
{
485+
name: "treeMap 有10个元素,由于第一个就不符合条件所以遍历立刻中断",
486+
expectedLen: 0,
487+
inputStart: 1,
488+
inputEnd: 10,
489+
endConditionFunc: func(key int) bool {
490+
return key < 5
491+
},
492+
},
493+
} {
494+
t.Run(testCase.name, func(t *testing.T) {
495+
treeMap, err := NewTreeMap[int, int](compare())
496+
assert.Nil(t, err)
497+
for i := testCase.inputStart; i <= testCase.inputEnd; i++ {
498+
assert.Nil(t, treeMap.Put(i, i))
499+
}
500+
arr := make([]int, 0)
501+
treeMap.Iterate(func(key, value int) bool {
502+
if testCase.endConditionFunc(key) {
503+
return false
504+
}
505+
arr = append(arr, value)
506+
return true
507+
})
508+
assert.Equal(t, testCase.expectedLen, len(arr))
509+
for i := 0; i < testCase.expectedLen; i++ {
510+
assert.Equal(t, testCase.inputStart+i, arr[i])
511+
}
512+
})
513+
}
514+
}
515+
448516
func compare() ekit.Comparator[int] {
449517
return ekit.ComparatorRealNumber[int]
450518
}

0 commit comments

Comments
 (0)