Skip to content

Commit 56a7790

Browse files
authored
Fixed readBuffer issue (#15)
* Added test for reader returning n > 0 and io.EOF * Fixed test with n>0, io.EOF * Formatted * Refactored readBuffer.Get * More refactoring
1 parent c238446 commit 56a7790

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

decode_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,3 +2418,36 @@ func TestUnmarshalMaxDepth(t *testing.T) {
24182418
}
24192419
}
24202420
}
2421+
2422+
var _ io.Reader = &eofSignalReader{}
2423+
2424+
type eofSignalReader struct {
2425+
data []byte
2426+
index int
2427+
}
2428+
2429+
func newEOFSignalReader(data []byte) *eofSignalReader {
2430+
return &eofSignalReader{
2431+
data: data,
2432+
}
2433+
}
2434+
2435+
func (e *eofSignalReader) Read(p []byte) (int, error) {
2436+
n := len(p)
2437+
if len(e.data)-e.index >= n {
2438+
copy(p, e.data[e.index:e.index+n])
2439+
e.index += n
2440+
return n, nil
2441+
}
2442+
length := len(e.data) - e.index
2443+
copy(p, e.data[e.index:])
2444+
e.index = len(e.data)
2445+
return length, io.EOF
2446+
}
2447+
2448+
func TestUnmarshalEOFSignalReader(t *testing.T) {
2449+
want := strings.Repeat("a", 1<<10+1<<5)
2450+
got := ""
2451+
require.NoError(t, Unmarshal(newEOFSignalReader([]byte(`"`+want+`"`)), &got))
2452+
assert.Equal(t, want, got)
2453+
}

internal/readbuffer/readbuffer.go

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ type ReadBuffer struct {
1515
index int
1616
len int
1717
src io.Reader
18+
finished bool
1819
}
1920

2021
func New(stream io.Reader) ReadBuffer {
@@ -24,30 +25,44 @@ func New(stream io.Reader) ReadBuffer {
2425
}
2526
}
2627

27-
func (r *ReadBuffer) Get(n int) ([]byte, error) {
28-
if r.len-r.index >= n {
29-
res := r.buf[r.index : r.index+n]
30-
r.index += n
31-
return res, nil
28+
func (r *ReadBuffer) Get(n int) (res []byte, err error) {
29+
n, res = r.appendFromBuffer(n, res)
30+
if n == 0 {
31+
return
3232
}
33-
res := make([]byte, r.len-r.index)
34-
copy(res, r.buf[r.index:r.len])
35-
r.index = r.len
36-
n -= len(res)
37-
var err error
3833
for err = r.load(); err == nil; err = r.load() {
39-
if r.len-r.index >= n {
40-
res = append(res, r.buf[r.index:r.index+n]...)
41-
r.index += n
42-
return res, nil
34+
n, res = r.appendFromBuffer(n, res)
35+
if n == 0 {
36+
return
4337
}
44-
45-
res = append(res, r.buf[r.index:r.len]...)
46-
n -= r.len - r.index
47-
r.index = r.len
4838
time.Sleep(readTimeout)
4939
}
50-
return res, err
40+
if n == 0 {
41+
return
42+
}
43+
if err == io.EOF {
44+
r.finished = true
45+
err = nil
46+
} else {
47+
return nil, err
48+
}
49+
n, res = r.appendFromBuffer(n, res)
50+
if n > 0 && r.finished {
51+
return res, io.EOF
52+
}
53+
return
54+
}
55+
56+
func (r *ReadBuffer) appendFromBuffer(n int, dst []byte) (int, []byte) {
57+
length := r.len - r.index
58+
if length >= n {
59+
res := append(dst, r.buf[r.index:r.index+n]...)
60+
r.index += n
61+
return 0, res
62+
}
63+
res := append(dst, r.buf[r.index:r.len]...)
64+
r.index = r.len
65+
return n - length, res
5166
}
5267

5368
func (r *ReadBuffer) load() error {

0 commit comments

Comments
 (0)