diff --git a/aws_lambda_powertools/event_handler/openapi/merge.py b/aws_lambda_powertools/event_handler/openapi/merge.py index 38b80914df3..9db9c0daa5c 100644 --- a/aws_lambda_powertools/event_handler/openapi/merge.py +++ b/aws_lambda_powertools/event_handler/openapi/merge.py @@ -67,11 +67,18 @@ def _file_has_resolver(file_path: Path, resolver_name: str) -> bool: return False for node in ast.walk(tree): + targets: list[ast.expr] = [] + value: ast.expr | None = None if isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Name) and target.id == resolver_name: - if _is_resolver_call(node.value): - return True + targets = node.targets + value = node.value + elif isinstance(node, ast.AnnAssign): + targets = [node.target] + value = node.value + for target in targets: + if isinstance(target, ast.Name) and target.id == resolver_name: + if value is not None and _is_resolver_call(value): + return True return False diff --git a/tests/functional/event_handler/_pydantic/merge_handlers/typed_handler.py b/tests/functional/event_handler/_pydantic/merge_handlers/typed_handler.py new file mode 100644 index 00000000000..53c2b4e0a12 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/merge_handlers/typed_handler.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + +app: APIGatewayRestResolver = APIGatewayRestResolver(enable_validation=True) + + +class Product(BaseModel): + id: int + name: str + price: float + + +@app.get("/products") +def get_products() -> list[Product]: + return [ + Product(id=1, name="Widget", price=9.99), + ] + + +def handler(event, context): + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_merge.py b/tests/functional/event_handler/_pydantic/test_openapi_merge.py index b4dc1d70232..12d41566e32 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_merge.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_merge.py @@ -367,3 +367,19 @@ def test_openapi_merge_schema_is_cached(): # AND paths should not be duplicated assert len([p for p in schema1["paths"] if p == "/users"]) == 1 + + +def test_openapi_merge_discover_type_annotated_resolver(): + # GIVEN an OpenAPIMerge instance + merge = OpenAPIMerge(title="Typed API", version="1.0.0") + + # WHEN discovering a handler with a type-annotated resolver (app: Resolver = Resolver()) + merge.discover( + path=MERGE_HANDLERS_PATH, + pattern="**/typed_handler.py", + resolver_name="app", + ) + + # THEN it should find the resolver and include its routes in the schema + schema = merge.get_openapi_schema() + assert "/products" in schema["paths"]