package circuitbreaker import ( "fmt" "sync" "time" ) // Interface contains behavior that needs to be implemented // for an object to be wrapped in a circuit-breaker type Interface interface { OnFailure() OnCircuitBreak() } // Counts holds the numbers of requests and their successes/failures. // CircuitBreaker clears the internal Counts either // on the change of the state or at the closed-state intervals. // Counts ignores the results of the requests sent before clearing. type Counter struct { Requests uint32 TotalSuccesses uint32 TotalFailures uint32 ConsecutiveSuccesses uint32 ConsecutiveFailures uint32 } func (c *Counter) Request() { c.Requests++ } func (c *Counter) Success() { c.TotalSuccesses++ c.ConsecutiveFailures = 0 c.ConsecutiveSuccesses++ } func (c *Counter) Failure() { c.TotalFailures++ c.ConsecutiveSuccesses = 0 c.ConsecutiveFailures++ } func (c *Counter) Clear() { c.Requests = 0 c.TotalSuccesses = 0 c.TotalFailures = 0 c.ConsecutiveSuccesses = 0 c.ConsecutiveFailures = 0 } // Settings configures CircuitBreaker: // // Name is the name of the CircuitBreaker. // // MaxRequests is the maximum number of requests allowed to pass through // when the CircuitBreaker is half-open. // If MaxRequests is 0, the CircuitBreaker allows only 1 request. // // Interval is the cyclic period of the closed state // for the CircuitBreaker to clear the internal Counts. // If Interval is 0, the CircuitBreaker doesn't clear internal Counts during the closed state. // // Timeout is the period of the open state, // after which the state of the CircuitBreaker becomes half-open. // If Timeout is 0, the timeout value of the CircuitBreaker is set to 60 seconds. // // ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state. // If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state. // If ReadyToTrip is nil, default ReadyToTrip is used. // Default ReadyToTrip returns true when the number of consecutive failures is more than 5. // // OnStateChange is called whenever the state of the CircuitBreaker changes. type Settings struct { Name string MaxRequests uint32 Interval time.Duration Timeout time.Duration ReadyToTrip func(counts Counts) bool OnStateChange func(name string, from State, to State) } // CircuitBreaker is a state machine to prevent sending requests that are likely to fail. type CircuitBreaker struct { name string maxRequests uint32 interval time.Duration timeout time.Duration readyToTrip func(counts Counts) bool onStateChange func(name string, from State, to State) mutex sync.Mutex state State generation uint64 counts Counts expiry time.Time } // NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. func NewCircuitBreaker(st Settings) *CircuitBreaker { cb := new(CircuitBreaker) cb.name = st.Name cb.interval = st.Interval cb.onStateChange = st.OnStateChange if st.MaxRequests == 0 { cb.maxRequests = 1 } else { cb.maxRequests = st.MaxRequests } if st.Timeout == 0 { cb.timeout = defaultTimeout } else { cb.timeout = st.Timeout } if st.ReadyToTrip == nil { cb.readyToTrip = defaultReadyToTrip } else { cb.readyToTrip = st.ReadyToTrip } cb.toNewGeneration(time.Now()) return cb } const defaultTimeout = time.Duration(60) * time.Second func defaultReadyToTrip(counts Counts) bool { return counts.ConsecutiveFailures > 5 } // State returns the current state of the CircuitBreaker. func (cb *CircuitBreaker) State() State { cb.mutex.Lock() defer cb.mutex.Unlock() now := time.Now() state, _ := cb.currentState(now) return state } // Execute runs the given request if the CircuitBreaker accepts it. // Execute returns an error instantly if the CircuitBreaker rejects the request. // Otherwise, Execute returns the result of the request. // If a panic occurs in the request, the CircuitBreaker handles it as an error // and causes the same panic again. func (cb *CircuitBreaker) Execute(req func() (interface{}, error)) (interface{}, error) { generation, err := cb.beforeRequest() if err != nil { return nil, err } defer func() { e := recover() if e != nil { cb.afterRequest(generation, fmt.Errorf("panic in request")) panic(e) } }() result, err := req() cb.afterRequest(generation, err) return result, err } func (cb *CircuitBreaker) beforeRequest() (uint64, error) { cb.mutex.Lock() defer cb.mutex.Unlock() now := time.Now() state, generation := cb.currentState(now) if state == StateOpen { return generation, cb.errorStateOpen() } else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests { return generation, fmt.Errorf("too many requests") } cb.counts.onRequest() return generation, nil } func (cb *CircuitBreaker) afterRequest(before uint64, err error) { cb.mutex.Lock() defer cb.mutex.Unlock() now := time.Now() state, generation := cb.currentState(now) if generation != before { return } if err == nil { cb.onSuccess(state, now) } else { cb.onFailure(state, now) } } func (cb *CircuitBreaker) onSuccess(state State, now time.Time) { switch state { case StateClosed: cb.counts.onSuccess() case StateHalfOpen: cb.counts.onSuccess() if cb.counts.ConsecutiveSuccesses >= cb.maxRequests { cb.setState(StateClosed, now) } } } func (cb *CircuitBreaker) onFailure(state State, now time.Time) { switch state { case StateClosed: cb.counts.onFailure() if cb.readyToTrip(cb.counts) { cb.setState(StateOpen, now) } case StateHalfOpen: cb.setState(StateOpen, now) } } func (cb *CircuitBreaker) currentState(now time.Time) (State, uint64) { switch cb.state { case StateClosed: if !cb.expiry.IsZero() && cb.expiry.Before(now) { cb.toNewGeneration(now) } case StateOpen: if cb.expiry.Before(now) { cb.setState(StateHalfOpen, now) } } return cb.state, cb.generation } func (cb *CircuitBreaker) setState(state State, now time.Time) { if cb.state == state { return } prev := cb.state cb.state = state cb.toNewGeneration(now) if cb.onStateChange != nil { cb.onStateChange(cb.name, prev, state) } } func (cb *CircuitBreaker) toNewGeneration(now time.Time) { cb.generation++ cb.counts.clear() var zero time.Time switch cb.state { case StateClosed: if cb.interval == 0 { cb.expiry = zero } else { cb.expiry = now.Add(cb.interval) } case StateOpen: cb.expiry = now.Add(cb.timeout) default: // StateHalfOpen cb.expiry = zero } } func (cb *CircuitBreaker) errorStateOpen() error { if cb.name == "" { return fmt.Errorf("circuit breaker is open") } return fmt.Errorf("circuit breaker '%s' is open", cb.name) }