Skip to content

Commit d7d8a0c

Browse files
committed
Feat: Stringer Identity + Optimize
- Reduce potential debug contention by using cmpandswap atomics - Add the ability to use fmt.Stringers for Identity functionality (not sure why i ever did anything else tbh) - Complete test coverage
1 parent e008c05 commit d7d8a0c

File tree

4 files changed

+115
-57
lines changed

4 files changed

+115
-57
lines changed

debug.go

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
package rate5
22

3-
import "fmt"
3+
import (
4+
"fmt"
5+
"sync/atomic"
6+
)
47

58
func (q *Limiter) debugPrintf(format string, a ...interface{}) {
6-
q.debugMutex.RLock()
7-
defer q.debugMutex.RUnlock()
8-
if !q.debug {
9+
if atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugDisabled) {
910
return
1011
}
1112
msg := fmt.Sprintf(format, a...)
1213
select {
1314
case q.debugChannel <- msg:
15+
//
1416
default:
15-
println(msg)
17+
// drop the message but increment the lost counter
18+
atomic.AddInt64(&q.debugLost, 1)
1619
}
1720
}
1821

@@ -23,26 +26,22 @@ func (q *Limiter) setDebugEvict() {
2326
}
2427

2528
func (q *Limiter) SetDebug(on bool) {
26-
q.debugMutex.Lock()
27-
if !on {
28-
q.debug = false
29-
q.Patrons.OnEvicted(nil)
30-
q.debugMutex.Unlock()
31-
return
29+
switch on {
30+
case true:
31+
atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled)
32+
q.debugPrintf("rate5 debug enabled")
33+
case false:
34+
atomic.CompareAndSwapUint32(&q.debug, DebugEnabled, DebugDisabled)
3235
}
33-
q.debug = on
34-
q.setDebugEvict()
35-
q.debugMutex.Unlock()
36-
q.debugPrintf("rate5 debug enabled")
3736
}
3837

