diff --git a/named.go b/named.go index 6ac4477..59aa30f 100644 --- a/named.go +++ b/named.go @@ -131,7 +131,7 @@ type namedPreparer interface { func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { bindType := BindType(p.DriverName()) - q, args, err := compileNamedQuery([]byte(query), bindType) + q, args, err := compileNamedQuery(query, bindType) if err != nil { return nil, err } @@ -211,7 +211,7 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err // The rules for binding field names to parameter names follow the same // conventions as for StructScan, including obeying the `db` struct tags. func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { - bound, names, err := compileNamedQuery([]byte(query), bindType) + bound, names, err := compileNamedQuery(query, bindType) if err != nil { return "", []interface{}{}, err } @@ -273,7 +273,7 @@ func fixBound(bound string, loop int) string { func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { // do the initial binding with QUESTION; if bindType is not question, // we can rebind it at the end. - bound, names, err := compileNamedQuery([]byte(query), QUESTION) + bound, names, err := compileNamedQuery(query, QUESTION) if err != nil { return "", []interface{}{}, err } @@ -302,7 +302,7 @@ func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) // bindMap binds a named parameter query with a map of arguments. func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { - bound, names, err := compileNamedQuery([]byte(query), bindType) + bound, names, err := compileNamedQuery(query, bindType) if err != nil { return "", []interface{}{}, err } @@ -319,23 +319,16 @@ func bindMap(bindType int, query string, args map[string]interface{}) (string, [ // digits and numbers, where '5' is a digit but 'δΊ”' is not. var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} -// FIXME: this function isn't safe for unicode named params, as a failing test -// can testify. This is not a regression but a failure of the original code -// as well. It should be modified to range over runes in a string rather than -// bytes, even though this is less convenient and slower. Hopefully the -// addition of the prepared NamedStmt (which will only do this once) will make -// up for the slightly slower ad-hoc NamedExec/NamedQuery. - // compile a NamedQuery into an unbound query (using the '?' bindvar) and // a list of names. -func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { +func compileNamedQuery(qs string, bindType int) (query string, names []string, err error) { names = make([]string, 0, 10) - rebound := make([]byte, 0, len(qs)) + rebound := make([]rune, 0, len(qs)) inName := false last := len(qs) - 1 currentVar := 1 - name := make([]byte, 0, 10) + name := make([]rune, 0, 10) for i, b := range qs { // a ':' while we're in a name is an error @@ -350,19 +343,19 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e return query, names, err } inName = true - name = []byte{} + name = []rune{} } else if inName && i > 0 && b == '=' && len(name) == 0 { rebound = append(rebound, ':', '=') inName = false continue // if we're in a name, and this is an allowed character, continue } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { - // append the byte to the name if we are in a name and not on the last byte + // append the rune to the name if we are in a name and not on the last rune name = append(name, b) // if we're in a name and it's not an allowed character, the name is done } else if inName { inName = false - // if this is the final byte of the string and it is part of the name, then + // if this is the final rune of the string and it is part of the name, then // make sure to add it to the name if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { name = append(name, b) @@ -380,24 +373,24 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e case DOLLAR: rebound = append(rebound, '$') for _, b := range strconv.Itoa(currentVar) { - rebound = append(rebound, byte(b)) + rebound = append(rebound, b) } currentVar++ case AT: rebound = append(rebound, '@', 'p') for _, b := range strconv.Itoa(currentVar) { - rebound = append(rebound, byte(b)) + rebound = append(rebound, b) } currentVar++ } - // add this byte to string unless it was not part of the name + // add this rune to string unless it was not part of the name if i != last { rebound = append(rebound, b) } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { rebound = append(rebound, b) } } else { - // this is a normal byte and should just go onto the rebound query + // this is a normal rune and should just go onto the rebound query rebound = append(rebound, b) } } diff --git a/named_context.go b/named_context.go index 9ad23f4..7dc94d4 100644 --- a/named_context.go +++ b/named_context.go @@ -17,7 +17,7 @@ type namedPreparerContext interface { func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { bindType := BindType(p.DriverName()) - q, args, err := compileNamedQuery([]byte(query), bindType) + q, args, err := compileNamedQuery(query, bindType) if err != nil { return nil, err } diff --git a/named_test.go b/named_test.go index 0ee5b85..8250898 100644 --- a/named_test.go +++ b/named_test.go @@ -66,7 +66,7 @@ func TestCompileQuery(t *testing.T) { } for _, test := range table { - qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION) + qr, names, err := compileNamedQuery(test.Q, QUESTION) if err != nil { t.Error(err) } @@ -82,17 +82,17 @@ func TestCompileQuery(t *testing.T) { } } } - qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR) + qd, _, _ := compileNamedQuery(test.Q, DOLLAR) if qd != test.D { t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) } - qt, _, _ := compileNamedQuery([]byte(test.Q), AT) + qt, _, _ := compileNamedQuery(test.Q, AT) if qt != test.T { t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt) } - qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED) + qq, _, _ := compileNamedQuery(test.Q, NAMED) if qq != test.N { t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) } @@ -125,7 +125,7 @@ func TestEscapedColons(t *testing.T) { t.Skip("not sure it is possible to support this in general case without an SQL parser") var qs = `SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND (now() AT TIME ZONE 'utc') - interval '01:30:00') AND name = '\'this is a test\'' and id = :id` - _, _, err := compileNamedQuery([]byte(qs), DOLLAR) + _, _, err := compileNamedQuery(qs, DOLLAR) if err != nil { t.Error("Didn't handle colons correctly when inside a string") }