Skip to content

Commit 3e069e7

Browse files
ccojocarCosmin Cojocar
authored and
Cosmin Cojocar
committed
Fix the errors rule whitelist to work on types methods
Signed-off-by: Cosmin Cojocar <[email protected]>
1 parent 459e2d3 commit 3e069e7

17 files changed

+285
-37
lines changed

call_list.go

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,22 @@ func (c CallList) Contains(selector, ident string) bool {
5656
return false
5757
}
5858

59-
// ContainsCallExpr resolves the call expression name and type
60-
/// or package and determines if it exists within the CallList
61-
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr {
59+
// ContainsPointer returns true if a pointer to the selector type or the type
60+
// itslef is a members of this call list.
61+
func (c CallList) ContainsPointer(selector, indent string) bool {
62+
if strings.HasPrefix(selector, "*") {
63+
if c.Contains(selector, indent) {
64+
return true
65+
}
66+
s := strings.TrimPrefix(selector, "*")
67+
return c.Contains(s, indent)
68+
}
69+
return false
70+
}
71+
72+
// ContainsPkgCallExpr resolves the call expression name and type, and then further looks
73+
// up the package path for that type. Finally, it determines if the call exists within the CallList
74+
func (c CallList) ContainsPkgCallExpr(n ast.Node, ctx *Context, stripVendor bool) *ast.CallExpr {
6275
selector, ident, err := GetCallInfo(n, ctx)
6376
if err != nil {
6477
return nil
@@ -79,12 +92,18 @@ func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context, stripVendor bool) *
7992
}
8093

8194
return n.(*ast.CallExpr)
82-
/*
83-
// Try direct resolution
84-
if c.Contains(selector, ident) {
85-
log.Printf("c.Contains == true, %s, %s.", selector, ident)
86-
return n.(*ast.CallExpr)
87-
}
88-
*/
95+
}
8996

97+
// ContainsCallExpr resolves the call experssion name and type, and then determines
98+
// if the call existis with the call list
99+
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) *ast.CallExpr {
100+
selector, ident, err := GetCallInfo(n, ctx)
101+
if err != nil {
102+
return nil
103+
}
104+
if !c.Contains(selector, ident) && !c.ContainsPointer(selector, ident) {
105+
return nil
106+
}
107+
108+
return n.(*ast.CallExpr)
90109
}

call_list_test.go

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ var _ = Describe("Call List", func() {
4646
Expect(actual).Should(Equal(expected))
4747
})
4848

49+
It("should be possible to add pointer call", func() {
50+
Expect(calls).Should(HaveLen(0))
51+
calls.Add("*bytes.Buffer", "WriteString")
52+
actual := calls.ContainsPointer("*bytes.Buffer", "WriteString")
53+
Expect(actual).Should(BeTrue())
54+
})
55+
56+
It("should be possible to check pointer call", func() {
57+
Expect(calls).Should(HaveLen(0))
58+
calls.Add("bytes.Buffer", "WriteString")
59+
actual := calls.ContainsPointer("*bytes.Buffer", "WriteString")
60+
Expect(actual).Should(BeTrue())
61+
})
62+
4963
It("should not return a match if none are present", func() {
5064
calls.Add("ioutil", "Copy")
5165
Expect(calls.Contains("fmt", "Println")).Should(BeFalse())
@@ -56,8 +70,7 @@ var _ = Describe("Call List", func() {
5670
Expect(calls.Contains("ioutil", "Copy")).Should(BeTrue())
5771
})
5872

59-
It("should match a call expression", func() {
60-
73+
It("should match a package call expression", func() {
6174
// Create file to be scanned
6275
pkg := testutils.NewTestPackage()
6376
defer pkg.Close()
@@ -73,14 +86,39 @@ var _ = Describe("Call List", func() {
7386
v := testutils.NewMockVisitor()
7487
v.Context = ctx
7588
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
76-
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx, false) != nil {
89+
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsPkgCallExpr(n, ctx, false) != nil {
7790
matched++
7891
}
7992
return true
8093
}
8194
ast.Walk(v, ctx.Root)
8295
Expect(matched).Should(Equal(1))
83-
8496
})
8597

98+
It("should match a call expression", func() {
99+
// Create file to be scanned
100+
pkg := testutils.NewTestPackage()
101+
defer pkg.Close()
102+
pkg.AddFile("main.go", testutils.SampleCodeG104[5].Code[0])
103+
104+
ctx := pkg.CreateContext("main.go")
105+
106+
calls.Add("bytes.Buffer", "WriteString")
107+
calls.Add("strings.Builder", "WriteString")
108+
calls.Add("io.Pipe", "CloseWithError")
109+
calls.Add("fmt", "Fprintln")
110+
111+
// Stub out visitor and count number of matched call expr
112+
matched := 0
113+
v := testutils.NewMockVisitor()
114+
v.Context = ctx
115+
v.Callback = func(n ast.Node, ctx *gosec.Context) bool {
116+
if _, ok := n.(*ast.CallExpr); ok && calls.ContainsCallExpr(n, ctx) != nil {
117+
matched++
118+
}
119+
return true
120+
}
121+
ast.Walk(v, ctx.Root)
122+
Expect(matched).Should(Equal(5))
123+
})
86124
})

helpers.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,40 @@ func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
135135
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
136136
}
137137
return expr.Name, fn.Sel.Name, nil
138+
case *ast.CallExpr:
139+
switch call := expr.Fun.(type) {
140+
case *ast.Ident:
141+
if call.Name == "new" {
142+
t := ctx.Info.TypeOf(expr.Args[0])
143+
if t != nil {
144+
return t.String(), fn.Sel.Name, nil
145+
}
146+
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
147+
}
148+
if call.Obj != nil {
149+
switch decl := call.Obj.Decl.(type) {
150+
case *ast.FuncDecl:
151+
ret := decl.Type.Results
152+
if ret != nil && len(ret.List) > 0 {
153+
ret1 := ret.List[0]
154+
if ret1 != nil {
155+
t := ctx.Info.TypeOf(ret1.Type)
156+
if t != nil {
157+
return t.String(), fn.Sel.Name, nil
158+
}
159+
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
160+
}
161+
}
162+
}
163+
}
164+
165+
}
138166
}
139167
case *ast.Ident:
140168
return ctx.Pkg.Name(), fn.Name, nil
141169
}
142170
}
171+
143172
return "", "", fmt.Errorf("unable to determine call info")
144173
}
145174

