added error handling to worker pool

This commit is contained in:
Jian Han 2018-01-14 21:54:09 +10:00
parent e2c47bd910
commit 3e0f3e4dfb

View File

@ -11,16 +11,19 @@ package main
// queue that has no guarantee it will ever be worked on. // queue that has no guarantee it will ever be worked on.
import ( import (
"errors"
"log" "log"
"sync" "sync"
"time" "time"
"github.com/davecgh/go-spew/spew"
) )
// Worker must be implemented by types that want to use // Worker must be implemented by types that want to use
// the work pool. // the work pool.
// The Worker interface declares a single method called Task // The Worker interface declares a single method called Task
type Worker interface { type Worker interface {
Task() Task() error
} }
// Pool provides a pool of goroutines that can execute any Worker // Pool provides a pool of goroutines that can execute any Worker
@ -29,14 +32,16 @@ type Worker interface {
// pool of goroutines and will have methods that process the work. The type declares two // pool of goroutines and will have methods that process the work. The type declares two
// fields, one named work, which is a channel of the Worker interface type, and a sync.WaitGroup named wg. // fields, one named work, which is a channel of the Worker interface type, and a sync.WaitGroup named wg.
type Pool struct { type Pool struct {
work chan Worker work chan Worker
wg sync.WaitGroup wg sync.WaitGroup
errChan chan error
} }
// New creates a new work pool. // New creates a new work pool.
func New(maxGoroutines int) *Pool { func New(maxGoroutines int) *Pool {
p := Pool{ p := Pool{
work: make(chan Worker), work: make(chan Worker),
errChan: make(chan error),
} }
p.wg.Add(maxGoroutines) p.wg.Add(maxGoroutines)
// The for range loop blocks until theres a Worker interface value to receive on the // The for range loop blocks until theres a Worker interface value to receive on the
@ -46,7 +51,7 @@ func New(maxGoroutines int) *Pool {
for i := 0; i < maxGoroutines; i++ { for i := 0; i < maxGoroutines; i++ {
go func() { go func() {
for w := range p.work { for w := range p.work {
w.Task() p.errChan <- w.Task()
} }
p.wg.Done() p.wg.Done()
}() }()
@ -60,8 +65,12 @@ func New(maxGoroutines int) *Pool {
// work channel. Since the work channel is an unbuffered channel, the caller must wait // work channel. Since the work channel is an unbuffered channel, the caller must wait
// for a goroutine from the pool to receive it. This is what we want, because the caller // for a goroutine from the pool to receive it. This is what we want, because the caller
// needs the guarantee that the work being submitted is being worked on once the call to Run returns. // needs the guarantee that the work being submitted is being worked on once the call to Run returns.
func (p *Pool) Run(w Worker) { func (p *Pool) Run(w Worker) (err error) {
p.work <- w p.work <- w
select {
case err = <-p.errChan:
}
return
} }
// Shutdown waits for all the goroutines to shutdown. // Shutdown waits for all the goroutines to shutdown.
@ -71,6 +80,7 @@ func (p *Pool) Run(w Worker) {
// they have terminated. // they have terminated.
func (p *Pool) Shutdown() { func (p *Pool) Shutdown() {
close(p.work) close(p.work)
close(p.errChan)
p.wg.Wait() p.wg.Wait()
} }
@ -78,8 +88,10 @@ var names = []string{
"steve", "steve",
"bob", "bob",
"mary", "mary",
"therese",
"jason", "jason",
"Bob",
"Lee",
"Jane",
} }
// namePrinter provides special support for printing names. // namePrinter provides special support for printing names.
@ -88,9 +100,13 @@ type namePrinter struct {
} }
// Task implements the Worker interface. // Task implements the Worker interface.
func (m *namePrinter) Task() { func (m *namePrinter) Task() error {
time.Sleep(time.Second * 1)
if m.name == "jason" {
return errors.New("Invalid name")
}
log.Println(m.name) log.Println(m.name)
time.Sleep(time.Second) return nil
} }
// main is the entry point for all Go programs. // main is the entry point for all Go programs.
@ -98,22 +114,23 @@ func main() {
// Create a work pool with 2 goroutines. // Create a work pool with 2 goroutines.
p := New(2) p := New(2)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(100 * len(names)) wg.Add(len(names))
for i := 0; i < 100; i++ { // Iterate over the slice of names.
// Iterate over the slice of names. for _, name := range names {
for _, name := range names { // Create a namePrinter and provide the
// Create a namePrinter and provide the // specific name.
// specific name. np := namePrinter{
np := namePrinter{ name: name,
name: name,
}
go func() {
// Submit the task to be worked on. When RunTask
// returns we know it is being handled.
p.Run(&np)
wg.Done()
}()
} }
go func() {
// Submit the task to be worked on. When RunTask
// returns we know it is being handled.
if err := p.Run(&np); err != nil {
spew.Dump(err)
}
wg.Done()
}()
} }
wg.Wait() wg.Wait()
// Shutdown the work pool and wait for all existing work // Shutdown the work pool and wait for all existing work