diff --git a/compression.go b/compression.go index 36a7410..1657b27 100644 --- a/compression.go +++ b/compression.go @@ -48,6 +48,7 @@ type DecompressionReader struct { mReadBuffer unsafe.Pointer mTempBuffer unsafe.Pointer mTempBufferLen int + mEof bool } type CompressionWriter struct { mBase io.WriteCloser @@ -113,6 +114,7 @@ func NewDecompressionReadCloser(r io.ReadCloser) (retval DecompressionReader) { retval.mTempBuffer = C.malloc(IMPL_LZMA_BUFFER_LENGTH) retval.mTempBufferLen = int(IMPL_LZMA_BUFFER_LENGTH) retval.mBase = r; + retval.mEof = false //mStream = LZMA_STREAM_INIT;<-- assume we're zero initialized var ret C.lzma_ret ret = C.lzma_stream_decoder( @@ -162,20 +164,32 @@ func (dr *DecompressionReader) Read(data []byte) (int, error) { var action C.lzma_action action = LZMA_RUN; var err error - if (dr.mStream.avail_in == 0) { + if (dr.mStream.avail_in == 0 && !dr.mEof) { dr.mStream.next_in = (*C.uint8_t)(dr.mReadBuffer); var bytesRead int bytesRead, err = dr.mBase.Read(readSlice) dr.mStream.avail_in = C.size_t(bytesRead); - if (bytesRead == 0) { - action = LZMA_FINISH; + if err == io.EOF { + dr.mEof = true + } else if err != nil { + return 0, err } } + if (dr.mStream.avail_in == 0) { + action = LZMA_FINISH + } + var ret C.lzma_ret ret = C.lzma_code(&dr.mStream, action); if (dr.mStream.avail_out == 0 || ret == LZMA_STREAM_END) { writeSize := len(data) - int(dr.mStream.avail_out) copy(data[:writeSize], tempSlice[:writeSize]) + + if ret == LZMA_STREAM_END { + err = io.EOF + } else { + err = nil + } return writeSize, err ///// (ret == LZMA_STREAM_END ///// || (ret == LZMA_OK &&writeSize > 0)) diff --git a/compression_test.go b/compression_test.go index 1931910..28797df 100644 --- a/compression_test.go +++ b/compression_test.go @@ -86,3 +86,61 @@ func TestRoundTrip(t *testing.T) { t.Errorf("Data to compress got bigger after xzing") } } + +type earlyEofReader struct { + r io.Reader +} + +func (eer *earlyEofReader) Read(p []byte) (int, error) { + n, _ := eer.r.Read(p) + return n, io.EOF +} + +func TestRoundTripEarlyEof(t *testing.T) { + byteArray := make([]byte, 8192) + initialReader := bytes.NewBuffer(byteArray) + var compressedData bytes.Buffer + cw := NewCompressionWriter(&compressedData) + for { + var buffer [4096]byte + nRead, readErr := initialReader.Read(buffer[:]) + if readErr != nil && readErr != io.EOF { + panic(readErr) + } + nWrite, writeErr := cw.Write(buffer[:nRead]) + if writeErr != nil { + panic(writeErr) + } + _ = nWrite + if readErr != nil { + break + } + } + cw.Close() + dr := NewDecompressionReader(&earlyEofReader{r: bytes.NewBuffer(compressedData.Bytes())}) + var roundTrippedData bytes.Buffer + for { + var buffer [4096]byte + nRead, readErr := dr.Read(buffer[:]) + if readErr != nil && readErr != io.EOF { + panic(readErr) + } + nWrite, writeErr := roundTrippedData.Write(buffer[:nRead]) + if writeErr != nil || nWrite < nRead{ + panic(writeErr) + } + if readErr != nil { + break + } + } + dr.Close() + if !bytes.Equal(byteArray, roundTrippedData.Bytes()) { + t.Errorf("Byte array does not match") + } + if string(compressedData.Bytes()[1:5]) != "7zXZ" { + t.Errorf("Invalid 7z signature") + } + if len(compressedData.Bytes()) > len(byteArray) { + t.Errorf("Data to compress got bigger after xzing") + } +}