Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions aws_lambda_powertools/event_handler/openapi/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ def _file_has_resolver(file_path: Path, resolver_name: str) -> bool:
return False

for node in ast.walk(tree):
targets: list[ast.expr] = []
value: ast.expr | None = None
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == resolver_name:
if _is_resolver_call(node.value):
return True
targets = node.targets
value = node.value
elif isinstance(node, ast.AnnAssign):
targets = [node.target]
value = node.value
for target in targets:
if isinstance(target, ast.Name) and target.id == resolver_name:
if value is not None and _is_resolver_call(value):
return True
return False


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import APIGatewayRestResolver

app: APIGatewayRestResolver = APIGatewayRestResolver(enable_validation=True)


class Product(BaseModel):
id: int
name: str
price: float


@app.get("/products")
def get_products() -> list[Product]:
return [
Product(id=1, name="Widget", price=9.99),
]


def handler(event, context):
return app.resolve(event, context)
16 changes: 16 additions & 0 deletions tests/functional/event_handler/_pydantic/test_openapi_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,19 @@ def test_openapi_merge_schema_is_cached():

# AND paths should not be duplicated
assert len([p for p in schema1["paths"] if p == "/users"]) == 1


def test_openapi_merge_discover_type_annotated_resolver():
# GIVEN an OpenAPIMerge instance
merge = OpenAPIMerge(title="Typed API", version="1.0.0")

# WHEN discovering a handler with a type-annotated resolver (app: Resolver = Resolver())
merge.discover(
path=MERGE_HANDLERS_PATH,
pattern="**/typed_handler.py",
resolver_name="app",
)

# THEN it should find the resolver and include its routes in the schema
schema = merge.get_openapi_schema()
assert "/products" in schema["paths"]
Loading