Skip to content

Commit a18311e

Browse files
committed
pr feedback (sethf)
1 parent 94b54c2 commit a18311e

File tree

4 files changed

+84
-73
lines changed

4 files changed

+84
-73
lines changed

relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@ import (
1515
"github.com/fullstorydev/relay-core/relay/version"
1616
)
1717

18-
type Encoding int
19-
20-
const (
21-
Identity Encoding = iota
22-
Gzip
23-
)
24-
2518
func TestContentBlocking(t *testing.T) {
2619
testCases := []contentBlockerTestCase{
2720
{
@@ -141,8 +134,8 @@ func TestContentBlocking(t *testing.T) {
141134
}
142135

143136
for _, testCase := range testCases {
144-
runContentBlockerTest(t, testCase, Identity)
145-
runContentBlockerTest(t, testCase, Gzip)
137+
runContentBlockerTest(t, testCase, traffic.Identity)
138+
runContentBlockerTest(t, testCase, traffic.Gzip)
146139
}
147140
}
148141

@@ -194,12 +187,12 @@ type contentBlockerTestCase struct {
194187
expectedHeaders map[string]string
195188
}
196189

197-
func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding Encoding) {
190+
func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding traffic.Encoding) {
198191
var encodingStr string
199192
switch encoding {
200-
case Gzip:
193+
case traffic.Gzip:
201194
encodingStr = "gzip"
202-
case Identity:
195+
case traffic.Identity:
203196
encodingStr = ""
204197
}
205198

@@ -223,7 +216,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi
223216
expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease
224217

225218
test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) {
226-
b, err := traffic.EncodeData([]byte(testCase.originalBody), encodingStr)
219+
b, err := traffic.EncodeData([]byte(testCase.originalBody), encoding)
227220
if err != nil {
228221
t.Errorf("Test '%v': Error encoding data: %v", desc, err)
229222
return
@@ -239,7 +232,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi
239232
return
240233
}
241234

242-
if encoding == Gzip {
235+
if encoding == traffic.Gzip {
243236
request.Header.Set("Content-Encoding", "gzip")
244237
}
245238

@@ -309,7 +302,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi
309302
)
310303
}
311304

312-
decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encodingStr)
305+
decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encoding)
313306
if err != nil {
314307
t.Errorf("Test '%v': Error decoding data: %v", desc, err)
315308
return

relay/traffic/encoding.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@ package traffic
33
import (
44
"bytes"
55
"compress/gzip"
6+
"fmt"
67
"io"
78
"net/http"
89
"net/url"
910
"strings"
1011
)
1112

12-
func GetContentEncoding(request *http.Request) (string, error) {
13+
type Encoding int
14+
15+
const (
16+
Unsupported Encoding = iota
17+
Identity
18+
Gzip
19+
)
20+
21+
func GetContentEncoding(request *http.Request) (Encoding, error) {
1322
// NOTE: This is a workaround for a bug in post-Go 1.17. See golang.org/issue/25192.
1423
// Our algorithm differs from the logic of AllowQuerySemicolons by replacing semicolons with encoded semicolons instead
1524
// of with ampersands. This is because we want to preserve the original query string as much as possible.
@@ -19,37 +28,47 @@ func GetContentEncoding(request *http.Request) (string, error) {
1928

2029
queryParams, err := url.ParseQuery(request.URL.RawQuery)
2130
if err != nil {
22-
return "", err
31+
return Unsupported, err
2332
}
2433

2534
// request query parameter takes precedence over request header
2635
encoding := queryParams.Get("ContentEncoding")
2736
if encoding == "" {
2837
encoding = request.Header.Get("Content-Encoding")
2938
}
30-
return encoding, nil
39+
40+
switch encoding {
41+
case "gzip":
42+
return Gzip, nil
43+
case "":
44+
return Identity, nil
45+
default:
46+
return Unsupported, fmt.Errorf("unsupported encoding: %v", encoding)
47+
}
3148
}
3249

3350
// WrapReader checks if the request Content-Encoding or request query parameter indicates gzip compression.
3451
// If so, it returns a gzip.Reader that decompresses the content.
35-
func WrapReader(request *http.Request, encoding string) (io.ReadCloser, error) {
52+
func WrapReader(request *http.Request, encoding Encoding) (io.ReadCloser, error) {
3653
if request.Body == nil {
3754
return nil, nil
3855
}
3956

4057
switch encoding {
41-
case "gzip":
58+
case Gzip:
4259
// Create a new gzip.Reader to decompress the request body
4360
return gzip.NewReader(request.Body)
44-
default:
61+
case Identity:
4562
// If the content is not gzip-compressed, return the original request body
4663
return request.Body, nil
64+
default:
65+
return nil, fmt.Errorf("unsupported encoding: %v", encoding)
4766
}
4867
}
4968

50-
func EncodeData(data []byte, encoding string) ([]byte, error) {
69+
func EncodeData(data []byte, encoding Encoding) ([]byte, error) {
5170
switch encoding {
52-
case "gzip":
71+
case Gzip:
5372
var buf bytes.Buffer
5473
gz := gzip.NewWriter(&buf)
5574

@@ -65,15 +84,16 @@ func EncodeData(data []byte, encoding string) ([]byte, error) {
6584

6685
compressedData := buf.Bytes()
6786
return compressedData, nil
68-
default:
69-
// identity encoding
87+
case Identity:
7088
return data, nil
89+
default:
90+
return nil, fmt.Errorf("unsupported encoding: %v", encoding)
7191
}
7292
}
7393

74-
func DecodeData(data []byte, encoding string) ([]byte, error) {
94+
func DecodeData(data []byte, encoding Encoding) ([]byte, error) {
7595
switch encoding {
76-
case "gzip":
96+
case Gzip:
7797
reader, err := gzip.NewReader(bytes.NewReader(data))
7898
if err != nil {
7999
return nil, err
@@ -85,8 +105,9 @@ func DecodeData(data []byte, encoding string) ([]byte, error) {
85105
}
86106

87107
return decodedData, nil
88-
default:
89-
// identity encoding
108+
case Identity:
90109
return data, nil
110+
default:
111+
return nil, fmt.Errorf("unsupported encoding: %v", encoding)
91112
}
92113
}

relay/traffic/handler.go

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re
6060

6161
encoding, err := GetContentEncoding(request)
6262
if err != nil {
63-
logger.Printf("URL %v error getting request content encoding: %v", request.URL, err)
63+
logger.Printf("URL %v error in request content encoding: %v", request.URL, err)
6464
request.Body = http.NoBody
6565
return
6666
}
@@ -95,7 +95,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re
9595
}
9696

9797
// prepareRequestBody wraps the request Body with a reader that will decode the content if necessary.
98-
func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding string) error {
98+
func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding Encoding) error {
9999
if reader, err := WrapReader(clientRequest, encoding); err != nil {
100100
return err
101101
} else if reader != nil && reader != http.NoBody {
@@ -104,7 +104,7 @@ func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding
104104
return nil
105105
}
106106

107-
func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding string) bool {
107+
func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding Encoding) bool {
108108
if serviced {
109109
return false
110110
}
@@ -124,35 +124,39 @@ func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, client
124124
}
125125
}
126126

127-
func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding string) {
128-
if encoding == "" || encoding == "identity" {
127+
func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding Encoding) {
128+
switch encoding {
129+
case Unsupported:
130+
logger.Println("Error unsupported content-encoding")
129131
return
130-
}
131-
132-
servicedBody, err := io.ReadAll(clientRequest.Body)
133-
if err != nil {
134-
logger.Printf("Error reading request body: %s", err)
135-
clientRequest.Body = http.NoBody
132+
case Identity:
136133
return
137-
}
134+
case Gzip:
135+
servicedBody, err := io.ReadAll(clientRequest.Body)
136+
if err != nil {
137+
logger.Printf("Error reading request body: %s", err)
138+
clientRequest.Body = http.NoBody
139+
return
140+
}
138141

139-
if encodedData, err := EncodeData(servicedBody, encoding); err != nil {
140-
logger.Printf("Error encoding request body: %s", err)
141-
clientRequest.Body = http.NoBody
142-
return
143-
} else {
144-
servicedBody = encodedData
145-
}
142+
if encodedData, err := EncodeData(servicedBody, encoding); err != nil {
143+
logger.Printf("Error encoding request body: %s", err)
144+
clientRequest.Body = http.NoBody
145+
return
146+
} else {
147+
servicedBody = encodedData
148+
}
146149

147-
// If the length of the body has changed, we should update the
148-
// Content-Length header too.
149-
contentLength := int64(len(servicedBody))
150-
if contentLength != clientRequest.ContentLength {
151-
clientRequest.ContentLength = contentLength
152-
clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
153-
}
150+
// If the length of the body has changed, we should update the
151+
// Content-Length header too.
152+
contentLength := int64(len(servicedBody))
153+
if contentLength != clientRequest.ContentLength {
154+
clientRequest.ContentLength = contentLength
155+
clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
156+
}
154157

155-
clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody))
158+
clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody))
159+
}
156160

157161
}
158162

