diff --git a/redis/scan.go b/redis/scan.go index 82121011..c1fb800d 100644 --- a/redis/scan.go +++ b/redis/scan.go @@ -477,7 +477,7 @@ var errScanStructValue = errors.New("redigo.ScanStruct: value must be non-nil po // ScanStruct uses exported field names to match values in the response. Use // 'redis' field tag to override the name: // -// Field int `redis:"myName"` +// Field int `redis:"myName"` // // Fields with the tag redis:"-" are ignored. // @@ -513,9 +513,9 @@ func ScanStruct(src []interface{}, dest interface{}) error { continue } - name, ok := src[i].([]byte) + name, ok := convertToBulk(src[i]) if !ok { - return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value", i) + return fmt.Errorf("redigo.ScanStruct: key %d not a bulk string value got type: %T", i, src[i]) } fs := ss.fieldSpec(name) @@ -530,6 +530,19 @@ func ScanStruct(src []interface{}, dest interface{}) error { return nil } +// convertToBulk converts src to a []byte if src is a string or bulk string +// and returns true. Otherwise nil and false is returned. +func convertToBulk(src interface{}) ([]byte, bool) { + switch v := src.(type) { + case []byte: + return v, true + case string: + return []byte(v), true + default: + return nil, false + } +} + var ( errScanSliceValue = errors.New("redigo.ScanSlice: dest must be non-nil pointer to a struct") ) diff --git a/redis/scan_test.go b/redis/scan_test.go index 53556ebb..a2a5d672 100644 --- a/redis/scan_test.go +++ b/redis/scan_test.go @@ -327,6 +327,22 @@ func TestScanStruct(t *testing.T) { } } +func TestScanStructStringKeys(t *testing.T) { + reply := []interface{}{"simple", []byte("value"), "number", []byte("123")} + expected := &struct { + Simple string `redis:"simple"` + Number int `redis:"number"` + }{ + Simple: "value", + Number: 123, + } + + value := reflect.New(reflect.ValueOf(expected).Type().Elem()).Interface() + err := redis.ScanStruct(reply, value) + require.NoError(t, err) + require.Equal(t, expected, value) +} + func TestBadScanStructArgs(t *testing.T) { x := []interface{}{"A", "b"} test := func(v interface{}) {