From 9c73cabd21928c58fd6564eabc214f0c2e525ea6 Mon Sep 17 00:00:00 2001 From: Neur0toxine Date: Thu, 26 Sep 2024 15:29:39 +0300 Subject: [PATCH] fix for body field writer exhaustion --- core/logger/attrs.go | 5 +++ core/logger/attrs_test.go | 81 +++++++++++++++++++++++++++++---------- 2 files changed, 66 insertions(+), 20 deletions(-) diff --git a/core/logger/attrs.go b/core/logger/attrs.go index 66ce7b0..2d6e2c3 100644 --- a/core/logger/attrs.go +++ b/core/logger/attrs.go @@ -96,6 +96,11 @@ func Body(val any) zap.Field { if err != nil { return zap.String(BodyAttr, fmt.Sprintf("%#v", val)) } + if seeker, ok := item.(io.Seeker); ok { + _, _ = seeker.Seek(0, 0) + } else if writer, ok := item.(io.Writer); ok { + _, _ = writer.Write(data) + } var m interface{} if err := json.Unmarshal(data, &m); err == nil { return zap.Any(BodyAttr, m) diff --git a/core/logger/attrs_test.go b/core/logger/attrs_test.go index 376a719..016a550 100644 --- a/core/logger/attrs_test.go +++ b/core/logger/attrs_test.go @@ -4,6 +4,8 @@ import ( "bytes" "errors" "fmt" + "github.com/stretchr/testify/require" + "go.uber.org/zap" "io" "net/http" "testing" @@ -62,94 +64,133 @@ func TestHTTPStatusName(t *testing.T) { func TestStreamID(t *testing.T) { var cases = []struct { + name string input interface{} result interface{} }{ { + name: "empty", input: "", result: "", }, { + name: "string", input: "test body", result: "test body", }, } for _, c := range cases { - val := StreamID(c.input) - assert.Equal(t, StreamIDAttr, val.Key) - assert.Equal(t, c.result, val.String) + t.Run(c.name, func(t *testing.T) { + val := StreamID(c.input) + assert.Equal(t, StreamIDAttr, val.Key) + assert.Equal(t, c.result, val.String) + }) } } func TestBody(t *testing.T) { var cases = []struct { - input interface{} - result interface{} + name string + input interface{} + result interface{} + asserts func(t *testing.T, field zap.Field, input, result interface{}) }{ { + name: "empty string input", input: "", result: nil, }, { + name: "nil input", input: nil, result: nil, }, { + name: "string input", input: "test body", result: "test body", }, { + name: "json input", input: `{"success":true}`, result: map[string]interface{}{"success": true}, }, { + name: "empty byte slice input", input: []byte{}, result: nil, }, { - input: nil, - result: nil, - }, - { + name: "byte slice input", input: []byte("test body"), result: "test body", }, { + name: "json byte slice input", input: []byte(`{"success":true}`), result: map[string]interface{}{"success": true}, }, { + name: "eof reader input", input: newReaderMock(func(p []byte) (n int, err error) { return 0, io.EOF }), result: nil, }, { + name: "empty reader input", input: newReaderMockData([]byte{}), result: nil, }, { + name: "data reader input", input: newReaderMockData([]byte("ooga booga")), result: "ooga booga", }, - { + name: "json data reader input", input: newReaderMockData([]byte(`{"success":true}`)), result: map[string]interface{}{"success": true}, }, + { + name: "check that seeker is rewound", + input: bytes.NewReader([]byte(`{"success":true}`)), + result: map[string]interface{}{"success": true}, + asserts: func(t *testing.T, val zap.Field, input, result interface{}) { + data, err := io.ReadAll(input.(io.Reader)) + require.NoError(t, err) + assert.Equal(t, []byte(`{"success":true}`), data) + }, + }, + { + name: "check that writer is rebuilt", + input: bytes.NewBuffer([]byte(`{"success":true}`)), + result: map[string]interface{}{"success": true}, + asserts: func(t *testing.T, val zap.Field, input, result interface{}) { + data, err := io.ReadAll(input.(io.Reader)) + require.NoError(t, err) + assert.Equal(t, []byte(`{"success":true}`), data) + }, + }, } for _, c := range cases { - val := Body(c.input) - assert.Equal(t, BodyAttr, val.Key) + t.Run(c.name, func(t *testing.T) { + val := Body(c.input) + assert.Equal(t, BodyAttr, val.Key) - switch assertion := c.result.(type) { - case string: - assert.Equal(t, assertion, val.String) - case int: - assert.Equal(t, assertion, int(val.Integer)) - default: - assert.Equal(t, c.result, val.Interface) - } + switch assertion := c.result.(type) { + case string: + assert.Equal(t, assertion, val.String) + case int: + assert.Equal(t, assertion, int(val.Integer)) + default: + assert.Equal(t, c.result, val.Interface) + } + + if c.asserts != nil { + c.asserts(t, val, c.input, c.result) + } + }) } }