From 04f7a4011323bb3d7c61c6641940813afaa3b095 Mon Sep 17 00:00:00 2001 From: Edward Date: Fri, 8 May 2020 19:34:33 +0800 Subject: [PATCH] add a parallelism pattern --- gomore/09_parallelism/parallelism.go | 109 ++++++++++++++++++++++ gomore/09_parallelism/parallelism_test.go | 26 ++++++ 2 files changed, 135 insertions(+) create mode 100644 gomore/09_parallelism/parallelism.go create mode 100644 gomore/09_parallelism/parallelism_test.go diff --git a/gomore/09_parallelism/parallelism.go b/gomore/09_parallelism/parallelism.go new file mode 100644 index 0000000..45d97d5 --- /dev/null +++ b/gomore/09_parallelism/parallelism.go @@ -0,0 +1,109 @@ +package parallelism + +import ( + "crypto/md5" + "errors" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "sort" + "sync" +) + +// A result is the product of reading and summing a file using MD5. +type result struct { + path string + sum [md5.Size]byte + err error +} + +// sumFiles starts goroutines to walk the directory tree at root and digest each +// regular file. These goroutines send the results of the digests on the result +// channel and send the result of the walk on the error channel. If done is +// closed, sumFiles abandons its work. +func sumFiles(done <-chan struct{}, root string) (<-chan result, <-chan error) { + // For each regular file, start a goroutine that sums the file and sends + // the result on c. Send the result of the walk on errc. + c := make(chan result) + errc := make(chan error, 1) + go func() { // HL + var wg sync.WaitGroup + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.Mode().IsRegular() { + return nil + } + wg.Add(1) + go func() { // HL + data, err := ioutil.ReadFile(path) + select { + case c <- result{path, md5.Sum(data), err}: // HL + case <-done: // HL + } + wg.Done() + }() + // Abort the walk if done is closed. + select { + case <-done: // HL + return errors.New("walk canceled") + default: + return nil + } + }) + // Walk has returned, so all calls to wg.Add are done. Start a + // goroutine to close c once all the sends are done. + go func() { // HL + wg.Wait() + close(c) // HL + }() + // No select needed here, since errc is buffered. + errc <- err // HL + }() + return c, errc +} + +// MD5All reads all the files in the file tree rooted at root and returns a map +// from file path to the MD5 sum of the file's contents. If the directory walk +// fails or any read operation fails, MD5All returns an error. In that case, +// MD5All does not wait for inflight read operations to complete. +func MD5All(root string) (map[string][md5.Size]byte, error) { + // MD5All closes the done channel when it returns; it may do so before + // receiving all the values from c and errc. + done := make(chan struct{}) // HLdone + defer close(done) // HLdone + + c, errc := sumFiles(done, root) // HLdone + + m := make(map[string][md5.Size]byte) + for r := range c { // HLrange + if r.err != nil { + return nil, r.err + } + m[r.path] = r.sum + } + if err := <-errc; err != nil { + return nil, err + } + return m, nil +} + +func main() { + // Calculate the MD5 sum of all files under the specified directory, + // then print the results sorted by path name. + m, err := MD5All(os.Args[1]) + if err != nil { + fmt.Println(err) + return + } + var paths []string + for path := range m { + paths = append(paths, path) + } + sort.Strings(paths) + for _, path := range paths { + fmt.Printf("%x %s\n", m[path], path) + } +} diff --git a/gomore/09_parallelism/parallelism_test.go b/gomore/09_parallelism/parallelism_test.go new file mode 100644 index 0000000..90b3590 --- /dev/null +++ b/gomore/09_parallelism/parallelism_test.go @@ -0,0 +1,26 @@ +package parallelism + +import ( + "fmt" + "os" + "sort" + "testing" +) + +func TestParallelism(t *testing.T) { + // Calculate the MD5 sum of all files under the specified directory, + // then print the results sorted by path name. + m, err := MD5All(os.Args[1]) + if err != nil { + fmt.Println(err) + return + } + var paths []string + for path := range m { + paths = append(paths, path) + } + sort.Strings(paths) + for _, path := range paths { + fmt.Printf("%x %s\n", m[path], path) + } +}