1+ from copy import deepcopy
12import logging
23from typing import Any , Dict , List , Optional
34
1516logger = logging .getLogger (__name__ )
1617
1718
18- def _to_single_value_headers (request_headers : Dict [str , List [str ]]) -> Dict [str , str ]:
19+ def _to_single_value_headers (headers : Dict [str , List [str ]]) -> Dict [str , str ]:
1920 """
2021 Convert multi-value headers to single-value headers.
21- If a header has multiple values, the first value is used .
22+ If a header has multiple values, join them with commas .
2223 """
2324 single_value_headers = {}
24- for key , values in request_headers .items ():
25- if len (values ) >= 1 :
26- single_value_headers [key ] = values [0 ]
25+ for key , values in headers .items ():
26+ single_value_headers [key ] = ", " .join (values )
2727 return single_value_headers
2828
2929
30+ def _merge_single_and_multi_value_headers (
31+ single_value_headers : Dict [str , str ],
32+ multi_value_headers : Dict [str , List [str ]],
33+ ):
34+ """
35+ Merge single-value headers with multi-value headers.
36+ If a header exists in both, we merge them removing duplicates
37+ """
38+ merged_headers = deepcopy (multi_value_headers )
39+ for key , value in single_value_headers .items ():
40+ if key not in merged_headers :
41+ merged_headers [key ] = [value ]
42+ elif value not in merged_headers [key ]:
43+ merged_headers [key ].append (value )
44+ return _to_single_value_headers (merged_headers )
45+
46+
3047def asm_start_request (
3148 span : Span ,
3249 event : Dict [str , Any ],
@@ -36,6 +53,7 @@ def asm_start_request(
3653 request_headers : Dict [str , str ] = {}
3754 peer_ip : Optional [str ] = None
3855 request_path_parameters : Optional [Dict [str , Any ]] = None
56+ route : Optional [str ] = None
3957
4058 if event_source .event_type == EventTypes .ALB :
4159 headers = event .get ("headers" )
@@ -59,11 +77,10 @@ def asm_start_request(
5977 elif event_source .event_type == EventTypes .API_GATEWAY :
6078 request_context = event .get ("requestContext" , {})
6179 request_path_parameters = event .get ("pathParameters" )
80+ route = trigger_tags .get ("http.route" )
6281
6382 if event_source .subtype == EventSubtypes .API_GATEWAY :
64- request_headers = _to_single_value_headers (
65- event .get ("multiValueHeaders" , {})
66- )
83+ request_headers = event .get ("headers" , {})
6784 peer_ip = request_context .get ("identity" , {}).get ("sourceIp" )
6885 raw_uri = event .get ("path" )
6986 parsed_query = event .get ("multiValueQueryStringParameters" )
@@ -105,7 +122,7 @@ def asm_start_request(
105122 body ,
106123 is_base64_encoded ,
107124 raw_uri ,
108- trigger_tags . get ( "http. route" ) ,
125+ route ,
109126 trigger_tags .get ("http.method" ),
110127 parsed_query ,
111128 request_path_parameters ,
@@ -122,9 +139,14 @@ def asm_start_response(
122139 if event_source .event_type not in _http_event_types :
123140 return
124141
125- response_headers = response .get ("headers" , {})
126- if not isinstance (response_headers , dict ):
127- response_headers = {}
142+ headers = response .get ("headers" , {})
143+ multi_value_request_headers = response .get ("multiValueHeaders" )
144+ if multi_value_request_headers :
145+ response_headers = _merge_single_and_multi_value_headers (
146+ headers , multi_value_request_headers
147+ )
148+ else :
149+ response_headers = headers
128150
129151 core .dispatch (
130152 "aws_lambda.start_response" ,
0 commit comments