helpers_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gosec_test
22

33
import (
4+
"go/ast"
45
"io/ioutil"
56
"os"
67
"path/filepath"
@@ -9,6 +10,7 @@ import (
910
. "github.com/onsi/ginkgo"
1011
. "github.com/onsi/gomega"
1112
"github.com/securego/gosec"
13+
"github.com/securego/gosec/testutils"
1214
)
1315

1416
var _ = Describe("Helpers", func() {
@@ -91,4 +93,140 @@ var _ = Describe("Helpers", func() {
9193
Expect(len(r)).Should(Equal(0))
9294
})
9395
})
96+
97+
Context("when getting call info", func() {
98+
It("should return the type and call name for selector expression", func() {
99+
pkg := testutils.NewTestPackage()
100+
defer pkg.Close()
101+
pkg.AddFile("main.go", `
102+
package main
103+
104+
import(
105+
"bytes"
106+
)
107+
108+
func main() {
109+
b := new(bytes.Buffer)
110+
_, err := b.WriteString("test")
111+
if err != nil {
112+
panic(err)
113+
}
114+
}
115+
`)
116+
ctx := pkg.CreateContext("main.go")
117+
result := map[string]string{}
118+
visitor := testutils.NewMockVisitor()
119+
visitor.Context = ctx
120+
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
121+
typeName, call, err := gosec.GetCallInfo(n, ctx)
122+
if err == nil {
123+
result[typeName] = call
124+
}
125+
return true
126+
}
127+
ast.Walk(visitor, ctx.Root)
128+
129+
Expect(result).Should(HaveKeyWithValue("*bytes.Buffer", "WriteString"))
130+
})
131+
132+
It("should return the type and call name for new selector expression", func() {
133+
pkg := testutils.NewTestPackage()
134+
defer pkg.Close()
135+
pkg.AddFile("main.go", `
136+
package main
137+
138+
import(
139+
"bytes"
140+
)
141+
142+
func main() {
143+
_, err := new(bytes.Buffer).WriteString("test")
144+
if err != nil {
145+
panic(err)
146+
}
147+
}
148+
`)
149+
ctx := pkg.CreateContext("main.go")
150+
result := map[string]string{}
151+
visitor := testutils.NewMockVisitor()
152+
visitor.Context = ctx
153+
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
154+
typeName, call, err := gosec.GetCallInfo(n, ctx)
155+
if err == nil {
156+
result[typeName] = call
157+
}
158+
return true
159+
}
160+
ast.Walk(visitor, ctx.Root)
161+
162+
Expect(result).Should(HaveKeyWithValue("bytes.Buffer", "WriteString"))
163+
})
164+
165+
It("should return the type and call name for function selector expression", func() {
166+
pkg := testutils.NewTestPackage()
167+
defer pkg.Close()
168+
pkg.AddFile("main.go", `
169+
package main
170+
171+
import(
172+
"bytes"
173+
)
174+
175+
func createBuffer() *bytes.Buffer {
176+
return new(bytes.Buffer)
177+
}
178+
179+
func main() {
180+
_, err := createBuffer().WriteString("test")
181+
if err != nil {
182+
panic(err)
183+
}
184+
}
185+
`)
186+
ctx := pkg.CreateContext("main.go")
187+
result := map[string]string{}
188+
visitor := testutils.NewMockVisitor()
189+
visitor.Context = ctx
190+
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
191+
typeName, call, err := gosec.GetCallInfo(n, ctx)
192+
if err == nil {
193+
result[typeName] = call
194+
}
195+
return true
196+
}
197+
ast.Walk(visitor, ctx.Root)
198+
199+
Expect(result).Should(HaveKeyWithValue("*bytes.Buffer", "WriteString"))
200+
})
201+
202+
It("should return the type and call name for package function", func() {
203+
pkg := testutils.NewTestPackage()
204+
defer pkg.Close()
205+
pkg.AddFile("main.go", `
206+
package main
207+
208+
import(
209+
"fmt"
210+
)
211+
212+
func main() {
213+
fmt.Println("test")
214+
}
215+
`)
216+
ctx := pkg.CreateContext("main.go")
217+
result := map[string]string{}
218+
visitor := testutils.NewMockVisitor()
219+
visitor.Context = ctx
220+
visitor.Callback = func(n ast.Node, ctx *gosec.Context) bool {
221+
typeName, call, err := gosec.GetCallInfo(n, ctx)
222+
if err == nil {
223+
result[typeName] = call
224+
}
225+
return true
226+
}
227+
ast.Walk(visitor, ctx.Root)
228+
229+
Expect(result).Should(HaveKeyWithValue("fmt", "Println"))
230+
})
231+
})
94232
})

