Coverage for fastapi_restly / query / _v2.py: 92%
252 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
1import functools
2from collections import defaultdict
3from typing import Any, Callable, Iterator, Optional, cast
5import pydantic
6import sqlalchemy
7from fastapi import HTTPException
8from pydantic.fields import FieldInfo
9from sqlalchemy import ColumnElement, Select
10from sqlalchemy.orm import DeclarativeBase
11from sqlalchemy.orm.attributes import InstrumentedAttribute
12from sqlalchemy.orm.properties import ColumnProperty
13from starlette.datastructures import QueryParams
15from ._shared import _escape_like_value, _unwrap_optional_annotation
17SchemaType = type[pydantic.BaseModel]
20def _is_string_field_v2(field: FieldInfo) -> bool:
21 """Check if a field is a string type."""
22 annotation = _unwrap_optional_annotation(field.annotation)
23 return annotation is str
26def create_query_param_schema_v2(schema_cls: SchemaType) -> SchemaType:
27 """
28 Create a pydantic model class that describes and validates all possible query parameters
29 for the v2 interface (direct field names, __gte, __lte, etc.).
30 """
31 fields = {
32 "page": (Optional[int], None),
33 "page_size": (Optional[int], None),
34 "order_by": (Optional[str], None),
35 }
36 for name, field in _iter_fields_including_nested_v2(schema_cls):
37 field_type = _get_field_type_for_schema(field)
38 fields[name] = (Optional[field_type], None)
39 # Add range filters
40 for suffix in ["__gte", "__lte", "__gt", "__lt"]:
41 fields[f"{name}{suffix}"] = (Optional[field_type], None)
42 fields[f"{name}__ne"] = (Optional[field_type], None)
43 fields[f"{name}__isnull"] = (Optional[bool], None)
45 # Add contains filter for string fields
46 if _is_string_field_v2(field):
47 fields[f"{name}__contains"] = (Optional[str], None)
49 schema_name = "QueryParamV2" + schema_cls.__name__
50 query_param_schema = pydantic.create_model(schema_name, **fields) # type: ignore
51 return query_param_schema
54def apply_query_modifiers_v2(
55 query_params: QueryParams,
56 select_query: Select[Any],
57 model: type[DeclarativeBase],
58 schema_cls: SchemaType,
59) -> Select[Any]:
60 """
61 Apply pagination, sorting, and filtering through URL query parameters on a SQL query.
63 Uses a more standard interface::
65 # Pagination
66 page=2&page_size=50
68 # Sorting
69 order_by=name,-created_at
71 # Filtering
72 name=Bob&status=active&created_at__gte=2024-01-01
74 # Contains (string fields)
75 name__contains=john&email__contains=example
76 """
77 select_query = apply_filtering_v2(query_params, select_query, model, schema_cls)
78 select_query = apply_sorting_v2(query_params, select_query, model, schema_cls)
79 select_query = apply_pagination_v2(query_params, select_query)
80 return select_query
83def apply_pagination_v2(
84 query_params: QueryParams, select_query: Select[Any]
85) -> Select[Any]:
86 """
87 Apply pagination using page and page_size parameters.
88 """
89 page = _get_int_v2(query_params, "page") or 1
90 if page < 1: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 raise HTTPException(400, "Invalid value for URL query parameter page: must be >= 1")
92 page_size = _get_int_v2(query_params, "page_size") or 100
93 if page_size <= 0: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 raise HTTPException(400, "Invalid value for URL query parameter page_size: must be > 0")
95 offset = (page - 1) * page_size
96 select_query = select_query.limit(page_size).offset(offset)
97 return select_query
100def _get_field_type_for_schema(field: FieldInfo) -> type:
101 annotation = _unwrap_optional_annotation(field.annotation)
102 if annotation is Any:
103 return Any
104 if isinstance(annotation, type): 104 ↛ 106line 104 didn't jump to line 106 because the condition on line 104 was always true
105 return annotation
106 return object
109def _iter_fields_including_nested_v2(
110 schema_cls: SchemaType, prefix: str = ""
111) -> Iterator[tuple[str, FieldInfo]]:
112 for name, field in schema_cls.model_fields.items():
113 # Use alias if available, otherwise use field name
114 field_name = field.alias or name
115 full_name = f"{prefix}.{field_name}" if prefix else field_name
116 nested = _get_nested_schema_v2(field)
117 if nested:
118 yield from _iter_fields_including_nested_v2(nested, full_name)
119 else:
120 yield full_name, field
123def _get_int_v2(query_params: QueryParams, param_name: str) -> Optional[int]:
124 value = query_params.get(param_name)
125 if not value:
126 return None
128 try:
129 return int(value)
130 except ValueError:
131 raise HTTPException(
132 400,
133 f"Invalid value for URL query parameter {param_name}: {value} is not an integer",
134 )
137def apply_sorting_v2(
138 query_params: QueryParams,
139 select_query: Select[Any],
140 model: type[DeclarativeBase],
141 schema_cls: SchemaType | None = None,
142) -> Select[Any]:
143 """
144 Apply sorting using the order_by parameter (comma-separated, - for descending).
145 """
146 sort_string = query_params.get("order_by")
147 if not sort_string:
148 id_column = getattr(model, "id", None)
149 if id_column: 149 ↛ 152line 149 didn't jump to line 152 because the condition on line 149 was always true
150 return select_query.order_by(id_column)
151 else:
152 return select_query
154 for column_name in sort_string.split(","):
155 order = sqlalchemy.asc
156 if column_name.startswith("-"):
157 order = sqlalchemy.desc
158 column_name = column_name[1:]
159 joins, column = _get_sqlalchemy_column_v2(model, column_name, schema_cls)
160 for join in joins:
161 select_query = select_query.join(join)
162 select_query = select_query.order_by(order(column))
163 return select_query
166def _get_sqlalchemy_column_v2(
167 model: type[DeclarativeBase], column_path: str, schema_cls: SchemaType | None = None
168) -> tuple[list[InstrumentedAttribute[Any]], InstrumentedAttribute[Any]]:
169 *models, column = _resolve_sqlalchemy_column_v2(model, column_path, schema_cls)
170 return cast(list[InstrumentedAttribute[Any]], models), cast(
171 InstrumentedAttribute[Any], column
172 )
175def _resolve_sqlalchemy_column_v2(
176 model: type[DeclarativeBase], column_name: str, schema_cls: SchemaType | None = None
177) -> Iterator[type[DeclarativeBase] | InstrumentedAttribute[Any]]:
178 if "." in column_name:
179 relation, _, column_part = column_name.partition(".")
180 rel = getattr(model, relation, None)
181 if not isinstance(rel, InstrumentedAttribute) or not hasattr( 181 ↛ 184line 181 didn't jump to line 184 because the condition on line 181 was never true
182 rel.property, "mapper"
183 ):
184 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
185 related_model = rel.property.mapper.class_
186 yield rel
187 yield from _resolve_sqlalchemy_column_v2(related_model, column_part, schema_cls)
188 else:
189 # Try to find the column directly
190 column = getattr(model, column_name, None)
191 if (
192 column is not None
193 and isinstance(column, InstrumentedAttribute)
194 and isinstance(column.property, ColumnProperty)
195 ):
196 yield cast(InstrumentedAttribute[Any], column)
197 return
199 # If not found and we have a schema, try to resolve alias to field name
200 if schema_cls:
201 field_name = None
203 # Look for field with this alias
204 for name, field in schema_cls.model_fields.items():
205 if field.alias == column_name:
206 field_name = name
207 break
209 # If not found by alias and populate_by_name is True, try field name
210 if field_name is None:
211 config = getattr(schema_cls, "model_config", pydantic.ConfigDict())
212 populate_by_name = config.get("populate_by_name", False)
213 if populate_by_name and column_name in schema_cls.model_fields: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 field_name = column_name
215 # Also allow fields that don't have aliases
216 elif not any(f.alias for f in schema_cls.model_fields.values()): 216 ↛ 222line 216 didn't jump to line 222 because the condition on line 216 was always true
217 # Schema has no aliases, so column_name might be a field name
218 if column_name in schema_cls.model_fields: 218 ↛ 219line 218 didn't jump to line 219 because the condition on line 218 was never true
219 field_name = column_name
220 else:
221 # Check if this field doesn't have an alias
222 for name, field in schema_cls.model_fields.items():
223 if name == column_name and not field.alias:
224 field_name = name
225 break
227 if field_name:
228 column = getattr(model, field_name, None)
230 if (
231 column is None
232 or not isinstance(column, InstrumentedAttribute)
233 or not isinstance(column.property, ColumnProperty)
234 ):
235 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
236 yield cast(InstrumentedAttribute[Any], column)
239def apply_filtering_v2(
240 query_params: QueryParams,
241 select_query: Select[Any],
242 model: type[DeclarativeBase],
243 schema_cls: SchemaType,
244) -> Select[Any]:
245 """
246 Apply filtering using direct field names and __suffixes for range/null filtering.
247 """
248 filters: dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]] = defaultdict(
249 list
250 )
251 all_joins: set[InstrumentedAttribute[Any]] = set()
253 # Handle different parameter types
254 standard_filters, standard_joins = _apply_standard_parameters_v2(
255 query_params, select_query, model, schema_cls
256 )
257 suffix_filters, suffix_joins = _apply_suffix_parameters_v2(
258 query_params, select_query, model, schema_cls
259 )
261 all_joins.update(standard_joins)
262 all_joins.update(suffix_joins)
264 for join in all_joins:
265 select_query = select_query.join(join)
267 # Merge all filters
268 for column, clauses in standard_filters.items():
269 filters[column].extend(clauses)
270 for column, clauses in suffix_filters.items():
271 filters[column].extend(clauses)
273 # Apply all filters
274 for column, or_clauses in filters.items():
275 if len(or_clauses) == 1: 275 ↛ 278line 275 didn't jump to line 278 because the condition on line 275 was always true
276 and_clause = or_clauses[0]
277 else:
278 and_clause = sqlalchemy.and_(*or_clauses)
279 select_query = select_query.where(and_clause)
280 return select_query
283def _apply_standard_parameters_v2(
284 query_params: QueryParams,
285 select_query: Select[Any],
286 model: type[DeclarativeBase],
287 schema_cls: SchemaType,
288) -> tuple[
289 dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]],
290 set[InstrumentedAttribute[Any]],
291]:
292 """Handle standard field parameters (no suffix)."""
293 filters: dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]] = defaultdict(
294 list
295 )
296 joins: set[InstrumentedAttribute[Any]] = set()
298 for key, raw_value in query_params.multi_items():
299 if key in ("page", "page_size", "order_by") or "__" in key:
300 continue
302 # Standard field parameter (eq operator)
303 column_name = key
304 column_joins, column = _get_sqlalchemy_column_v2(model, column_name, schema_cls)
305 joins.update(column_joins)
307 parser = functools.partial(_parse_value_v2, schema_cls, column_name)
308 split_values = raw_value.split(",")
309 clauses = [_make_where_clause_v2(column, v, "eq", parser) for v in split_values]
310 if len(clauses) == 1:
311 or_clause = clauses[0]
312 else:
313 or_clause = sqlalchemy.or_(*clauses)
314 filters[column].append(or_clause)
316 return filters, joins
319def _apply_suffix_parameters_v2(
320 query_params: QueryParams,
321 select_query: Select[Any],
322 model: type[DeclarativeBase],
323 schema_cls: SchemaType,
324) -> tuple[
325 dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]],
326 set[InstrumentedAttribute[Any]],
327]:
328 """Handle parameters with __suffixes (gte, lte, gt, lt, isnull, contains, etc.)."""
329 filters: dict[InstrumentedAttribute[Any], list[ColumnElement[Any]]] = defaultdict(
330 list
331 )
332 joins: set[InstrumentedAttribute[Any]] = set()
334 for key, raw_value in query_params.multi_items():
335 if key in ("page", "page_size", "order_by") or "__" not in key:
336 continue
338 # Parse suffixes
339 column_name, op = key.split("__", 1)
340 column_joins, column = _get_sqlalchemy_column_v2(model, column_name, schema_cls)
341 joins.update(column_joins)
343 parser = functools.partial(_parse_value_v2, schema_cls, column_name)
345 if op == "isnull":
346 try:
347 value = pydantic.TypeAdapter(bool).validate_python(raw_value)
348 except pydantic.ValidationError as exc:
349 raise HTTPException(
350 400, f"Invalid attribute in URL query: {key}"
351 ) from exc
352 clause = column.is_(None) if value else column.isnot(None)
353 filters[column].append(clause)
354 continue
356 # For contains, split by whitespace; for other operators, split by comma
357 if op == "contains":
358 split_values = [v for v in raw_value.split() if v.strip()]
359 else:
360 split_values = raw_value.split(",")
361 clauses = [_make_where_clause_v2(column, v, op, parser) for v in split_values]
362 if not clauses:
363 continue
364 if len(clauses) == 1:
365 or_clause = clauses[0]
366 else:
367 or_clause = sqlalchemy.or_(*clauses)
368 filters[column].append(or_clause)
370 return filters, joins
373def _parse_value_v2(schema_cls: SchemaType, column_name: str, value: str) -> Any:
374 if "." in column_name:
375 relation, _, column_part = column_name.partition(".")
376 field = schema_cls.model_fields.get(relation)
377 schema = _get_nested_schema_v2(field)
378 if not schema: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true
379 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
380 return _parse_value_v2(schema, column_part, value)
382 # Check if populate_by_name is enabled
383 config = getattr(schema_cls, "model_config", pydantic.ConfigDict())
384 populate_by_name = config.get("populate_by_name", False)
386 # Try to find the field by alias first
387 field_name = None
388 for name, field in schema_cls.model_fields.items():
389 if field.alias == column_name:
390 field_name = name
391 break
393 # If not found by alias and populate_by_name is True, try field name
394 if field_name is None and populate_by_name:
395 if column_name in schema_cls.model_fields: 395 ↛ 399line 395 didn't jump to line 399 because the condition on line 395 was always true
396 field_name = column_name
398 # If still not found and populate_by_name is False, try field name for schemas without aliases
399 if field_name is None and not populate_by_name:
400 # Check if this schema has any aliases
401 has_aliases = any(field.alias for field in schema_cls.model_fields.values())
402 if not has_aliases and column_name in schema_cls.model_fields:
403 field_name = column_name
404 # Also allow fields that don't have aliases (like 'age' in our test)
405 elif has_aliases:
406 # Check if this field doesn't have an alias
407 for name, field in schema_cls.model_fields.items():
408 if name == column_name and not field.alias:
409 field_name = name
410 break
412 # If still not found, raise error
413 if field_name is None:
414 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
416 try:
417 obj = schema_cls.__pydantic_validator__.validate_assignment(
418 schema_cls.model_construct(), field_name, value
419 )
420 return getattr(obj, field_name)
421 except Exception:
422 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
425def _get_nested_schema_v2(field: FieldInfo | None) -> SchemaType | None:
426 if field is None: 426 ↛ 427line 426 didn't jump to line 427 because the condition on line 426 was never true
427 return None
428 annotation = _unwrap_optional_annotation(field.annotation)
429 if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel):
430 return annotation
431 return None
434def _make_where_clause_v2(
435 column: InstrumentedAttribute[Any], filter_value: str, op: str, parser: Callable
436) -> ColumnElement[Any]:
437 if op == "gte":
438 value = parser(filter_value)
439 return column >= value
440 elif op == "lte":
441 value = parser(filter_value)
442 return column <= value
443 elif op == "gt":
444 value = parser(filter_value)
445 return column > value
446 elif op == "lt":
447 value = parser(filter_value)
448 return column < value
449 elif op == "ne":
450 value = parser(filter_value)
451 return column != value
452 elif op == "contains":
453 # For contains, we don't need to parse the value since it's just a string
454 escaped = _escape_like_value(filter_value)
455 return column.ilike(f"%{escaped}%", escape="\\")
456 else: # eq
457 value = parser(filter_value)
458 return column == value