Skip to content

Commit 09262ec

Browse files
committed
part: Fix MarshalJSON and UnmarshalYAML with singletons
Neither of these properly handled the singleton cases. Fix and add tests. Signed-off-by: Jussi Maki <[email protected]>
1 parent 503524c commit 09262ec

File tree

3 files changed

+69
-10
lines changed

3 files changed

+69
-10
lines changed

part/map.go

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,23 @@ func (m Map[K, V]) Len() int {
259259
}
260260

261261
func (m Map[K, V]) MarshalJSON() ([]byte, error) {
262-
if m.tree == nil {
262+
if m.tree == nil && m.singleton == nil {
263263
return []byte("[]"), nil
264264
}
265265

266266
var b bytes.Buffer
267267
b.WriteRune('[')
268+
269+
if m.singleton != nil {
270+
bs, err := json.Marshal(*m.singleton)
271+
if err != nil {
272+
return nil, err
273+
}
274+
b.Write(bs)
275+
b.WriteRune(']')
276+
return b.Bytes(), nil
277+
}
278+
268279
iter := m.tree.Iterator()
269280
_, kv, ok := iter.Next()
270281
for ok {
@@ -293,17 +304,32 @@ func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
293304
if d, ok := t.(json.Delim); !ok || d != '[' {
294305
return fmt.Errorf("%T.UnmarshalJSON: expected '[' got %v", m, t)
295306
}
307+
if !dec.More() {
308+
return nil
309+
}
310+
311+
var kv mapKVPair[K, V]
312+
err = dec.Decode(&kv)
313+
if err != nil {
314+
return err
315+
}
316+
317+
if !dec.More() {
318+
m.singleton = &kv
319+
return nil
320+
}
321+
296322
m.ensureTree()
297323
txn := m.tree.Txn()
324+
txn.Insert(m.keyToBytes(kv.Key), kv)
298325
for dec.More() {
299326
var kv mapKVPair[K, V]
300327
err := dec.Decode(&kv)
301328
if err != nil {
302329
return err
303330
}
304-
txn.Insert(m.keyToBytes(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value})
331+
txn.Insert(m.keyToBytes(kv.Key), kv)
305332
}
306-
307333
t, err = dec.Token()
308334
if err != nil {
309335
return err
@@ -312,12 +338,6 @@ func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
312338
return fmt.Errorf("%T.UnmarshalJSON: expected ']' got %v", m, t)
313339
}
314340
m.tree = txn.CommitOnly()
315-
316-
if m.tree.size == 1 {
317-
_, kv, _ := m.tree.Iterator().Next()
318-
m.singleton = &kv
319-
m.tree = nil
320-
}
321341
return nil
322342
}
323343

@@ -335,9 +355,18 @@ func (m *Map[K, V]) UnmarshalYAML(value *yaml.Node) error {
335355
if value.Kind != yaml.SequenceNode {
336356
return fmt.Errorf("%T.UnmarshalYAML: expected sequence", m)
337357
}
338-
if len(value.Content) == 0 {
358+
switch len(value.Content) {
359+
case 0:
360+
return nil
361+
case 1:
362+
var kv mapKVPair[K, V]
363+
if err := value.Content[0].Decode(&kv); err != nil {
364+
return err
365+
}
366+
m.singleton = &kv
339367
return nil
340368
}
369+
341370
m.ensureTree()
342371
txn := m.tree.Txn()
343372
for _, e := range value.Content {

part/map_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,18 @@ func TestSingletonMap(t *testing.T) {
215215
v, found = m2.Get("two")
216216
assert.True(t, found)
217217
assert.Equal(t, 2, v)
218+
219+
var m3 part.Map[string, int]
220+
bs, err := m.MarshalJSON()
221+
assert.NoError(t, err)
222+
assert.NoError(t, m3.UnmarshalJSON(bs))
223+
assert.True(t, m.SlowEqual(m3))
224+
225+
m3 = part.Map[string, int]{}
226+
bs, err = yaml.Marshal(m)
227+
assert.NoError(t, err)
228+
assert.NoError(t, yaml.Unmarshal(bs, &m3))
229+
assert.True(t, m.SlowEqual(m3))
218230
}
219231

220232
func TestUint64Map(t *testing.T) {

part/quick_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// Copyright Authors of Cilium
3+
14
package part_test
25

36
import (
7+
"encoding/json"
48
"fmt"
59
"maps"
610
"slices"
@@ -9,6 +13,7 @@ import (
913

1014
"github.com/cilium/statedb/part"
1115
"github.com/stretchr/testify/require"
16+
"gopkg.in/yaml.v3"
1217
)
1318

1419
var quickConfig = &quick.Config{
@@ -233,6 +238,19 @@ func TestQuick_Map(t *testing.T) {
233238
)
234239
require.False(t, partMap.SlowEqual(newPartMap), "SlowEqual")
235240
}
241+
242+
bs, err := json.Marshal(newPartMap)
243+
require.NoError(t, err, "json.Marshal")
244+
var m part.Map[uint8, int]
245+
require.NoError(t, json.Unmarshal(bs, &m), "json.Unmarshal")
246+
require.True(t, m.SlowEqual(newPartMap), "SlowEqual after json.Marshal")
247+
248+
m = part.Map[uint8, int]{}
249+
bs, err = yaml.Marshal(newPartMap)
250+
require.NoError(t, err)
251+
require.NoError(t, yaml.Unmarshal(bs, &m), "yaml.Unmarshal")
252+
require.True(t, m.SlowEqual(newPartMap), "SlowEqual after yaml.Marshal")
253+
236254
partMap = newPartMap
237255
return
238256
}

0 commit comments

Comments
 (0)