Coverage for fastapi_restly / schemas / _generator.py: 92%
161 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-05 09:15 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-05 09:15 +0000
1"""
2Schema generation utilities for auto-generating Pydantic schemas from SQLAlchemy models.
3"""
5import enum
6import inspect
7import types
8from datetime import date, datetime, time
9from decimal import Decimal
10from typing import Any, Dict, List, Optional, Union, get_args
11from uuid import UUID
13from sqlalchemy import inspect as sa_inspect
14from sqlalchemy.orm import DeclarativeBase, Mapped, RelationshipProperty
16from ._base import BaseSchema, IDSchema, TimestampsSchemaMixin
19def get_sqlalchemy_field_type(field: Any) -> Any:
20 """
21 Extract the Python type from a SQLAlchemy Mapped field.
23 Args:
24 field: A SQLAlchemy Mapped field
26 Returns:
27 The Python type annotation
28 """
29 # Get the type annotation from the Mapped field
30 if hasattr(field, "type"):
31 return field.type
32 elif hasattr(field, "__origin__"):
33 return field.__origin__
34 else:
35 # Fallback to Any if we can't determine the type
36 return Any
39def is_relationship_field(field: Any) -> bool:
40 """
41 Check if a field is a SQLAlchemy relationship.
43 Args:
44 field: A SQLAlchemy Mapped field
46 Returns:
47 True if the field is a relationship, False otherwise
48 """
49 if isinstance(field, RelationshipProperty):
50 return True
51 return isinstance(getattr(field, "property", None), RelationshipProperty)
54def get_relationship_target_model(field: Any) -> Optional[type[DeclarativeBase]]:
55 """
56 Get the target model class for a relationship field.
58 Args:
59 field: A SQLAlchemy relationship field
61 Returns:
62 The target model class or None if not found
63 """
64 if not is_relationship_field(field):
65 return None
67 # Try to get the target from the relationship property
68 relationship = field
69 if not isinstance(relationship, RelationshipProperty):
70 relationship = getattr(field, "property", None)
72 if relationship is not None and hasattr(relationship, "mapper") and hasattr(relationship.mapper, "class_"):
73 return relationship.mapper.class_
75 # Try to get from the type annotation
76 if hasattr(field, "type"): 76 ↛ 86line 76 didn't jump to line 86 because the condition on line 76 was always true
77 target_type = field.type
78 if hasattr(target_type, "__origin__") and target_type.__origin__ is list:
79 # Handle List[Model] case
80 args = get_args(target_type)
81 if args: 81 ↛ 86line 81 didn't jump to line 86 because the condition on line 81 was always true
82 return args[0]
83 elif inspect.isclass(target_type) and issubclass(target_type, DeclarativeBase): 83 ↛ 86line 83 didn't jump to line 86 because the condition on line 83 was always true
84 return target_type
86 return None
89def get_model_fields(model_cls: type[DeclarativeBase]) -> Dict[str, Any]:
90 """
91 Extract field information from a SQLAlchemy model.
93 Args:
94 model_cls: A SQLAlchemy model class
96 Returns:
97 Dictionary mapping field names to their types and metadata
98 """
99 fields: Dict[str, Any] = {}
101 mapper = sa_inspect(model_cls)
103 # Get all annotations from the model class and its base classes
104 all_annotations = {}
105 for cls in model_cls.mro():
106 if hasattr(cls, "__annotations__"):
107 all_annotations.update(cls.__annotations__)
109 for name, field_type in all_annotations.items():
110 if name.startswith("_"):
111 continue
113 # Check if it's a Mapped field
114 if not hasattr(field_type, "__origin__") or field_type.__origin__ is not Mapped:
115 continue
117 # Extract the actual type from Mapped[Type]
118 args = get_args(field_type)
119 if not args: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 continue
122 actual_type = args[0]
123 relationship = mapper.relationships.get(name)
125 rel_mapper = getattr(relationship, "mapper", None) if relationship is not None else None
126 field_info: Dict[str, Any] = {
127 "type": actual_type,
128 "is_relationship": relationship is not None,
129 "target_model": (
130 rel_mapper.class_ if rel_mapper is not None else None
131 ),
132 "is_optional": False,
133 "default": None,
134 }
136 # Check if the field is optional (Union with None or Optional)
137 if isinstance(actual_type, types.UnionType):
138 # Python 3.10+ `str | None` syntax
139 union_args = get_args(actual_type)
140 if type(None) in union_args: 140 ↛ 156line 140 didn't jump to line 156 because the condition on line 140 was always true
141 field_info["is_optional"] = True
142 non_none_types = [arg for arg in union_args if arg is not type(None)]
143 if non_none_types: 143 ↛ 156line 143 didn't jump to line 156 because the condition on line 143 was always true
144 field_info["type"] = non_none_types[0]
145 elif hasattr(actual_type, "__origin__"):
146 origin = actual_type.__origin__
147 if origin is Union: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 args = get_args(actual_type)
149 if type(None) in args:
150 field_info["is_optional"] = True
151 # Remove None from the type
152 non_none_types = [arg for arg in args if arg is not type(None)]
153 if non_none_types:
154 field_info["type"] = non_none_types[0]
156 if relationship is not None:
157 # Relationship fields are response-oriented in generated schemas.
158 # Keep them optional so create/update inputs can rely on FK columns.
159 field_info["is_optional"] = True
160 elif name in mapper.columns: 160 ↛ 166line 160 didn't jump to line 166 because the condition on line 160 was always true
161 column = mapper.columns[name]
162 if column.default is not None or column.server_default is not None:
163 field_info["default"] = column.default or column.server_default
164 field_info["is_optional"] = True
166 fields[name] = field_info
168 return fields
171def create_schema_from_model(
172 model_cls: type[DeclarativeBase],
173 schema_name: Optional[str] = None,
174 include_relationships: bool = True,
175 include_readonly_fields: bool = True,
176) -> type[BaseSchema]:
177 """
178 Auto-generate a Pydantic schema from a SQLAlchemy model.
180 Args:
181 model_cls: The SQLAlchemy model class
182 schema_name: Optional name for the generated schema class
183 include_relationships: Whether to include relationship fields
184 include_readonly_fields: Whether to include read-only fields like id, created_at, etc.
186 Returns:
187 A Pydantic schema class
188 """
189 if schema_name is None:
190 schema_name = f"{model_cls.__name__}Schema"
192 # Get field information from the model
193 model_fields = get_model_fields(model_cls)
195 # Determine base classes - start with the most specific ones
196 bases: List[type] = []
198 # Check if model has timestamp fields (inherits from TimestampsMixin)
199 has_timestamps = "created_at" in model_fields and "updated_at" in model_fields
200 if has_timestamps:
201 bases.append(TimestampsSchemaMixin)
203 # Check if model has an id field (inherits from IDBase)
204 has_id = "id" in model_fields
205 if has_id:
206 bases.append(IDSchema)
208 # Always include BaseSchema as the base
209 bases.append(BaseSchema)
211 # Create field definitions for the schema
212 field_definitions: Dict[str, Any] = {}
213 read_only_fields: List[str] = []
215 for field_name, field_info in model_fields.items():
216 # Skip relationships if not requested
217 if field_info["is_relationship"] and not include_relationships:
218 continue
220 # Determine if field should be read-only
221 is_readonly = (
222 field_name in ["id", "created_at", "updated_at"] and include_readonly_fields
223 )
225 if is_readonly:
226 read_only_fields.append(field_name)
228 # Convert SQLAlchemy type to Pydantic type
229 pydantic_type = convert_sqlalchemy_type_to_pydantic(
230 field_info["type"], field_info["is_optional"]
231 )
233 # Handle relationships
234 if field_info["is_relationship"] and field_info["target_model"]:
235 target_model = field_info["target_model"]
237 # Skip self-referential relationship to avoid infinite recursion
238 if target_model is model_cls:
239 continue
241 if (
242 hasattr(field_info["type"], "__origin__")
243 and field_info["type"].__origin__ is list
244 ):
245 # Many relationship
246 target_schema = create_schema_from_model(
247 target_model,
248 include_relationships=False, # Avoid circular references
249 include_readonly_fields=False,
250 )
251 pydantic_type = List[target_schema]
252 else:
253 # One relationship
254 target_schema = create_schema_from_model(
255 target_model,
256 include_relationships=False, # Avoid circular references
257 include_readonly_fields=False,
258 )
259 pydantic_type = target_schema
261 if field_info["is_optional"]: 261 ↛ 266line 261 didn't jump to line 266 because the condition on line 261 was always true
262 pydantic_type = Optional[pydantic_type]
264 # Add field to definitions - use proper Pydantic field format
265 # Don't include SQLAlchemy defaults as they're not JSON-serializable
266 if field_info["is_optional"]:
267 from pydantic import Field
269 field_definitions[field_name] = (pydantic_type, Field(default=None))
270 else:
271 field_definitions[field_name] = (pydantic_type, ...)
273 # Apply ReadOnly annotation to read-only fields
274 if read_only_fields:
275 from ._base import ReadOnly
277 for field_name in read_only_fields:
278 if field_name in field_definitions: 278 ↛ 277line 278 didn't jump to line 277 because the condition on line 278 was always true
279 original_type, field_info = field_definitions[field_name]
280 # Apply ReadOnly annotation to the type
281 field_definitions[field_name] = (ReadOnly[original_type], field_info)
283 # Create the schema class using pydantic.create_model
284 import pydantic
286 schema_cls = pydantic.create_model( # type: ignore
287 schema_name,
288 __doc__=f"Auto-generated schema for {model_cls.__name__}",
289 __base__=tuple(bases),
290 **field_definitions,
291 )
293 return schema_cls
296def convert_sqlalchemy_type_to_pydantic(
297 sqlalchemy_type: Any, is_optional: bool = False
298) -> Any:
299 """
300 Convert a SQLAlchemy type to a Pydantic-compatible type.
302 Args:
303 sqlalchemy_type: The SQLAlchemy type
304 is_optional: Whether the field is optional
306 Returns:
307 A Pydantic-compatible type
308 """
309 type_name = getattr(sqlalchemy_type, "__name__", str(sqlalchemy_type))
311 if sqlalchemy_type is Any:
312 pydantic_type = Any
313 elif sqlalchemy_type in (
314 str,
315 int,
316 float,
317 bool,
318 dict,
319 list,
320 datetime,
321 date,
322 time,
323 UUID,
324 Decimal,
325 ):
326 pydantic_type = sqlalchemy_type
327 elif isinstance(sqlalchemy_type, type) and issubclass(sqlalchemy_type, enum.Enum):
328 pydantic_type = sqlalchemy_type
329 elif isinstance(sqlalchemy_type, type) and issubclass(
330 sqlalchemy_type, DeclarativeBase
331 ):
332 # Relationship targets are replaced with nested schemas later.
333 pydantic_type = sqlalchemy_type
334 elif getattr(sqlalchemy_type, "__origin__", None) is not None:
335 # Preserve parameterized container types like dict[str, Any] or list[int].
336 pydantic_type = sqlalchemy_type
337 elif type_name in {"Text", "String"}:
338 pydantic_type = str
339 elif type_name in {"Integer"}:
340 pydantic_type = int
341 elif type_name in {"Float"}:
342 pydantic_type = float
343 elif type_name in {"Boolean"}:
344 pydantic_type = bool
345 elif type_name in {"DateTime"}:
346 pydantic_type = datetime
347 elif type_name in {"Date"}:
348 pydantic_type = date
349 elif type_name in {"Time"}:
350 pydantic_type = time
351 else:
352 raise TypeError(
353 f"Unsupported field type for auto-generated schema: {sqlalchemy_type!r}"
354 )
356 # Handle optional types
357 if is_optional:
358 from typing import Optional
360 pydantic_type = Optional[pydantic_type]
362 return pydantic_type
365def auto_generate_schema_for_view(
366 view_cls: type, model_cls: type[DeclarativeBase], schema_name: Optional[str] = None
367) -> type[BaseSchema]:
368 """
369 Auto-generate a schema for a view class if none is specified.
371 Args:
372 view_cls: The view class
373 model_cls: The SQLAlchemy model class
374 schema_name: Optional name for the generated schema
376 Returns:
377 A Pydantic schema class
378 """
379 if schema_name is None:
380 schema_name = f"{view_cls.__name__}Schema"
382 return create_schema_from_model(
383 model_cls, schema_name, include_relationships=False
384 )