Skip to content

Return a copy from the memstore to avoid data races #708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions memstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (store *MemoryStore) Put(key string, message packets.ControlPacket) {
}

// Get takes a key and looks in the store for a matching Message
// returning either the Message pointer or nil.
// returning either a copy of the Message as packets.ControlPacket or nil.
func (store *MemoryStore) Get(key string) packets.ControlPacket {
store.RLock()
defer store.RUnlock()
Expand All @@ -74,13 +74,15 @@ func (store *MemoryStore) Get(key string) packets.ControlPacket {
return nil
}
mid := mIDFromKey(key)
m := store.messages[key]
if m == nil {
m, ok := store.messages[key]
if !ok {
CRITICAL.Println(STR, "memorystore get: message", mid, "not found")
} else {
DEBUG.Println(STR, "memorystore get: message", mid, "found")
return nil
}
return m

DEBUG.Println(STR, "memorystore get: message", mid, "found")

return m.Copy()
}

// All returns a slice of strings containing all the keys currently
Expand Down
9 changes: 9 additions & 0 deletions packets/connack.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,12 @@ func (ca *ConnackPacket) Unpack(b io.Reader) error {
func (ca *ConnackPacket) Details() Details {
return Details{Qos: 0, MessageID: 0}
}

// Copy creates a deep copy of the ConnackPacket
func (ca *ConnackPacket) Copy() ControlPacket {
cp := NewControlPacket(Connack).(*ConnackPacket)

*cp = *ca

return cp
}
19 changes: 19 additions & 0 deletions packets/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,22 @@ func (c *ConnectPacket) Validate() byte {
func (c *ConnectPacket) Details() Details {
return Details{Qos: 0, MessageID: 0}
}

// Copy creates a deep copy of the ConnectPacket
func (c *ConnectPacket) Copy() ControlPacket {
cp := NewControlPacket(Connect).(*ConnectPacket)

*cp = *c

if len(c.Password) > 0 {
cp.Password = make([]byte, len(c.Password))
copy(cp.Password, c.Password)
}

if len(c.WillMessage) > 0 {
cp.WillMessage = make([]byte, len(c.WillMessage))
copy(cp.WillMessage, c.WillMessage)
}

return cp
}
9 changes: 9 additions & 0 deletions packets/disconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ func (d *DisconnectPacket) Unpack(b io.Reader) error {
func (d *DisconnectPacket) Details() Details {
return Details{Qos: 0, MessageID: 0}
}

// Copy creates a deep copy of the DisconnectPacket
func (d *DisconnectPacket) Copy() ControlPacket {
cp := NewControlPacket(Disconnect).(*DisconnectPacket)

*cp = *d

return cp
}
1 change: 1 addition & 0 deletions packets/packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type ControlPacket interface {
Unpack(io.Reader) error
String() string
Details() Details
Copy() ControlPacket
}

// PacketNames maps the constants for each of the MQTT packet types
Expand Down
119 changes: 119 additions & 0 deletions packets/packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package packets

import (
"bytes"
"fmt"
"reflect"
"testing"
)

Expand Down Expand Up @@ -251,3 +253,120 @@ func TestEncoding(t *testing.T) {
}
}
}

// isCopy checks if the original and copy are the same, recursively.
// It will fail the test if the values are different or if the pointer
// of the original and copy are the same.
func isCopy(t *testing.T, original, copy any, fieldName ...string) {
t.Helper()

log := func(field string, original, copy interface{}) {
t.Logf("Field: %s", field)
t.Logf("Original: %#v", original)
t.Logf("Copy: %#v", copy)
}

originalValue := reflect.ValueOf(original)
copyValue := reflect.ValueOf(copy)

fullFieldName := ""
if len(fieldName) > 0 {
fullFieldName = fieldName[0]
for _, name := range fieldName[1:] {
fullFieldName += "." + name
}
}

if originalValue.Kind() != copyValue.Kind() {
log(fullFieldName, original, copy)
t.Errorf("Kind of original and copy are different: %s != %s", originalValue.Kind(), copyValue.Kind())
}

switch originalValue.Kind() {
case reflect.Ptr:
if originalValue.Pointer() == copyValue.Pointer() {
log(fullFieldName, original, copy)
t.Errorf("Pointer of original and copy are the same: %x == %x", originalValue.Pointer(), copyValue.Pointer())
}
isCopy(t, originalValue.Elem().Interface(), copyValue.Elem().Interface(), append(fieldName, originalValue.Type().Elem().Name())...)
case reflect.Slice:
if originalValue.IsNil() && copyValue.IsNil() {
return
}
if originalValue.IsNil() != copyValue.IsNil() {
log(fullFieldName, original, copy)
t.Errorf("IsNil of original and copy are different: %t != %t", originalValue.IsNil(), copyValue.IsNil())
}
if originalValue.Len() != copyValue.Len() {
log(fullFieldName, original, copy)
t.Errorf("Length of original and copy are different: %d != %d", originalValue.Len(), copyValue.Len())
}
if originalValue.Len() > 0 && originalValue.Pointer() == copyValue.Pointer() {
log(fullFieldName, original, copy)
t.Errorf("Pointer of original and copy are the same: %x == %x", originalValue.Pointer(), copyValue.Pointer())
}
for i := 0; i < originalValue.Len(); i++ {
isCopy(t, originalValue.Index(i).Interface(), copyValue.Index(i).Interface(), append(fieldName, fmt.Sprintf("[%d]", i))...)
}
case reflect.Struct:
for i := 0; i < originalValue.Type().NumField(); i++ {
field := originalValue.Type().Field(i)
isCopy(t, originalValue.Field(i).Interface(), copyValue.Field(i).Interface(), append(fieldName, field.Name)...)
}
default:
if !reflect.DeepEqual(originalValue.Interface(), copyValue.Interface()) {
log(fullFieldName, original, copy)
t.Errorf("Values of original and copy are different: %v != %v", originalValue.Interface(), copyValue.Interface())
}
}
}

// createValidPointers creates valid pointer for map, slices or normal pointer if they are nil.
func createValidPointers(s any) {
val := reflect.ValueOf(s).Elem()
for i := range val.NumField() {
field := val.Field(i)
switch field.Kind() {
case reflect.Ptr:
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
case reflect.Slice:
if field.IsNil() {
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
}
case reflect.Map:
if field.IsNil() {
field.Set(reflect.MakeMap(field.Type()))
}
case reflect.Struct:
createValidPointers(field.Addr().Interface())
}
}
}

func TestPacketCopy(t *testing.T) {
packets := []ControlPacket{
NewControlPacket(Connack).(*ConnackPacket),
NewControlPacket(Connect).(*ConnectPacket),
NewControlPacket(Disconnect).(*DisconnectPacket),
NewControlPacket(Pingreq).(*PingreqPacket),
NewControlPacket(Pingresp).(*PingrespPacket),
NewControlPacket(Puback).(*PubackPacket),
NewControlPacket(Pubcomp).(*PubcompPacket),
NewControlPacket(Publish).(*PublishPacket),
NewControlPacket(Pubrec).(*PubrecPacket),
NewControlPacket(Pubrel).(*PubrelPacket),
NewControlPacket(Suback).(*SubackPacket),
NewControlPacket(Subscribe).(*SubscribePacket),
NewControlPacket(Unsuback).(*UnsubackPacket),
NewControlPacket(Unsubscribe).(*UnsubscribePacket),
}

for _, packet := range packets {
createValidPointers(packet)
copy := packet.Copy()

isCopy(t, packet, copy)
}
}
9 changes: 9 additions & 0 deletions packets/pingreq.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ func (pr *PingreqPacket) Unpack(b io.Reader) error {
func (pr *PingreqPacket) Details() Details {
return Details{Qos: 0, MessageID: 0}
}

// Copy creates a deep copy of the PingreqPacket
func (pr *PingreqPacket) Copy() ControlPacket {
cp := NewControlPacket(Pingreq).(*PingreqPacket)

*cp = *pr

return cp
}
9 changes: 9 additions & 0 deletions packets/pingresp.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ func (pr *PingrespPacket) Unpack(b io.Reader) error {
func (pr *PingrespPacket) Details() Details {
return Details{Qos: 0, MessageID: 0}
}

// Copy creates a deep copy of the PingrespPacket
func (pr *PingrespPacket) Copy() ControlPacket {
cp := NewControlPacket(Pingresp).(*PingrespPacket)

*cp = *pr

return cp
}
9 changes: 9 additions & 0 deletions packets/puback.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ func (pa *PubackPacket) Unpack(b io.Reader) error {
func (pa *PubackPacket) Details() Details {
return Details{Qos: pa.Qos, MessageID: pa.MessageID}
}

// Copy creates a deep copy of the PubackPacket
func (pa *PubackPacket) Copy() ControlPacket {
cp := NewControlPacket(Puback).(*PubackPacket)

*cp = *pa

return cp
}
9 changes: 9 additions & 0 deletions packets/pubcomp.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ func (pc *PubcompPacket) Unpack(b io.Reader) error {
func (pc *PubcompPacket) Details() Details {
return Details{Qos: pc.Qos, MessageID: pc.MessageID}
}

// Copy creates a deep copy of the PubcompPacket
func (pc *PubcompPacket) Copy() ControlPacket {
cp := NewControlPacket(Pubcomp).(*PubcompPacket)

*cp = *pc

return cp
}
28 changes: 15 additions & 13 deletions packets/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (p *PublishPacket) Write(w io.Writer) error {
// Unpack decodes the details of a ControlPacket after the fixed
// header has been read
func (p *PublishPacket) Unpack(b io.Reader) error {
var payloadLength = p.FixedHeader.RemainingLength
payloadLength := p.FixedHeader.RemainingLength
var err error
p.TopicName, err = decodeString(b)
if err != nil {
Expand All @@ -80,20 +80,22 @@ func (p *PublishPacket) Unpack(b io.Reader) error {
return err
}

// Copy creates a new PublishPacket with the same topic and payload
// but an empty fixed header, useful for when you want to deliver
// a message with different properties such as Qos but the same
// content
func (p *PublishPacket) Copy() *PublishPacket {
newP := NewControlPacket(Publish).(*PublishPacket)
newP.TopicName = p.TopicName
newP.Payload = p.Payload

return newP
}

// Details returns a Details struct containing the Qos and
// MessageID of this ControlPacket
func (p *PublishPacket) Details() Details {
return Details{Qos: p.Qos, MessageID: p.MessageID}
}

// Copy creates a deep copy of the PublishPacket
func (p *PublishPacket) Copy() ControlPacket {
cp := NewControlPacket(Publish).(*PublishPacket)

*cp = *p

if len(p.Payload) > 0 {
cp.Payload = make([]byte, len(p.Payload))
copy(cp.Payload, p.Payload)
}

return cp
}
9 changes: 9 additions & 0 deletions packets/pubrec.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ func (pr *PubrecPacket) Unpack(b io.Reader) error {
func (pr *PubrecPacket) Details() Details {
return Details{Qos: pr.Qos, MessageID: pr.MessageID}
}

// Copy creates a deep copy of the PubrecPacket
func (pr *PubrecPacket) Copy() ControlPacket {
cp := NewControlPacket(Pubrec).(*PubrecPacket)

*cp = *pr

return cp
}
9 changes: 9 additions & 0 deletions packets/pubrel.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ func (pr *PubrelPacket) Unpack(b io.Reader) error {
func (pr *PubrelPacket) Details() Details {
return Details{Qos: pr.Qos, MessageID: pr.MessageID}
}

// Copy creates a deep copy of the PubrelPacket
func (pr *PubrelPacket) Copy() ControlPacket {
cp := NewControlPacket(Pubrel).(*PubrelPacket)

*cp = *pr

return cp
}
14 changes: 14 additions & 0 deletions packets/suback.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,17 @@ func (sa *SubackPacket) Unpack(b io.Reader) error {
func (sa *SubackPacket) Details() Details {
return Details{Qos: 0, MessageID: sa.MessageID}
}

// Copy creates a deep copy of the SubackPacket
func (sa *SubackPacket) Copy() ControlPacket {
cp := NewControlPacket(Suback).(*SubackPacket)

*cp = *sa

if len(sa.ReturnCodes) > 0 {
cp.ReturnCodes = make([]byte, len(sa.ReturnCodes))
copy(cp.ReturnCodes, sa.ReturnCodes)
}

return cp
}
19 changes: 19 additions & 0 deletions packets/subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,22 @@ func (s *SubscribePacket) Unpack(b io.Reader) error {
func (s *SubscribePacket) Details() Details {
return Details{Qos: 1, MessageID: s.MessageID}
}

// Copy creates a deep copy of the SubscribePacket
func (s *SubscribePacket) Copy() ControlPacket {
cp := NewControlPacket(Subscribe).(*SubscribePacket)

*cp = *s

if len(s.Topics) > 0 {
cp.Topics = make([]string, len(s.Topics))
copy(cp.Topics, s.Topics)
}

if len(s.Qoss) > 0 {
cp.Qoss = make([]byte, len(s.Qoss))
copy(cp.Qoss, s.Qoss)
}

return cp
}
9 changes: 9 additions & 0 deletions packets/unsuback.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ func (ua *UnsubackPacket) Unpack(b io.Reader) error {
func (ua *UnsubackPacket) Details() Details {
return Details{Qos: 0, MessageID: ua.MessageID}
}

// Copy creates a deep copy of the UnsubackPacket
func (ua *UnsubackPacket) Copy() ControlPacket {
cp := NewControlPacket(Unsuback).(*UnsubackPacket)

*cp = *ua

return cp
}
Loading