Coverage for fastapi_restly / schemas / _base.py: 96%
210 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
1import types
2from datetime import datetime
3from typing import (
4 Annotated,
5 Any,
6 Generic,
7 Optional,
8 TypeVar,
9 Union,
10 get_args,
11 get_origin,
12)
14import pydantic
15from fastapi import HTTPException
16from pydantic.fields import Field, FieldInfo
17from sqlalchemy import select
18from sqlalchemy.exc import NoResultFound
19from sqlalchemy.ext.asyncio.session import AsyncSession as SA_AsyncSession
20from sqlalchemy.orm import DeclarativeBase
21from sqlalchemy.orm.session import Session as SA_Session
24class BaseSchema(pydantic.BaseModel):
25 # Allow validating SQLAlchemy model instances directly in request/response flows.
26 # This keeps aliased fields working when FastAPI validates ORM objects.
27 model_config = pydantic.ConfigDict(from_attributes=True)
30class _Marker:
31 def __init__(self, name: str):
32 self.name = name
34 def __repr__(self):
35 return f"fr.{self.name}"
38readonly_marker = _Marker("ReadOnly")
39writeonly_marker = _Marker("WriteOnly")
41_T = TypeVar("_T")
43ReadOnly = Annotated[_T, readonly_marker, Field(json_schema_extra={"readOnly": True})]
44WriteOnly = Annotated[
45 _T, writeonly_marker, Field(json_schema_extra={"writeOnly": True})
46]
49class TimestampsSchemaMixin(pydantic.BaseModel):
50 created_at: ReadOnly[datetime]
51 updated_at: ReadOnly[datetime]
54SQLAlchemyModel = TypeVar("SQLAlchemyModel", bound=DeclarativeBase)
57class IDSchema(BaseSchema, Generic[SQLAlchemyModel]):
58 """Generic schema useful for serializing only the id of objects.
59 Can be used as IDSchema[MyModel].
60 """
62 # Keep this broad so relation-id payloads can target non-int primary keys.
63 id: ReadOnly[Any]
65 @classmethod
66 def _get_sql_model_annotation(cls) -> type[DeclarativeBase] | None:
67 try:
68 sql_model = cls.__pydantic_generic_metadata__["args"][0]
69 except Exception:
70 return None
71 return sql_model if isinstance(sql_model, type) else None
73 @classmethod
74 def _get_sql_model_id_type(cls) -> Any:
75 sql_model = cls._get_sql_model_annotation()
76 if sql_model is None:
77 return None
79 for model_cls in sql_model.mro(): 79 ↛ 90line 79 didn't jump to line 90 because the loop on line 79 didn't complete
80 annotation = getattr(model_cls, "__annotations__", {}).get("id")
81 if annotation is None:
82 continue
83 origin = get_origin(annotation)
84 if origin is not None: 84 ↛ 88line 84 didn't jump to line 88 because the condition on line 84 was always true
85 args = get_args(annotation)
86 if args: 86 ↛ 88line 86 didn't jump to line 88 because the condition on line 86 was always true
87 return args[0]
88 return annotation
90 try:
91 return sql_model.__mapper__.primary_key[0].type.python_type
92 except Exception:
93 return None
95 @pydantic.field_validator("id", mode="before", check_fields=False)
96 @classmethod
97 def _coerce_id_to_model_primary_key_type(cls, value: Any) -> Any:
98 id_type = cls._get_sql_model_id_type()
99 if id_type in (None, Any):
100 return value
101 return pydantic.TypeAdapter(id_type).validate_python(value)
103 def get_sql_model_annotation(self) -> SQLAlchemyModel | None:
104 """
105 Return the annotation on IDSchema when used as:
107 foo: IDSchema[Foo]
109 This property will return "Foo".
110 """
111 return self._get_sql_model_annotation()
114class IDStampsSchema(TimestampsSchemaMixin, IDSchema):
115 pass
118async def async_resolve_ids_to_sqlalchemy_objects(
119 session: SA_AsyncSession, schema_obj: BaseSchema
120) -> None:
121 """
122 Go over the Pydantic fields and turn any IDSchema objects into SQLAlchemy instances.
123 A database request is made for each IDSchema to look up the related row in the database.
124 If an id is not found in the database `sqlalchemy.orm.exc.NoResultFound` is raised.
125 """
126 # Go over all Pydantic fields and check if any of them are an IDSchema object or
127 # a list of IDSchema objects.
128 for field in schema_obj.model_fields_set:
129 value = getattr(schema_obj, field, None)
131 if isinstance(value, IDSchema):
132 sql_model = value.get_sql_model_annotation()
133 if not sql_model:
134 continue
136 # Replace the IDSchema object with a SQLAlchemy instance from the database
137 try:
138 sql_model_obj = await session.get_one(sql_model, value.id)
139 except NoResultFound as e:
140 raise HTTPException(
141 status_code=404, detail=f"Id not found for {field}: {value.id}"
142 ) from e
143 setattr(schema_obj, field, sql_model_obj)
145 elif isinstance(value, list) and any(isinstance(i, IDSchema) for i in value):
146 # Assume all IdSchemas are for the same model
147 sql_model = value[0].get_sql_model_annotation()
148 if not sql_model: 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true
149 continue
151 # Replace all IDSchema objects with SQLAlchemy instances
152 ids = [obj.id for obj in value]
153 query = select(sql_model).where(sql_model.id.in_(ids))
154 sql_model_objs = list(await session.scalars(query))
156 if len(ids) != len(sql_model_objs):
157 missing_ids = set(ids).difference(o.id for o in sql_model_objs)
158 raise HTTPException(
159 status_code=404, detail=f"Id not found for {field}: {missing_ids}"
160 )
162 setattr(schema_obj, field, sql_model_objs)
165def resolve_ids_to_sqlalchemy_objects(
166 session: SA_Session, schema_obj: BaseSchema
167) -> None:
168 """
169 Go over the Pydantic fields and turn any IDSchema objects into SQLAlchemy instances.
170 A database request is made for each IDSchema to look up the related row in the database.
171 If an id is not found in the database `sqlalchemy.orm.exc.NoResultFound` is raised.
172 """
173 # Go over all Pydantic fields and check if any of them are an IDSchema object or
174 # a list of IDSchema objects.
175 for field in schema_obj.model_fields_set:
176 value = getattr(schema_obj, field, None)
178 if isinstance(value, IDSchema):
179 sql_model = value.get_sql_model_annotation()
180 if not sql_model:
181 continue
183 # Replace the IDSchema object with a SQLAlchemy instance from the database
184 try:
185 sql_model_obj = session.get_one(sql_model, value.id)
186 except NoResultFound as e:
187 raise HTTPException(
188 status_code=404, detail=f"Id not found for {field}: {value.id}"
189 ) from e
190 setattr(schema_obj, field, sql_model_obj)
192 elif isinstance(value, list) and any(isinstance(i, IDSchema) for i in value):
193 # Assume all IdSchemas are for the same model
194 sql_model = value[0].get_sql_model_annotation()
195 if not sql_model: 195 ↛ 196line 195 didn't jump to line 196 because the condition on line 195 was never true
196 continue
198 # Replace all IDSchema objects with SQLAlchemy instances
199 ids = [obj.id for obj in value]
200 query = select(sql_model).where(sql_model.id.in_(ids))
201 sql_model_objs = list(session.scalars(query))
203 if len(ids) != len(sql_model_objs):
204 missing_ids = set(ids).difference(o.id for o in sql_model_objs)
205 raise HTTPException(
206 status_code=404, detail=f"Id not found for {field}: {missing_ids}"
207 )
209 setattr(schema_obj, field, sql_model_objs)
212def get_read_only_fields(model_cls: type[pydantic.BaseModel]) -> list[str]:
213 """Get all fields from a model annotated as ReadOnly[]"""
214 read_only_fields: list[str] = []
215 # Get read-only fields from Annotated metadata
216 for field_name, field_info in model_cls.model_fields.items():
217 metadata = getattr(field_info, "metadata", None)
218 if metadata and readonly_marker in metadata:
219 read_only_fields.append(field_name)
220 return read_only_fields
223def is_readonly_field(
224 model: pydantic.BaseModel | type[pydantic.BaseModel], field_name: str
225) -> bool:
226 """Check if a specific field is marked as readonly."""
227 if isinstance(model, pydantic.BaseModel):
228 model = model.__class__
229 field_info = model.model_fields.get(field_name)
230 return _is_readonly(field_info)
233def _is_readonly(field_info: FieldInfo | None) -> bool:
234 if field_info is None:
235 return False
236 metadata = getattr(field_info, "metadata", None)
237 if not metadata:
238 return False
239 return readonly_marker in metadata
242def _is_writeonly(field_info: FieldInfo | None) -> bool:
243 if field_info is None:
244 return False
245 metadata = getattr(field_info, "metadata", None)
246 if not metadata:
247 return False
248 return writeonly_marker in metadata
251def get_write_only_fields(model_cls: type[pydantic.BaseModel]) -> list[str]:
252 """Get all fields from a model annotated as WriteOnly[]"""
253 write_only_fields: list[str] = []
254 # Get write-only fields from Annotated metadata
255 for field_name, field_info in model_cls.model_fields.items():
256 if _is_writeonly(field_info):
257 write_only_fields.append(field_name)
258 return write_only_fields
261def is_field_writeonly(model_cls: pydantic.BaseModel | type[pydantic.BaseModel], field_name: str) -> bool:
262 """Check if a specific field is marked as writeonly."""
263 if isinstance(model_cls, pydantic.BaseModel):
264 model_cls = model_cls.__class__
265 field_info = model_cls.model_fields.get(field_name)
266 return _is_writeonly(field_info)
269def create_model_without_read_only_fields(
270 model_cls: type[pydantic.BaseModel],
271) -> type[pydantic.BaseModel]:
272 """
273 Create a subclass of the given pydantic model class with a new name.
274 """
275 new_model_name = "Create" + model_cls.__name__
276 new_doc = (model_cls.__doc__ or "") + "\nRead-only fields have been removed."
278 # Create a subclass that mixes in OmitReadOnlyMixin
279 new_model_cls = type(
280 new_model_name,
281 (OmitReadOnlyMixin, model_cls),
282 {"__module__": model_cls.__module__, "__doc__": new_doc},
283 )
285 return new_model_cls
288class OmitReadOnlyMixin(pydantic.BaseModel):
289 """
290 Mixin for pydantic models that removes all fields marked as ReadOnly.
291 """
293 @classmethod
294 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
295 super().__pydantic_init_subclass__(**kwargs)
297 # Collect readonly fields to delete first
298 readonly_fields = []
299 for name, field_info in cls.model_fields.items():
300 if _is_readonly(field_info):
301 readonly_fields.append(name)
303 # Delete readonly fields after iteration is complete
304 for name in readonly_fields:
305 del cls.model_fields[name]
307 cls.model_rebuild(force=True)
310def rebase_with_model_config(
311 base: tuple[type, ...], model_cls: type[pydantic.BaseModel]
312) -> type[pydantic.BaseModel]:
313 def class_body(ns: dict[str, Any]) -> None:
314 ns["model_config"] = model_cls.model_config.copy()
316 return types.new_class(
317 f"{model_cls.__name__}ModelConfig", base, exec_body=class_body
318 )
321def create_model_with_optional_fields(
322 model_cls: type[pydantic.BaseModel],
323) -> type[pydantic.BaseModel]:
324 """
325 Create a subclass of the given pydantic model class with a new name.
326 Read-only fields are removed and all writable fields are made optional with None as default.
327 """
328 new_model_name = "Update" + model_cls.__name__
329 new_doc = (
330 model_cls.__doc__ or ""
331 ) + "\nRead-only fields have been removed and all fields are optional."
333 # Create a subclass that mixes in both OmitReadOnlyMixin and PatchMixin
334 new_model_cls = type(
335 new_model_name,
336 (PatchMixin, OmitReadOnlyMixin, model_cls),
337 {"__module__": model_cls.__module__, "__doc__": new_doc},
338 )
340 return new_model_cls
343class PatchMixin(pydantic.BaseModel):
344 """
345 A mixin for pydantic classes that makes all fields optional and replaces defaults
346 with None.
347 """
349 @classmethod
350 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
351 super().__pydantic_init_subclass__(**kwargs)
353 for field in cls.model_fields.values():
354 field.default = None
355 # Only wrap if not already Optional, to avoid Optional[Optional[T]]
356 annotation = field.annotation
357 if isinstance(annotation, types.UnionType):
358 # Python 3.10+ `X | Y` syntax - check if None is already a member.
359 # Convert to typing.Optional form so FieldInfo.annotation stays compatible.
360 union_args = get_args(annotation)
361 if type(None) not in union_args:
362 non_none = [a for a in union_args if a is not type(None)]
363 inner = non_none[0] if len(non_none) == 1 else Union[tuple(non_none)]
364 field.annotation = Optional[inner] # type: ignore[assignment]
365 else:
366 origin = getattr(annotation, "__origin__", None)
367 if origin is not Union or type(None) not in get_args(annotation):
368 field.annotation = Optional[annotation] # type: ignore[assignment]
370 cls.model_rebuild(force=True)
373def getattrs(obj: Any, *attrs: str, default: Any = None) -> Any:
374 """
375 Try access a chain of attributes and return the default if any of the attrs is not defined.
376 """
377 for attr in attrs:
378 if not hasattr(obj, attr):
379 return default
380 obj = getattr(obj, attr)
381 return obj
384def set_schema_title(schema_cls: type[pydantic.BaseModel]) -> None:
385 """Set the title of a schema class to its name.
386 This is used to make the schema title match the model name in the OpenAPI schema.
387 """
388 schema_cls.model_config["title"] = schema_cls.__name__
391def get_writable_inputs(
392 schema_obj: BaseSchema, schema_cls: type[pydantic.BaseModel] | None = None
393) -> dict[str, Any]:
394 """
395 Return a dictionary of field_name: value pairs for writable input fields.
397 Filters out:
398 - ReadOnly fields
399 - fields not provided with input (using Pydantic model_fields_set)
401 Args:
402 schema_obj: The schema object to extract writable fields from
403 schema_cls: The schema class to check for readonly fields. If None, uses schema_obj.__class__
405 Returns:
406 Dictionary mapping field names to their values for writable input fields only
407 """
408 if schema_cls is None:
409 schema_cls = schema_obj.__class__
411 updated_fields: dict[str, Any] = {}
412 for field_name, value in schema_obj:
413 if field_name not in schema_obj.model_fields_set:
414 continue
415 # Skip readonly fields
416 if is_readonly_field(schema_cls, field_name):
417 continue
418 updated_fields[field_name] = value
420 return updated_fields