Skip to content

Commit f4f60b4

Browse files
authored
Merge pull request #52 from ShiftLeftSecurity/preetam/support-placeholder-escaping-ExpandArgs
also support placeholder escaping in ExpandArgs
2 parents c34ed70 + a57162c commit f4f60b4

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

db/chain/placeholders.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,20 @@ func ExpandArgs(args []interface{}, querySegment string) (string, []interface{})
1616
expandedArgs := []interface{}{}
1717
newQuery := &strings.Builder{}
1818
var argPosition = 0
19-
for _, queryChar := range querySegment {
19+
skip := false
20+
for i, queryChar := range querySegment {
21+
if skip {
22+
skip = false
23+
continue
24+
}
25+
26+
if queryChar == '\\' && i < len(querySegment)-1 && querySegment[i+1] == '?' {
27+
// Escaped '?'
28+
newQuery.WriteRune('?')
29+
skip = true
30+
continue
31+
}
32+
2033
if queryChar == '?' {
2134
arg := args[argPosition]
2235
if arg == nil {

db/chain/placeholders_test.go

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,41 +62,52 @@ func Test_digitSize(t *testing.T) {
6262

6363
func TestPlaceholderEscaping(t *testing.T) {
6464
tests := []struct {
65-
q string
66-
want string
67-
args []interface{}
65+
q string
66+
wantPlaceholders string
67+
wantExpanded string
68+
args []interface{}
6869
}{
6970
{
70-
q: "? = 1",
71-
want: "$1 = 1",
72-
args: []interface{}{1},
71+
q: "? = 1",
72+
wantPlaceholders: "$1 = 1",
73+
wantExpanded: "? = 1",
74+
args: []interface{}{1},
7375
},
7476
{
75-
q: "\\? = 1",
76-
want: "? = 1",
77-
args: []interface{}{},
77+
q: "\\? = 1",
78+
wantPlaceholders: "? = 1",
79+
wantExpanded: "? = 1",
80+
args: []interface{}{},
7881
},
7982
{
80-
q: "? = ? AND \\? = 1",
81-
want: "$1 = $2 AND ? = 1",
82-
args: []interface{}{1, 1},
83+
q: "? = ? AND \\? = 1",
84+
wantPlaceholders: "$1 = $2 AND ? = 1",
85+
wantExpanded: "? = ? AND ? = 1",
86+
args: []interface{}{1, 1},
8387
},
8488
{
85-
q: `'["a", "b"]'::jsonb \?& array['a', 'b']`,
86-
want: `'["a", "b"]'::jsonb ?& array['a', 'b']`,
87-
args: []interface{}{},
89+
q: `'["a", "b"]'::jsonb \?& array['a', 'b']`,
90+
wantPlaceholders: `'["a", "b"]'::jsonb ?& array['a', 'b']`,
91+
wantExpanded: `'["a", "b"]'::jsonb ?& array['a', 'b']`,
92+
args: []interface{}{},
8893
},
8994
{
90-
q: `'["a", "b"]'::jsonb \?& array[?]`,
91-
want: `'["a", "b"]'::jsonb ?& array[$1]`,
92-
args: []interface{}{"a"},
95+
q: `'["a", "b"]'::jsonb \?& array[?]`,
96+
wantPlaceholders: `'["a", "b"]'::jsonb ?& array[$1]`,
97+
wantExpanded: `'["a", "b"]'::jsonb ?& array[?]`,
98+
args: []interface{}{"a"},
9399
},
94100
}
95101
for i, tt := range tests {
96102
t.Run(fmt.Sprint(i), func(t *testing.T) {
97103
result, _, _ := MarksToPlaceholders(tt.q, tt.args)
98-
if result != tt.want {
99-
t.Errorf("got %v, want %v", result, tt.want)
104+
if result != tt.wantPlaceholders {
105+
t.Errorf("got %v, want placeholders %v", result, tt.wantPlaceholders)
106+
}
107+
108+
result, _ = ExpandArgs(tt.args, tt.q)
109+
if result != tt.wantExpanded {
110+
t.Errorf("got %v, want expanded %v", result, tt.wantExpanded)
100111
}
101112
})
102113
}

0 commit comments

Comments
 (0)