Skip to content

Commit 7fd9446

Browse files
cschoenduve-splunkgcmurphy
authored andcommitted
update to G304 which adds binary expressions and file joining (#233)
* Added features to G304 * Linted * Added path selectors * Used better solution * removed debugging lines * fixed comments * Added test code * fixed a spacing change
1 parent e4ba96a commit 7fd9446

File tree

3 files changed

+133
-1
lines changed

3 files changed

+133
-1
lines changed

helpers.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,33 @@ func ConcatString(n *ast.BinaryExpr) (string, bool) {
281281
}
282282
return s, true
283283
}
284+
285+
// FindVarIdentities returns array of all variable identities in a given binary expression
286+
func FindVarIdentities(n *ast.BinaryExpr, c *Context) ([]*ast.Ident, bool) {
287+
identities := []*ast.Ident{}
288+
// sub expressions are found in X object, Y object is always the last term
289+
if rightOperand, ok := n.Y.(*ast.Ident); ok {
290+
obj := c.Info.ObjectOf(rightOperand)
291+
if _, ok := obj.(*types.Var); ok && !TryResolve(rightOperand, c) {
292+
identities = append(identities, rightOperand)
293+
}
294+
}
295+
if leftOperand, ok := n.X.(*ast.BinaryExpr); ok {
296+
if leftIdentities, ok := FindVarIdentities(leftOperand, c); ok {
297+
identities = append(identities, leftIdentities...)
298+
}
299+
} else {
300+
if leftOperand, ok := n.X.(*ast.Ident); ok {
301+
obj := c.Info.ObjectOf(leftOperand)
302+
if _, ok := obj.(*types.Var); ok && !TryResolve(leftOperand, c) {
303+
identities = append(identities, leftOperand)
304+
}
305+
}
306+
}
307+
308+
if len(identities) > 0 {
309+
return identities, true
310+
}
311+
// if nil or error, return false
312+
return nil, false
313+
}

rules/readfile.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,57 @@ import (
2424
type readfile struct {
2525
gosec.MetaData
2626
gosec.CallList
27+
pathJoin gosec.CallList
2728
}
2829

2930
// ID returns the identifier for this rule
3031
func (r *readfile) ID() string {
3132
return r.MetaData.ID
3233
}
3334

35+
// isJoinFunc checks if there is a filepath.Join or other join function
36+
func (r *readfile) isJoinFunc(n ast.Node, c *gosec.Context) bool {
37+
if call := r.pathJoin.ContainsCallExpr(n, c); call != nil {
38+
for _, arg := range call.Args {
39+
// edge case: check if one of the args is a BinaryExpr
40+
if binExp, ok := arg.(*ast.BinaryExpr); ok {
41+
// iterate and resolve all found identites from the BinaryExpr
42+
if _, ok := gosec.FindVarIdentities(binExp, c); ok {
43+
return true
44+
}
45+
}
46+
47+
// try and resolve identity
48+
if ident, ok := arg.(*ast.Ident); ok {
49+
obj := c.Info.ObjectOf(ident)
50+
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
51+
return true
52+
}
53+
}
54+
}
55+
}
56+
return false
57+
}
58+
3459
// Match inspects AST nodes to determine if the match the methods `os.Open` or `ioutil.ReadFile`
3560
func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
3661
if node := r.ContainsCallExpr(n, c); node != nil {
3762
for _, arg := range node.Args {
63+
// handles path joining functions in Arg
64+
// eg. os.Open(filepath.Join("/tmp/", file))
65+
if callExpr, ok := arg.(*ast.CallExpr); ok {
66+
if r.isJoinFunc(callExpr, c) {
67+
return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil
68+
}
69+
}
70+
// handles binary string concatenation eg. ioutil.Readfile("/tmp/" + file + "/blob")
71+
if binExp, ok := arg.(*ast.BinaryExpr); ok {
72+
// resolve all found identites from the BinaryExpr
73+
if _, ok := gosec.FindVarIdentities(binExp, c); ok {
74+
return gosec.NewIssue(c, n, r.ID(), r.What, r.Severity, r.Confidence), nil
75+
}
76+
}
77+
3878
if ident, ok := arg.(*ast.Ident); ok {
3979
obj := c.Info.ObjectOf(ident)
4080
if _, ok := obj.(*types.Var); ok && !gosec.TryResolve(ident, c) {
@@ -49,6 +89,7 @@ func (r *readfile) Match(n ast.Node, c *gosec.Context) (*gosec.Issue, error) {
4989
// NewReadFile detects cases where we read files
5090
func NewReadFile(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
5191
rule := &readfile{
92+
pathJoin: gosec.NewCallList(),
5293
CallList: gosec.NewCallList(),
5394
MetaData: gosec.MetaData{
5495
ID: id,
@@ -57,6 +98,8 @@ func NewReadFile(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
5798
Confidence: gosec.High,
5899
},
59100
}
101+
rule.pathJoin.Add("path/filepath", "Join")
102+
rule.pathJoin.Add("path", "Join")
60103
rule.Add("io/ioutil", "ReadFile")
61104
rule.Add("os", "Open")
62105
return rule, []ast.Node{(*ast.CallExpr)(nil)}

testutils/source.go

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ import (
500500
501501
func main() {
502502
http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) {
503-
title := r.URL.Query().Get("title")
503+
title := r.URL.Query().Get("title")
504504
f, err := os.Open(title)
505505
if err != nil {
506506
fmt.Printf("Error: %v\n", err)
@@ -512,6 +512,65 @@ func main() {
512512
fmt.Fprintf(w, "%s", body)
513513
})
514514
log.Fatal(http.ListenAndServe(":3000", nil))
515+
}`, 1}, {`
516+
package main
517+
518+
import (
519+
"log"
520+
"os"
521+
"io/ioutil"
522+
)
523+
524+
func main() {
525+
f2 := os.Getenv("tainted_file2")
526+
body, err := ioutil.ReadFile("/tmp/" + f2)
527+
if err != nil {
528+
log.Printf("Error: %v\n", err)
529+
}
530+
log.Print(body)
531+
}`, 1}, {`
532+
package main
533+
534+
import (
535+
"bufio"
536+
"fmt"
537+
"os"
538+
"path/filepath"
539+
)
540+
541+
func main() {
542+
reader := bufio.NewReader(os.Stdin)
543+
fmt.Print("Please enter file to read: ")
544+
file, _ := reader.ReadString('\n')
545+
file = file[:len(file)-1]
546+
f, err := os.Open(filepath.Join("/tmp/service/", file))
547+
if err != nil {
548+
fmt.Printf("Error: %v\n", err)
549+
}
550+
contents := make([]byte, 15)
551+
if _, err = f.Read(contents); err != nil {
552+
fmt.Printf("Error: %v\n", err)
553+
}
554+
fmt.Println(string(contents))
555+
}`, 1}, {`
556+
package main
557+
558+
import (
559+
"log"
560+
"os"
561+
"io/ioutil"
562+
"path/filepath"
563+
)
564+
565+
func main() {
566+
dir := os.Getenv("server_root")
567+
f3 := os.Getenv("tainted_file3")
568+
// edge case where both a binary expression and file Join are used.
569+
body, err := ioutil.ReadFile(filepath.Join("/var/"+dir, f3))
570+
if err != nil {
571+
log.Printf("Error: %v\n", err)
572+
}
573+
log.Print(body)
515574
}`, 1}}
516575

517576
// SampleCodeG305 - File path traversal when extracting zip archives

0 commit comments

Comments
 (0)