Coverage for fastapi_restly / schemas / _base.py: 96%

210 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-02 09:54 +0000

1import types 

2from datetime import datetime 

3from typing import ( 

4 Annotated, 

5 Any, 

6 Generic, 

7 Optional, 

8 TypeVar, 

9 Union, 

10 get_args, 

11 get_origin, 

12) 

13 

14import pydantic 

15from fastapi import HTTPException 

16from pydantic.fields import Field, FieldInfo 

17from sqlalchemy import select 

18from sqlalchemy.exc import NoResultFound 

19from sqlalchemy.ext.asyncio.session import AsyncSession as SA_AsyncSession 

20from sqlalchemy.orm import DeclarativeBase 

21from sqlalchemy.orm.session import Session as SA_Session 

22 

23 

24class BaseSchema(pydantic.BaseModel): 

25 # Allow validating SQLAlchemy model instances directly in request/response flows. 

26 # This keeps aliased fields working when FastAPI validates ORM objects. 

27 model_config = pydantic.ConfigDict(from_attributes=True) 

28 

29 

30class _Marker: 

31 def __init__(self, name: str): 

32 self.name = name 

33 

34 def __repr__(self): 

35 return f"fr.{self.name}" 

36 

37 

38readonly_marker = _Marker("ReadOnly") 

39writeonly_marker = _Marker("WriteOnly") 

40 

41_T = TypeVar("_T") 

42 

43ReadOnly = Annotated[_T, readonly_marker, Field(json_schema_extra={"readOnly": True})] 

44WriteOnly = Annotated[ 

45 _T, writeonly_marker, Field(json_schema_extra={"writeOnly": True}) 

46] 

47 

48 

49class TimestampsSchemaMixin(pydantic.BaseModel): 

50 created_at: ReadOnly[datetime] 

51 updated_at: ReadOnly[datetime] 

52 

53 

54SQLAlchemyModel = TypeVar("SQLAlchemyModel", bound=DeclarativeBase) 

55 

56 

57class IDSchema(BaseSchema, Generic[SQLAlchemyModel]): 

58 """Generic schema useful for serializing only the id of objects. 

59 Can be used as IDSchema[MyModel]. 

60 """ 

61 

62 # Keep this broad so relation-id payloads can target non-int primary keys. 

63 id: ReadOnly[Any] 

64 

65 @classmethod 

66 def _get_sql_model_annotation(cls) -> type[DeclarativeBase] | None: 

67 try: 

68 sql_model = cls.__pydantic_generic_metadata__["args"][0] 

69 except Exception: 

70 return None 

71 return sql_model if isinstance(sql_model, type) else None 

72 

73 @classmethod 

74 def _get_sql_model_id_type(cls) -> Any: 

75 sql_model = cls._get_sql_model_annotation() 

76 if sql_model is None: 

77 return None 

78 

79 for model_cls in sql_model.mro(): 79 ↛ 90line 79 didn't jump to line 90 because the loop on line 79 didn't complete

80 annotation = getattr(model_cls, "__annotations__", {}).get("id") 

81 if annotation is None: 

82 continue 

83 origin = get_origin(annotation) 

84 if origin is not None: 84 ↛ 88line 84 didn't jump to line 88 because the condition on line 84 was always true

85 args = get_args(annotation) 

86 if args: 86 ↛ 88line 86 didn't jump to line 88 because the condition on line 86 was always true

87 return args[0] 

88 return annotation 

89 

90 try: 

91 return sql_model.__mapper__.primary_key[0].type.python_type 

92 except Exception: 

93 return None 

94 

95 @pydantic.field_validator("id", mode="before", check_fields=False) 

96 @classmethod 

97 def _coerce_id_to_model_primary_key_type(cls, value: Any) -> Any: 

98 id_type = cls._get_sql_model_id_type() 

99 if id_type in (None, Any): 

100 return value 

101 return pydantic.TypeAdapter(id_type).validate_python(value) 

102 

103 def get_sql_model_annotation(self) -> SQLAlchemyModel | None: 

104 """ 

105 Return the annotation on IDSchema when used as: 

106 

107 foo: IDSchema[Foo] 

108 

109 This property will return "Foo". 

110 """ 

