Skip to content

Commit 5242a2c

Browse files
committed
Extend helpers and call list
- Update call list to work directly with call expression - Add call list test cases - Extend helpers to add GetCallInfo to resolve call name and package or type if it's a var. - Add test cases to ensure correct behaviour
1 parent d29c648 commit 5242a2c

File tree

4 files changed

+139
-51
lines changed

4 files changed

+139
-51
lines changed

core/call_list.go

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,53 +13,53 @@
1313

1414
package core
1515

16-
type set map[string]bool
16+
import "go/ast"
1717

18-
type calls struct {
19-
matchAny bool
20-
functions set
21-
}
18+
type set map[string]bool
2219

2320
/// CallList is used to check for usage of specific packages
2421
/// and functions.
25-
type CallList map[string]*calls
22+
type CallList map[string]set
2623

2724
/// NewCallList creates a new empty CallList
2825
func NewCallList() CallList {
2926
return make(CallList)
3027
}
3128

3229
/// NewCallListFor createse a call list using the package path
33-
func NewCallListFor(pkg string, funcs ...string) CallList {
30+
func NewCallListFor(selector string, idents ...string) CallList {
3431
c := NewCallList()
35-
if len(funcs) == 0 {
36-
c[pkg] = &calls{true, make(set)}
37-
} else {
38-
for _, fn := range funcs {
39-
c.Add(pkg, fn)
40-
}
32+
c[selector] = make(set)
33+
for _, ident := range idents {
34+
c.Add(selector, ident)
4135
}
4236
return c
4337
}
4438

45-
/// Add a new package and function to the call list
46-
func (c CallList) Add(pkg, fn string) {
47-
if cl, ok := c[pkg]; ok {
48-
if cl.matchAny {
49-
cl.matchAny = false
50-
}
51-
} else {
52-
c[pkg] = &calls{false, make(set)}
39+
/// Add a selector and call to the call list
40+
func (c CallList) Add(selector, ident string) {
41+
if _, ok := c[selector]; !ok {
42+
c[selector] = make(set)
5343
}
54-
c[pkg].functions[fn] = true
44+
c[selector][ident] = true
5545
}
5646

5747
/// Contains returns true if the package and function are
5848
/// members of this call list.
59-
func (c CallList) Contains(pkg, fn string) bool {
60-
if funcs, ok := c[pkg]; ok {
61-
_, ok = funcs.functions[fn]
62-
return ok || funcs.matchAny
49+
func (c CallList) Contains(selector, ident string) bool {
50+
if idents, ok := c[selector]; ok {
51+
_, found := idents[ident]
52+
return found
6353
}
6454
return false
6555
}
56+
57+
/// ContainsCallExpr resolves the call expression name and type
58+
/// or package and determines if it exists within the CallList
59+
func (c CallList) ContainsCallExpr(n ast.Node, ctx *Context) bool {
60+
selector, ident, err := GetCallInfo(n, ctx)
61+
if err != nil {
62+
return false
63+
}
64+
return c.Contains(selector, ident)
65+
}

core/call_list_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package core
2+
3+
import (
4+
"go/ast"
5+
"testing"
6+
)
7+
8+
type callListRule struct {
9+
MetaData
10+
callList CallList
11+
matched int
12+
}
13+
14+
func (r *callListRule) Match(n ast.Node, c *Context) (gi *Issue, err error) {
15+
if r.callList.ContainsCallExpr(n, c) {
16+
r.matched += 1
17+
}
18+
return nil, nil
19+
}
20+
21+
func TestCallListContainsCallExpr(t *testing.T) {
22+
config := map[string]interface{}{"ignoreNosec": false}
23+
analyzer := NewAnalyzer(config, nil)
24+
rule := &callListRule{
25+
MetaData: MetaData{
26+
Severity: Low,
27+
Confidence: Low,
28+
What: "A dummy rule",
29+
},
30+
callList: NewCallListFor("bytes.Buffer", "Write", "WriteTo"),
31+
matched: 0,
32+
}
33+
analyzer.AddRule(rule, []ast.Node{(*ast.CallExpr)(nil)})
34+
source := `
35+
package main
36+
import (
37+
"bytes"
38+
"fmt"
39+
)
40+
func main() {
41+
var b bytes.Buffer
42+
b.Write([]byte("Hello "))
43+
fmt.Fprintf(&b, "world!")
44+
}`
45+
46+
analyzer.ProcessSource("dummy.go", source)
47+
if rule.matched != 1 {
48+
t.Errorf("Expected to match a bytes.Buffer.Write call")
49+
}
50+
}
51+
52+
func TestCallListContains(t *testing.T) {
53+
callList := NewCallList()
54+
callList.Add("fmt", "Printf")
55+
if !callList.Contains("fmt", "Printf") {
56+
t.Errorf("Expected call list to contain fmt.Printf")
57+
}
58+
}

core/helpers.go

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
6969
importName = alias
7070
}
7171

72-
switch node := n.(type) {
73-
case *ast.CallExpr:
74-
switch fn := node.Fun.(type) {
75-
case *ast.SelectorExpr:
76-
switch expr := fn.X.(type) {
77-
case *ast.Ident:
78-
if expr.Name == importName {
79-
for _, name := range names {
80-
if fn.Sel.Name == name {
81-
return node, true
82-
}
83-
}
72+
if callExpr, ok := n.(*ast.CallExpr); ok {
73+
packageName, callName, err := GetCallInfo(callExpr, c)
74+
if err != nil {
75+
return nil, false
76+
}
77+
if packageName == importName {
78+
for _, name := range names {
79+
if callName == name {
80+
return callExpr, true
8481
}
8582
}
8683
}
@@ -95,19 +92,15 @@ func MatchCallByPackage(n ast.Node, c *Context, pkg string, names ...string) (*a
9592
// node, matched := MatchCallByType(n, ctx, "bytes.Buffer", "WriteTo", "Write")
9693
//
9794
func MatchCallByType(n ast.Node, ctx *Context, requiredType string, calls ...string) (*ast.CallExpr, bool) {
98-
switch callExpr := n.(type) {
99-
case *ast.CallExpr:
100-
switch fn := callExpr.Fun.(type) {
101-
case *ast.SelectorExpr:
102-
switch expr := fn.X.(type) {
103-
case *ast.Ident:
104-
t := ctx.Info.TypeOf(expr)
105-
if t != nil && t.String() == requiredType {
106-
for _, call := range calls {
107-
if fn.Sel.Name == call {
108-
return callExpr, true
109-
}
110-
}
95+
if callExpr, ok := n.(*ast.CallExpr); ok {
96+
typeName, callName, err := GetCallInfo(callExpr, ctx)
97+
if err != nil {
98+
return nil, false
99+
}
100+
if typeName == requiredType {
101+
for _, call := range calls {
102+
if call == callName {
103+
return callExpr, true
111104
}
112105
}
113106
}
@@ -171,3 +164,28 @@ func GetCallObject(n ast.Node, ctx *Context) (*ast.CallExpr, types.Object) {
171164
}
172165
return nil, nil
173166
}
167+
168+
// GetCallInfo returns the package or type and name associated with a
169+
// call expression.
170+
func GetCallInfo(n ast.Node, ctx *Context) (string, string, error) {
171+
switch node := n.(type) {
172+
case *ast.CallExpr:
173+
switch fn := node.Fun.(type) {
174+
case *ast.SelectorExpr:
175+
switch expr := fn.X.(type) {
176+
case *ast.Ident:
177+
if expr.Obj != nil && expr.Obj.Kind == ast.Var {
178+
t := ctx.Info.TypeOf(expr)
179+
if t != nil {
180+
return t.String(), fn.Sel.Name, nil
181+
} else {
182+
return "undefined", fn.Sel.Name, fmt.Errorf("missing type info")
183+
}
184+
} else {
185+
return expr.Name, fn.Sel.Name, nil
186+
}
187+
}
188+
}
189+
}
190+
return "", "", fmt.Errorf("unable to determine call info")
191+
}

core/helpers_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,16 @@ func TestMatchCallByType(t *testing.T) {
5656
if rule.matched != 1 || len(rule.callExpr) != 1 {
5757
t.Errorf("Expected to match a bytes.Buffer.Write call")
5858
}
59+
60+
typeName, callName, err := GetCallInfo(rule.callExpr[0], &analyzer.context)
61+
if err != nil {
62+
t.Errorf("Unable to resolve call info: %v\n", err)
63+
}
64+
if typeName != "bytes.Buffer" {
65+
t.Errorf("Expected: %s, Got: %s\n", "bytes.Buffer", typeName)
66+
}
67+
if callName != "Write" {
68+
t.Errorf("Expected: %s, Got: %s\n", "Write", callName)
69+
}
70+
5971
}

0 commit comments

Comments
 (0)