3938
// DebugChannel enables debug mode and returns a channel where debug messages are sent.
40-
// NOTE: You must read from this channel if created via this function or it will block
39+
//
40+
// NOTE: If you do not read from this channel, the debug messages will eventually be lost.
41+
// If this happens,
4142
func (q *Limiter) DebugChannel() chan string {
4243
defer func() {
43-
q.debugMutex.Lock()
44-
q.debug = true
45-
q.debugMutex.Unlock()
44+
atomic.CompareAndSwapUint32(&q.debug, DebugDisabled, DebugEnabled)
4645
}()
4746
q.debugMutex.RLock()
4847
if q.debugChannel != nil {
@@ -52,7 +51,7 @@ func (q *Limiter) DebugChannel() chan string {
5251
q.debugMutex.RUnlock()
5352
q.debugMutex.Lock()
5453
defer q.debugMutex.Unlock()
55-
q.debugChannel = make(chan string, 25)
54+
q.debugChannel = make(chan string, 55)
5655
q.setDebugEvict()
5756
return q.debugChannel
5857
}

models.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package rate5
22

33
import (
4+
"fmt"
45
"sync"
56

67
"github.com/patrickmn/go-cache"
@@ -18,19 +19,33 @@ type Identity interface {
1819
UniqueKey() string
1920
}
2021

22+
// IdentityStringer is an implentation of Identity that acts as a shim for types that implement fmt.Stringer.
23+
type IdentityStringer struct {
24+
stringer fmt.Stringer
25+
}
26+
27+
func (i IdentityStringer) UniqueKey() string {
28+
return i.stringer.String()
29+
}
30+
31+
const (
32+
DebugDisabled uint32 = iota
33+
DebugEnabled
34+
)
35+
2136
// Limiter implements an Enforcer to create an arbitrary ratelimiter.
2237
type Limiter struct {
23-
// Source is the implementation of the Identity interface. It is used to create a unique key for each request.
24-
Source Identity
2538
// Patrons gives access to the underlying cache type that powers the ratelimiter.
2639
// It is exposed for testing purposes.
2740
Patrons *cache.Cache
41+
2842
// Ruleset determines the Policy which is used to determine whether or not to ratelimit.
2943
// It consists of a Window and Burst, see Policy for more details.
3044
Ruleset Policy
3145

32-
debug bool
46+
debug uint32
3347
debugChannel chan string
48+
debugLost int64
3449
known map[interface{}]*int64
3550
debugMutex *sync.RWMutex
3651
*sync.RWMutex

ratelimiter.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package rate5
22

33
import (
4+
"fmt"
45
"sync"
56
"sync/atomic"
67
"time"
@@ -57,10 +58,12 @@ func NewStrictLimiter(window int, burst int) *Limiter {
5758
})
5859
}
5960

60-
/*NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled.
61+
/*
62+
NewHardcoreLimiter returns a custom limiter with Strict + Hardcore modes enabled.
6163
6264
Hardcore mode causes the time limited to be multiplied by the number of hits.
63-
This differs from strict mode which is only using addition instead of multiplication.*/
65+
This differs from strict mode which is only using addition instead of multiplication.
66+
*/
6467
func NewHardcoreLimiter(window int, burst int) *Limiter {
6568
l := NewStrictLimiter(window, burst)
6669
l.Ruleset.Hardcore = true
@@ -80,7 +83,7 @@ func newLimiter(policy Policy) *Limiter {
8083
known: make(map[interface{}]*int64),
8184
RWMutex: &sync.RWMutex{},
8285
debugMutex: &sync.RWMutex{},
83-
debug: false,
86+
debug: DebugDisabled,
8487
}
8588
}
8689

@@ -122,6 +125,11 @@ func (q *Limiter) strictLogic(src string, count int64) {
122125
q.debugPrintf("%s ratelimit for %s: last count %d. time: %s", prefix, src, count, exttime)
123126
}
124127

128+
func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
129+
targ := IdentityStringer{stringer: from}
130+
return q.Check(targ)
131+
}
132+
125133
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
126134
func (q *Limiter) Check(from Identity) (limited bool) {
127135
var count int64
@@ -159,3 +167,8 @@ func (q *Limiter) Peek(from Identity) bool {
159167
}
160168
return false
161169
}
170+
171+
func (q *Limiter) PeekStringer(from fmt.Stringer) bool {
172+
targ := IdentityStringer{stringer: from}
173+
return q.Peek(targ)
174+
}

ratelimiter_test.go

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ var (
4848
)
4949

5050
func watchDebug(ctx context.Context, r *Limiter, t *testing.T) {
51-
t.Helper()
5251
watchDebugMutex.Lock()
5352
defer watchDebugMutex.Unlock()
5453
rd := r.DebugChannel()
@@ -68,25 +67,28 @@ func watchDebug(ctx context.Context, r *Limiter, t *testing.T) {
6867
}
6968
}
7069

71-
func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe bool) {
72-
t.Helper()
70+
func peekCheckLimited(t *testing.T, limiter *Limiter, shouldbe, stringer bool) {
71+
limited := limiter.Peek(dummyTicker)
72+
if stringer {
73+
limited = limiter.PeekStringer(dummyTicker)
74+
}
7375
switch {
74-
case limiter.Peek(dummyTicker) && !shouldbe:
76+
case limited && !shouldbe:
7577
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
7678
t.Errorf("Should not have been limited. Ratelimiter count: %d", ct)
7779
} else {
7880
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
7981
}
80-
case !limiter.Peek(dummyTicker) && shouldbe:
82+
case !limited && shouldbe:
8183
if ct, ok := limiter.Patrons.Get(dummyTicker.UniqueKey()); ok {
8284
t.Errorf("Should have been limited. Ratelimiter count: %d", ct)
8385
} else {
8486
t.Fatalf("dummyTicker does not exist in ratelimiter at all!")
8587
}
86-
case limiter.Peek(dummyTicker) && shouldbe:
87-
t.Logf("dummyTicker is limited")
88-
case !limiter.Peek(dummyTicker) && !shouldbe:
89-
t.Logf("dummyTicker is not limited")
88+
case limited && shouldbe:
89+
t.Logf("dummyTicker is limited as expected.")
90+
case !limited && !shouldbe:
91+
t.Logf("dummyTicker is not limited as expected.")
9092
}
9193
}
9294

@@ -105,6 +107,10 @@ func (tick *ticker) UniqueKey() string {
105107
return "TestItem"
106108
}
107109

110+
func (tick *ticker) String() string {
111+
return "TestItem"
112+
}
113+
108114
func Test_ResetItem(t *testing.T) {
109115
limiter := NewLimiter(500, 1)
110116
ctx, cancel := context.WithCancel(context.Background())
@@ -114,26 +120,36 @@ func Test_ResetItem(t *testing.T) {
114120
limiter.Check(dummyTicker)
115121
}
116122
limiter.ResetItem(dummyTicker)
117-
peekCheckLimited(t, limiter, false)
123+
peekCheckLimited(t, limiter, false, false)
118124
cancel()
119125
}
120126

121127
func Test_NewDefaultLimiter(t *testing.T) {
122128
limiter := NewDefaultLimiter()
123129
limiter.Check(dummyTicker)
124-
peekCheckLimited(t, limiter, false)
130+
peekCheckLimited(t, limiter, false, false)
125131
for n := 0; n != DefaultBurst; n++ {
126132
limiter.Check(dummyTicker)
127133
}
128-
peekCheckLimited(t, limiter, true)
134+
peekCheckLimited(t, limiter, true, false)
135+
}
136+
137+
func Test_CheckAndPeekStringer(t *testing.T) {
138+
limiter := NewDefaultLimiter()
139+
limiter.CheckStringer(dummyTicker)
140+
peekCheckLimited(t, limiter, false, true)
141+
for n := 0; n != DefaultBurst; n++ {
142+
limiter.CheckStringer(dummyTicker)
143+
}
144+
peekCheckLimited(t, limiter, true, true)
129145
}
130146

131147
func Test_NewLimiter(t *testing.T) {
132148
limiter := NewLimiter(5, 1)
133149
limiter.Check(dummyTicker)
134-
peekCheckLimited(t, limiter, false)
150+
peekCheckLimited(t, limiter, false, false)
135151
limiter.Check(dummyTicker)
136-
peekCheckLimited(t, limiter, true)
152+
peekCheckLimited(t, limiter, true, false)
137153
}
138154

139155
func Test_NewDefaultStrictLimiter(t *testing.T) {
@@ -144,9 +160,9 @@ func Test_NewDefaultStrictLimiter(t *testing.T) {
144160
for n := 0; n < 25; n++ {
145161
limiter.Check(dummyTicker)
146162
}
147-
peekCheckLimited(t, limiter, false)
163+
peekCheckLimited(t, limiter, false, false)
148164
limiter.Check(dummyTicker)
149-
peekCheckLimited(t, limiter, true)
165+
peekCheckLimited(t, limiter, true, false)
150166
cancel()
151167
limiter = nil
152168
}
@@ -156,23 +172,23 @@ func Test_NewStrictLimiter(t *testing.T) {
156172
ctx, cancel := context.WithCancel(context.Background())
157173
go watchDebug(ctx, limiter, t)
158174
limiter.Check(dummyTicker)
159-
peekCheckLimited(t, limiter, false)
175+
peekCheckLimited(t, limiter, false, false)
160176
limiter.Check(dummyTicker)
161-
peekCheckLimited(t, limiter, true)
177+
peekCheckLimited(t, limiter, true, false)
162178
limiter.Check(dummyTicker)
163179
// for coverage, first we give the debug messages a couple seconds to be safe,
164180
// then we wait for the cache eviction to trigger a debug message.
165181
time.Sleep(2 * time.Second)
166182
t.Logf(<-limiter.DebugChannel())
167-
peekCheckLimited(t, limiter, false)
183+
peekCheckLimited(t, limiter, false, false)
168184
for n := 0; n != 6; n++ {
169185
limiter.Check(dummyTicker)
170186
}
171-
peekCheckLimited(t, limiter, true)
187+
peekCheckLimited(t, limiter, true, false)
172188
time.Sleep(5 * time.Second)
173-
peekCheckLimited(t, limiter, true)
189+
peekCheckLimited(t, limiter, true, false)
174190
time.Sleep(8 * time.Second)
175-
peekCheckLimited(t, limiter, false)
191+
peekCheckLimited(t, limiter, false, false)
176192
cancel()
177193
limiter = nil
178194
}
@@ -184,35 +200,35 @@ func Test_NewHardcoreLimiter(t *testing.T) {
184200
for n := 0; n != 4; n++ {
185201
limiter.Check(dummyTicker)
186202
}
187-
peekCheckLimited(t, limiter, false)
203+
peekCheckLimited(t, limiter, false, false)
188204
if !limiter.Check(dummyTicker) {
189205
t.Errorf("Should have been limited")
190206
}
191207
t.Logf("limited once, waiting for cache eviction")
192208
time.Sleep(2 * time.Second)
193-
peekCheckLimited(t, limiter, false)
209+
peekCheckLimited(t, limiter, false, false)
194210
for n := 0; n != 4; n++ {
195211
limiter.Check(dummyTicker)
196212
}
197-
peekCheckLimited(t, limiter, false)
213+
peekCheckLimited(t, limiter, false, false)
198214
if !limiter.Check(dummyTicker) {
199215
t.Errorf("Should have been limited")
200216
}
201217
limiter.Check(dummyTicker)
202218
limiter.Check(dummyTicker)
203219
time.Sleep(3 * time.Second)
204-
peekCheckLimited(t, limiter, true)
220+
peekCheckLimited(t, limiter, true, false)
205221
time.Sleep(5 * time.Second)
206-
peekCheckLimited(t, limiter, false)
222+
peekCheckLimited(t, limiter, false, false)
207223
for n := 0; n != 4; n++ {
208224
limiter.Check(dummyTicker)
209225
}
210-
peekCheckLimited(t, limiter, false)
226+
peekCheckLimited(t, limiter, false, false)
211227
for n := 0; n != 10; n++ {
212228
limiter.Check(dummyTicker)
213229
}
214230
time.Sleep(10 * time.Second)
215-
peekCheckLimited(t, limiter, true)
231+
peekCheckLimited(t, limiter, true, false)
216232
cancel()
217233
// for coverage, triggering the switch statement case for hardcore logic
218234
limiter2 := NewHardcoreLimiter(2, 5)
@@ -221,9 +237,9 @@ func Test_NewHardcoreLimiter(t *testing.T) {
221237
for n := 0; n != 6; n++ {
222238
limiter2.Check(dummyTicker)
223239
}
224-
peekCheckLimited(t, limiter2, true)
240+
peekCheckLimited(t, limiter2, true, false)
225241
time.Sleep(4 * time.Second)
226-
peekCheckLimited(t, limiter2, false)
242+
peekCheckLimited(t, limiter2, false, false)
227243
cancel2()
228244
}
229245

@@ -314,3 +330,18 @@ func Test_ConcurrentShouldLimit(t *testing.T) {
314330
concurrentTest(t, 50, 21, 20, true)
315331
concurrentTest(t, 50, 51, 50, true)
316332
}
333+
334+
func Test_debugChannelOverflow(t *testing.T) {
335+
limiter := NewDefaultLimiter()
336+
_ = limiter.DebugChannel()
337+
for n := 0; n != 78; n++ {
338+
limiter.Check(dummyTicker)
339+
if limiter.debugLost > 0 {
340+
t.Fatalf("debug channel overflowed")
341+
}
342+
}
343+
limiter.Check(dummyTicker)
344+
if limiter.debugLost == 0 {
345+
t.Fatalf("debug channel did not overflow")
346+
}
347+
}

0 commit comments

Comments
 (0)