111 return self._get_sql_model_annotation() 

112 

113 

114class IDStampsSchema(TimestampsSchemaMixin, IDSchema): 

115 pass 

116 

117 

118async def async_resolve_ids_to_sqlalchemy_objects( 

119 session: SA_AsyncSession, schema_obj: BaseSchema 

120) -> None: 

121 """ 

122 Go over the Pydantic fields and turn any IDSchema objects into SQLAlchemy instances. 

123 A database request is made for each IDSchema to look up the related row in the database. 

124 If an id is not found in the database `sqlalchemy.orm.exc.NoResultFound` is raised. 

125 """ 

126 # Go over all Pydantic fields and check if any of them are an IDSchema object or 

127 # a list of IDSchema objects. 

128 for field in schema_obj.model_fields_set: 

129 value = getattr(schema_obj, field, None) 

130 

131 if isinstance(value, IDSchema): 

132 sql_model = value.get_sql_model_annotation() 

133 if not sql_model: 

134 continue 

135 

136 # Replace the IDSchema object with a SQLAlchemy instance from the database 

137 try: 

138 sql_model_obj = await session.get_one(sql_model, value.id) 

139 except NoResultFound as e: 

140 raise HTTPException( 

141 status_code=404, detail=f"Id not found for {field}: {value.id}" 

142 ) from e 

143 setattr(schema_obj, field, sql_model_obj) 

144 

145 elif isinstance(value, list) and any(isinstance(i, IDSchema) for i in value): 

146 # Assume all IdSchemas are for the same model 

147 sql_model = value[0].get_sql_model_annotation() 

148 if not sql_model: 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true

149 continue 

150 

151 # Replace all IDSchema objects with SQLAlchemy instances 

152 ids = [obj.id for obj in value] 

153 query = select(sql_model).where(sql_model.id.in_(ids)) 

154 sql_model_objs = list(await session.scalars(query)) 

155 

156 if len(ids) != len(sql_model_objs): 

157 missing_ids = set(ids).difference(o.id for o in sql_model_objs) 

158 raise HTTPException( 

159 status_code=404, detail=f"Id not found for {field}: {missing_ids}" 

160 ) 

161 

162 setattr(schema_obj, field, sql_model_objs) 

163 

164 

165def resolve_ids_to_sqlalchemy_objects( 

166 session: SA_Session, schema_obj: BaseSchema 

167) -> None: 

168 """ 

169 Go over the Pydantic fields and turn any IDSchema objects into SQLAlchemy instances. 

170 A database request is made for each IDSchema to look up the related row in the database. 

171 If an id is not found in the database `sqlalchemy.orm.exc.NoResultFound` is raised. 

172 """ 

173 # Go over all Pydantic fields and check if any of them are an IDSchema object or 

174 # a list of IDSchema objects. 

175 for field in schema_obj.model_fields_set: 

176 value = getattr(schema_obj, field, None) 

177 

178 if isinstance(value, IDSchema): 

179 sql_model = value.get_sql_model_annotation() 

180 if not sql_model: 

181 continue 

182 

183 # Replace the IDSchema object with a SQLAlchemy instance from the database 

184 try: 

185 sql_model_obj = session.get_one(sql_model, value.id) 

186 except NoResultFound as e: 

187 raise HTTPException( 

188 status_code=404, detail=f"Id not found for {field}: {value.id}" 

189 ) from e 

190 setattr(schema_obj, field, sql_model_obj) 

191 

192 elif isinstance(value, list) and any(isinstance(i, IDSchema) for i in value): 

193 # Assume all IdSchemas are for the same model 

194 sql_model = value[0].get_sql_model_annotation() 

195 if not sql_model: 195 ↛ 196line 195 didn't jump to line 196 because the condition on line 195 was never true

196 continue 

197 

198 # Replace all IDSchema objects with SQLAlchemy instances 

199 ids = [obj.id for obj in value] 

200 query = select(sql_model).where(sql_model.id.in_(ids)) 

201 sql_model_objs = list(session.scalars(query)) 

202 

203 if len(ids) != len(sql_model_objs): 

204 missing_ids = set(ids).difference(o.id for o in sql_model_objs) 