relay/traffic/traffic_test.go

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -151,33 +151,26 @@ func TestMaxBodySize(t *testing.T) {
151151
})
152152
}
153153

154-
type Encoding int
155-
156-
const (
157-
Identity Encoding = iota
158-
Gzip
159-
)
160-
161154
func TestRelaySupportsContentEncoding(t *testing.T) {
162155
testCases := map[string]struct {
163-
encoding Encoding
156+
encoding traffic.Encoding
164157
bodyContentStr string
165158
headers map[string]string
166159
customUrl func(relayServiceURL string) string
167160
}{
168161
"identity": {
169-
encoding: Identity,
162+
encoding: traffic.Identity,
170163
bodyContentStr: "Hello, world!",
171164
},
172165
"gzip - with header": {
173-
encoding: Gzip,
166+
encoding: traffic.Gzip,
174167
bodyContentStr: "Hello, world!",
175168
headers: map[string]string{
176169
"Content-Encoding": "gzip",
177170
},
178171
},
179172
"gzip - with query param": {
180-
encoding: Gzip,
173+
encoding: traffic.Gzip,
181174
bodyContentStr: "Hello, world!",
182175
customUrl: func(relayServiceURL string) string {
183176
return fmt.Sprintf("%v?ContentEncoding=gzip", relayServiceURL)
@@ -190,14 +183,14 @@ func TestRelaySupportsContentEncoding(t *testing.T) {
190183
// convert the body content to a reader with the proper content encoding applied
191184
var body io.Reader
192185
switch testCase.encoding {
193-
case Gzip:
194-
b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), "gzip")
186+
case traffic.Gzip:
187+
b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), traffic.Gzip)
195188
if err != nil {
196189
t.Errorf("Test %s - Error encoding data: %v", desc, err)
197190
return
198191
}
199192
body = bytes.NewReader(b)
200-
case Identity:
193+
case traffic.Identity:
201194
body = strings.NewReader(testCase.bodyContentStr)
202195
}
203196

@@ -235,16 +228,16 @@ func TestRelaySupportsContentEncoding(t *testing.T) {
235228
}
236229

237230
switch testCase.encoding {
238-
case Gzip:
239-
decodedData, err := traffic.DecodeData(lastRequest, "gzip")
231+
case traffic.Gzip:
232+
decodedData, err := traffic.DecodeData(lastRequest, traffic.Gzip)
240233
if err != nil {
241234
t.Errorf("Test %s - Error decoding data: %v", desc, err)
242235
return
243236
}
244237
if string(decodedData) != testCase.bodyContentStr {
245238
t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(decodedData))
246239
}
247-
case Identity:
240+
case traffic.Identity:
248241
if string(lastRequest) != testCase.bodyContentStr {
249242
t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(lastRequest))
250243
}

0 commit comments

Comments
 (0)