Coverage for fastapi_restly / query / _v2.py: 92%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-02 09:54 +0000

1import functools 

2from collections import defaultdict 

3from typing import Any, Callable, Iterator, Optional, cast 

4 

5import pydantic 

6import sqlalchemy 

7from fastapi import HTTPException 

8from pydantic.fields import FieldInfo 

9from sqlalchemy import ColumnElement, Select 

10from sqlalchemy.orm import DeclarativeBase 

11from sqlalchemy.orm.attributes import InstrumentedAttribute 

12from sqlalchemy.orm.properties import ColumnProperty 

13from starlette.datastructures import QueryParams 

14 

15from ._shared import _escape_like_value, _unwrap_optional_annotation 

16 

17SchemaType = type[pydantic.BaseModel] 

18 

19 

20def _is_string_field_v2(field: FieldInfo) -> bool: 

21 """Check if a field is a string type.""" 

22 annotation = _unwrap_optional_annotation(field.annotation) 

23 return annotation is str 

24 

25 

26def create_query_param_schema_v2(schema_cls: SchemaType) -> SchemaType: 

27 """ 

28 Create a pydantic model class that describes and validates all possible query parameters 

29 for the v2 interface (direct field names, __gte, __lte, etc.). 

30 """ 

31 fields = { 

32 "page": (Optional[int], None), 

33 "page_size": (Optional[int], None), 

34 "order_by": (Optional[str], None), 

35 } 

36 for name, field in _iter_fields_including_nested_v2(schema_cls): 

37 field_type = _get_field_type_for_schema(field) 

38 fields[name] = (Optional[field_type], None) 

39 # Add range filters 

40 for suffix in ["__gte", "__lte", "__gt", "__lt"]: 

41 fields[f"{name}{suffix}"] = (Optional[field_type], None) 

42 fields[f"{name}__ne"] = (Optional[field_type], None) 

43 fields[f"{name}__isnull"] = (Optional[bool], None) 

44 

45 # Add contains filter for string fields 

46 if _is_string_field_v2(field): 

47 fields[f"{name}__contains"] = (Optional[str], None) 

48 

49 schema_name = "QueryParamV2" + schema_cls.__name__ 

50 query_param_schema = pydantic.create_model(schema_name, **fields) # type: ignore 

51 return query_param_schema 

52 

53 

54def apply_query_modifiers_v2( 

55 query_params: QueryParams, 

56 select_query: Select[Any], 

57 model: type[DeclarativeBase], 

58 schema_cls: SchemaType, 

59) -> Select[Any]: 

60 """ 

61 Apply pagination, sorting, and filtering through URL query parameters on a SQL query. 

62 

63 Uses a more standard interface:: 

64 

65 # Pagination 

66 page=2&page_size=50 

67 

68 # Sorting 

69 order_by=name,-created_at 

70 

71 # Filtering 

72 name=Bob&status=active&created_at__gte=2024-01-01 

73 

74 # Contains (string fields) 

75 name__contains=john&email__contains=example 

76 """ 

77 select_query = apply_filtering_v2(query_params, select_query, model, schema_cls) 

78 select_query = apply_sorting_v2(query_params, select_query, model, schema_cls) 

79 select_query = apply_pagination_v2(query_params, select_query) 

80 return select_query 

81 

82 

83def apply_pagination_v2( 

84 query_params: QueryParams, select_query: Select[Any] 

85) -> Select[Any]: 

86 """ 

87 Apply pagination using page and page_size parameters. 

88 """ 

89 page = _get_int_v2(query_params, "page") or 1 

90 if page < 1: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 raise HTTPException(400, "Invalid value for URL query parameter page: must be >= 1") 

92 page_size = _get_int_v2(query_params, "page_size") or 100 

93 if page_size <= 0: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 raise HTTPException(400, "Invalid value for URL query parameter page_size: must be > 0") 

95 offset = (page - 1) * page_size 

96 select_query = select_query.limit(page_size).offset(offset) 

97 return select_query 

98 

99 

100def _get_field_type_for_schema(field: FieldInfo) -> type: 

101 annotation = _unwrap_optional_annotation(field.annotation) 

102 if annotation is Any: 

103 return Any 

104 if isinstance(annotation, type): 104 ↛ 106line 104 didn't jump to line 106 because the condition on line 104 was always true

105 return annotation 

106 return object 

107 

108 

109def _iter_fields_including_nested_v2( 

110 schema_cls: SchemaType, prefix: str = "" 

111) -> Iterator[tuple[str, FieldInfo]]: 

