From ee2654e8799cbfa983a27bb1fd6cdb14d2dee379 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Sat, 4 Apr 2026 19:08:55 +0200 Subject: [PATCH] feat(event_handler): add support for cookie field in OpenAPI utility --- .../middlewares/openapi_validation.py | 9 +- .../event_handler/openapi/params.py | 73 +++++ .../data_classes/api_gateway_proxy_event.py | 11 + .../utilities/data_classes/common.py | 41 +++ .../test_openapi_validation_middleware.py | 295 ++++++++++++++++++ 5 files changed, 428 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index dbfc0a6f9d7..1bfe416dac7 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -96,10 +96,17 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> headers, ) + # Process cookie values + cookie_values, cookie_errors = _request_params_to_args( + route.dependant.cookie_params, + app.current_event.resolved_cookies_field, + ) + values.update(path_values) values.update(query_values) values.update(header_values) - errors += path_errors + query_errors + header_errors + values.update(cookie_values) + errors += path_errors + query_errors + header_errors + cookie_errors # Process the request body, if it exists if route.dependant.body_params: diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 8b70b7cb074..534a20a5686 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -658,6 +658,79 @@ def alias(self, value: str | None = None): self._alias = value.lower() +class Cookie(Param): # type: ignore[misc] + """ + A class used internally to represent a cookie parameter in a path operation. + """ + + in_ = ParamTypes.cookie + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Callable[[], Any] | None = _Unset, + annotation: Any | None = None, + alias: str | None = None, + alias_priority: int | None = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: str | None = _Unset, + serialization_alias: str | None = None, + title: str | None = None, + description: str | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, + discriminator: str | None = None, + strict: bool | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + examples: list[Any] | None = None, + openapi_examples: dict[str, Example] | None = None, + deprecated: bool | None = None, + include_in_schema: bool = True, + json_schema_extra: dict[str, Any] | None = None, + **extra: Any, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + openapi_examples=openapi_examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + class Body(FieldInfo): # type: ignore[misc] """ A class used internally to represent a body parameter in a path operation. diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 540e86a5c51..6e24873e6d7 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -280,6 +280,17 @@ def raw_query_string(self) -> str: def cookies(self) -> list[str]: return self.get("cookies") or [] + @property + def resolved_cookies_field(self) -> dict[str, str]: + """ + Parse cookies from the dedicated ``cookies`` field in API Gateway HTTP API v2 format. + + The ``cookies`` field contains a list of strings like ``["session=abc", "theme=dark"]``. + """ + from aws_lambda_powertools.utilities.data_classes.common import _parse_cookie_string + + return _parse_cookie_string("; ".join(self.cookies)) + @property def request_context(self) -> RequestContextV2: return RequestContextV2(self["requestContext"]) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index a85d7b2d2a8..c35a02cc3ce 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -29,6 +29,17 @@ ) +def _parse_cookie_string(cookie_string: str) -> dict[str, str]: + """Parse a cookie string (``key=value; key2=value2``) into a dict.""" + cookies: dict[str, str] = {} + for segment in cookie_string.split(";"): + stripped = segment.strip() + if "=" in stripped: + name, _, value = stripped.partition("=") + cookies[name.strip()] = value.strip() + return cookies + + class CaseInsensitiveDict(dict): """Case insensitive dict implementation. Assumes string keys only.""" @@ -203,6 +214,36 @@ def resolved_headers_field(self) -> dict[str, str]: """ return self.headers + @property + def resolved_cookies_field(self) -> dict[str, str]: + """ + This property extracts cookies from the request as a dict of name-value pairs. + + By default, cookies are parsed from the ``Cookie`` header. + Uses ``self.headers`` (CaseInsensitiveDict) first for reliable case-insensitive + lookup, then falls back to ``resolved_headers_field`` for proxies that only + populate multi-value headers (e.g., ALB without single-value headers). + Subclasses may override this for event formats that provide cookies + in a dedicated field (e.g., API Gateway HTTP API v2). + """ + # Primary: self.headers is CaseInsensitiveDict — case-insensitive lookup + cookie_value: str | list[str] = self.headers.get("cookie") or "" + + # Fallback: resolved_headers_field covers ALB/REST v1 multi-value headers + # where the event may not have a single-value 'headers' dict at all + if not cookie_value: + headers = self.resolved_headers_field or {} + cookie_value = headers.get("cookie") or headers.get("Cookie") or "" + + # Multi-value headers (ALB, REST v1) may return a list + if isinstance(cookie_value, list): + cookie_value = "; ".join(cookie_value) + + if not cookie_value: + return {} + + return _parse_cookie_string(cookie_value) + @property def is_base64_encoded(self) -> bool | None: return self.get("isBase64Encoded") diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index 21bc9b26e0a..01935d3aba3 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -3327,3 +3327,298 @@ def handler(items: Annotated[Union[_Item, List[_Item]], Body()]) -> Dict[str, An status, body = _post_json(app, "/items", big_payload) assert status == 200 assert body["count"] == 100 + + +# ---------- Cookie parameter tests ---------- + + +def test_cookie_param_basic(gw_event): + """Test basic cookie parameter extraction from REST API v1 (Cookie header).""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler(session_id: Annotated[str, Cookie()]): + return {"session_id": session_id} + + gw_event["path"] = "/me" + gw_event["headers"]["cookie"] = "session_id=abc123; theme=dark" + # Clear multiValueHeaders to avoid interference + gw_event.pop("multiValueHeaders", None) + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["session_id"] == "abc123" + + +def test_cookie_param_missing_required(gw_event): + """Test that a missing required cookie returns 422.""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler(session_id: Annotated[str, Cookie()]): + return {"session_id": session_id} + + gw_event["path"] = "/me" + gw_event["headers"]["cookie"] = "theme=dark" + gw_event.pop("multiValueHeaders", None) + + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + +def test_cookie_param_with_default(gw_event): + """Test cookie parameter with a default value when cookie is absent.""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler(theme: Annotated[str, Cookie()] = "light"): + return {"theme": theme} + + gw_event["path"] = "/me" + gw_event["headers"].pop("cookie", None) + gw_event.pop("multiValueHeaders", None) + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["theme"] == "light" + + +def test_cookie_param_multiple_cookies(gw_event): + """Test extracting multiple cookie parameters.""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler( + session_id: Annotated[str, Cookie()], + theme: Annotated[str, Cookie()] = "light", + ): + return {"session_id": session_id, "theme": theme} + + gw_event["path"] = "/me" + gw_event["headers"]["cookie"] = "session_id=abc123; theme=dark" + gw_event.pop("multiValueHeaders", None) + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["session_id"] == "abc123" + assert body["theme"] == "dark" + + +def test_cookie_param_int_validation(gw_event): + """Test cookie parameter with int type validation.""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler(visits: Annotated[int, Cookie()]): + return {"visits": visits} + + gw_event["path"] = "/me" + gw_event["headers"]["cookie"] = "visits=42" + gw_event.pop("multiValueHeaders", None) + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["visits"] == 42 + + # Invalid int + gw_event["headers"]["cookie"] = "visits=not_a_number" + result = app(gw_event, {}) + assert result["statusCode"] == 422 + + +def test_cookie_param_http_api_v2(gw_event_http): + """Test cookie parameter with HTTP API v2 (dedicated cookies field).""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayHttpResolver(enable_validation=True) + + @app.get("/me") + def handler(session_id: Annotated[str, Cookie()]): + return {"session_id": session_id} + + gw_event_http["rawPath"] = "/me" + gw_event_http["requestContext"]["http"]["method"] = "GET" + gw_event_http["cookies"] = ["session_id=xyz789", "theme=dark"] + + result = app(gw_event_http, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["session_id"] == "xyz789" + + +def test_cookie_param_lambda_function_url(gw_event_lambda_url): + """Test cookie parameter with Lambda Function URL (v2 format).""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = LambdaFunctionUrlResolver(enable_validation=True) + + @app.get("/me") + def handler(session_id: Annotated[str, Cookie()]): + return {"session_id": session_id} + + gw_event_lambda_url["rawPath"] = "/me" + gw_event_lambda_url["requestContext"]["http"]["method"] = "GET" + gw_event_lambda_url["cookies"] = ["session_id=fn_url_abc"] + + result = app(gw_event_lambda_url, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["session_id"] == "fn_url_abc" + + +def test_cookie_param_alb(gw_event_alb): + """Test cookie parameter with ALB (Cookie header in multiValueHeaders).""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = ALBResolver(enable_validation=True) + + @app.get("/me") + def handler(session_id: Annotated[str, Cookie()]): + return {"session_id": session_id} + + gw_event_alb["path"] = "/me" + gw_event_alb["httpMethod"] = "GET" + gw_event_alb["multiValueHeaders"]["cookie"] = ["session_id=alb_abc"] + + result = app(gw_event_alb, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["session_id"] == "alb_abc" + + +def test_cookie_param_openapi_schema(): + """Test that Cookie() generates correct OpenAPI schema with in=cookie.""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler( + session_id: Annotated[str, Cookie(description="Session identifier")], + theme: Annotated[str, Cookie(description="UI theme")] = "light", + ): + return {"session_id": session_id} + + schema = app.get_openapi_schema() + schema_dict = schema.model_dump(mode="json", by_alias=True, exclude_none=True) + + path = schema_dict["paths"]["/me"]["get"] + params = path["parameters"] + + cookie_params = [p for p in params if p["in"] == "cookie"] + assert len(cookie_params) == 2 + + session_param = next(p for p in cookie_params if p["name"] == "session_id") + assert session_param["required"] is True + assert session_param["description"] == "Session identifier" + + theme_param = next(p for p in cookie_params if p["name"] == "theme") + assert theme_param.get("required") is not True + assert theme_param["description"] == "UI theme" + + +def test_cookie_param_with_query_and_header(gw_event): + """Test that Cookie(), Query(), and Header() work together.""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler( + user_id: Annotated[str, Query()], + x_request_id: Annotated[str, Header()], + session_id: Annotated[str, Cookie()], + ): + return { + "user_id": user_id, + "x_request_id": x_request_id, + "session_id": session_id, + } + + gw_event["path"] = "/me" + gw_event["queryStringParameters"] = {"user_id": "u123"} + gw_event["multiValueQueryStringParameters"] = {"user_id": ["u123"]} + gw_event["headers"]["x-request-id"] = "req-456" + gw_event["multiValueHeaders"] = {"x-request-id": ["req-456"], "cookie": ["session_id=sess-789"]} + gw_event["headers"]["cookie"] = "session_id=sess-789" + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["user_id"] == "u123" + assert body["x_request_id"] == "req-456" + assert body["session_id"] == "sess-789" + + +def test_cookie_param_no_cookies_in_request(gw_event): + """Test that empty cookies dict is handled gracefully.""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/me") + def handler(theme: Annotated[str, Cookie()] = "light"): + return {"theme": theme} + + gw_event["path"] = "/me" + gw_event["headers"] = {} + gw_event.pop("multiValueHeaders", None) + + result = app(gw_event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["theme"] == "light" + + +def test_cookie_param_vpc_lattice_v2(gw_event_vpc_lattice): + """Test cookie parameter with VPC Lattice v2 (headers are lists).""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = VPCLatticeV2Resolver(enable_validation=True) + + @app.get("/me") + def handler(session_id: Annotated[str, Cookie()]): + return {"session_id": session_id} + + gw_event_vpc_lattice["method"] = "GET" + gw_event_vpc_lattice["path"] = "/me" + gw_event_vpc_lattice["headers"]["cookie"] = ["session_id=lattice_abc"] + + result = app(gw_event_vpc_lattice, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["session_id"] == "lattice_abc" + + +def test_cookie_param_vpc_lattice_v1(gw_event_vpc_lattice_v1): + """Test cookie parameter with VPC Lattice v1 (comma-separated headers).""" + from aws_lambda_powertools.event_handler.openapi.params import Cookie + + app = VPCLatticeResolver(enable_validation=True) + + @app.get("/me") + def handler(session_id: Annotated[str, Cookie()]): + return {"session_id": session_id} + + gw_event_vpc_lattice_v1["method"] = "GET" + gw_event_vpc_lattice_v1["raw_path"] = "/me" + gw_event_vpc_lattice_v1["headers"]["cookie"] = "session_id=lattice_v1_abc" + + result = app(gw_event_vpc_lattice_v1, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["session_id"] == "lattice_v1_abc"