rules/archive.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func (a *archive) ID() string {
1919

2020
// Match inspects AST nodes to determine if the filepath.Joins uses any argument derived from type zip.File
2121
func (a *archive) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
22-
if node := a.calls.ContainsCallExpr(n, c, false); node != nil {
22+
if node := a.calls.ContainsPkgCallExpr(n, c, false); node != nil {
2323
for _, arg := range node.Args {
2424
var argType types.Type
2525
if selector, ok := arg.(*ast.SelectorExpr); ok {

rules/bind.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func (r *bindsToAllNetworkInterfaces) ID() string {
3333
}
3434

3535
func (r *bindsToAllNetworkInterfaces) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
36-
callExpr := r.calls.ContainsCallExpr(n, c, false)
36+
callExpr := r.calls.ContainsPkgCallExpr(n, c, false)
3737
if callExpr == nil {
3838
return nil, nil
3939
}

rules/decompression-bomb.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ package rules
1616

1717
import (
1818
"fmt"
19-
"github.com/securego/gosec"
2019
"go/ast"
20+
21+
"github.com/securego/gosec"
2122
)
2223

2324
type decompressionBombCheck struct {
@@ -31,15 +32,12 @@ func (d *decompressionBombCheck) ID() string {
3132
}
3233

3334
func containsReaderCall(node ast.Node, ctx *gosec.Context, list gosec.CallList) bool {
34-
if list.ContainsCallExpr(node, ctx, false) != nil {
35+
if list.ContainsPkgCallExpr(node, ctx, false) != nil {
3536
return true
3637
}
3738
// Resolve type info of ident (for *archive/zip.File.Open)
3839
s, idt, _ := gosec.GetCallInfo(node, ctx)
39-
if list.Contains(s, idt) {
40-
return true
41-
}
42-
return false
40+
return list.Contains(s, idt)
4341
}
4442

4543
func (d *decompressionBombCheck) Match(node ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
@@ -70,7 +68,7 @@ func (d *decompressionBombCheck) Match(node ast.Node, ctx *gosec.Context) (*gose
7068
}
7169
}
7270
case *ast.CallExpr:
73-
if d.copyCalls.ContainsCallExpr(n, ctx, false) != nil {
71+
if d.copyCalls.ContainsPkgCallExpr(n, ctx, false) != nil {
7472
if idt, ok := n.Args[1].(*ast.Ident); ok {
7573
if _, ok := readerVarObj[idt.Obj]; ok {
7674
// Detect io.Copy(x, r)

rules/errors.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
5555
cfg := ctx.Config
5656
if enabled, err := cfg.IsGlobalEnabled(gosec.Audit); err == nil && enabled {
5757
for _, expr := range stmt.Rhs {
58-
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx, false) == nil {
58+
if callExpr, ok := expr.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(expr, ctx) == nil {
5959
pos := returnsError(callExpr, ctx)
6060
if pos < 0 || pos >= len(stmt.Lhs) {
6161
return nil, nil
@@ -67,7 +67,7 @@ func (r *noErrorCheck) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, erro
6767
}
6868
}
6969
case *ast.ExprStmt:
70-
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx, false) == nil {
70+
if callExpr, ok := stmt.X.(*ast.CallExpr); ok && r.whitelist.ContainsCallExpr(stmt.X, ctx) == nil {
7171
pos := returnsError(callExpr, ctx)
7272
if pos >= 0 {
7373
return gosec.NewIssue(ctx, n, r.ID(), r.What, r.Severity, r.Confidence), nil

0 commit comments

Comments
 (0)