112 for name, field in schema_cls.model_fields.items(): 

113 # Use alias if available, otherwise use field name 

114 field_name = field.alias or name 

115 full_name = f"{prefix}.{field_name}" if prefix else field_name 

116 nested = _get_nested_schema_v2(field) 

117 if nested: 

118 yield from _iter_fields_including_nested_v2(nested, full_name) 

119 else: 

120 yield full_name, field 

121 

122 

123def _get_int_v2(query_params: QueryParams, param_name: str) -> Optional[int]: 

124 value = query_params.get(param_name) 

125 if not value: 

126 return None 

127 

128 try: 

129 return int(value) 

130 except ValueError: 

131 raise HTTPException( 

132 400, 

133 f"Invalid value for URL query parameter {param_name}: {value} is not an integer", 

134 ) 

135 

136 

137def apply_sorting_v2( 

138 query_params: QueryParams, 

139 select_query: Select[Any], 

140 model: type[DeclarativeBase], 

141 schema_cls: SchemaType | None = None, 

142) -> Select[Any]: 

143 """ 

144 Apply sorting using the order_by parameter (comma-separated, - for descending). 

145 """ 

146 sort_string = query_params.get("order_by") 

147 if not sort_string: 

148 id_column = getattr(model, "id", None) 

149 if id_column: 149 ↛ 152line 149 didn't jump to line 152 because the condition on line 149 was always true

150 return select_query.order_by(id_column) 

151 else: 

152 return select_query 

153 

154 for column_name in sort_string.split(","): 

155 order = sqlalchemy.asc 

156 if column_name.startswith("-"): 

157 order = sqlalchemy.desc 

158 column_name = column_name[1:] 

159 joins, column = _get_sqlalchemy_column_v2(model, column_name, schema_cls) 

160 for join in joins: 

161 select_query = select_query.join(join) 

162 select_query = select_query.order_by(order(column)) 

163 return select_query 

164 

165 

166def _get_sqlalchemy_column_v2( 

167 model: type[DeclarativeBase], column_path: str, schema_cls: SchemaType | None = None 

168) -> tuple[list[InstrumentedAttribute[Any]], InstrumentedAttribute[Any]]: 

169 *models, column = _resolve_sqlalchemy_column_v2(model, column_path, schema_cls) 

170 return cast(list[InstrumentedAttribute[Any]], models), cast( 

171 InstrumentedAttribute[Any], column 

172 ) 

173 

174 

175def _resolve_sqlalchemy_column_v2( 

176 model: type[DeclarativeBase], column_name: str, schema_cls: SchemaType | None = None 

177) -> Iterator[type[DeclarativeBase] | InstrumentedAttribute[Any]]: 

178 if "." in column_name: 

179 relation, _, column_part = column_name.partition(".") 

180 rel = getattr(model, relation, None) 

181 if not isinstance(rel, InstrumentedAttribute) or not hasattr( 181 ↛ 184line 181 didn't jump to line 184 because the condition on line 181 was never true

182 rel.property, "mapper" 

183 ): 

184 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

185 related_model = rel.property.mapper.class_ 

186 yield rel 

187 yield from _resolve_sqlalchemy_column_v2(related_model, column_part, schema_cls) 

188 else: 

189 # Try to find the column directly 

190 column = getattr(model, column_name, None) 

191 if ( 

192 column is not None 

193 and isinstance(column, InstrumentedAttribute) 

194 and isinstance(column.property, ColumnProperty) 

195 ): 

196 yield cast(InstrumentedAttribute[Any], column) 

197 return 

198 

199 # If not found and we have a schema, try to resolve alias to field name 

200 if schema_cls: 

201 field_name = None 

202 

203 # Look for field with this alias 

204 for name, field in schema_cls.model_fields.items(): 

205 if field.alias == column_name: 

206 field_name = name 

207 break 

208 

209 # If not found by alias and populate_by_name is True, try field name 

210 if field_name is None: 

211 config = getattr(schema_cls, "model_config", pydantic.ConfigDict()) 

212 populate_by_name = config.get("populate_by_name", False) 

213 if populate_by_name and column_name in schema_cls.model_fields: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 field_name = column_name 

215 # Also allow fields that don't have aliases 

216 elif not any(f.alias for f in schema_cls.model_fields.values()): 216 ↛ 222line 216 didn't jump to line 222 because the condition on line 216 was always true

217 # Schema has no aliases, so column_name might be a field name 