205 raise HTTPException( 

206 status_code=404, detail=f"Id not found for {field}: {missing_ids}" 

207 ) 

208 

209 setattr(schema_obj, field, sql_model_objs) 

210 

211 

212def get_read_only_fields(model_cls: type[pydantic.BaseModel]) -> list[str]: 

213 """Get all fields from a model annotated as ReadOnly[]""" 

214 read_only_fields: list[str] = [] 

215 # Get read-only fields from Annotated metadata 

216 for field_name, field_info in model_cls.model_fields.items(): 

217 metadata = getattr(field_info, "metadata", None) 

218 if metadata and readonly_marker in metadata: 

219 read_only_fields.append(field_name) 

220 return read_only_fields 

221 

222 

223def is_readonly_field( 

224 model: pydantic.BaseModel | type[pydantic.BaseModel], field_name: str 

225) -> bool: 

226 """Check if a specific field is marked as readonly.""" 

227 if isinstance(model, pydantic.BaseModel): 

228 model = model.__class__ 

229 field_info = model.model_fields.get(field_name) 

230 return _is_readonly(field_info) 

231 

232 

233def _is_readonly(field_info: FieldInfo | None) -> bool: 

234 if field_info is None: 

235 return False 

236 metadata = getattr(field_info, "metadata", None) 

237 if not metadata: 

238 return False 

239 return readonly_marker in metadata 

240 

241 

242def _is_writeonly(field_info: FieldInfo | None) -> bool: 

243 if field_info is None: 

244 return False 

245 metadata = getattr(field_info, "metadata", None) 

246 if not metadata: 

247 return False 

248 return writeonly_marker in metadata 

249 

250 

251def get_write_only_fields(model_cls: type[pydantic.BaseModel]) -> list[str]: 

252 """Get all fields from a model annotated as WriteOnly[]""" 

253 write_only_fields: list[str] = [] 

254 # Get write-only fields from Annotated metadata 

255 for field_name, field_info in model_cls.model_fields.items(): 

256 if _is_writeonly(field_info): 

257 write_only_fields.append(field_name) 

258 return write_only_fields 

259 

260 

261def is_field_writeonly(model_cls: pydantic.BaseModel | type[pydantic.BaseModel], field_name: str) -> bool: 

262 """Check if a specific field is marked as writeonly.""" 

263 if isinstance(model_cls, pydantic.BaseModel): 

264 model_cls = model_cls.__class__ 

265 field_info = model_cls.model_fields.get(field_name) 

266 return _is_writeonly(field_info) 

267 

268 

269def create_model_without_read_only_fields( 

270 model_cls: type[pydantic.BaseModel], 

271) -> type[pydantic.BaseModel]: 

272 """ 

273 Create a subclass of the given pydantic model class with a new name. 

274 """ 

275 new_model_name = "Create" + model_cls.__name__ 

276 new_doc = (model_cls.__doc__ or "") + "\nRead-only fields have been removed." 

277 

278 # Create a subclass that mixes in OmitReadOnlyMixin 

279 new_model_cls = type( 

280 new_model_name, 

281 (OmitReadOnlyMixin, model_cls), 

282 {"__module__": model_cls.__module__, "__doc__": new_doc}, 

283 ) 

284 

285 return new_model_cls 

286 

287 

288class OmitReadOnlyMixin(pydantic.BaseModel): 

289 """ 

290 Mixin for pydantic models that removes all fields marked as ReadOnly. 

291 """ 

292 

293 @classmethod 

294 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: 

295 super().__pydantic_init_subclass__(**kwargs) 

296 

297 # Collect readonly fields to delete first 

298 readonly_fields = [] 

299 for name, field_info in cls.model_fields.items(): 

300 if _is_readonly(field_info): 

301 readonly_fields.append(name) 

302 

303 # Delete readonly fields after iteration is complete 

304 for name in readonly_fields: 

305 del cls.model_fields[name] 

306 

307 cls.model_rebuild(force=True) 

308 

309 

310def rebase_with_model_config( 

311 base: tuple[type, ...], model_cls: type[pydantic.BaseModel] 

312) -> type[pydantic.BaseModel]: 

