diff --git a/datadog_lambda/asm.py b/datadog_lambda/asm.py index 31f750d82..7bd8272f8 100644 --- a/datadog_lambda/asm.py +++ b/datadog_lambda/asm.py @@ -73,17 +73,16 @@ def asm_start_request( route: Optional[str] = None if event_source.event_type == EventTypes.ALB: - headers = event.get("headers") - multi_value_request_headers = event.get("multiValueHeaders") - if multi_value_request_headers: - request_headers = _to_single_value_headers(multi_value_request_headers) - else: - request_headers = headers or {} - raw_uri = event.get("path") - parsed_query = event.get("multiValueQueryStringParameters") or event.get( - "queryStringParameters" - ) + + if event_source.subtype == EventSubtypes.ALB: + request_headers = event.get("headers", {}) + parsed_query = event.get("queryStringParameters") + if event_source.subtype == EventSubtypes.ALB_MULTI_VALUE_HEADERS: + request_headers = _to_single_value_headers( + event.get("multiValueHeaders", {}) + ) + parsed_query = event.get("multiValueQueryStringParameters") elif event_source.event_type == EventTypes.LAMBDA_FUNCTION_URL: request_headers = event.get("headers", {}) @@ -226,15 +225,27 @@ def get_asm_blocked_response( content_type = blocked.get("content-type", "application/json") content = http_utils._get_blocked_template(content_type) - response_headers = { - "content-type": content_type, - } - if "location" in blocked: - response_headers["location"] = blocked["location"] - - return { + response = { "statusCode": blocked.get("status_code", 403), - "headers": response_headers, "body": content, "isBase64Encoded": False, } + + needs_multi_value_headers = event_source.equals( + EventTypes.ALB, EventSubtypes.ALB_MULTI_VALUE_HEADERS + ) + + if needs_multi_value_headers: + response["multiValueHeaders"] = { + "content-type": [content_type], + } + if "location" in blocked: + response["multiValueHeaders"]["location"] = [blocked["location"]] + else: + response["headers"] = { + "content-type": content_type, + } + if "location" in blocked: + response["headers"]["location"] = blocked["location"] + + return response diff --git a/datadog_lambda/trigger.py b/datadog_lambda/trigger.py index 65800cf72..e0c3c4fa1 100644 --- a/datadog_lambda/trigger.py +++ b/datadog_lambda/trigger.py @@ -54,6 +54,10 @@ class EventSubtypes(_stringTypedEnum): WEBSOCKET = "websocket" HTTP_API = "http-api" + ALB = "alb" # regular alb + # ALB with the multi-value headers option checked on the target group + ALB_MULTI_VALUE_HEADERS = "alb-multi-value-headers" + class _EventSource: """ @@ -133,7 +137,12 @@ def parse_event_source(event: dict) -> _EventSource: event_source.subtype = EventSubtypes.WEBSOCKET if request_context and request_context.get("elb"): - event_source = _EventSource(EventTypes.ALB) + if "multiValueHeaders" in event: + event_source = _EventSource( + EventTypes.ALB, EventSubtypes.ALB_MULTI_VALUE_HEADERS + ) + else: + event_source = _EventSource(EventTypes.ALB, EventSubtypes.ALB) if event.get("awslogs"): event_source = _EventSource(EventTypes.CLOUDWATCH_LOGS) diff --git a/tests/event_samples/application-load-balancer-mutivalue-headers.json b/tests/event_samples/application-load-balancer-multivalue-headers.json similarity index 96% rename from tests/event_samples/application-load-balancer-mutivalue-headers.json rename to tests/event_samples/application-load-balancer-multivalue-headers.json index 6d446d15c..a35ca5023 100644 --- a/tests/event_samples/application-load-balancer-mutivalue-headers.json +++ b/tests/event_samples/application-load-balancer-multivalue-headers.json @@ -6,7 +6,7 @@ }, "httpMethod": "GET", "path": "/lambda", - "queryStringParameters": { + "multiValueQueryStringParameters": { "query": "1234ABCD" }, "multiValueHeaders": { diff --git a/tests/test_asm.py b/tests/test_asm.py index 7a5e6c560..1e11b102d 100644 --- a/tests/test_asm.py +++ b/tests/test_asm.py @@ -8,6 +8,7 @@ get_asm_blocked_response, ) from datadog_lambda.trigger import ( + EventSubtypes, EventTypes, _EventSource, extract_trigger_tags, @@ -34,7 +35,7 @@ ), ( "application_load_balancer_multivalue_headers", - "application-load-balancer-mutivalue-headers.json", + "application-load-balancer-multivalue-headers.json", "72.12.164.125", "/lambda?query=1234ABCD", "GET", @@ -111,7 +112,7 @@ ), ( "application_load_balancer_multivalue_headers", - "application-load-balancer-mutivalue-headers.json", + "application-load-balancer-multivalue-headers.json", { "statusCode": 404, "multiValueHeaders": { @@ -397,6 +398,25 @@ def test_get_asm_blocked_response_blocked( response = get_asm_blocked_response(event_source) assert response["statusCode"] == expected_status assert response["headers"] == expected_headers + assert "multiValueHeaders" not in response + + +@patch("datadog_lambda.asm.get_blocked") +def test_get_asm_blocked_response_blocked_multi_value_headers( + mock_get_blocked, +): + # HTML blocking response + mock_get_blocked.return_value = { + "status_code": 401, + "type": "html", + "content-type": "text/html", + } + + event_source = _EventSource(EventTypes.ALB, EventSubtypes.ALB_MULTI_VALUE_HEADERS) + response = get_asm_blocked_response(event_source) + assert response["statusCode"] == 401 + assert response["multiValueHeaders"] == {"content-type": ["text/html"]} + assert "headers" not in response @patch("datadog_lambda.asm.get_blocked") diff --git a/tests/test_trigger.py b/tests/test_trigger.py index 452635cfc..182e61d8e 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -5,6 +5,7 @@ from datadog_lambda.trigger import ( EventSubtypes, + EventTypes, parse_event_source, get_event_source_arn, extract_trigger_tags, @@ -117,6 +118,22 @@ def test_event_source_application_load_balancer(self): event_source = parse_event_source(event) event_source_arn = get_event_source_arn(event_source, event, ctx) self.assertEqual(event_source.to_string(), event_sample_source) + self.assertEqual(event_source.subtype, EventSubtypes.ALB) + self.assertEqual( + event_source_arn, + "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-xyz/123abc", + ) + + def test_event_source_application_load_balancer_multi_value_headers(self): + event_sample_source = "application-load-balancer-multivalue-headers" + test_file = event_samples + event_sample_source + ".json" + with open(test_file, "r") as event: + event = json.load(event) + ctx = get_mock_context() + event_source = parse_event_source(event) + event_source_arn = get_event_source_arn(event_source, event, ctx) + self.assertEqual(event_source.event_type, EventTypes.ALB) + self.assertEqual(event_source.subtype, EventSubtypes.ALB_MULTI_VALUE_HEADERS) self.assertEqual( event_source_arn, "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-xyz/123abc",