Coverage for fastapi_restly / views / _base.py: 88%
328 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
1"""
2This module provides a framework for class-based views on SQLAlchemy models.
4View class:
5This class is used to create a collection of endpoints that share an
6APIRouter (created when calling `include_view()`) and dependencies
7as class attributes. It uses the same mechanics as the class based
8view decorator from fastapi-utils.
9(https://fastapi-utils.davidmontague.xyz/user-guide/class-based-views/)
11AsyncRestView:
12Provides default reading and writing functions on the database using
13SQLAlchemy models.
14"""
16import dataclasses
17import functools
18import inspect
19import types
20from enum import Enum
21from math import ceil
22from typing import (
23 Annotated,
24 Any,
25 Callable,
26 ClassVar,
27 Sequence,
28 TypeVar,
29 Union,
30 get_args,
31 get_origin,
32 get_type_hints,
33 overload,
34)
36import fastapi
37import pydantic
38from pydantic import create_model
39from sqlalchemy import inspect as sa_inspect
40from sqlalchemy.orm import DeclarativeBase, selectinload
41from starlette.datastructures import QueryParams
43from ..query import (
44 QueryModifierVersion,
45 create_query_param_schema,
46 get_query_modifier_version,
47 use_query_modifier_version,
48)
49from ..schemas import (
50 BaseSchema,
51 IDSchema,
52 auto_generate_schema_for_view,
53 create_model_with_optional_fields,
54 create_model_without_read_only_fields,
55 is_field_writeonly,
56)
59def _accepts_init_kwarg(model_cls: type, attr_name: str) -> bool:
60 """Return True if attr_name can be passed as a keyword argument to model_cls.__init__.
62 Non-dataclass models (DeclarativeBase subclasses using mapped_column) accept all
63 kwargs. Dataclass-based models may have fields with init=False, in which case
64 passing the attribute to __init__ raises TypeError.
65 """
66 if not dataclasses.is_dataclass(model_cls): 66 ↛ 67line 66 didn't jump to line 67 because the condition on line 66 was never true
67 return True
68 dc_fields = {f.name: f for f in dataclasses.fields(model_cls)}
69 return attr_name not in dc_fields or dc_fields[attr_name].init
72def _unwrap_optional_annotation(annotation: Any) -> Any:
73 origin = get_origin(annotation)
74 if origin not in (types.UnionType, Union, None):
75 return annotation
77 if origin is None:
78 return annotation
80 non_none_args = [arg for arg in get_args(annotation) if arg is not type(None)]
81 if len(non_none_args) == 1: 81 ↛ 83line 81 didn't jump to line 83 because the condition on line 81 was always true
82 return non_none_args[0]
83 return annotation
86def _is_idschema_reference_annotation(annotation: Any) -> bool:
87 annotation = _unwrap_optional_annotation(annotation)
88 if annotation is IDSchema: 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true
89 return True
90 if not inspect.isclass(annotation):
91 return False
92 try:
93 if not issubclass(annotation, IDSchema):
94 return False
95 except TypeError:
96 return False
97 metadata = getattr(annotation, "__pydantic_generic_metadata__", {})
98 return metadata.get("origin") is IDSchema
101def _serialize_idschema_value(annotation: Any, value: Any) -> Any:
102 if value is None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 return None
104 id_value = value.id if hasattr(value, "id") else value
105 if inspect.isclass(annotation) and issubclass(annotation, IDSchema): 105 ↛ 107line 105 didn't jump to line 107 because the condition on line 105 was always true
106 return annotation.model_construct(id=id_value)
107 return {"id": id_value}
110def _serialize_response_value(annotation: Any, value: Any) -> Any:
111 annotation = _unwrap_optional_annotation(annotation)
113 if _is_idschema_reference_annotation(annotation):
114 return _serialize_idschema_value(annotation, value)
116 origin = get_origin(annotation)
117 if origin is list:
118 item_annotation = get_args(annotation)[0] if get_args(annotation) else Any
119 if _is_idschema_reference_annotation(item_annotation) and isinstance( 119 ↛ 122line 119 didn't jump to line 122 because the condition on line 119 was never true
120 value, Sequence
121 ):
122 return [
123 _serialize_idschema_value(item_annotation, item) for item in value
124 ]
126 return value
129def _get_nested_schema_annotation(annotation: Any) -> type[BaseSchema] | None:
130 annotation = _unwrap_optional_annotation(annotation)
132 try:
133 if inspect.isclass(annotation) and issubclass(annotation, BaseSchema):
134 return annotation
135 except TypeError:
136 pass
138 origin = get_origin(annotation)
139 if origin is list: 139 ↛ 144line 139 didn't jump to line 144 because the condition on line 139 was always true
140 args = get_args(annotation)
141 if args: 141 ↛ 144line 141 didn't jump to line 144 because the condition on line 141 was always true
142 return _get_nested_schema_annotation(args[0])
144 return None
147def _build_relationship_loader_options(
148 model_cls: type[DeclarativeBase],
149 schema_cls: type[BaseSchema],
150 seen: set[tuple[type[DeclarativeBase], type[BaseSchema]]] | None = None,
151) -> list[Any]:
152 if seen is None:
153 seen = set()
155 visit_key = (model_cls, schema_cls)
156 if visit_key in seen: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 return []
158 seen = seen | {visit_key}
160 mapper = sa_inspect(model_cls)
161 options: list[Any] = []
162 for field_name, field_info in schema_cls.model_fields.items():
163 if field_name not in mapper.relationships:
164 continue
166 relationship_prop = mapper.relationships[field_name]
167 loader = selectinload(getattr(model_cls, field_name))
168 nested_schema = _get_nested_schema_annotation(field_info.annotation)
170 if nested_schema is not None: 170 ↛ 177line 170 didn't jump to line 177 because the condition on line 170 was always true
171 child_options = _build_relationship_loader_options(
172 relationship_prop.mapper.class_, nested_schema, seen
173 )
174 if child_options:
175 loader = loader.options(*child_options)
177 options.append(loader)
179 return options
182class View:
183 """
184 A View that combined with `include_view()` will produce class-based views.
185 Almost exactly like the @cbv decorator from fastapi-utils:
186 https://fastapi-utils.davidmontague.xyz/user-guide/class-based-views/
187 """
189 prefix: ClassVar[str]
190 tags: ClassVar[list[str] | None] = None # View class name will be added by default
191 dependencies: ClassVar[list[Any] | None] = None
192 responses: ClassVar[dict[int, Any]] = {404: {"description": "Not found"}}
194 @classmethod
195 def before_include_view(cls):
196 pass
198 @classmethod
199 def add_to_router(cls, parent_router: fastapi.APIRouter | fastapi.FastAPI) -> None:
200 _init_view_cls_and_add_to_router(cls, parent_router)
203V = TypeVar("V", bound=type[View])
206@overload
207def include_view(
208 parent_router: fastapi.APIRouter | fastapi.FastAPI, view_cls: V
209) -> V: ...
210@overload
211def include_view(
212 parent_router: fastapi.APIRouter | fastapi.FastAPI,
213) -> Callable[[V], V]: ...
216def include_view(
217 parent_router: fastapi.APIRouter | fastapi.FastAPI, view_cls: V | None = None
218) -> V | Callable[[V], V]:
219 """
220 Add the routes of a View class to a FastAPI app or APIRouter.
221 This function should be used for every View class.
223 Can be used as a decorator::
225 @include_view(app)
226 class MyView(AsyncRestView):
227 ...
229 Or as a function::
231 include_view(app, MyView)
232 """
233 if view_cls is not None:
234 _init_view_cls_and_add_to_router(view_cls, parent_router)
235 return view_cls
237 def class_decorator(view_cls: V) -> V:
238 _init_view_cls_and_add_to_router(view_cls, parent_router)
239 return view_cls
241 return class_decorator
244def route(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
245 """Decorator to mark a View method as an endpoint.
246 The path and api_route_kwargs are passed into APIRouter.add_api_route(), see for example:
247 https://fastapi.tiangolo.com/reference/apirouter/#fastapi.APIRouter.get
249 Endpoints methods are later added as routes to the FastAPI app using `include_view()`
250 """
252 def store_args_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
253 # Create a new attribute: '_api_route_args'
254 func._api_route_args = (path, api_route_kwargs) # type: ignore[attr-defined]
255 return func
257 return store_args_decorator
260def get(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
261 """Decorator to mark a View method as a GET endpoint.
263 Equivalent to::
265 @route(path, methods=["GET"], status_code=200, ... )
266 """
267 api_route_kwargs.setdefault("methods", ["GET"])
268 api_route_kwargs.setdefault("status_code", 200)
269 return route(path, **api_route_kwargs)
272def post(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
273 """Decorator to mark a View method as a POST endpoint.
275 Equivalent to::
277 @route(path, methods=["POST"], status_code=201, ... )
278 """
279 api_route_kwargs.setdefault("methods", ["POST"])
280 api_route_kwargs.setdefault("status_code", 201)
281 return route(path, **api_route_kwargs)
284def put(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
285 """Decorator to mark a View method as a PUT endpoint.
287 Equivalent to::
289 @route(path, methods=["PUT"], ... )
291 No default status code is set; FastAPI will use 200 if none is specified.
292 """
293 api_route_kwargs.setdefault("methods", ["PUT"])
294 return route(path, **api_route_kwargs)
297def patch(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
298 """Decorator to mark a View method as a PATCH endpoint.
300 Equivalent to::
302 @route(path, methods=["PATCH"], ... )
304 No default status code is set; FastAPI will use 200 if none is specified.
305 """
306 api_route_kwargs.setdefault("methods", ["PATCH"])
307 return route(path, **api_route_kwargs)
310def delete(path: str, **api_route_kwargs: Any) -> Callable[..., Any]:
311 """Decorator to mark a View method as a DELETE endpoint.
313 Equivalent to::
315 @route(path, methods=["DELETE"], status_code=204, ... )
316 """
317 api_route_kwargs.setdefault("methods", ["DELETE"])
318 api_route_kwargs.setdefault("status_code", 204)
319 return route(path, **api_route_kwargs)
322class BaseRestView(View):
323 """
324 Base class for RestView implementations.
326 This class contains the common functionality shared between AsyncRestView
327 and RestView, including schema definitions, model configuration, and
328 common CRUD operation logic.
329 """
331 schema: ClassVar[type[BaseSchema]]
332 # If 'creation_schema' is not defined it will be created from 'schema'
333 # using `create_model_without_read_only_fields()`.
334 creation_schema: ClassVar[type[BaseSchema]]
335 update_schema: ClassVar[type[BaseSchema]]
336 model: ClassVar[type[DeclarativeBase]]
337 id_type: ClassVar[type[Any]] = int
338 include_pagination_metadata: ClassVar[bool] = False # Set True to include count/total in list responses
339 exclude_routes: ClassVar[tuple[str, ...]] = ()
340 query_modifier_version: ClassVar[QueryModifierVersion] # Controls V1 vs V2 query parameter style; defaults to global setting
342 request: fastapi.Request
344 def get_query_modifier_version(self) -> QueryModifierVersion:
345 return getattr(self, "query_modifier_version", get_query_modifier_version())
347 def get_relationship_loader_options(self) -> list[Any]:
348 return _build_relationship_loader_options(self.model, self.schema)
350 def to_response_schema(self, obj: Any) -> BaseSchema:
351 """Serialize an ORM object to the configured response schema."""
352 if isinstance(obj, self.schema): 352 ↛ 353line 352 didn't jump to line 353 because the condition on line 352 was never true
353 return obj
355 # Build a payload using canonical field names. Alias rendering happens
356 # when FastAPI serializes the response model.
357 payload: dict[str, Any] = {}
358 for field_name, field_info in self.schema.model_fields.items():
359 if is_field_writeonly(self.schema, field_name):
360 continue
361 if hasattr(obj, field_name):
362 value = getattr(obj, field_name)
363 payload[field_name] = _serialize_response_value(
364 field_info.annotation, value
365 )
366 elif field_info.alias and hasattr(obj, field_info.alias): 366 ↛ 367line 366 didn't jump to line 367 because the condition on line 366 was never true
367 payload[field_name] = getattr(obj, field_info.alias)
369 # model_construct intentionally bypasses validation so response-only
370 # omissions (for example WriteOnly fields) don't trigger required errors.
371 return self.schema.model_construct(**payload)
373 @staticmethod
374 def _to_query_params(query_params: Any) -> QueryParams:
375 if isinstance(query_params, QueryParams): 375 ↛ 376line 375 didn't jump to line 376 because the condition on line 375 was never true
376 return query_params
377 if isinstance(query_params, pydantic.BaseModel):
378 dumped = query_params.model_dump(
379 exclude_none=True, by_alias=True, mode="json"
380 )
381 return QueryParams({k: str(v) for k, v in dumped.items()})
382 if isinstance(query_params, dict): 382 ↛ 384line 382 didn't jump to line 384 because the condition on line 382 was always true
383 return QueryParams({k: str(v) for k, v in query_params.items()})
384 return QueryParams(query_params)
386 @classmethod
387 def _create_pagination_response_schema(
388 cls, response_schema: type[BaseSchema]
389 ) -> type[pydantic.BaseModel]:
390 return create_model(
391 f"{cls.__name__}PaginatedResponse",
392 items=(Sequence[response_schema], ...),
393 total=(int, ...),
394 page=(int | None, None),
395 page_size=(int | None, None),
396 total_pages=(int | None, None),
397 limit=(int | None, None),
398 offset=(int | None, None),
399 )
401 def _build_pagination_payload(
402 self, query_params: Any, items: Sequence[Any], total: int
403 ) -> dict[str, Any]:
404 params = self._to_query_params(query_params)
405 payload: dict[str, Any] = {
406 "items": [self.to_response_schema(obj) for obj in items],
407 "total": total,
408 "page": None,
409 "page_size": None,
410 "total_pages": None,
411 "limit": None,
412 "offset": None,
413 }
414 uses_v2_pagination = (
415 self.get_query_modifier_version() == QueryModifierVersion.V2
416 )
417 if uses_v2_pagination or "page" in params or "page_size" in params:
418 page = int(params.get("page", "1"))
419 page_size = int(params.get("page_size", "100"))
420 payload["page"] = page
421 payload["page_size"] = page_size
422 payload["total_pages"] = ceil(total / page_size) if page_size > 0 else 0
423 payload["limit"] = page_size
424 payload["offset"] = (page - 1) * page_size
425 return payload
427 if "limit" in params: 427 ↛ 429line 427 didn't jump to line 429 because the condition on line 427 was always true
428 payload["limit"] = int(params["limit"])
429 if "offset" in params: 429 ↛ 431line 429 didn't jump to line 431 because the condition on line 429 was always true
430 payload["offset"] = int(params["offset"])
431 return payload
433 @classmethod
434 def before_include_view(cls):
435 """
436 Apply type annotations needed for FastAPI, before creating an APIRouter from
437 this view and registering it.
439 This function can be overridden to further tweak the endpoints before they
440 are added to FastAPI.
441 """
442 # Auto-generate schema if none is provided
443 if not hasattr(cls, "schema"):
444 if not hasattr(cls, "model"): 444 ↛ 445line 444 didn't jump to line 445 because the condition on line 444 was never true
445 raise ValueError(
446 f"'{cls.__name__}.model' must be specified to auto-generate schema"
447 )
448 cls.schema = auto_generate_schema_for_view(cls, cls.model)
450 if not hasattr(cls, "query_modifier_version"):
451 cls.query_modifier_version = get_query_modifier_version()
452 if not hasattr(cls, "index_param_schema"): 452 ↛ 455line 452 didn't jump to line 455 because the condition on line 452 was always true
453 with use_query_modifier_version(cls.query_modifier_version):
454 cls.index_param_schema = create_query_param_schema(cls.schema)
455 if not hasattr(cls, "creation_schema"): 455 ↛ 457line 455 didn't jump to line 457 because the condition on line 455 was always true
456 cls.creation_schema = create_model_without_read_only_fields(cls.schema)
457 if not hasattr(cls, "update_schema"): 457 ↛ 460line 457 didn't jump to line 460 because the condition on line 457 was always true
458 cls.update_schema = create_model_with_optional_fields(cls.schema)
460 response_schema = cls.schema
462 # Only annotate if the methods exist (they will be overridden in subclasses)
463 index_response_annotation: Any = Sequence[response_schema]
464 if cls.include_pagination_metadata:
465 cls.pagination_response_schema = cls._create_pagination_response_schema(
466 response_schema
467 )
468 index_response_annotation = cls.pagination_response_schema
470 if hasattr(cls, "index"): 470 ↛ 476line 470 didn't jump to line 476 because the condition on line 470 was always true
471 _annotate(
472 cls.index,
473 return_annotation=index_response_annotation,
474 query_params=Annotated[cls.index_param_schema, fastapi.Query()],
475 )
476 if hasattr(cls, "get"): 476 ↛ 478line 476 didn't jump to line 478 because the condition on line 476 was always true
477 _annotate(cls.get, return_annotation=response_schema, id=cls.id_type)
478 if hasattr(cls, "post"): 478 ↛ 484line 478 didn't jump to line 484 because the condition on line 478 was always true
479 _annotate(
480 cls.post,
481 return_annotation=response_schema,
482 schema_obj=cls.creation_schema,
483 )
484 if hasattr(cls, "patch"): 484 ↛ 491line 484 didn't jump to line 491 because the condition on line 484 was always true
485 _annotate(
486 cls.patch,
487 return_annotation=response_schema,
488 schema_obj=cls.update_schema,
489 id=cls.id_type,
490 )
491 if hasattr(cls, "delete"): 491 ↛ 493line 491 didn't jump to line 493 because the condition on line 491 was always true
492 _annotate(cls.delete, return_annotation=fastapi.Response, id=cls.id_type)
493 _exclude_routes(cls)
496def _exclude_routes(cls: type[View]):
497 for method_name in cls.exclude_routes:
498 # @route decorator adds `_api_route_args` to a method to create the route later.
499 # By removing it from the method, the method will no longer be added as a route.
500 try:
501 view_func = getattr(cls, method_name)
502 except AttributeError:
503 raise AttributeError(f"{method_name!r} is not a route on {cls.__name__}")
504 if not hasattr(view_func, "_api_route_args"): 504 ↛ 505line 504 didn't jump to line 505 because the condition on line 504 was never true
505 raise AttributeError(f"{method_name!r} is not a route on {cls.__name__}")
506 del view_func._api_route_args
509def _init_view_cls_and_add_to_router(
510 view_cls: type[View], parent_router: fastapi.APIRouter | fastapi.FastAPI
511):
512 """
513 To make View classes work in FastAPI some hacks are needed. Those hacks are
514 applied here.
516 FastAPI does a lot with annotations. For example, accepted or returned JSON is
517 often described with Pydantic classes like this:
519 def my_endpoint(foo: FooSchema) -> FooSchema:
521 Most of the hacks here are to set the correct annotations on (inherited) class
522 methods.
523 """
524 _copy_all_parent_class_endpoints_into_this_subclass(view_cls)
525 _init_all_endpoints(view_cls)
526 view_cls.before_include_view()
527 _init_class_based_view(view_cls)
528 api_router = _init_api_router(view_cls)
529 parent_router.include_router(api_router)
532def _copy_all_parent_class_endpoints_into_this_subclass(view_cls: type[View]):
533 """
534 Override all methods with a @route decorator of the parent classes of view_cls
535 with a new copy directly on view_cls . This allows us to change the
536 annotations on these endpoints without affecting the parent endpoints.
538 For example, FooView.get() delegates to AsyncRestView.get() if it is not
539 overridden (this is called implicit delegation through method resolution). And if
540 we add the annotation that FooView.get() returns FooSchema but do not make a copy
541 then AsyncRestView.get() and all other subclasses will get the FooSchema
542 annotation as well.
543 """
544 for endpoint in _get_all_parent_endpoints(view_cls):
545 # Use `cls.__dict__` to check what attributes are directly on the class.
546 # This way we side-step the method resolution.
547 if endpoint.__name__ in view_cls.__dict__: 547 ↛ 549line 547 didn't jump to line 549 because the condition on line 547 was never true
548 # This endpoint is already overridden!
549 continue
551 # The original endpoint might be shared between subclasses.
552 # So make a copy and put that on the view_cls.
553 endpoint_wrapper = _make_copy(endpoint, view_cls)
554 # Set explicit __qualname__ for debugging purposes.
555 endpoint_wrapper.__qualname__ = (
556 f"{view_cls.__name__}_{endpoint.__qualname__}_wrapper"
557 )
558 setattr(view_cls, endpoint.__name__, endpoint_wrapper)
561def _make_copy(endpoint: Callable, view_cls: type[View]) -> Callable:
562 """
563 Wrap the endpoint in a new function as kind of copy.
565 Fun fact: You cannot do this inside a for loop, because the closure of 'endpoint'
566 inside the wrapper works on the variable, not on the value. And for-loops in Python
567 do not have their own variable scope.
569 https://eev.ee/blog/2011/04/24/gotcha-python-scoping-closures/
570 """
571 if inspect.iscoroutinefunction(endpoint):
573 @functools.wraps(endpoint)
574 async def endpoint_wrapper(self, *args, **kwargs):
575 return await endpoint(self, *args, **kwargs)
577 else:
579 @functools.wraps(endpoint)
580 def endpoint_wrapper(self, *args, **kwargs):
581 return endpoint(self, *args, **kwargs)
583 endpoint_wrapper.__annotations__ = endpoint.__annotations__.copy()
584 return endpoint_wrapper
587def _init_all_endpoints(view_cls: type[View]):
588 """
589 Ensure every endpoint has a unique name and update the 'self' annotation.
590 """
591 for attr in view_cls.__dict__.values():
592 if not hasattr(attr, "_api_route_args"):
593 continue
594 endpoint = attr
595 # Give every endpoint a unique name
596 # This will give the FooView.post() endpoint the name "fooview_post"
597 endpoint.__name__ = view_cls.__name__.lower() + "_" + endpoint.__name__
598 _annotate_self(view_cls, endpoint)
601def _annotate(func: Callable, return_annotation: Any = None, **param_annotations):
602 """
603 Annotate a function by setting func.__signature__ explicitly.
604 """
605 sig = inspect.signature(func)
606 new_params = []
607 for param in sig.parameters.values():
608 if param.name in param_annotations:
609 annotation = param_annotations[param.name]
610 new_param = param.replace(annotation=annotation)
611 new_params.append(new_param)
612 else:
613 new_params.append(param)
614 func.__signature__ = sig.replace( # type: ignore[attr-defined]
615 parameters=new_params, return_annotation=return_annotation
616 )
619def _get_all_parent_endpoints(view_cls: type[View]) -> list[Callable]:
620 endpoints = []
621 for cls in view_cls.mro():
622 if cls is view_cls:
623 continue
624 for name, value in cls.__dict__.items():
625 if hasattr(value, "_api_route_args"):
626 endpoints.append(value)
627 return endpoints
630def _init_api_router(view_cls: type[View]) -> fastapi.APIRouter:
631 tags: list[str | Enum] = [view_cls.__name__]
632 if view_cls.tags: 632 ↛ 633line 632 didn't jump to line 633 because the condition on line 632 was never true
633 tags += view_cls.tags
635 # Concatenate prefixes defined at each level of the class hierarchy (base → derived).
636 prefix = "".join(c.__dict__["prefix"] for c in reversed(view_cls.mro()) if "prefix" in c.__dict__)
637 api_router = fastapi.APIRouter(
638 prefix=prefix,
639 tags=tags,
640 responses=view_cls.responses,
641 dependencies=view_cls.dependencies,
642 )
644 # Find all endpoint functions in this class and add them to the router
645 for attr in view_cls.__dict__.values():
646 if not hasattr(attr, "_api_route_args"):
647 continue
648 endpoint = attr
649 path, route_kwargs = endpoint._api_route_args
650 api_router.add_api_route(path, endpoint, **route_kwargs)
652 return api_router
655def _annotate_self(view_cls: type[View], endpoint: Callable) -> None:
656 """
657 Annotate the 'self' argument as 'self=Depends(view_cls)'. That way FastAPI instantiates the
658 view_cls before calling the endpoint function and passes it as 'self'.
659 Note that it sets endpoint.__signature__ which overrides any other inspection.
661 Note: Copied (MIT license) and adjusted from: https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
663 Fixes the endpoint signature to ensure FastAPI performs dependency injection properly.
664 """
665 sig = inspect.signature(endpoint)
666 params: list[inspect.Parameter] = list(sig.parameters.values())
667 self_param = params[0]
668 new_self_param = self_param.replace(default=fastapi.Depends(view_cls))
670 new_params = [new_self_param] + [
671 param.replace(kind=inspect.Parameter.KEYWORD_ONLY) for param in params[1:]
672 ]
673 endpoint.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined]
676def _init_class_based_view(view_cls: type[View]) -> None:
677 """
678 Note: Copied (MIT license) and adjusted from: https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
680 Idempotently modifies the provided `cls`, performing the following modifications:
681 * The `__init__` function is updated to set any class-annotated dependencies as instance attributes
682 * The `__signature__` attribute is updated to indicate to FastAPI what arguments should be passed to the initializer
683 """
684 if getattr(view_cls, "__class_based_view", False): 684 ↛ 685line 684 didn't jump to line 685 because the condition on line 684 was never true
685 return # Already initialized
686 old_init: Callable[..., Any] = view_cls.__init__
687 old_signature = inspect.signature(old_init)
688 old_parameters = list(old_signature.parameters.values())[1:] # drop `self`
689 new_parameters = [
690 x
691 for x in old_parameters
692 if x.kind
693 not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
694 ]
695 dependency_names: list[str] = []
696 for name, annotation in get_type_hints(view_cls, include_extras=True).items():
697 if get_origin(annotation) is ClassVar:
698 continue
699 dependency_names.append(name)
700 default_value = getattr(view_cls, name, inspect.Parameter.empty)
701 new_parameters.append(
702 inspect.Parameter(
703 name=name,
704 kind=inspect.Parameter.KEYWORD_ONLY,
705 default=default_value,
706 annotation=annotation,
707 )
708 )
709 new_signature = old_signature.replace(parameters=new_parameters)
711 def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
712 for dep_name in dependency_names:
713 dep_value = kwargs.pop(dep_name)
714 setattr(self, dep_name, dep_value)
715 old_init(self, *args, **kwargs)
717 setattr(view_cls, "__signature__", new_signature)
718 setattr(view_cls, "__init__", new_init)
719 setattr(view_cls, "__class_based_view", True)