Coverage for fastapi_restly / schemas / _generator.py: 92%

161 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-05 09:15 +0000

1""" 

2Schema generation utilities for auto-generating Pydantic schemas from SQLAlchemy models. 

3""" 

4 

5import enum 

6import inspect 

7import types 

8from datetime import date, datetime, time 

9from decimal import Decimal 

10from typing import Any, Dict, List, Optional, Union, get_args 

11from uuid import UUID 

12 

13from sqlalchemy import inspect as sa_inspect 

14from sqlalchemy.orm import DeclarativeBase, Mapped, RelationshipProperty 

15 

16from ._base import BaseSchema, IDSchema, TimestampsSchemaMixin 

17 

18 

19def get_sqlalchemy_field_type(field: Any) -> Any: 

20 """ 

21 Extract the Python type from a SQLAlchemy Mapped field. 

22 

23 Args: 

24 field: A SQLAlchemy Mapped field 

25 

26 Returns: 

27 The Python type annotation 

28 """ 

29 # Get the type annotation from the Mapped field 

30 if hasattr(field, "type"): 

31 return field.type 

32 elif hasattr(field, "__origin__"): 

33 return field.__origin__ 

34 else: 

35 # Fallback to Any if we can't determine the type 

36 return Any 

37 

38 

39def is_relationship_field(field: Any) -> bool: 

40 """ 

41 Check if a field is a SQLAlchemy relationship. 

42 

43 Args: 

44 field: A SQLAlchemy Mapped field 

45 

46 Returns: 

47 True if the field is a relationship, False otherwise 

48 """ 

49 if isinstance(field, RelationshipProperty): 

50 return True 

51 return isinstance(getattr(field, "property", None), RelationshipProperty) 

52 

53 

54def get_relationship_target_model(field: Any) -> Optional[type[DeclarativeBase]]: 

55 """ 

56 Get the target model class for a relationship field. 

57 

58 Args: 

59 field: A SQLAlchemy relationship field 

60 

61 Returns: 

62 The target model class or None if not found 

63 """ 

64 if not is_relationship_field(field): 

65 return None 

66 

67 # Try to get the target from the relationship property 

68 relationship = field 

69 if not isinstance(relationship, RelationshipProperty): 

70 relationship = getattr(field, "property", None) 

71 

72 if relationship is not None and hasattr(relationship, "mapper") and hasattr(relationship.mapper, "class_"): 

73 return relationship.mapper.class_ 

74 

75 # Try to get from the type annotation 

76 if hasattr(field, "type"): 76 ↛ 86line 76 didn't jump to line 86 because the condition on line 76 was always true

77 target_type = field.type 

78 if hasattr(target_type, "__origin__") and target_type.__origin__ is list: 

79 # Handle List[Model] case 

80 args = get_args(target_type) 

81 if args: 81 ↛ 86line 81 didn't jump to line 86 because the condition on line 81 was always true

82 return args[0] 

83 elif inspect.isclass(target_type) and issubclass(target_type, DeclarativeBase): 83 ↛ 86line 83 didn't jump to line 86 because the condition on line 83 was always true

84 return target_type 

85 

86 return None 

87 

88 

89def get_model_fields(model_cls: type[DeclarativeBase]) -> Dict[str, Any]: 

90 """ 

91 Extract field information from a SQLAlchemy model. 

92 

93 Args: 

94 model_cls: A SQLAlchemy model class 

95 

96 Returns: 

97 Dictionary mapping field names to their types and metadata 

98 """ 

99 fields: Dict[str, Any] = {} 

100 

101 mapper = sa_inspect(model_cls) 

102 

103 # Get all annotations from the model class and its base classes 

104 all_annotations = {} 

105 for cls in model_cls.mro(): 

106 if hasattr(cls, "__annotations__"): 

107 all_annotations.update(cls.__annotations__) 

108 

109 for name, field_type in all_annotations.items(): 

110 if name.startswith("_"): 

111 continue 

112 

113 # Check if it's a Mapped field 

114 if not hasattr(field_type, "__origin__") or field_type.__origin__ is not Mapped: 

115 continue 

116 

117 # Extract the actual type from Mapped[Type] 

118 args = get_args(field_type) 

119 if not args: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 continue 

121 

122 actual_type = args[0] 

123 relationship = mapper.relationships.get(name) 

124 

125 rel_mapper = getattr(relationship, "mapper", None) if relationship is not None else None 

126 field_info: Dict[str, Any] = { 

127 "type": actual_type, 

128 "is_relationship": relationship is not None, 

129 "target_model": ( 

130 rel_mapper.class_ if rel_mapper is not None else None 

131 ), 

132 "is_optional": False, 

133 "default": None, 

134 } 