218 if column_name in schema_cls.model_fields: 218 ↛ 219line 218 didn't jump to line 219 because the condition on line 218 was never true

219 field_name = column_name 

220 else: 

221 # Check if this field doesn't have an alias 

222 for name, field in schema_cls.model_fields.items(): 

223 if name == column_name and not field.alias: 

224 field_name = name 

225 break 

226 

227 if field_name: 

228 column = getattr(model, field_name, None) 

229 

230 if ( 

231 column is None 

232 or not isinstance(column, InstrumentedAttribute) 

233 or not isinstance(column.property, ColumnProperty) 

234 ): 

235 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

236 yield cast(InstrumentedAttribute[Any], column) 

237 

238 

239def apply_filtering_v2( 

240 query_params: QueryParams, 

241 select_query: Select[Any], 

242 model: type[DeclarativeBase], 

243 schema_cls: SchemaType, 

244) -> Select[Any]: 

245 """ 

246 Apply filtering using direct field names and __suffixes for range/null filtering. 

247 """ 

248 filters: dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]] = defaultdict( 

249 list 

250 ) 

251 all_joins: set[InstrumentedAttribute[Any]] = set() 

252 

253 # Handle different parameter types 

254 standard_filters, standard_joins = _apply_standard_parameters_v2( 

255 query_params, select_query, model, schema_cls 

256 ) 

257 suffix_filters, suffix_joins = _apply_suffix_parameters_v2( 

258 query_params, select_query, model, schema_cls 

259 ) 

260 

261 all_joins.update(standard_joins) 

262 all_joins.update(suffix_joins) 

263 

264 for join in all_joins: 

265 select_query = select_query.join(join) 

266 

267 # Merge all filters 

268 for column, clauses in standard_filters.items(): 

269 filters[column].extend(clauses) 

270 for column, clauses in suffix_filters.items(): 

271 filters[column].extend(clauses) 

272 

273 # Apply all filters 

274 for column, or_clauses in filters.items(): 

275 if len(or_clauses) == 1: 275 ↛ 278line 275 didn't jump to line 278 because the condition on line 275 was always true

276 and_clause = or_clauses[0] 

277 else: 

278 and_clause = sqlalchemy.and_(*or_clauses) 

279 select_query = select_query.where(and_clause) 

280 return select_query 

281 

282 

283def _apply_standard_parameters_v2( 

284 query_params: QueryParams, 

285 select_query: Select[Any], 

286 model: type[DeclarativeBase], 

287 schema_cls: SchemaType, 

288) -> tuple[ 

289 dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]], 

290 set[InstrumentedAttribute[Any]], 

291]: 

292 """Handle standard field parameters (no suffix).""" 

293 filters: dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]] = defaultdict( 

294 list 

295 ) 

296 joins: set[InstrumentedAttribute[Any]] = set() 

297 

298 for key, raw_value in query_params.multi_items(): 

299 if key in ("page", "page_size", "order_by") or "__" in key: 

300 continue 

301 

302 # Standard field parameter (eq operator) 

303 column_name = key 

304 column_joins, column = _get_sqlalchemy_column_v2(model, column_name, schema_cls) 

305 joins.update(column_joins) 

306 

307 parser = functools.partial(_parse_value_v2, schema_cls, column_name) 

308 split_values = raw_value.split(",") 

309 clauses = [_make_where_clause_v2(column, v, "eq", parser) for v in split_values] 

310 if len(clauses) == 1: 

311 or_clause = clauses[0] 

312 else: 

313 or_clause = sqlalchemy.or_(*clauses) 

314 filters[column].append(or_clause) 

315 

316 return filters, joins 

317 

318 

319def _apply_suffix_parameters_v2( 

320 query_params: QueryParams, 

321 select_query: Select[Any], 

322 model: type[DeclarativeBase], 

323 schema_cls: SchemaType, 

324) -> tuple[ 

325 dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]], 

326 set[InstrumentedAttribute[Any]], 

327]: 

328 """Handle parameters with __suffixes (gte, lte, gt, lt, isnull, contains, etc.).""" 

329 filters: dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]] = defaultdict( 

330 list 

331 ) 

332 joins: set[InstrumentedAttribute[Any]] = set() 

333 

334 for key, raw_value in query_params.multi_items(): 

335 if key in ("page", "page_size", "order_by") or "__" not in key: 

336 continue 

337 

338 # Parse suffixes 

339 column_name, op = key.split("__", 1) 

