From 7ea178bbc3d0839a9b5d5cceac7e38d2085817ff Mon Sep 17 00:00:00 2001 From: goodliu Date: Sun, 11 May 2025 15:15:14 +0800 Subject: [PATCH 1/2] feat(circuit breaker): add sliding window for request result tracking --- circuit_breaker.go | 124 +++++++++++++++++++++++++++++++++++++++------ client_test.go | 4 +- 2 files changed, 111 insertions(+), 17 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index 8b25251b..456b702e 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -8,6 +8,7 @@ package resty import ( "errors" "net/http" + "sync" "sync/atomic" "time" ) @@ -29,15 +30,15 @@ type CircuitBreaker struct { failureThreshold uint32 successThreshold uint32 state atomic.Value // circuitBreakerState - failureCount atomic.Uint32 - successCount atomic.Uint32 - lastFailureAt time.Time + openStartAt atomic.Value // time.Time + sw *tfsw } // NewCircuitBreaker method creates a new [CircuitBreaker] with default settings. // // The default settings are: // - Timeout: 10 seconds +// - SlidingWindowBucketSize: 10 // - FailThreshold: 3 // - SuccessThreshold: 1 // - Policies: CircuitBreaker5xxPolicy @@ -48,6 +49,11 @@ func NewCircuitBreaker() *CircuitBreaker { failureThreshold: 3, successThreshold: 1, } + cb.sw = newSlidingWindow( + func() totalAndFailures { return totalAndFailures{} }, + cb.timeout, + 10, + ) cb.state.Store(circuitBreakerStateClosed) return cb } @@ -75,6 +81,7 @@ func (cb *CircuitBreaker) SetPolicies(policies ...CircuitBreakerPolicy) *Circuit // timeout reaches, a single request is allowed to determine the state. func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) *CircuitBreaker { cb.timeout = timeout + cb.sw.SetInterval(timeout) return cb } @@ -142,30 +149,34 @@ func (cb *CircuitBreaker) applyPolicies(resp *http.Response) { } if failed { - if cb.failureCount.Load() > 0 && time.Since(cb.lastFailureAt) > cb.timeout { - cb.failureCount.Store(0) - } - + cb.sw.Add(totalAndFailures{total: 1, failures: 1}) switch cb.getState() { case circuitBreakerStateClosed: - failCount := cb.failureCount.Add(1) - if failCount >= cb.failureThreshold { + if cb.sw.Get().failures >= int(cb.failureThreshold) { cb.open() - } else { - cb.lastFailureAt = time.Now() } case circuitBreakerStateHalfOpen: cb.open() + case circuitBreakerStateOpen: + if time.Since(cb.openStartAt.Load().(time.Time)) >= cb.timeout { + cb.changeState(circuitBreakerStateHalfOpen) + } } + } else { + cb.sw.Add(totalAndFailures{total: 1, failures: 0}) switch cb.getState() { case circuitBreakerStateClosed: return case circuitBreakerStateHalfOpen: - successCount := cb.successCount.Add(1) - if successCount >= cb.successThreshold { + totalAndFailure := cb.sw.Get() + if totalAndFailure.total-totalAndFailure.failures >= int(cb.successThreshold) { cb.changeState(circuitBreakerStateClosed) } + case circuitBreakerStateOpen: + if time.Since(cb.openStartAt.Load().(time.Time)) >= cb.timeout { + cb.changeState(circuitBreakerStateHalfOpen) + } } } } @@ -179,7 +190,90 @@ func (cb *CircuitBreaker) open() { } func (cb *CircuitBreaker) changeState(state circuitBreakerState) { - cb.failureCount.Store(0) - cb.successCount.Store(0) cb.state.Store(state) + cb.openStartAt.Store(time.Now()) +} + +type tfsw = slidingWindow[totalAndFailures] + +func newSlidingWindow[G group[G]]( + newEmpty func() G, + interval time.Duration, + bucketSize int, +) *slidingWindow[G] { + values := make([]G, 0, bucketSize) + for i := 0; i < bucketSize; i++ { + values = append(values, newEmpty()) + } + return &slidingWindow[G]{ + total: newEmpty(), + values: values, + lastStart: time.Now(), + interval: interval / time.Duration(bucketSize), + } +} + +type slidingWindow[G group[G]] struct { + mutex sync.RWMutex + total G + values []G + + idx int + lastStart time.Time + interval time.Duration +} + +// group is a mathematical concept. The values in the sliding window adhere to group properties. +type group[T any] interface { + op(T) T + empty() T + inverse() T +} + +func (sw *slidingWindow[G]) Add(val G) { + sw.mutex.Lock() + defer sw.mutex.Unlock() + for elapsed := time.Since(sw.lastStart); elapsed > sw.interval; elapsed -= sw.interval { + sw.idx++ + if sw.idx >= len(sw.values) { + sw.idx = 0 + } + sw.lastStart = sw.lastStart.Add(sw.interval) + sw.total = sw.total.op(sw.values[sw.idx].inverse()) + sw.values[sw.idx] = sw.values[sw.idx].empty() + } + sw.total = sw.total.op(val) + sw.values[sw.idx] = sw.values[sw.idx].op(val) +} + +func (sw *slidingWindow[G]) Get() G { + sw.mutex.RLock() + defer sw.mutex.RUnlock() + return sw.total +} +func (sw *slidingWindow[G]) SetInterval(interval time.Duration) { + sw.mutex.Lock() + defer sw.mutex.Unlock() + sw.interval = interval / time.Duration(len(sw.values)) +} + +type totalAndFailures struct { + total int + failures int +} + +func (tf totalAndFailures) op(g totalAndFailures) totalAndFailures { + tf.total += g.total + tf.failures += g.failures + return tf +} + +func (tf totalAndFailures) empty() totalAndFailures { + return totalAndFailures{} +} + +func (tf totalAndFailures) inverse() totalAndFailures { + tf.total = -tf.total + tf.failures = -tf.failures + return tf } diff --git a/client_test.go b/client_test.go index 8a5eef50..6c064ecc 100644 --- a/client_test.go +++ b/client_test.go @@ -1507,11 +1507,11 @@ func TestClientCircuitBreaker(t *testing.T) { _, err = c.R().Get(ts.URL + "/500") assertError(t, err) - assertEqual(t, uint32(1), c.circuitBreaker.failureCount.Load()) + assertEqual(t, 1, c.circuitBreaker.sw.Get().failures) time.Sleep(timeout) _, err = c.R().Get(ts.URL + "/500") assertError(t, err) - assertEqual(t, uint32(1), c.circuitBreaker.failureCount.Load()) + assertEqual(t, 1, c.circuitBreaker.sw.Get().failures) } From f383d5dbfc32448ae171da02d99ae47be86cc765 Mon Sep 17 00:00:00 2001 From: goodliu Date: Sun, 11 May 2025 15:40:48 +0800 Subject: [PATCH 2/2] refactor: streamline state transition logic in CircuitBreaker --- circuit_breaker.go | 41 ++++++++++++++++------------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index 456b702e..8592dba7 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -150,33 +150,24 @@ func (cb *CircuitBreaker) applyPolicies(resp *http.Response) { if failed { cb.sw.Add(totalAndFailures{total: 1, failures: 1}) - switch cb.getState() { - case circuitBreakerStateClosed: - if cb.sw.Get().failures >= int(cb.failureThreshold) { - cb.open() - } - case circuitBreakerStateHalfOpen: - cb.open() - case circuitBreakerStateOpen: - if time.Since(cb.openStartAt.Load().(time.Time)) >= cb.timeout { - cb.changeState(circuitBreakerStateHalfOpen) - } - } - } else { cb.sw.Add(totalAndFailures{total: 1, failures: 0}) - switch cb.getState() { - case circuitBreakerStateClosed: - return - case circuitBreakerStateHalfOpen: - totalAndFailure := cb.sw.Get() - if totalAndFailure.total-totalAndFailure.failures >= int(cb.successThreshold) { - cb.changeState(circuitBreakerStateClosed) - } - case circuitBreakerStateOpen: - if time.Since(cb.openStartAt.Load().(time.Time)) >= cb.timeout { - cb.changeState(circuitBreakerStateHalfOpen) - } + } + switch cb.getState() { + case circuitBreakerStateClosed: + if cb.sw.Get().failures >= int(cb.failureThreshold) { + cb.open() + } + case circuitBreakerStateHalfOpen: + totalAndFailure := cb.sw.Get() + if totalAndFailure.total-totalAndFailure.failures >= int(cb.successThreshold) { + cb.changeState(circuitBreakerStateClosed) + } else { + cb.open() + } + case circuitBreakerStateOpen: + if time.Since(cb.openStartAt.Load().(time.Time)) >= cb.timeout { + cb.changeState(circuitBreakerStateHalfOpen) } } }