diff --git a/gomore/06_circuit_breaker/gobreaker/gobreaker.go b/gomore/06_circuit_breaker/gobreaker/gobreaker.go new file mode 100644 index 0000000..f402217 --- /dev/null +++ b/gomore/06_circuit_breaker/gobreaker/gobreaker.go @@ -0,0 +1,350 @@ +// Package gobreaker implements the Circuit Breaker pattern. +// See https://msdn.microsoft.com/en-us/library/dn589784.aspx. +package gobreaker + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// State is a type that represents a state of CircuitBreaker. +type State int + +// These constants are states of CircuitBreaker. +const ( + StateClosed State = iota + StateHalfOpen + StateOpen +) + +var ( + // ErrTooManyRequests is returned when the CB state is half open and the requests count is over the cb maxRequests + ErrTooManyRequests = errors.New("too many requests") + // ErrOpenState is returned when the CB state is open + ErrOpenState = errors.New("circuit breaker is open") +) + +// String implements stringer interface. +func (s State) String() string { + switch s { + case StateClosed: + return "closed" + case StateHalfOpen: + return "half-open" + case StateOpen: + return "open" + default: + return fmt.Sprintf("unknown state: %d", s) + } +} + +// Counts holds the numbers of requests and their successes/failures. +// CircuitBreaker clears the internal Counts either +// on the change of the state or at the closed-state intervals. +// Counts ignores the results of the requests sent before clearing. +type Counts struct { + Requests uint32 + TotalSuccesses uint32 + TotalFailures uint32 + ConsecutiveSuccesses uint32 + ConsecutiveFailures uint32 +} + +func (c *Counts) onRequest() { + c.Requests++ +} + +func (c *Counts) onSuccess() { + c.TotalSuccesses++ + c.ConsecutiveSuccesses++ + c.ConsecutiveFailures = 0 +} + +func (c *Counts) onFailure() { + c.TotalFailures++ + c.ConsecutiveFailures++ + c.ConsecutiveSuccesses = 0 +} + +func (c *Counts) clear() { + c.Requests = 0 + c.TotalSuccesses = 0 + c.TotalFailures = 0 + c.ConsecutiveSuccesses = 0 + c.ConsecutiveFailures = 0 +} + +// Settings configures CircuitBreaker: +// +// Name is the name of the CircuitBreaker. +// +// MaxRequests is the maximum number of requests allowed to pass through +// when the CircuitBreaker is half-open. +// If MaxRequests is 0, the CircuitBreaker allows only 1 request. +// +// Interval is the cyclic period of the closed state +// for the CircuitBreaker to clear the internal Counts. +// If Interval is less than or equal to 0, the CircuitBreaker doesn't clear internal Counts during the closed state. +// +// Timeout is the period of the open state, +// after which the state of the CircuitBreaker becomes half-open. +// If Timeout is less than or equal to 0, the timeout value of the CircuitBreaker is set to 60 seconds. +// +// ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state. +// If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state. +// If ReadyToTrip is nil, default ReadyToTrip is used. +// Default ReadyToTrip returns true when the number of consecutive failures is more than 5. +// +// OnStateChange is called whenever the state of the CircuitBreaker changes. +type Settings struct { + Name string + MaxRequests uint32 + Interval time.Duration + Timeout time.Duration + ReadyToTrip func(counts Counts) bool + OnStateChange func(name string, from State, to State) +} + +// CircuitBreaker is a state machine to prevent sending requests that are likely to fail. +type CircuitBreaker struct { + name string + maxRequests uint32 + interval time.Duration + timeout time.Duration + readyToTrip func(counts Counts) bool + onStateChange func(name string, from State, to State) + + mutex sync.Mutex + state State + generation uint64 + counts Counts + expiry time.Time +} + +// TwoStepCircuitBreaker is like CircuitBreaker but instead of surrounding a function +// with the breaker functionality, it only checks whether a request can proceed and +// expects the caller to report the outcome in a separate step using a callback. +type TwoStepCircuitBreaker struct { + cb *CircuitBreaker +} + +// NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. +func NewCircuitBreaker(st Settings) *CircuitBreaker { + cb := new(CircuitBreaker) + + cb.name = st.Name + cb.onStateChange = st.OnStateChange + + if st.MaxRequests == 0 { + cb.maxRequests = 1 + } else { + cb.maxRequests = st.MaxRequests + } + + if st.Interval <= 0 { + cb.interval = defaultInterval + } else { + cb.interval = st.Interval + } + + if st.Timeout <= 0 { + cb.timeout = defaultTimeout + } else { + cb.timeout = st.Timeout + } + + if st.ReadyToTrip == nil { + cb.readyToTrip = defaultReadyToTrip + } else { + cb.readyToTrip = st.ReadyToTrip + } + + cb.toNewGeneration(time.Now()) + + return cb +} + +// NewTwoStepCircuitBreaker returns a new TwoStepCircuitBreaker configured with the given Settings. +func NewTwoStepCircuitBreaker(st Settings) *TwoStepCircuitBreaker { + return &TwoStepCircuitBreaker{ + cb: NewCircuitBreaker(st), + } +} + +const defaultInterval = time.Duration(0) * time.Second +const defaultTimeout = time.Duration(60) * time.Second + +func defaultReadyToTrip(counts Counts) bool { + return counts.ConsecutiveFailures > 5 +} + +// Name returns the name of the CircuitBreaker. +func (cb *CircuitBreaker) Name() string { + return cb.name +} + +// State returns the current state of the CircuitBreaker. +func (cb *CircuitBreaker) State() State { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, _ := cb.currentState(now) + return state +} + +// Execute runs the given request if the CircuitBreaker accepts it. +// Execute returns an error instantly if the CircuitBreaker rejects the request. +// Otherwise, Execute returns the result of the request. +// If a panic occurs in the request, the CircuitBreaker handles it as an error +// and causes the same panic again. +func (cb *CircuitBreaker) Execute(req func() (interface{}, error)) (interface{}, error) { + generation, err := cb.beforeRequest() + if err != nil { + return nil, err + } + + defer func() { + e := recover() + if e != nil { + cb.afterRequest(generation, false) + panic(e) + } + }() + + result, err := req() + cb.afterRequest(generation, err == nil) + return result, err +} + +// Name returns the name of the TwoStepCircuitBreaker. +func (tscb *TwoStepCircuitBreaker) Name() string { + return tscb.cb.Name() +} + +// State returns the current state of the TwoStepCircuitBreaker. +func (tscb *TwoStepCircuitBreaker) State() State { + return tscb.cb.State() +} + +// Allow checks if a new request can proceed. It returns a callback that should be used to +// register the success or failure in a separate step. If the circuit breaker doesn't allow +// requests, it returns an error. +func (tscb *TwoStepCircuitBreaker) Allow() (done func(success bool), err error) { + generation, err := tscb.cb.beforeRequest() + if err != nil { + return nil, err + } + + return func(success bool) { + tscb.cb.afterRequest(generation, success) + }, nil +} + +func (cb *CircuitBreaker) beforeRequest() (uint64, error) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, generation := cb.currentState(now) + + if state == StateOpen { + return generation, ErrOpenState + } else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests { + return generation, ErrTooManyRequests + } + + cb.counts.onRequest() + return generation, nil +} + +func (cb *CircuitBreaker) afterRequest(before uint64, success bool) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + now := time.Now() + state, generation := cb.currentState(now) + if generation != before { + return + } + + if success { + cb.onSuccess(state, now) + } else { + cb.onFailure(state, now) + } +} + +func (cb *CircuitBreaker) onSuccess(state State, now time.Time) { + switch state { + case StateClosed: + cb.counts.onSuccess() + case StateHalfOpen: + cb.counts.onSuccess() + if cb.counts.ConsecutiveSuccesses >= cb.maxRequests { + cb.setState(StateClosed, now) + } + } +} + +func (cb *CircuitBreaker) onFailure(state State, now time.Time) { + switch state { + case StateClosed: + cb.counts.onFailure() + if cb.readyToTrip(cb.counts) { + cb.setState(StateOpen, now) + } + case StateHalfOpen: + cb.setState(StateOpen, now) + } +} + +func (cb *CircuitBreaker) currentState(now time.Time) (State, uint64) { + switch cb.state { + case StateClosed: + if !cb.expiry.IsZero() && cb.expiry.Before(now) { + cb.toNewGeneration(now) + } + case StateOpen: + if cb.expiry.Before(now) { + cb.setState(StateHalfOpen, now) + } + } + return cb.state, cb.generation +} + +func (cb *CircuitBreaker) setState(state State, now time.Time) { + if cb.state == state { + return + } + + prev := cb.state + cb.state = state + + cb.toNewGeneration(now) + + if cb.onStateChange != nil { + cb.onStateChange(cb.name, prev, state) + } +} + +func (cb *CircuitBreaker) toNewGeneration(now time.Time) { + cb.generation++ + cb.counts.clear() + + var zero time.Time + switch cb.state { + case StateClosed: + if cb.interval == 0 { + cb.expiry = zero + } else { + cb.expiry = now.Add(cb.interval) + } + case StateOpen: + cb.expiry = now.Add(cb.timeout) + default: // StateHalfOpen + cb.expiry = zero + } +} diff --git a/gomore/06_circuit_breaker/gobreaker/gobreaker_example_test.go b/gomore/06_circuit_breaker/gobreaker/gobreaker_example_test.go new file mode 100644 index 0000000..e85ac49 --- /dev/null +++ b/gomore/06_circuit_breaker/gobreaker/gobreaker_example_test.go @@ -0,0 +1,57 @@ +package gobreaker + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + +) + +var cb *gobreaker.CircuitBreaker + + +func TestGoBreaker(t *testing.T) { + body, err := Get("http://www.google.com/robots.txt") + if err != nil { + t.Fatal(err) + } + + fmt.Println(string(body)) +} + + +func initBreaker() { + var st gobreaker.Settings + st.Name = "HTTP GET" + st.ReadyToTrip = func(counts gobreaker.Counts) bool { + failureRatio := float64(counts.TotalFailures) / float64(counts.Requests) + return counts.Requests >= 3 && failureRatio >= 0.6 + } + + cb = gobreaker.NewCircuitBreaker(st) +} + +// Get wraps http.Get in CircuitBreaker. +func Get(url string) ([]byte, error) { + body, err := cb.Execute(func() (interface{}, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return body, nil + }) + if err != nil { + return nil, err + } + + return body.([]byte), nil +} + diff --git a/gomore/06_circuit_breaker/gobreaker/gobreaker_test.go b/gomore/06_circuit_breaker/gobreaker/gobreaker_test.go new file mode 100644 index 0000000..eb8ecc5 --- /dev/null +++ b/gomore/06_circuit_breaker/gobreaker/gobreaker_test.go @@ -0,0 +1,370 @@ +package gobreaker + +import ( + "fmt" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var defaultCB *CircuitBreaker +var customCB *CircuitBreaker +var negativeDurationCB *CircuitBreaker + +type StateChange struct { + name string + from State + to State +} + +var stateChange StateChange + +func pseudoSleep(cb *CircuitBreaker, period time.Duration) { + if !cb.expiry.IsZero() { + cb.expiry = cb.expiry.Add(-period) + } +} + +func succeed(cb *CircuitBreaker) error { + _, err := cb.Execute(func() (interface{}, error) { return nil, nil }) + return err +} + +func succeedLater(cb *CircuitBreaker, delay time.Duration) <-chan error { + ch := make(chan error) + go func() { + _, err := cb.Execute(func() (interface{}, error) { + time.Sleep(delay) + return nil, nil + }) + ch <- err + }() + return ch +} + +func succeed2Step(cb *TwoStepCircuitBreaker) error { + done, err := cb.Allow() + if err != nil { + return err + } + + done(true) + return nil +} + +func fail(cb *CircuitBreaker) error { + msg := "fail" + _, err := cb.Execute(func() (interface{}, error) { return nil, fmt.Errorf(msg) }) + if err.Error() == msg { + return nil + } + return err +} + +func fail2Step(cb *TwoStepCircuitBreaker) error { + done, err := cb.Allow() + if err != nil { + return err + } + + done(false) + return nil +} + +func causePanic(cb *CircuitBreaker) error { + _, err := cb.Execute(func() (interface{}, error) { panic("oops"); return nil, nil }) + return err +} + +func newCustom() *CircuitBreaker { + var customSt Settings + customSt.Name = "cb" + customSt.MaxRequests = 3 + customSt.Interval = time.Duration(30) * time.Second + customSt.Timeout = time.Duration(90) * time.Second + customSt.ReadyToTrip = func(counts Counts) bool { + numReqs := counts.Requests + failureRatio := float64(counts.TotalFailures) / float64(numReqs) + + counts.clear() // no effect on customCB.counts + + return numReqs >= 3 && failureRatio >= 0.6 + } + customSt.OnStateChange = func(name string, from State, to State) { + stateChange = StateChange{name, from, to} + } + + return NewCircuitBreaker(customSt) +} + +func newNegativeDurationCB() *CircuitBreaker { + var negativeSt Settings + negativeSt.Name = "ncb" + negativeSt.Interval = time.Duration(-30) * time.Second + negativeSt.Timeout = time.Duration(-90) * time.Second + + return NewCircuitBreaker(negativeSt) +} + +func init() { + defaultCB = NewCircuitBreaker(Settings{}) + customCB = newCustom() + negativeDurationCB = newNegativeDurationCB() +} + +func TestStateConstants(t *testing.T) { + assert.Equal(t, State(0), StateClosed) + assert.Equal(t, State(1), StateHalfOpen) + assert.Equal(t, State(2), StateOpen) + + assert.Equal(t, StateClosed.String(), "closed") + assert.Equal(t, StateHalfOpen.String(), "half-open") + assert.Equal(t, StateOpen.String(), "open") + assert.Equal(t, State(100).String(), "unknown state: 100") +} + +func TestNewCircuitBreaker(t *testing.T) { + defaultCB := NewCircuitBreaker(Settings{}) + assert.Equal(t, "", defaultCB.name) + assert.Equal(t, uint32(1), defaultCB.maxRequests) + assert.Equal(t, time.Duration(0), defaultCB.interval) + assert.Equal(t, time.Duration(60)*time.Second, defaultCB.timeout) + assert.NotNil(t, defaultCB.readyToTrip) + assert.Nil(t, defaultCB.onStateChange) + assert.Equal(t, StateClosed, defaultCB.state) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.True(t, defaultCB.expiry.IsZero()) + + customCB := newCustom() + assert.Equal(t, "cb", customCB.name) + assert.Equal(t, uint32(3), customCB.maxRequests) + assert.Equal(t, time.Duration(30)*time.Second, customCB.interval) + assert.Equal(t, time.Duration(90)*time.Second, customCB.timeout) + assert.NotNil(t, customCB.readyToTrip) + assert.NotNil(t, customCB.onStateChange) + assert.Equal(t, StateClosed, customCB.state) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + assert.False(t, customCB.expiry.IsZero()) + + negativeDurationCB := newNegativeDurationCB() + assert.Equal(t, "ncb", negativeDurationCB.name) + assert.Equal(t, uint32(1), negativeDurationCB.maxRequests) + assert.Equal(t, time.Duration(0)*time.Second, negativeDurationCB.interval) + assert.Equal(t, time.Duration(60)*time.Second, negativeDurationCB.timeout) + assert.NotNil(t, negativeDurationCB.readyToTrip) + assert.Nil(t, negativeDurationCB.onStateChange) + assert.Equal(t, StateClosed, negativeDurationCB.state) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, negativeDurationCB.counts) + assert.True(t, negativeDurationCB.expiry.IsZero()) +} + +func TestDefaultCircuitBreaker(t *testing.T) { + assert.Equal(t, "", defaultCB.Name()) + + for i := 0; i < 5; i++ { + assert.Nil(t, fail(defaultCB)) + } + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{5, 0, 5, 0, 5}, defaultCB.counts) + + assert.Nil(t, succeed(defaultCB)) + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{6, 1, 5, 1, 0}, defaultCB.counts) + + assert.Nil(t, fail(defaultCB)) + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{7, 1, 6, 0, 1}, defaultCB.counts) + + // StateClosed to StateOpen + for i := 0; i < 5; i++ { + assert.Nil(t, fail(defaultCB)) // 6 consecutive failures + } + assert.Equal(t, StateOpen, defaultCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.False(t, defaultCB.expiry.IsZero()) + + assert.Error(t, succeed(defaultCB)) + assert.Error(t, fail(defaultCB)) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + + pseudoSleep(defaultCB, time.Duration(59)*time.Second) + assert.Equal(t, StateOpen, defaultCB.State()) + + // StateOpen to StateHalfOpen + pseudoSleep(defaultCB, time.Duration(1)*time.Second) // over Timeout + assert.Equal(t, StateHalfOpen, defaultCB.State()) + assert.True(t, defaultCB.expiry.IsZero()) + + // StateHalfOpen to StateOpen + assert.Nil(t, fail(defaultCB)) + assert.Equal(t, StateOpen, defaultCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.False(t, defaultCB.expiry.IsZero()) + + // StateOpen to StateHalfOpen + pseudoSleep(defaultCB, time.Duration(60)*time.Second) + assert.Equal(t, StateHalfOpen, defaultCB.State()) + assert.True(t, defaultCB.expiry.IsZero()) + + // StateHalfOpen to StateClosed + assert.Nil(t, succeed(defaultCB)) + assert.Equal(t, StateClosed, defaultCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, defaultCB.counts) + assert.True(t, defaultCB.expiry.IsZero()) +} + +func TestCustomCircuitBreaker(t *testing.T) { + assert.Equal(t, "cb", customCB.Name()) + + for i := 0; i < 5; i++ { + assert.Nil(t, succeed(customCB)) + assert.Nil(t, fail(customCB)) + } + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{10, 5, 5, 0, 1}, customCB.counts) + + pseudoSleep(customCB, time.Duration(29)*time.Second) + assert.Nil(t, succeed(customCB)) + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{11, 6, 5, 1, 0}, customCB.counts) + + pseudoSleep(customCB, time.Duration(1)*time.Second) // over Interval + assert.Nil(t, fail(customCB)) + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{1, 0, 1, 0, 1}, customCB.counts) + + // StateClosed to StateOpen + assert.Nil(t, succeed(customCB)) + assert.Nil(t, fail(customCB)) // failure ratio: 2/3 >= 0.6 + assert.Equal(t, StateOpen, customCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + assert.False(t, customCB.expiry.IsZero()) + assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) + + // StateOpen to StateHalfOpen + pseudoSleep(customCB, time.Duration(90)*time.Second) + assert.Equal(t, StateHalfOpen, customCB.State()) + assert.True(t, defaultCB.expiry.IsZero()) + assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) + + assert.Nil(t, succeed(customCB)) + assert.Nil(t, succeed(customCB)) + assert.Equal(t, StateHalfOpen, customCB.State()) + assert.Equal(t, Counts{2, 2, 0, 2, 0}, customCB.counts) + + // StateHalfOpen to StateClosed + ch := succeedLater(customCB, time.Duration(100)*time.Millisecond) // 3 consecutive successes + time.Sleep(time.Duration(50) * time.Millisecond) + assert.Equal(t, Counts{3, 2, 0, 2, 0}, customCB.counts) + assert.Error(t, succeed(customCB)) // over MaxRequests + assert.Nil(t, <-ch) + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + assert.False(t, customCB.expiry.IsZero()) + assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) +} + +func TestTwoStepCircuitBreaker(t *testing.T) { + tscb := NewTwoStepCircuitBreaker(Settings{Name: "tscb"}) + assert.Equal(t, "tscb", tscb.Name()) + + for i := 0; i < 5; i++ { + assert.Nil(t, fail2Step(tscb)) + } + + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{5, 0, 5, 0, 5}, tscb.cb.counts) + + assert.Nil(t, succeed2Step(tscb)) + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{6, 1, 5, 1, 0}, tscb.cb.counts) + + assert.Nil(t, fail2Step(tscb)) + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{7, 1, 6, 0, 1}, tscb.cb.counts) + + // StateClosed to StateOpen + for i := 0; i < 5; i++ { + assert.Nil(t, fail2Step(tscb)) // 6 consecutive failures + } + assert.Equal(t, StateOpen, tscb.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + assert.False(t, tscb.cb.expiry.IsZero()) + + assert.Error(t, succeed2Step(tscb)) + assert.Error(t, fail2Step(tscb)) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + + pseudoSleep(tscb.cb, time.Duration(59)*time.Second) + assert.Equal(t, StateOpen, tscb.State()) + + // StateOpen to StateHalfOpen + pseudoSleep(tscb.cb, time.Duration(1)*time.Second) // over Timeout + assert.Equal(t, StateHalfOpen, tscb.State()) + assert.True(t, tscb.cb.expiry.IsZero()) + + // StateHalfOpen to StateOpen + assert.Nil(t, fail2Step(tscb)) + assert.Equal(t, StateOpen, tscb.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + assert.False(t, tscb.cb.expiry.IsZero()) + + // StateOpen to StateHalfOpen + pseudoSleep(tscb.cb, time.Duration(60)*time.Second) + assert.Equal(t, StateHalfOpen, tscb.State()) + assert.True(t, tscb.cb.expiry.IsZero()) + + // StateHalfOpen to StateClosed + assert.Nil(t, succeed2Step(tscb)) + assert.Equal(t, StateClosed, tscb.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, tscb.cb.counts) + assert.True(t, tscb.cb.expiry.IsZero()) +} + +func TestPanicInRequest(t *testing.T) { + assert.Panics(t, func() { causePanic(defaultCB) }) + assert.Equal(t, Counts{1, 0, 1, 0, 1}, defaultCB.counts) +} + +func TestGeneration(t *testing.T) { + pseudoSleep(customCB, time.Duration(29)*time.Second) + assert.Nil(t, succeed(customCB)) + ch := succeedLater(customCB, time.Duration(1500)*time.Millisecond) + time.Sleep(time.Duration(500) * time.Millisecond) + assert.Equal(t, Counts{2, 1, 0, 1, 0}, customCB.counts) + + time.Sleep(time.Duration(500) * time.Millisecond) // over Interval + assert.Equal(t, StateClosed, customCB.State()) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) + + // the request from the previous generation has no effect on customCB.counts + assert.Nil(t, <-ch) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts) +} + +func TestCircuitBreakerInParallel(t *testing.T) { + runtime.GOMAXPROCS(runtime.NumCPU()) + + ch := make(chan error) + + const numReqs = 10000 + routine := func() { + for i := 0; i < numReqs; i++ { + ch <- succeed(customCB) + } + } + + const numRoutines = 10 + for i := 0; i < numRoutines; i++ { + go routine() + } + + total := uint32(numReqs * numRoutines) + for i := uint32(0); i < total; i++ { + err := <-ch + assert.Nil(t, err) + } + assert.Equal(t, Counts{total, total, 0, total, 0}, customCB.counts) +}