135 

136 # Check if the field is optional (Union with None or Optional) 

137 if isinstance(actual_type, types.UnionType): 

138 # Python 3.10+ `str | None` syntax 

139 union_args = get_args(actual_type) 

140 if type(None) in union_args: 140 ↛ 156line 140 didn't jump to line 156 because the condition on line 140 was always true

141 field_info["is_optional"] = True 

142 non_none_types = [arg for arg in union_args if arg is not type(None)] 

143 if non_none_types: 143 ↛ 156line 143 didn't jump to line 156 because the condition on line 143 was always true

144 field_info["type"] = non_none_types[0] 

145 elif hasattr(actual_type, "__origin__"): 

146 origin = actual_type.__origin__ 

147 if origin is Union: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 args = get_args(actual_type) 

149 if type(None) in args: 

150 field_info["is_optional"] = True 

151 # Remove None from the type 

152 non_none_types = [arg for arg in args if arg is not type(None)] 

153 if non_none_types: 

154 field_info["type"] = non_none_types[0] 

155 

156 if relationship is not None: 

157 # Relationship fields are response-oriented in generated schemas. 

158 # Keep them optional so create/update inputs can rely on FK columns. 

159 field_info["is_optional"] = True 

160 elif name in mapper.columns: 160 ↛ 166line 160 didn't jump to line 166 because the condition on line 160 was always true

161 column = mapper.columns[name] 

162 if column.default is not None or column.server_default is not None: 

163 field_info["default"] = column.default or column.server_default 

164 field_info["is_optional"] = True 

165 

166 fields[name] = field_info 

167 

168 return fields 

169 

170 

171def create_schema_from_model( 

172 model_cls: type[DeclarativeBase], 

173 schema_name: Optional[str] = None, 

174 include_relationships: bool = True, 

175 include_readonly_fields: bool = True, 

176) -> type[BaseSchema]: 

177 """ 

178 Auto-generate a Pydantic schema from a SQLAlchemy model. 

179 

180 Args: 

181 model_cls: The SQLAlchemy model class 

182 schema_name: Optional name for the generated schema class 

183 include_relationships: Whether to include relationship fields 

184 include_readonly_fields: Whether to include read-only fields like id, created_at, etc. 

185 

186 Returns: 

187 A Pydantic schema class 

188 """ 

189 if schema_name is None: 

190 schema_name = f"{model_cls.__name__}Schema" 

191 

192 # Get field information from the model 

193 model_fields = get_model_fields(model_cls) 

194 

195 # Determine base classes - start with the most specific ones 

196 bases: List[type] = [] 

197 

198 # Check if model has timestamp fields (inherits from TimestampsMixin) 

199 has_timestamps = "created_at" in model_fields and "updated_at" in model_fields 

200 if has_timestamps: 

201 bases.append(TimestampsSchemaMixin) 

202 

203 # Check if model has an id field (inherits from IDBase) 

204 has_id = "id" in model_fields 

205 if has_id: 

206 bases.append(IDSchema) 

207 

208 # Always include BaseSchema as the base 

209 bases.append(BaseSchema) 

210 

211 # Create field definitions for the schema 

212 field_definitions: Dict[str, Any] = {} 

213 read_only_fields: List[str] = [] 

214 

215 for field_name, field_info in model_fields.items(): 

216 # Skip relationships if not requested 

217 if field_info["is_relationship"] and not include_relationships: 

218 continue 

219 

220 # Determine if field should be read-only 

221 is_readonly = ( 

222 field_name in ["id", "created_at", "updated_at"] and include_readonly_fields 

223 ) 

224 

225 if is_readonly: 

226 read_only_fields.append(field_name) 

227 

228 # Convert SQLAlchemy type to Pydantic type 

229 pydantic_type = convert_sqlalchemy_type_to_pydantic( 

230 field_info["type"], field_info["is_optional"] 

231 ) 

232 

233 # Handle relationships 

234 if field_info["is_relationship"] and field_info["target_model"]: 

235 target_model = field_info["target_model"] 

236 

237 # Skip self-referential relationship to avoid infinite recursion 

238 if target_model is model_cls: 

239 continue 

240 

241 if ( 

242 hasattr(field_info["type"], "__origin__") 

243 and field_info["type"].__origin__ is list 

244 ): 

245 # Many relationship 

246 target_schema = create_schema_from_model( 

247 target_model, 

248 include_relationships=False, # Avoid circular references 

249 include_readonly_fields=False, 

250 ) 

251 pydantic_type = List[target_schema] 

252 else: 

253 # One relationship 

254 target_schema = create_schema_from_model( 

255 target_model, 

256 include_relationships=False, # Avoid circular references 

257 include_readonly_fields=False, 

258 ) 

