diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go new file mode 100644 index 0000000..99c6732 --- /dev/null +++ b/circuitbreaker/circuitbreaker.go @@ -0,0 +1,284 @@ +package circuitbreaker + +import ( + "fmt" + "sync" + "time" +) + +// Interface contains behavior that needs to be implemented +// for an object to be wrapped in a circuit-breaker +type Interface interface { + OnFailure() + OnCircuitBreak() +} + +// Counts holds the numbers of requests and their successes/failures. +// CircuitBreaker clears the internal Counts either +// on the change of the state or at the closed-state intervals. +// Counts ignores the results of the requests sent before clearing. +type Counter struct { + Requests uint32 + TotalSuccesses uint32 + TotalFailures uint32 + ConsecutiveSuccesses uint32 + ConsecutiveFailures uint32 +} + +func (c *Counter) Request() { + c.Requests++ +} + +func (c *Counter) Success() { + c.TotalSuccesses++ + + c.ConsecutiveFailures = 0 + c.ConsecutiveSuccesses++ +} + +func (c *Counter) Failure() { + c.TotalFailures++ + + c.ConsecutiveSuccesses = 0 + c.ConsecutiveFailures++ +} + +func (c *Counter) Clear() { + c.Requests = 0 + c.TotalSuccesses = 0 + c.TotalFailures = 0 + c.ConsecutiveSuccesses = 0 + c.ConsecutiveFailures = 0 +} + +// Settings configures CircuitBreaker: +// +// Name is the name of the CircuitBreaker. +// +// MaxRequests is the maximum number of requests allowed to pass through +// when the CircuitBreaker is half-open. +// If MaxRequests is 0, the CircuitBreaker allows only 1 request. +// +// Interval is the cyclic period of the closed state +// for the CircuitBreaker to clear the internal Counts. +// If Interval is 0, the CircuitBreaker doesn't clear internal Counts during the closed state. +// +// Timeout is the period of the open state, +// after which the state of the CircuitBreaker becomes half-open. +// If Timeout is 0, the timeout value of the CircuitBreaker is set to 60 seconds. +// +// ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state. +// If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state. +// If ReadyToTrip is nil, default ReadyToTrip is used. +// Default ReadyToTrip returns true when the number of consecutive failures is more than 5. +// +// OnStateChange is called whenever the state of the CircuitBreaker changes. +type Settings struct { + Name string + MaxRequests uint32 + Interval time.Duration + Timeout time.Duration + ReadyToTrip func(counts Counts) bool + OnStateChange func(name string, from State, to State) +} + +// CircuitBreaker is a state machine to prevent sending requests that are likely to fail. +type CircuitBreaker struct { + name string + maxRequests uint32 + interval time.Duration + timeout time.Duration + readyToTrip func(counts Counts) bool + onStateChange func(name string, from State, to State) + + mutex sync.Mutex + state State + generation uint64 + counts Counts + expiry time.Time +} + +// NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. +func NewCircuitBreaker(st Settings) *CircuitBreaker { + cb := new(CircuitBreaker) + + cb.name = st.Name + cb.interval = st.Interval + cb.onStateChange = st.OnStateChange + + if st.MaxRequests == 0 { + cb.maxRequests = 1 + } else { + cb.maxRequests = st.MaxRequests + } + + if st.Timeout == 0 { + cb.timeout = defaultTimeout + } else { + cb.timeout = st.Timeout + } + + if st.ReadyToTrip == nil { + cb.readyToTrip = defaultReadyToTrip + } else { + cb.readyToTrip = st.ReadyToTrip + } + + cb.toNewGeneration(time.Now()) + + return cb +} + +const defaultTimeout = time.Duration(60) * time.Second + +func defaultReadyToTrip(counts Counts) bool { + return counts.ConsecutiveFailures > 5 +} + +// State returns the current state of the CircuitBreaker. +func (cb *CircuitBreaker) State() State { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, _ := cb.currentState(now) + return state +} + +// Execute runs the given request if the CircuitBreaker accepts it. +// Execute returns an error instantly if the CircuitBreaker rejects the request. +// Otherwise, Execute returns the result of the request. +// If a panic occurs in the request, the CircuitBreaker handles it as an error +// and causes the same panic again. +func (cb *CircuitBreaker) Execute(req func() (interface{}, error)) (interface{}, error) { + generation, err := cb.beforeRequest() + if err != nil { + return nil, err + } + + defer func() { + e := recover() + if e != nil { + cb.afterRequest(generation, fmt.Errorf("panic in request")) + panic(e) + } + }() + + result, err := req() + cb.afterRequest(generation, err) + return result, err +} + +func (cb *CircuitBreaker) beforeRequest() (uint64, error) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, generation := cb.currentState(now) + + if state == StateOpen { + return generation, cb.errorStateOpen() + } else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests { + return generation, fmt.Errorf("too many requests") + } + + cb.counts.onRequest() + return generation, nil +} + +func (cb *CircuitBreaker) afterRequest(before uint64, err error) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, generation := cb.currentState(now) + if generation != before { + return + } + + if err == nil { + cb.onSuccess(state, now) + } else { + cb.onFailure(state, now) + } +} + +func (cb *CircuitBreaker) onSuccess(state State, now time.Time) { + switch state { + case StateClosed: + cb.counts.onSuccess() + case StateHalfOpen: + cb.counts.onSuccess() + if cb.counts.ConsecutiveSuccesses >= cb.maxRequests { + cb.setState(StateClosed, now) + } + } +} + +func (cb *CircuitBreaker) onFailure(state State, now time.Time) { + switch state { + case StateClosed: + cb.counts.onFailure() + if cb.readyToTrip(cb.counts) { + cb.setState(StateOpen, now) + } + case StateHalfOpen: + cb.setState(StateOpen, now) + } +} + +func (cb *CircuitBreaker) currentState(now time.Time) (State, uint64) { + switch cb.state { + case StateClosed: + if !cb.expiry.IsZero() && cb.expiry.Before(now) { + cb.toNewGeneration(now) + } + case StateOpen: + if cb.expiry.Before(now) { + cb.setState(StateHalfOpen, now) + } + } + return cb.state, cb.generation +} + +func (cb *CircuitBreaker) setState(state State, now time.Time) { + if cb.state == state { + return + } + + prev := cb.state + cb.state = state + + cb.toNewGeneration(now) + + if cb.onStateChange != nil { + cb.onStateChange(cb.name, prev, state) + } +} + +func (cb *CircuitBreaker) toNewGeneration(now time.Time) { + cb.generation++ + cb.counts.clear() + + var zero time.Time + switch cb.state { + case StateClosed: + if cb.interval == 0 { + cb.expiry = zero + } else { + cb.expiry = now.Add(cb.interval) + } + case StateOpen: + cb.expiry = now.Add(cb.timeout) + default: // StateHalfOpen + cb.expiry = zero + } +} + +func (cb *CircuitBreaker) errorStateOpen() error { + if cb.name == "" { + return fmt.Errorf("circuit breaker is open") + } + + return fmt.Errorf("circuit breaker '%s' is open", cb.name) +}