@@ -9,6 +9,29 @@ import (
9
9
"github.com/patrickmn/go-cache"
10
10
)
11
11
12
+ const (
13
+ strictPrefix = "strict"
14
+ hardcorePrefix = "hardcore"
15
+ )
16
+
17
+ var _counters = & sync.Pool {
18
+ New : func () interface {} {
19
+ i := & atomic.Int64 {}
20
+ i .Store (0 )
21
+ return i
22
+ },
23
+ }
24
+
25
+ func getCounter () * atomic.Int64 {
26
+ got := _counters .Get ().(* atomic.Int64 )
27
+ got .Store (0 )
28
+ return got
29
+ }
30
+
31
+ func putCounter (i * atomic.Int64 ) {
32
+ _counters .Put (i )
33
+ }
34
+
12
35
/*NewDefaultLimiter returns a ratelimiter with default settings without Strict mode.
13
36
* Default window: 25 seconds
14
37
* Default burst: 25 requests */
@@ -70,28 +93,40 @@ func NewHardcoreLimiter(window int, burst int) *Limiter {
70
93
return l
71
94
}
72
95
96
+ // ResetItem removes an Identity from the limiter's cache.
97
+ // This effectively resets the rate limit for the Identity.
73
98
func (q * Limiter ) ResetItem (from Identity ) {
74
99
q .Patrons .Delete (from .UniqueKey ())
75
- q .debugPrintf ("ratelimit for %s has been reset" , from .UniqueKey ())
100
+ q .debugPrintf (msgRateLimitedRst , from .UniqueKey ())
101
+ }
102
+
103
+ func (q * Limiter ) onEvict (src string , count interface {}) {
104
+ q .debugPrintf (msgRateLimitExpired , src , count )
105
+ putCounter (count .(* atomic.Int64 ))
106
+
76
107
}
77
108
78
109
func newLimiter (policy Policy ) * Limiter {
79
110
window := time .Duration (policy .Window ) * time .Second
80
- return & Limiter {
111
+ q := & Limiter {
81
112
Ruleset : policy ,
82
113
Patrons : cache .New (window , time .Duration (policy .Window )* time .Second ),
83
- known : make (map [interface {}]* int64 ),
114
+ known : make (map [interface {}]* atomic. Int64 ),
84
115
RWMutex : & sync.RWMutex {},
85
116
debugMutex : & sync.RWMutex {},
86
117
debug : DebugDisabled ,
87
118
}
119
+ q .Patrons .OnEvicted (q .onEvict )
120
+ return q
88
121
}
89
122
90
- func intPtr (i int64 ) * int64 {
91
- return & i
123
+ func intPtr (i int64 ) * atomic.Int64 {
124
+ a := getCounter ()
125
+ a .Store (i )
126
+ return a
92
127
}
93
128
94
- func (q * Limiter ) getHitsPtr (src string ) * int64 {
129
+ func (q * Limiter ) getHitsPtr (src string ) * atomic. Int64 {
95
130
q .RLock ()
96
131
if _ , ok := q .known [src ]; ok {
97
132
oldPtr := q .known [src ]
@@ -100,29 +135,29 @@ func (q *Limiter) getHitsPtr(src string) *int64 {
100
135
}
101
136
q .RUnlock ()
102
137
q .Lock ()
103
- newPtr := intPtr ( 0 )
138
+ newPtr := getCounter ( )
104
139
q .known [src ] = newPtr
105
140
q .Unlock ()
106
141
return newPtr
107
142
}
108
143
109
- func (q * Limiter ) strictLogic (src string , count int64 ) {
144
+ func (q * Limiter ) strictLogic (src string , count * atomic. Int64 ) {
110
145
knownHits := q .getHitsPtr (src )
111
- atomic . AddInt64 ( knownHits , 1 )
146
+ knownHits . Add ( 1 )
112
147
var extwindow int64
113
- prefix := "hardcore"
148
+ prefix := hardcorePrefix
114
149
switch {
115
150
case q .Ruleset .Hardcore && q .Ruleset .Window > 1 :
116
- extwindow = atomic . LoadInt64 ( knownHits ) * q .Ruleset .Window
151
+ extwindow = knownHits . Load ( ) * q .Ruleset .Window
117
152
case q .Ruleset .Hardcore && q .Ruleset .Window <= 1 :
118
- extwindow = atomic . LoadInt64 ( knownHits ) * 2
153
+ extwindow = knownHits . Load ( ) * 2
119
154
case ! q .Ruleset .Hardcore :
120
- prefix = "strict"
121
- extwindow = atomic . LoadInt64 ( knownHits ) + q .Ruleset .Window
155
+ prefix = strictPrefix
156
+ extwindow = knownHits . Load ( ) + q .Ruleset .Window
122
157
}
123
158
exttime := time .Duration (extwindow ) * time .Second
124
159
_ = q .Patrons .Replace (src , count , exttime )
125
- q .debugPrintf ("%s ratelimit for %s: last count %d. time: %s" , prefix , src , count , exttime )
160
+ q .debugPrintf (msgRateLimitStrict , prefix , src , count . Load () , exttime )
126
161
}
127
162
128
163
func (q * Limiter ) CheckStringer (from fmt.Stringer ) bool {
@@ -133,33 +168,32 @@ func (q *Limiter) CheckStringer(from fmt.Stringer) bool {
133
168
// Check checks and increments an Identities UniqueKey() output against a list of cached strings to determine and raise it's ratelimitting status.
134
169
func (q * Limiter ) Check (from Identity ) (limited bool ) {
135
170
var count int64
136
- var err error
137
- src := from .UniqueKey ()
138
- count , err = q .Patrons .IncrementInt64 (src , 1 )
139
- if err != nil {
140
- // IncrementInt64 should only error if the value is not an int64, so we can assume it's a new key.
141
- q .debugPrintf ("ratelimit %s (new) " , src )
171
+ aval , ok := q .Patrons .Get (from .UniqueKey ())
172
+ switch {
173
+ case ! ok :
174
+ q .debugPrintf (msgRateLimitedNew , from .UniqueKey ())
175
+ aval = intPtr (1 )
142
176
// We can't reproduce this throwing an error, we can only assume that the key is new.
143
- _ = q .Patrons .Add (src , int64 (1 ), time .Duration (q .Ruleset .Window )* time .Second )
144
- return false
145
- }
146
- if count < q .Ruleset .Burst {
177
+ _ = q .Patrons .Add (from .UniqueKey (), aval , time .Duration (q .Ruleset .Window )* time .Second )
147
178
return false
179
+ case aval != nil :
180
+ count = aval .(* atomic.Int64 ).Add (1 )
181
+ if count < q .Ruleset .Burst {
182
+ return false
183
+ }
148
184
}
149
185
if q .Ruleset .Strict {
150
- q .strictLogic (src , count )
151
- } else {
152
- q .debugPrintf ("ratelimit %s: last count %d. time: %s" ,
153
- src , count , time .Duration (q .Ruleset .Window )* time .Second )
186
+ q .strictLogic (from .UniqueKey (), aval .(* atomic.Int64 ))
187
+ return true
154
188
}
189
+ q .debugPrintf (msgRateLimited , from .UniqueKey (), count , time .Duration (q .Ruleset .Window )* time .Second )
155
190
return true
156
191
}
157
192
158
193
// Peek checks an Identities UniqueKey() output against a list of cached strings to determine ratelimitting status without adding to its request count.
159
194
func (q * Limiter ) Peek (from Identity ) bool {
160
- q .Patrons .DeleteExpired ()
161
195
if ct , ok := q .Patrons .Get (from .UniqueKey ()); ok {
162
- count := ct .(int64 )
196
+ count := ct .(* atomic. Int64 ). Load ( )
163
197
if count > q .Ruleset .Burst {
164
198
return true
165
199
}
0 commit comments