313 def class_body(ns: dict[str, Any]) -> None: 

314 ns["model_config"] = model_cls.model_config.copy() 

315 

316 return types.new_class( 

317 f"{model_cls.__name__}ModelConfig", base, exec_body=class_body 

318 ) 

319 

320 

321def create_model_with_optional_fields( 

322 model_cls: type[pydantic.BaseModel], 

323) -> type[pydantic.BaseModel]: 

324 """ 

325 Create a subclass of the given pydantic model class with a new name. 

326 Read-only fields are removed and all writable fields are made optional with None as default. 

327 """ 

328 new_model_name = "Update" + model_cls.__name__ 

329 new_doc = ( 

330 model_cls.__doc__ or "" 

331 ) + "\nRead-only fields have been removed and all fields are optional." 

332 

333 # Create a subclass that mixes in both OmitReadOnlyMixin and PatchMixin 

334 new_model_cls = type( 

335 new_model_name, 

336 (PatchMixin, OmitReadOnlyMixin, model_cls), 

337 {"__module__": model_cls.__module__, "__doc__": new_doc}, 

338 ) 

339 

340 return new_model_cls 

341 

342 

343class PatchMixin(pydantic.BaseModel): 

344 """ 

345 A mixin for pydantic classes that makes all fields optional and replaces defaults 

346 with None. 

347 """ 

348 

349 @classmethod 

350 def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: 

351 super().__pydantic_init_subclass__(**kwargs) 

352 

353 for field in cls.model_fields.values(): 

354 field.default = None 

355 # Only wrap if not already Optional, to avoid Optional[Optional[T]] 

356 annotation = field.annotation 

357 if isinstance(annotation, types.UnionType): 

358 # Python 3.10+ `X | Y` syntax - check if None is already a member. 

359 # Convert to typing.Optional form so FieldInfo.annotation stays compatible. 

360 union_args = get_args(annotation) 

361 if type(None) not in union_args: 

362 non_none = [a for a in union_args if a is not type(None)] 

363 inner = non_none[0] if len(non_none) == 1 else Union[tuple(non_none)] 

364 field.annotation = Optional[inner] # type: ignore[assignment] 

365 else: 

366 origin = getattr(annotation, "__origin__", None) 

367 if origin is not Union or type(None) not in get_args(annotation): 

368 field.annotation = Optional[annotation] # type: ignore[assignment] 

369 

370 cls.model_rebuild(force=True) 

371 

372 

373def getattrs(obj: Any, *attrs: str, default: Any = None) -> Any: 

374 """ 

375 Try access a chain of attributes and return the default if any of the attrs is not defined. 

376 """ 

377 for attr in attrs: 

378 if not hasattr(obj, attr): 

379 return default 

380 obj = getattr(obj, attr) 

381 return obj 

382 

383 

384def set_schema_title(schema_cls: type[pydantic.BaseModel]) -> None: 

385 """Set the title of a schema class to its name. 

386 This is used to make the schema title match the model name in the OpenAPI schema. 

387 """ 

388 schema_cls.model_config["title"] = schema_cls.__name__ 

389 

390 

391def get_writable_inputs( 

392 schema_obj: BaseSchema, schema_cls: type[pydantic.BaseModel] | None = None 

393) -> dict[str, Any]: 

394 """ 

395 Return a dictionary of field_name: value pairs for writable input fields. 

396 

397 Filters out: 

398 - ReadOnly fields 

399 - fields not provided with input (using Pydantic model_fields_set) 

400 

401 Args: 

402 schema_obj: The schema object to extract writable fields from 

403 schema_cls: The schema class to check for readonly fields. If None, uses schema_obj.__class__ 

404 

405 Returns: 

406 Dictionary mapping field names to their values for writable input fields only 

407 """ 

408 if schema_cls is None: 

409 schema_cls = schema_obj.__class__ 

410 

411 updated_fields: dict[str, Any] = {} 

412 for field_name, value in schema_obj: 

413 if field_name not in schema_obj.model_fields_set: 

414 continue 

415 # Skip readonly fields 

416 if is_readonly_field(schema_cls, field_name): 

417 continue 

418 updated_fields[field_name] = value 

419 

420 return updated_fields