diff --git a/pyproject.toml b/pyproject.toml index 4ebd99df..466d834d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ dependencies = [ "aiohttp", "defusedxml", # For safely parsing XML files - "fastapi[standard-no-fastapi-cloud-cli]>=0.116.0,<0.128.1", + "fastapi[standard-no-fastapi-cloud-cli]>=0.116.0", "pydantic>=2", "pydantic-settings", "python-jose", diff --git a/src/murfey/cli/generate_route_manifest.py b/src/murfey/cli/generate_route_manifest.py index e88cba7a..45d59f73 100644 --- a/src/murfey/cli/generate_route_manifest.py +++ b/src/murfey/cli/generate_route_manifest.py @@ -11,7 +11,7 @@ from argparse import ArgumentParser from pathlib import Path from types import ModuleType -from typing import Any +from typing import Annotated, Union, get_args, get_origin import yaml from fastapi import APIRouter @@ -20,6 +20,39 @@ from murfey.cli import PrettierDumper +def extract_base_type(annotation): + """ + Given a Python type annotation, return its underlying base type. + + This function unwraps `typing.Annotated` to extract the annotated type + and simplifies `Optional[T]` / `Union[T, None]` to `T`. All other union + types and complex annotations are returned unchanged. + + Parameters + ---------- + annotation: + A Python type annotation (e.g. int, Annotated[int, ...], Optional[int], + Union[int, str]) + + Returns + ------- + The unwrapped base type, or the original annotation if no unambiguous base + type can be determined. + """ + # Unwrap Annotated type annotations + if get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + + # Unwrap and return single-type Optional type annotations + if get_origin(annotation) is Union: + args = [a for a in get_args(annotation) if a is not type(None)] + if len(args) == 1: + return args[0] + + # Return complex multi-type annotations or simple unpacked ones + return annotation + + def find_routers(name: str) -> dict[str, APIRouter]: def _extract_routers_from_module(module: ModuleType): routers = {} @@ -74,7 +107,7 @@ def get_route_manifest(routers: dict[str, APIRouter]): for route in router.routes: path_params = [] for param in route.dependant.path_params: - param_type = param.type_ if param.type_ is not None else Any + param_type = extract_base_type(param._type_adapter._type) param_info = { "name": param.name if hasattr(param, "name") else "", "type": ( @@ -86,7 +119,7 @@ def get_route_manifest(routers: dict[str, APIRouter]): path_params.append(param_info) for route_dependency in route.dependant.dependencies: for param in route_dependency.path_params: - param_type = param.type_ if param.type_ is not None else Any + param_type = extract_base_type(param._type_adapter._type) param_info = { "name": param.name if hasattr(param, "name") else "", "type": (