diff --git a/memstore.go b/memstore.go index e9f80882..610ee493 100644 --- a/memstore.go +++ b/memstore.go @@ -65,7 +65,7 @@ func (store *MemoryStore) Put(key string, message packets.ControlPacket) { } // Get takes a key and looks in the store for a matching Message -// returning either the Message pointer or nil. +// returning either a copy of the Message as packets.ControlPacket or nil. func (store *MemoryStore) Get(key string) packets.ControlPacket { store.RLock() defer store.RUnlock() @@ -74,13 +74,15 @@ func (store *MemoryStore) Get(key string) packets.ControlPacket { return nil } mid := mIDFromKey(key) - m := store.messages[key] - if m == nil { + m, ok := store.messages[key] + if !ok { CRITICAL.Println(STR, "memorystore get: message", mid, "not found") - } else { - DEBUG.Println(STR, "memorystore get: message", mid, "found") + return nil } - return m + + DEBUG.Println(STR, "memorystore get: message", mid, "found") + + return m.Copy() } // All returns a slice of strings containing all the keys currently diff --git a/packets/connack.go b/packets/connack.go index 3a7b98fc..07f3bf9e 100644 --- a/packets/connack.go +++ b/packets/connack.go @@ -66,3 +66,12 @@ func (ca *ConnackPacket) Unpack(b io.Reader) error { func (ca *ConnackPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the ConnackPacket +func (ca *ConnackPacket) Copy() ControlPacket { + cp := NewControlPacket(Connack).(*ConnackPacket) + + *cp = *ca + + return cp +} diff --git a/packets/connect.go b/packets/connect.go index b4446a55..62dd6bb8 100644 --- a/packets/connect.go +++ b/packets/connect.go @@ -169,3 +169,22 @@ func (c *ConnectPacket) Validate() byte { func (c *ConnectPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the ConnectPacket +func (c *ConnectPacket) Copy() ControlPacket { + cp := NewControlPacket(Connect).(*ConnectPacket) + + *cp = *c + + if len(c.Password) > 0 { + cp.Password = make([]byte, len(c.Password)) + copy(cp.Password, c.Password) + } + + if len(c.WillMessage) > 0 { + cp.WillMessage = make([]byte, len(c.WillMessage)) + copy(cp.WillMessage, c.WillMessage) + } + + return cp +} diff --git a/packets/disconnect.go b/packets/disconnect.go index cf352a37..49c5aabd 100644 --- a/packets/disconnect.go +++ b/packets/disconnect.go @@ -48,3 +48,12 @@ func (d *DisconnectPacket) Unpack(b io.Reader) error { func (d *DisconnectPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the DisconnectPacket +func (d *DisconnectPacket) Copy() ControlPacket { + cp := NewControlPacket(Disconnect).(*DisconnectPacket) + + *cp = *d + + return cp +} diff --git a/packets/packets.go b/packets/packets.go index b2d7ed1b..73d27303 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -32,6 +32,7 @@ type ControlPacket interface { Unpack(io.Reader) error String() string Details() Details + Copy() ControlPacket } // PacketNames maps the constants for each of the MQTT packet types diff --git a/packets/packets_test.go b/packets/packets_test.go index c829dc4f..b05dd868 100644 --- a/packets/packets_test.go +++ b/packets/packets_test.go @@ -18,6 +18,8 @@ package packets import ( "bytes" + "fmt" + "reflect" "testing" ) @@ -251,3 +253,120 @@ func TestEncoding(t *testing.T) { } } } + +// isCopy checks if the original and copy are the same, recursively. +// It will fail the test if the values are different or if the pointer +// of the original and copy are the same. +func isCopy(t *testing.T, original, copy any, fieldName ...string) { + t.Helper() + + log := func(field string, original, copy interface{}) { + t.Logf("Field: %s", field) + t.Logf("Original: %#v", original) + t.Logf("Copy: %#v", copy) + } + + originalValue := reflect.ValueOf(original) + copyValue := reflect.ValueOf(copy) + + fullFieldName := "" + if len(fieldName) > 0 { + fullFieldName = fieldName[0] + for _, name := range fieldName[1:] { + fullFieldName += "." + name + } + } + + if originalValue.Kind() != copyValue.Kind() { + log(fullFieldName, original, copy) + t.Errorf("Kind of original and copy are different: %s != %s", originalValue.Kind(), copyValue.Kind()) + } + + switch originalValue.Kind() { + case reflect.Ptr: + if originalValue.Pointer() == copyValue.Pointer() { + log(fullFieldName, original, copy) + t.Errorf("Pointer of original and copy are the same: %x == %x", originalValue.Pointer(), copyValue.Pointer()) + } + isCopy(t, originalValue.Elem().Interface(), copyValue.Elem().Interface(), append(fieldName, originalValue.Type().Elem().Name())...) + case reflect.Slice: + if originalValue.IsNil() && copyValue.IsNil() { + return + } + if originalValue.IsNil() != copyValue.IsNil() { + log(fullFieldName, original, copy) + t.Errorf("IsNil of original and copy are different: %t != %t", originalValue.IsNil(), copyValue.IsNil()) + } + if originalValue.Len() != copyValue.Len() { + log(fullFieldName, original, copy) + t.Errorf("Length of original and copy are different: %d != %d", originalValue.Len(), copyValue.Len()) + } + if originalValue.Len() > 0 && originalValue.Pointer() == copyValue.Pointer() { + log(fullFieldName, original, copy) + t.Errorf("Pointer of original and copy are the same: %x == %x", originalValue.Pointer(), copyValue.Pointer()) + } + for i := 0; i < originalValue.Len(); i++ { + isCopy(t, originalValue.Index(i).Interface(), copyValue.Index(i).Interface(), append(fieldName, fmt.Sprintf("[%d]", i))...) + } + case reflect.Struct: + for i := 0; i < originalValue.Type().NumField(); i++ { + field := originalValue.Type().Field(i) + isCopy(t, originalValue.Field(i).Interface(), copyValue.Field(i).Interface(), append(fieldName, field.Name)...) + } + default: + if !reflect.DeepEqual(originalValue.Interface(), copyValue.Interface()) { + log(fullFieldName, original, copy) + t.Errorf("Values of original and copy are different: %v != %v", originalValue.Interface(), copyValue.Interface()) + } + } +} + +// createValidPointers creates valid pointer for map, slices or normal pointer if they are nil. +func createValidPointers(s any) { + val := reflect.ValueOf(s).Elem() + for i := range val.NumField() { + field := val.Field(i) + switch field.Kind() { + case reflect.Ptr: + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + case reflect.Slice: + if field.IsNil() { + field.Set(reflect.MakeSlice(field.Type(), 1, 1)) + } + case reflect.Map: + if field.IsNil() { + field.Set(reflect.MakeMap(field.Type())) + } + case reflect.Struct: + createValidPointers(field.Addr().Interface()) + } + } +} + +func TestPacketCopy(t *testing.T) { + packets := []ControlPacket{ + NewControlPacket(Connack).(*ConnackPacket), + NewControlPacket(Connect).(*ConnectPacket), + NewControlPacket(Disconnect).(*DisconnectPacket), + NewControlPacket(Pingreq).(*PingreqPacket), + NewControlPacket(Pingresp).(*PingrespPacket), + NewControlPacket(Puback).(*PubackPacket), + NewControlPacket(Pubcomp).(*PubcompPacket), + NewControlPacket(Publish).(*PublishPacket), + NewControlPacket(Pubrec).(*PubrecPacket), + NewControlPacket(Pubrel).(*PubrelPacket), + NewControlPacket(Suback).(*SubackPacket), + NewControlPacket(Subscribe).(*SubscribePacket), + NewControlPacket(Unsuback).(*UnsubackPacket), + NewControlPacket(Unsubscribe).(*UnsubscribePacket), + } + + for _, packet := range packets { + createValidPointers(packet) + copy := packet.Copy() + + isCopy(t, packet, copy) + } +} diff --git a/packets/pingreq.go b/packets/pingreq.go index cd52948e..3443a7c7 100644 --- a/packets/pingreq.go +++ b/packets/pingreq.go @@ -48,3 +48,12 @@ func (pr *PingreqPacket) Unpack(b io.Reader) error { func (pr *PingreqPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the PingreqPacket +func (pr *PingreqPacket) Copy() ControlPacket { + cp := NewControlPacket(Pingreq).(*PingreqPacket) + + *cp = *pr + + return cp +} diff --git a/packets/pingresp.go b/packets/pingresp.go index d7becdf2..1520b2be 100644 --- a/packets/pingresp.go +++ b/packets/pingresp.go @@ -48,3 +48,12 @@ func (pr *PingrespPacket) Unpack(b io.Reader) error { func (pr *PingrespPacket) Details() Details { return Details{Qos: 0, MessageID: 0} } + +// Copy creates a deep copy of the PingrespPacket +func (pr *PingrespPacket) Copy() ControlPacket { + cp := NewControlPacket(Pingresp).(*PingrespPacket) + + *cp = *pr + + return cp +} diff --git a/packets/puback.go b/packets/puback.go index f6e727ec..55377ae9 100644 --- a/packets/puback.go +++ b/packets/puback.go @@ -56,3 +56,12 @@ func (pa *PubackPacket) Unpack(b io.Reader) error { func (pa *PubackPacket) Details() Details { return Details{Qos: pa.Qos, MessageID: pa.MessageID} } + +// Copy creates a deep copy of the PubackPacket +func (pa *PubackPacket) Copy() ControlPacket { + cp := NewControlPacket(Puback).(*PubackPacket) + + *cp = *pa + + return cp +} diff --git a/packets/pubcomp.go b/packets/pubcomp.go index 84a1af5d..051b133b 100644 --- a/packets/pubcomp.go +++ b/packets/pubcomp.go @@ -56,3 +56,12 @@ func (pc *PubcompPacket) Unpack(b io.Reader) error { func (pc *PubcompPacket) Details() Details { return Details{Qos: pc.Qos, MessageID: pc.MessageID} } + +// Copy creates a deep copy of the PubcompPacket +func (pc *PubcompPacket) Copy() ControlPacket { + cp := NewControlPacket(Pubcomp).(*PubcompPacket) + + *cp = *pc + + return cp +} diff --git a/packets/publish.go b/packets/publish.go index 9fba5df8..228ad0ca 100644 --- a/packets/publish.go +++ b/packets/publish.go @@ -55,7 +55,7 @@ func (p *PublishPacket) Write(w io.Writer) error { // Unpack decodes the details of a ControlPacket after the fixed // header has been read func (p *PublishPacket) Unpack(b io.Reader) error { - var payloadLength = p.FixedHeader.RemainingLength + payloadLength := p.FixedHeader.RemainingLength var err error p.TopicName, err = decodeString(b) if err != nil { @@ -80,20 +80,22 @@ func (p *PublishPacket) Unpack(b io.Reader) error { return err } -// Copy creates a new PublishPacket with the same topic and payload -// but an empty fixed header, useful for when you want to deliver -// a message with different properties such as Qos but the same -// content -func (p *PublishPacket) Copy() *PublishPacket { - newP := NewControlPacket(Publish).(*PublishPacket) - newP.TopicName = p.TopicName - newP.Payload = p.Payload - - return newP -} - // Details returns a Details struct containing the Qos and // MessageID of this ControlPacket func (p *PublishPacket) Details() Details { return Details{Qos: p.Qos, MessageID: p.MessageID} } + +// Copy creates a deep copy of the PublishPacket +func (p *PublishPacket) Copy() ControlPacket { + cp := NewControlPacket(Publish).(*PublishPacket) + + *cp = *p + + if len(p.Payload) > 0 { + cp.Payload = make([]byte, len(p.Payload)) + copy(cp.Payload, p.Payload) + } + + return cp +} diff --git a/packets/pubrec.go b/packets/pubrec.go index da9ed2a4..3ec42917 100644 --- a/packets/pubrec.go +++ b/packets/pubrec.go @@ -56,3 +56,12 @@ func (pr *PubrecPacket) Unpack(b io.Reader) error { func (pr *PubrecPacket) Details() Details { return Details{Qos: pr.Qos, MessageID: pr.MessageID} } + +// Copy creates a deep copy of the PubrecPacket +func (pr *PubrecPacket) Copy() ControlPacket { + cp := NewControlPacket(Pubrec).(*PubrecPacket) + + *cp = *pr + + return cp +} diff --git a/packets/pubrel.go b/packets/pubrel.go index f418ff86..3d321aa5 100644 --- a/packets/pubrel.go +++ b/packets/pubrel.go @@ -56,3 +56,12 @@ func (pr *PubrelPacket) Unpack(b io.Reader) error { func (pr *PubrelPacket) Details() Details { return Details{Qos: pr.Qos, MessageID: pr.MessageID} } + +// Copy creates a deep copy of the PubrelPacket +func (pr *PubrelPacket) Copy() ControlPacket { + cp := NewControlPacket(Pubrel).(*PubrelPacket) + + *cp = *pr + + return cp +} diff --git a/packets/suback.go b/packets/suback.go index 261cf21c..7342693c 100644 --- a/packets/suback.go +++ b/packets/suback.go @@ -71,3 +71,17 @@ func (sa *SubackPacket) Unpack(b io.Reader) error { func (sa *SubackPacket) Details() Details { return Details{Qos: 0, MessageID: sa.MessageID} } + +// Copy creates a deep copy of the SubackPacket +func (sa *SubackPacket) Copy() ControlPacket { + cp := NewControlPacket(Suback).(*SubackPacket) + + *cp = *sa + + if len(sa.ReturnCodes) > 0 { + cp.ReturnCodes = make([]byte, len(sa.ReturnCodes)) + copy(cp.ReturnCodes, sa.ReturnCodes) + } + + return cp +} diff --git a/packets/subscribe.go b/packets/subscribe.go index 313bf5a2..2ec81474 100644 --- a/packets/subscribe.go +++ b/packets/subscribe.go @@ -83,3 +83,22 @@ func (s *SubscribePacket) Unpack(b io.Reader) error { func (s *SubscribePacket) Details() Details { return Details{Qos: 1, MessageID: s.MessageID} } + +// Copy creates a deep copy of the SubscribePacket +func (s *SubscribePacket) Copy() ControlPacket { + cp := NewControlPacket(Subscribe).(*SubscribePacket) + + *cp = *s + + if len(s.Topics) > 0 { + cp.Topics = make([]string, len(s.Topics)) + copy(cp.Topics, s.Topics) + } + + if len(s.Qoss) > 0 { + cp.Qoss = make([]byte, len(s.Qoss)) + copy(cp.Qoss, s.Qoss) + } + + return cp +} diff --git a/packets/unsuback.go b/packets/unsuback.go index acdd400a..a151ef31 100644 --- a/packets/unsuback.go +++ b/packets/unsuback.go @@ -56,3 +56,12 @@ func (ua *UnsubackPacket) Unpack(b io.Reader) error { func (ua *UnsubackPacket) Details() Details { return Details{Qos: 0, MessageID: ua.MessageID} } + +// Copy creates a deep copy of the UnsubackPacket +func (ua *UnsubackPacket) Copy() ControlPacket { + cp := NewControlPacket(Unsuback).(*UnsubackPacket) + + *cp = *ua + + return cp +} diff --git a/packets/unsubscribe.go b/packets/unsubscribe.go index 54d06aa2..525c2822 100644 --- a/packets/unsubscribe.go +++ b/packets/unsubscribe.go @@ -70,3 +70,17 @@ func (u *UnsubscribePacket) Unpack(b io.Reader) error { func (u *UnsubscribePacket) Details() Details { return Details{Qos: 1, MessageID: u.MessageID} } + +// Copy creates a deep copy of the UnsubscribePacket +func (u *UnsubscribePacket) Copy() ControlPacket { + cp := NewControlPacket(Unsubscribe).(*UnsubscribePacket) + + *cp = *u + + if len(u.Topics) > 0 { + cp.Topics = make([]string, len(u.Topics)) + copy(cp.Topics, u.Topics) + } + + return cp +}