diff --git a/message.go b/message.go index 6315532..5f7391a 100644 --- a/message.go +++ b/message.go @@ -46,6 +46,16 @@ func (t MessageType) String() string { return "StreamCommit" case MessageTypeStreamAbort: return "StreamAbort" + case MessageTypeBeginPrepare: + return "BeginPrepare" + case MessageTypePrepare: + return "Prepare" + case MessageTypeCommitPrepared: + return "CommitPrepared" + case MessageTypeRollbackPrepared: + return "RollbackPrepared" + case MessageTypeStreamPrepare: + return "StreamPrepare" default: return "Unknown" } @@ -53,20 +63,29 @@ func (t MessageType) String() string { // List of types of logical replication messages. const ( - MessageTypeBegin MessageType = 'B' - MessageTypeMessage MessageType = 'M' - MessageTypeCommit MessageType = 'C' - MessageTypeOrigin MessageType = 'O' - MessageTypeRelation MessageType = 'R' - MessageTypeType MessageType = 'Y' - MessageTypeInsert MessageType = 'I' - MessageTypeUpdate MessageType = 'U' - MessageTypeDelete MessageType = 'D' - MessageTypeTruncate MessageType = 'T' + MessageTypeBegin MessageType = 'B' + MessageTypeMessage MessageType = 'M' + MessageTypeCommit MessageType = 'C' + MessageTypeOrigin MessageType = 'O' + MessageTypeRelation MessageType = 'R' + MessageTypeType MessageType = 'Y' + MessageTypeInsert MessageType = 'I' + MessageTypeUpdate MessageType = 'U' + MessageTypeDelete MessageType = 'D' + MessageTypeTruncate MessageType = 'T' + + // introduced in protocol version 2 MessageTypeStreamStart MessageType = 'S' MessageTypeStreamStop MessageType = 'E' MessageTypeStreamCommit MessageType = 'c' MessageTypeStreamAbort MessageType = 'A' + + // introduced in protocol version 3 + MessageTypeBeginPrepare MessageType = 'b' + MessageTypePrepare MessageType = 'P' + MessageTypeCommitPrepared MessageType = 'K' + MessageTypeRollbackPrepared MessageType = 'r' + MessageTypeStreamPrepare MessageType = 'p' ) // Message is a message received from server. @@ -182,7 +201,6 @@ func (m *BeginMessage) Decode(src []byte) error { m.Xid = binary.BigEndian.Uint32(src[low:]) m.SetType(MessageTypeBegin) - return nil } diff --git a/messageV2.go b/messageV2.go index 990e46f..9b23bb8 100644 --- a/messageV2.go +++ b/messageV2.go @@ -7,7 +7,6 @@ import ( // MessageDecoderV2 decodes message from V2 protocol into struct. type MessageDecoderV2 interface { - MessageDecoder DecodeV2(src []byte, inStream bool) error } diff --git a/messageV2_test.go b/messageV2_test.go index bebd996..9f98fbc 100644 --- a/messageV2_test.go +++ b/messageV2_test.go @@ -1,9 +1,9 @@ package pglogrepl import ( - "fmt" - "github.com/stretchr/testify/suite" "testing" + + "github.com/stretchr/testify/suite" ) func TestLogicalDecodingMessageV2Suite(t *testing.T) { @@ -123,7 +123,6 @@ func (s *streamCommitSuite) Test() { msg[0] = 'c' bigEndian.PutUint32(msg[1:], xid) - fmt.Printf("%+v\n", msg) msg[5] = flags bigEndian.PutUint64(msg[6:], uint64(commitLSN)) bigEndian.PutUint64(msg[14:], uint64(transactionEndLSN)) diff --git a/messageV3.go b/messageV3.go new file mode 100644 index 0000000..d29884d --- /dev/null +++ b/messageV3.go @@ -0,0 +1,215 @@ +package pglogrepl + +import ( + "time" +) + +type MessageDecoderV3 interface { + DecodeV3(src []byte, inStream bool) error +} + +type BeginPrepareMessageV3 struct { + baseMessage + PrepareLSN LSN + TransactionEndLSN LSN + // The time at which the transaction was prepared. + PrepareTime time.Time + // The transaction ID of the prepared transaction. + Xid uint32 + // The user defined GID of the prepared transaction. + Gid string +} + +func (m *BeginPrepareMessageV3) DecodeV3(src []byte, _ bool) (err error) { + if len(src) < 29 { + return m.lengthError("BeginPrepareMessage", 29, len(src)) + } + + var low, used int + m.PrepareLSN, used = m.decodeLSN(src) + low += used + m.TransactionEndLSN, used = m.decodeLSN(src[low:]) + low += used + m.PrepareTime, used = m.decodeTime(src[low:]) + low += used + m.Xid, used = m.decodeUint32(src[low:]) + low += used + m.Gid, _ = m.decodeString(src[low:]) + m.SetType(MessageTypeBeginPrepare) + + return nil +} + +type PrepareMessageV3 struct { + baseMessage + // Flags currently unused (must be 0). + Flags uint8 + PrepareLSN LSN + TransactionEndLSN LSN + // The time at which the transaction was prepared. + PrepareTime time.Time + // The transaction ID of the prepared transaction. + Xid uint32 + // The user defined GID of the prepared transaction. + Gid string +} + +func (m *PrepareMessageV3) DecodeV3(src []byte, _ bool) (err error) { + if len(src) < 30 { + return m.lengthError("PrepareMessage", 30, len(src)) + } + + var low, used int + m.Flags = src[low] + low += 1 + m.PrepareLSN, used = m.decodeLSN(src[low:]) + low += used + m.TransactionEndLSN, used = m.decodeLSN(src[low:]) + low += used + m.PrepareTime, used = m.decodeTime(src[low:]) + low += used + m.Xid, used = m.decodeUint32(src[low:]) + low += used + m.Gid, _ = m.decodeString(src[low:]) + m.SetType(MessageTypePrepare) + + return nil +} + +type CommitPreparedMessageV3 struct { + baseMessage + // Flags currently unused (must be 0). + Flags uint8 + CommitLSN LSN + TransactionEndLSN LSN + CommitTime time.Time + Xid uint32 + // The user defined GID of the prepared transaction. + Gid string +} + +func (m *CommitPreparedMessageV3) DecodeV3(src []byte, _ bool) (err error) { + if len(src) < 30 { + return m.lengthError("CommitPreparedMessage", 30, len(src)) + } + + var low, used int + m.Flags = src[low] + low += 1 + m.CommitLSN, used = m.decodeLSN(src[low:]) + low += used + m.TransactionEndLSN, used = m.decodeLSN(src[low:]) + low += used + m.CommitTime, used = m.decodeTime(src[low:]) + low += used + m.Xid, used = m.decodeUint32(src[low:]) + low += used + m.Gid, _ = m.decodeString(src[low:]) + m.SetType(MessageTypeCommitPrepared) + + return nil +} + +type RollbackPreparedMessageV3 struct { + baseMessage + // Flags currently unused (must be 0). + Flags uint8 + TransactionEndLSN LSN + // The end LSN of the rollback of the prepared transaction. + TransactionRollbackLSN LSN + PrepareTime time.Time + RollbackTime time.Time + Xid uint32 + // The user defined GID of the prepared transaction. + Gid string +} + +func (m *RollbackPreparedMessageV3) DecodeV3(src []byte, _ bool) (err error) { + if len(src) < 38 { + return m.lengthError("RollbackPreparedMessage", 38, len(src)) + } + + var low, used int + m.Flags = src[low] + low += 1 + m.TransactionEndLSN, used = m.decodeLSN(src[low:]) + low += used + m.TransactionRollbackLSN, used = m.decodeLSN(src[low:]) + low += used + m.PrepareTime, used = m.decodeTime(src[low:]) + low += used + m.RollbackTime, used = m.decodeTime(src[low:]) + low += used + m.Xid, used = m.decodeUint32(src[low:]) + low += used + m.Gid, _ = m.decodeString(src[low:]) + m.SetType(MessageTypeRollbackPrepared) + + return nil +} + +type StreamPrepareMessageV3 struct { + baseMessage + // Flags currently unused (must be 0). + Flags uint8 + PrepareLSN LSN + TransactionEndLSN LSN + PrepareTime time.Time + Xid uint32 + // The user defined GID of the prepared transaction. + Gid string +} + +func (m *StreamPrepareMessageV3) DecodeV3(src []byte, _ bool) (err error) { + if len(src) < 30 { + return m.lengthError("StreamPrepareMessage", 30, len(src)) + } + + var low, used int + m.Flags = src[low] + low += 1 + m.PrepareLSN, used = m.decodeLSN(src[low:]) + low += used + m.TransactionEndLSN, used = m.decodeLSN(src[low:]) + low += used + m.PrepareTime, used = m.decodeTime(src[low:]) + low += used + m.Xid, used = m.decodeUint32(src[low:]) + low += used + m.Gid, _ = m.decodeString(src[low:]) + m.SetType(MessageTypeStreamPrepare) + + return nil +} + +// ParseV3 parse a logical replication message from protocol version #3 +// it accepts a slice of bytes read from PG and inStream parameter +// inStream must be true when StreamStartMessageV2 has been read +// it must be false after StreamStopMessageV2 has been read +func ParseV3(data []byte, inStream bool) (m Message, err error) { + var decoder MessageDecoderV3 + msgType := MessageType(data[0]) + + switch msgType { + case MessageTypeBeginPrepare: + decoder = new(BeginPrepareMessageV3) + case MessageTypePrepare: + decoder = new(PrepareMessageV3) + case MessageTypeCommitPrepared: + decoder = new(CommitPreparedMessageV3) + case MessageTypeRollbackPrepared: + decoder = new(RollbackPreparedMessageV3) + case MessageTypeStreamPrepare: + decoder = new(StreamPrepareMessageV3) + default: + // all messages from V2 are unchanged in V3 + // so we can just call ParseV2 + return ParseV2(data, inStream) + } + + if err = decoder.DecodeV3(data[1:], inStream); err != nil { + return nil, err + } + + return decoder.(Message), nil +} diff --git a/messageV3_test.go b/messageV3_test.go new file mode 100644 index 0000000..4d8ae00 --- /dev/null +++ b/messageV3_test.go @@ -0,0 +1,431 @@ +package pglogrepl + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" +) + +func TestTooShortMessageV3Suite(t *testing.T) { + suite.Run(t, new(tooShortMessageV3Suite)) +} + +type tooShortMessageV3Suite struct { + messageSuite +} + +func (s *tooShortMessageV3Suite) TestTooShortError() { + msg := make([]byte, 29) + + tooShortForMessageType := func(messageType MessageType, minBytes int) { + for _, inStream := range []bool{false, true} { + msg[0] = uint8(messageType) + m, err := ParseV3(msg, inStream) + s.Nil(m) + s.ErrorContains(err, + fmt.Sprintf("%sMessage must have %d bytes, got 28 bytes", messageType, minBytes)) + } + } + + tooShortForMessageType(MessageTypeBeginPrepare, 29) + tooShortForMessageType(MessageTypePrepare, 30) + tooShortForMessageType(MessageTypeCommitPrepared, 30) + tooShortForMessageType(MessageTypeRollbackPrepared, 38) + tooShortForMessageType(MessageTypeStreamPrepare, 30) +} + +func TestBeginPrepareMessageV3Suite(t *testing.T) { + suite.Run(t, new(beginPrepareMessageV3Suite)) +} + +type beginPrepareMessageV3Suite struct { + messageSuite +} + +func (s *beginPrepareMessageV3Suite) Test() { + // last byte is for NUL terminator + msg := make([]byte, 1+8+8+8+4+4+1) + msg[0] = uint8(MessageTypeBeginPrepare) + prepareLSN := s.newLSN() + transactionEndLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + xid := s.newXid() + gid := "test" + bigEndian.PutUint64(msg[1:], uint64(prepareLSN)) + bigEndian.PutUint64(msg[1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+8+8:], prepareTimeU64) + bigEndian.PutUint32(msg[1+8+8+8:], xid) + s.putString(msg[1+8+8+8+4:], gid) + + expected := &BeginPrepareMessageV3{ + PrepareLSN: prepareLSN, + TransactionEndLSN: transactionEndLSN, + PrepareTime: prepareTime, + Xid: xid, + Gid: gid, + } + expected.msgType = MessageTypeBeginPrepare + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + // ideally we should error if inStream true + // but sticking to what other messages do for now + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*BeginPrepareMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func (s *beginPrepareMessageV3Suite) TestNoGID() { + msg := make([]byte, 1+8+8+8+4+1) + msg[0] = uint8(MessageTypeBeginPrepare) + prepareLSN := s.newLSN() + transactionEndLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + xid := s.newXid() + bigEndian.PutUint64(msg[1:], uint64(prepareLSN)) + bigEndian.PutUint64(msg[1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+8+8:], prepareTimeU64) + bigEndian.PutUint32(msg[1+8+8+8:], xid) + msg[1+8+8+8+4] = 0 + + expected := &BeginPrepareMessageV3{ + PrepareLSN: prepareLSN, + TransactionEndLSN: transactionEndLSN, + PrepareTime: prepareTime, + Xid: xid, + } + expected.msgType = MessageTypeBeginPrepare + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*BeginPrepareMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func TestPrepareMessageV3Suite(t *testing.T) { + suite.Run(t, new(prepareMessageV3Suite)) +} + +type prepareMessageV3Suite struct { + messageSuite +} + +func (s *prepareMessageV3Suite) Test() { + msg := make([]byte, 1+1+8+8+8+4+4+1) + msg[0] = uint8(MessageTypePrepare) + msg[1] = 0 + prepareLSN := s.newLSN() + transactionEndLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + xid := s.newXid() + gid := "test" + bigEndian.PutUint64(msg[1+1:], uint64(prepareLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], prepareTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8:], xid) + s.putString(msg[1+1+8+8+8+4:], gid) + + expected := &PrepareMessageV3{ + Flags: 0, + PrepareLSN: prepareLSN, + TransactionEndLSN: transactionEndLSN, + PrepareTime: prepareTime, + Xid: xid, + Gid: gid, + } + expected.msgType = MessageTypePrepare + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*PrepareMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func (s *prepareMessageV3Suite) TestNoGID() { + msg := make([]byte, 1+1+8+8+8+4+1) + msg[0] = uint8(MessageTypePrepare) + msg[1] = 0 + prepareLSN := s.newLSN() + transactionEndLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + xid := s.newXid() + bigEndian.PutUint64(msg[1+1:], uint64(prepareLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], prepareTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8:], xid) + msg[1+1+8+8+8+4] = 0 + + expected := &PrepareMessageV3{ + Flags: 0, + PrepareLSN: prepareLSN, + TransactionEndLSN: transactionEndLSN, + PrepareTime: prepareTime, + Xid: xid, + } + expected.msgType = MessageTypePrepare + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*PrepareMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func TestCommitPreparedV3Suite(t *testing.T) { + suite.Run(t, new(commitPreparedMessageV3Suite)) +} + +type commitPreparedMessageV3Suite struct { + messageSuite +} + +func (s *commitPreparedMessageV3Suite) Test() { + msg := make([]byte, 1+1+8+8+8+4+4+1) + msg[0] = uint8(MessageTypeCommitPrepared) + msg[1] = 0 + commitLSN := s.newLSN() + transactionEndLSN := s.newLSN() + commitTime, commitTimeU64 := s.newTime() + xid := s.newXid() + gid := "test" + bigEndian.PutUint64(msg[1+1:], uint64(commitLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], commitTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8:], xid) + s.putString(msg[1+1+8+8+8+4:], gid) + + expected := &CommitPreparedMessageV3{ + Flags: 0, + CommitLSN: commitLSN, + TransactionEndLSN: transactionEndLSN, + CommitTime: commitTime, + Xid: xid, + Gid: gid, + } + expected.msgType = MessageTypeCommitPrepared + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*CommitPreparedMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func (s *commitPreparedMessageV3Suite) TestNoGID() { + msg := make([]byte, 1+1+8+8+8+4+1) + msg[0] = uint8(MessageTypeCommitPrepared) + msg[1] = 0 + commitLSN := s.newLSN() + transactionEndLSN := s.newLSN() + commitTime, commitTimeU64 := s.newTime() + xid := s.newXid() + bigEndian.PutUint64(msg[1+1:], uint64(commitLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], commitTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8:], xid) + msg[1+1+8+8+8+4] = 0 + + expected := &CommitPreparedMessageV3{ + Flags: 0, + CommitLSN: commitLSN, + TransactionEndLSN: transactionEndLSN, + CommitTime: commitTime, + Xid: xid, + } + expected.msgType = MessageTypeCommitPrepared + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*CommitPreparedMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func TestRollbackPreparedV3Suite(t *testing.T) { + suite.Run(t, new(rollbackPreparedMessageV3Suite)) +} + +type rollbackPreparedMessageV3Suite struct { + messageSuite +} + +func (s *rollbackPreparedMessageV3Suite) Test() { + msg := make([]byte, 1+1+8+8+8+8+4+4+1) + msg[0] = uint8(MessageTypeRollbackPrepared) + msg[1] = 0 + transactionEndLSN := s.newLSN() + transactionRollbackLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + rollbackTime, rollbackTimeU64 := s.newTime() + xid := s.newXid() + gid := "test" + bigEndian.PutUint64(msg[1+1:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionRollbackLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], prepareTimeU64) + bigEndian.PutUint64(msg[1+1+8+8+8:], rollbackTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8+8:], xid) + s.putString(msg[1+1+8+8+8+8+4:], gid) + + expected := &RollbackPreparedMessageV3{ + Flags: 0, + TransactionEndLSN: transactionEndLSN, + TransactionRollbackLSN: transactionRollbackLSN, + PrepareTime: prepareTime, + RollbackTime: rollbackTime, + Xid: xid, + Gid: gid, + } + expected.msgType = MessageTypeRollbackPrepared + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*RollbackPreparedMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func (s *rollbackPreparedMessageV3Suite) TestNoGID() { + msg := make([]byte, 1+1+8+8+8+8+4+1) + msg[0] = uint8(MessageTypeRollbackPrepared) + msg[1] = 0 + transactionEndLSN := s.newLSN() + transactionRollbackLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + rollbackTime, rollbackTimeU64 := s.newTime() + xid := s.newXid() + bigEndian.PutUint64(msg[1+1:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionRollbackLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], prepareTimeU64) + bigEndian.PutUint64(msg[1+1+8+8+8:], rollbackTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8+8:], xid) + msg[1+1+8+8+8+8+4] = 0 + + expected := &RollbackPreparedMessageV3{ + Flags: 0, + TransactionEndLSN: transactionEndLSN, + TransactionRollbackLSN: transactionRollbackLSN, + PrepareTime: prepareTime, + RollbackTime: rollbackTime, + Xid: xid, + } + expected.msgType = MessageTypeRollbackPrepared + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*RollbackPreparedMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func TestStreamPrepareMessageV3Suite(t *testing.T) { + suite.Run(t, new(streamPrepareMessageV3Suite)) +} + +type streamPrepareMessageV3Suite struct { + messageSuite +} + +func (s *streamPrepareMessageV3Suite) Test() { + msg := make([]byte, 1+1+8+8+8+4+4+1) + msg[0] = uint8(MessageTypeStreamPrepare) + msg[1] = 0 + prepareLSN := s.newLSN() + transactionEndLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + xid := s.newXid() + gid := "test" + bigEndian.PutUint64(msg[1+1:], uint64(prepareLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], prepareTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8:], xid) + s.putString(msg[1+1+8+8+8+4:], gid) + + expected := &StreamPrepareMessageV3{ + Flags: 0, + PrepareLSN: prepareLSN, + TransactionEndLSN: transactionEndLSN, + PrepareTime: prepareTime, + Xid: xid, + Gid: gid, + } + expected.msgType = MessageTypeStreamPrepare + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*StreamPrepareMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} + +func (s *streamPrepareMessageV3Suite) TestNoGID() { + msg := make([]byte, 1+1+8+8+8+4+1) + msg[0] = uint8(MessageTypeStreamPrepare) + msg[1] = 0 + prepareLSN := s.newLSN() + transactionEndLSN := s.newLSN() + prepareTime, prepareTimeU64 := s.newTime() + xid := s.newXid() + bigEndian.PutUint64(msg[1+1:], uint64(prepareLSN)) + bigEndian.PutUint64(msg[1+1+8:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[1+1+8+8:], prepareTimeU64) + bigEndian.PutUint32(msg[1+1+8+8+8:], xid) + msg[1+1+8+8+8+4] = 0 + + expected := &StreamPrepareMessageV3{ + Flags: 0, + PrepareLSN: prepareLSN, + TransactionEndLSN: transactionEndLSN, + PrepareTime: prepareTime, + Xid: xid, + } + expected.msgType = MessageTypeStreamPrepare + s.assertV1NotSupported(msg) + s.assertV2NotSupported(msg) + + for _, inStream := range []bool{false, true} { + m, err := ParseV3(msg, inStream) + s.NoError(err) + logicalDecodingMsg, ok := m.(*StreamPrepareMessageV3) + s.True(ok) + s.Equal(expected, logicalDecodingMsg) + } +} diff --git a/message_test.go b/message_test.go index 069117a..3bdf329 100644 --- a/message_test.go +++ b/message_test.go @@ -119,6 +119,15 @@ func (s *messageSuite) assertV1NotSupported(msg []byte) { s.True(errors.Is(err, errMsgNotSupported)) } +func (s *messageSuite) assertV2NotSupported(msg []byte) { + _, err := ParseV2(msg, true) + s.Error(err) + s.True(errors.Is(err, errMsgNotSupported)) + _, err = ParseV2(msg, false) + s.Error(err) + s.True(errors.Is(err, errMsgNotSupported)) +} + func (s *messageSuite) createRelationTestData() ([]byte, *RelationMessage) { relationID := uint32(rand.Int31()) namespace := "public" diff --git a/pglogrepl_test.go b/pglogrepl_test.go index ee684b2..d4f57af 100644 --- a/pglogrepl_test.go +++ b/pglogrepl_test.go @@ -362,6 +362,7 @@ func TestBaseBackup(t *testing.T) { f, err := os.CreateTemp("", fmt.Sprintf("pglogrepl_test_tbs_%d.tar", i)) require.NoError(t, err) err = pglogrepl.NextTableSpace(context.Background(), conn) + require.NoError(t, err) var message pgproto3.BackendMessage L: for {