340 column_joins, column = _get_sqlalchemy_column_v2(model, column_name, schema_cls) 

341 joins.update(column_joins) 

342 

343 parser = functools.partial(_parse_value_v2, schema_cls, column_name) 

344 

345 if op == "isnull": 

346 try: 

347 value = pydantic.TypeAdapter(bool).validate_python(raw_value) 

348 except pydantic.ValidationError as exc: 

349 raise HTTPException( 

350 400, f"Invalid attribute in URL query: {key}" 

351 ) from exc 

352 clause = column.is_(None) if value else column.isnot(None) 

353 filters[column].append(clause) 

354 continue 

355 

356 # For contains, split by whitespace; for other operators, split by comma 

357 if op == "contains": 

358 split_values = [v for v in raw_value.split() if v.strip()] 

359 else: 

360 split_values = raw_value.split(",") 

361 clauses = [_make_where_clause_v2(column, v, op, parser) for v in split_values] 

362 if not clauses: 

363 continue 

364 if len(clauses) == 1: 

365 or_clause = clauses[0] 

366 else: 

367 or_clause = sqlalchemy.or_(*clauses) 

368 filters[column].append(or_clause) 

369 

370 return filters, joins 

371 

372 

373def _parse_value_v2(schema_cls: SchemaType, column_name: str, value: str) -> Any: 

374 if "." in column_name: 

375 relation, _, column_part = column_name.partition(".") 

376 field = schema_cls.model_fields.get(relation) 

377 schema = _get_nested_schema_v2(field) 

378 if not schema: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true

379 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

380 return _parse_value_v2(schema, column_part, value) 

381 

382 # Check if populate_by_name is enabled 

383 config = getattr(schema_cls, "model_config", pydantic.ConfigDict()) 

384 populate_by_name = config.get("populate_by_name", False) 

385 

386 # Try to find the field by alias first 

387 field_name = None 

388 for name, field in schema_cls.model_fields.items(): 

389 if field.alias == column_name: 

390 field_name = name 

391 break 

392 

393 # If not found by alias and populate_by_name is True, try field name 

394 if field_name is None and populate_by_name: 

395 if column_name in schema_cls.model_fields: 395 ↛ 399line 395 didn't jump to line 399 because the condition on line 395 was always true

396 field_name = column_name 

397 

398 # If still not found and populate_by_name is False, try field name for schemas without aliases 

399 if field_name is None and not populate_by_name: 

400 # Check if this schema has any aliases 

401 has_aliases = any(field.alias for field in schema_cls.model_fields.values()) 

402 if not has_aliases and column_name in schema_cls.model_fields: 

403 field_name = column_name 

404 # Also allow fields that don't have aliases (like 'age' in our test) 

405 elif has_aliases: 

406 # Check if this field doesn't have an alias 

407 for name, field in schema_cls.model_fields.items(): 

408 if name == column_name and not field.alias: 

409 field_name = name 

410 break 

411 

412 # If still not found, raise error 

413 if field_name is None: 

414 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

415 

416 try: 

417 obj = schema_cls.__pydantic_validator__.validate_assignment( 

418 schema_cls.model_construct(), field_name, value 

419 ) 

420 return getattr(obj, field_name) 

421 except Exception: 

422 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

423 

424 

425def _get_nested_schema_v2(field: FieldInfo | None) -> SchemaType | None: 

426 if field is None: 426 ↛ 427line 426 didn't jump to line 427 because the condition on line 426 was never true

427 return None 

428 annotation = _unwrap_optional_annotation(field.annotation) 

429 if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel): 

430 return annotation 

431 return None 

432 

433 

434def _make_where_clause_v2( 

435 column: InstrumentedAttribute[Any], filter_value: str, op: str, parser: Callable 

436) -> ColumnElement[Any]: 

437 if op == "gte": 

438 value = parser(filter_value) 

439 return column >= value 

440 elif op == "lte": 

441 value = parser(filter_value) 

442 return column <= value 

443 elif op == "gt": 

444 value = parser(filter_value) 

445 return column > value 

446 elif op == "lt": 

447 value = parser(filter_value) 

448 return column < value 

449 elif op == "ne": 

450 value = parser(filter_value) 

451 return column != value 

452 elif op == "contains": 

453 # For contains, we don't need to parse the value since it's just a string 

454 escaped = _escape_like_value(filter_value) 

455 return column.ilike(f"%{escaped}%", escape="\\") 

456 else: # eq 

457 value = parser(filter_value) 

458 return column == value