259 pydantic_type = target_schema 

260 

261 if field_info["is_optional"]: 261 ↛ 266line 261 didn't jump to line 266 because the condition on line 261 was always true

262 pydantic_type = Optional[pydantic_type] 

263 

264 # Add field to definitions - use proper Pydantic field format 

265 # Don't include SQLAlchemy defaults as they're not JSON-serializable 

266 if field_info["is_optional"]: 

267 from pydantic import Field 

268 

269 field_definitions[field_name] = (pydantic_type, Field(default=None)) 

270 else: 

271 field_definitions[field_name] = (pydantic_type, ...) 

272 

273 # Apply ReadOnly annotation to read-only fields 

274 if read_only_fields: 

275 from ._base import ReadOnly 

276 

277 for field_name in read_only_fields: 

278 if field_name in field_definitions: 278 ↛ 277line 278 didn't jump to line 277 because the condition on line 278 was always true

279 original_type, field_info = field_definitions[field_name] 

280 # Apply ReadOnly annotation to the type 

281 field_definitions[field_name] = (ReadOnly[original_type], field_info) 

282 

283 # Create the schema class using pydantic.create_model 

284 import pydantic 

285 

286 schema_cls = pydantic.create_model( # type: ignore 

287 schema_name, 

288 __doc__=f"Auto-generated schema for {model_cls.__name__}", 

289 __base__=tuple(bases), 

290 **field_definitions, 

291 ) 

292 

293 return schema_cls 

294 

295 

296def convert_sqlalchemy_type_to_pydantic( 

297 sqlalchemy_type: Any, is_optional: bool = False 

298) -> Any: 

299 """ 

300 Convert a SQLAlchemy type to a Pydantic-compatible type. 

301 

302 Args: 

303 sqlalchemy_type: The SQLAlchemy type 

304 is_optional: Whether the field is optional 

305 

306 Returns: 

307 A Pydantic-compatible type 

308 """ 

309 type_name = getattr(sqlalchemy_type, "__name__", str(sqlalchemy_type)) 

310 

311 if sqlalchemy_type is Any: 

312 pydantic_type = Any 

313 elif sqlalchemy_type in ( 

314 str, 

315 int, 

316 float, 

317 bool, 

318 dict, 

319 list, 

320 datetime, 

321 date, 

322 time, 

323 UUID, 

324 Decimal, 

325 ): 

326 pydantic_type = sqlalchemy_type 

327 elif isinstance(sqlalchemy_type, type) and issubclass(sqlalchemy_type, enum.Enum): 

328 pydantic_type = sqlalchemy_type 

329 elif isinstance(sqlalchemy_type, type) and issubclass( 

330 sqlalchemy_type, DeclarativeBase 

331 ): 

332 # Relationship targets are replaced with nested schemas later. 

333 pydantic_type = sqlalchemy_type 

334 elif getattr(sqlalchemy_type, "__origin__", None) is not None: 

335 # Preserve parameterized container types like dict[str, Any] or list[int]. 

336 pydantic_type = sqlalchemy_type 

337 elif type_name in {"Text", "String"}: 

338 pydantic_type = str 

339 elif type_name in {"Integer"}: 

340 pydantic_type = int 

341 elif type_name in {"Float"}: 

342 pydantic_type = float 

343 elif type_name in {"Boolean"}: 

344 pydantic_type = bool 

345 elif type_name in {"DateTime"}: 

346 pydantic_type = datetime 

347 elif type_name in {"Date"}: 

348 pydantic_type = date 

349 elif type_name in {"Time"}: 

350 pydantic_type = time 

351 else: 

352 raise TypeError( 

353 f"Unsupported field type for auto-generated schema: {sqlalchemy_type!r}" 

354 ) 

355 

356 # Handle optional types 

357 if is_optional: 

358 from typing import Optional 

359 

360 pydantic_type = Optional[pydantic_type] 

361 

362 return pydantic_type 

363 

364 

365def auto_generate_schema_for_view( 

366 view_cls: type, model_cls: type[DeclarativeBase], schema_name: Optional[str] = None 

367) -> type[BaseSchema]: 

368 """ 

369 Auto-generate a schema for a view class if none is specified. 

370 

371 Args: 

372 view_cls: The view class 

373 model_cls: The SQLAlchemy model class 

374 schema_name: Optional name for the generated schema 

375 

376 Returns: 

377 A Pydantic schema class 

378 """ 

379 if schema_name is None: 

380 schema_name = f"{view_cls.__name__}Schema" 

381 

382 return create_schema_from_model( 

383 model_cls, schema_name, include_relationships=False 

384 )