Coverage for fastapi_restly / query / _v1.py: 95%
189 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
1import functools
2from collections import defaultdict
3from typing import Any, Callable, Iterator, Optional, cast
5import pydantic
6import sqlalchemy
7from fastapi import HTTPException
8from pydantic.fields import FieldInfo
9from sqlalchemy import ColumnElement, Select
10from sqlalchemy.orm import DeclarativeBase
11from sqlalchemy.orm.attributes import InstrumentedAttribute
12from sqlalchemy.orm.properties import ColumnProperty
13from starlette.datastructures import QueryParams
15from ._shared import _escape_like_value, _unwrap_optional_annotation
17SchemaType = type[pydantic.BaseModel]
20def create_query_param_schema(schema_cls: SchemaType) -> SchemaType:
21 """
22 Create a pydantic model class that describes and validates all possible query parameters.
23 """
24 fields = {
25 "limit": (int | None, None),
26 "offset": (int | None, None),
27 "sort": (str | None, None),
28 }
29 for name, field in _iter_fields_including_nested(schema_cls):
30 filter = f"filter[{name}]"
31 fields[filter] = (Optional[field.annotation], None)
33 # Add contains parameter for string fields
34 if _is_string_field(field):
35 contains = f"contains[{name}]"
36 fields[contains] = (Optional[str], None)
38 # TODO: Implement matching as OR-filters
39 # match = f"match[{name}]"
40 # fields[match] = (Optional[field.annotation], None)
42 schema_name = "QueryParam" + schema_cls.__name__
43 query_param_schema = pydantic.create_model(schema_name, **fields)
44 return query_param_schema
47def apply_query_modifiers(
48 query_params: QueryParams,
49 select_query: Select,
50 model: type[DeclarativeBase],
51 schema_cls: SchemaType,
52) -> Select:
53 """
54 Apply pagination, sorting, and filtering through URL query parameters on a SQL query.
56 Roughly follows JSONAPI query parameter families:
57 https://jsonapi.org/format/#query-parameters-families
59 Common examples::
61 # Pagination
62 limit=100&offset=200
64 # Sorting
65 sort=field1,-field2
67 # Equality filter
68 filter[foo_id]=1&filter[name]=Bob
70 # OR-values
71 filter[id]=1,2,3
73 # Range filters
74 filter[created_at]=>=2024-01-01&filter[created_at]=<2025-01-01
76 # NULL checks
77 filter[foo_id]=!null
79 # Case-insensitive contains
80 contains[name]=john&contains[email]=example
81 """
82 select_query = apply_pagination(query_params, select_query)
83 select_query = apply_sorting(query_params, select_query, model)
84 select_query = apply_filtering(query_params, select_query, model, schema_cls)
85 return select_query
88def _iter_fields_including_nested(
89 schema_cls: SchemaType, prefix: str = ""
90) -> Iterator[tuple[str, FieldInfo]]:
91 for name, field in schema_cls.model_fields.items():
92 full_name = f"{prefix}.{name}" if prefix else name
93 nested = _get_nested_schema(field)
94 if nested:
95 yield from _iter_fields_including_nested(nested, full_name)
96 else:
97 yield full_name, field
100def _is_string_field(field: FieldInfo) -> bool:
101 """Check if a field is a string type."""
102 annotation = _unwrap_optional_annotation(field.annotation)
103 return annotation is str
106def apply_pagination(query_params: QueryParams, select_query: Select) -> Select:
107 limit = _get_int(query_params, "limit")
108 if limit is not None:
109 if limit < 0: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 raise HTTPException(400, "Invalid value for URL query parameter limit: must be non-negative")
111 select_query = select_query.limit(limit)
112 offset = _get_int(query_params, "offset")
113 if offset is not None:
114 if offset < 0: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true
115 raise HTTPException(400, "Invalid value for URL query parameter offset: must be non-negative")
116 select_query = select_query.offset(offset)
117 return select_query
120def _get_int(query_params: QueryParams, param_name: str) -> int | None:
121 value = query_params.get(param_name)
122 if not value:
123 return None
124 try:
125 return int(value)
126 except ValueError:
127 raise HTTPException(
128 400,
129 f"Invalid value for URL query parameter {param_name}: {value} is not an integer",
130 )
133def apply_sorting(
134 query_params: QueryParams, select_query: Select, model: type[DeclarativeBase]
135) -> Select:
136 sort_string = query_params.get("sort")
137 if not sort_string:
138 # Try to apply a default ordering
139 id_column = getattr(model, "id", None)
140 # TODO: Maybe check if this is a UUID and dont sort in that case?
141 if id_column: 141 ↛ 144line 141 didn't jump to line 144 because the condition on line 141 was always true
142 return select_query.order_by(id_column)
143 else:
144 return select_query # Unordered
146 for column_name in sort_string.split(","):
147 order = sqlalchemy.asc
148 if column_name.startswith("-"):
149 order = sqlalchemy.desc
150 column_name = column_name[1:]
151 joins, column = _get_sqlalchemy_column(model, column_name)
152 for join in joins:
153 select_query = select_query.join(join)
154 select_query = select_query.order_by(order(column))
155 return select_query
158def _get_sqlalchemy_column(
159 model: type[DeclarativeBase], column_path: str
160) -> tuple[list[InstrumentedAttribute], InstrumentedAttribute]:
161 """Get SQLAlchemy column and joins needed for a column path."""
162 joins, column = _resolve_sqlalchemy_column(model, column_path)
163 return joins, column
166def _resolve_sqlalchemy_column(
167 model: type[DeclarativeBase], column_name: str
168) -> tuple[list[InstrumentedAttribute], InstrumentedAttribute]:
169 """
170 Recursively resolve a dot-separated column path to its SQLAlchemy column.
172 Returns a tuple of (joins, column) where joins is a list of relationship attributes
173 that need to be joined, and column is the final InstrumentedAttribute representing the column.
175 Example:
176 For column_name="upload.created_by.email", returns:
177 - joins: [upload.created_by] (relationship attribute)
178 - column: CreatedBy.email (InstrumentedAttribute)
179 """
180 joins = []
181 current_model = model
183 while "." in column_name:
184 relation, _, column_part = column_name.partition(".")
185 rel = getattr(current_model, relation, None)
186 if not isinstance(rel, InstrumentedAttribute) or not hasattr( 186 ↛ 190line 186 didn't jump to line 190 because the condition on line 186 was never true
187 rel.property, "mapper"
188 ):
189 # Fail if it is not a relation
190 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
192 # Add the relationship attribute to joins
193 joins.append(rel)
195 # Move to the related model
196 related_model = rel.property.mapper.class_
197 current_model = related_model
198 column_name = column_part
200 # Get the final column
201 column = getattr(current_model, column_name, None)
202 if (
203 column is None
204 or not isinstance(column, InstrumentedAttribute)
205 or not isinstance(column.property, ColumnProperty)
206 ):
207 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
209 return joins, cast(InstrumentedAttribute, column)
212def apply_filtering(
213 query_params: QueryParams,
214 select_query: Select,
215 model: type[DeclarativeBase],
216 schema_cls: SchemaType,
217) -> Select:
218 """
219 Apply filtering through URL query parameters on a SQL query.
220 """
221 filters: dict[InstrumentedAttribute, list[ColumnElement]] = defaultdict(list)
222 all_joins: set[InstrumentedAttribute] = set()
224 # Handle filter parameters
225 filter_filters, filter_joins = _apply_filter_parameters(
226 query_params, select_query, model, schema_cls
227 )
229 # Handle contains parameters
230 contains_filters, contains_joins = _apply_contains_parameters(
231 query_params, select_query, model, schema_cls
232 )
234 # Collect all joins
235 all_joins.update(filter_joins)
236 all_joins.update(contains_joins)
238 # Apply all joins
239 for join in all_joins:
240 select_query = select_query.join(join)
242 # Merge all filters
243 for column, clauses in filter_filters.items():
244 filters[column].extend(clauses)
245 for column, clauses in contains_filters.items():
246 filters[column].extend(clauses)
248 # Apply all filters
249 for column, or_clauses in filters.items():
250 if len(or_clauses) == 1: 250 ↛ 253line 250 didn't jump to line 253 because the condition on line 250 was always true
251 and_clause = or_clauses[0]
252 else:
253 and_clause = sqlalchemy.and_(*or_clauses) # CNF > DNF
254 select_query = select_query.where(and_clause)
256 return select_query
259def _apply_filter_parameters(
260 query_params: QueryParams,
261 select_query: Select,
262 model: type[DeclarativeBase],
263 schema_cls: SchemaType,
264) -> tuple[
265 dict[InstrumentedAttribute, list[ColumnElement]], set[InstrumentedAttribute]
266]:
267 """Handle filter[field] parameters."""
268 filters: dict[InstrumentedAttribute, list[ColumnElement]] = defaultdict(list)
269 joins: set[InstrumentedAttribute] = set()
271 for key, raw_value in query_params.multi_items():
272 if not (key.startswith("filter[") and key.endswith("]")):
273 continue
274 column_name = key[7:-1]
275 column_joins, column = _get_sqlalchemy_column(model, column_name)
276 joins.update(column_joins)
278 # Create a parser/validator for the filter values. Which is user input after all.
279 parser = functools.partial(_parse_value, schema_cls, column_name)
280 split_values = raw_value.split(",") # Filtering on empty strings is also OK
281 clauses = [_make_where_clause(column, v, parser) for v in split_values]
282 if len(clauses) == 1:
283 or_clause = clauses[0]
284 else:
285 or_clause = sqlalchemy.or_(*clauses)
286 filters[column].append(or_clause)
288 return filters, joins
291def _apply_contains_parameters(
292 query_params: QueryParams,
293 select_query: Select,
294 model: type[DeclarativeBase],
295 schema_cls: SchemaType,
296) -> tuple[
297 dict[InstrumentedAttribute, list[ColumnElement]], set[InstrumentedAttribute]
298]:
299 """Handle contains[field] parameters for string fields."""
300 filters: dict[InstrumentedAttribute, list[ColumnElement]] = defaultdict(list)
301 joins: set[InstrumentedAttribute] = set()
303 for key, raw_value in query_params.multi_items():
304 if not (key.startswith("contains[") and key.endswith("]")):
305 continue
306 column_name = key[9:-1] # Extract field name from contains[field]
307 column_joins, column = _get_sqlalchemy_column(model, column_name)
308 joins.update(column_joins)
310 # For contains, we don't need to parse the value since it's just a string
311 # Split by space for multiple contains values (AND logic)
312 split_values = raw_value.split()
313 clauses = [
314 column.ilike(f"%{_escape_like_value(v)}%", escape="\\")
315 for v in split_values
316 if v.strip()
317 ]
318 if clauses:
319 if len(clauses) == 1:
320 clause = clauses[0]
321 else:
322 clause = sqlalchemy.and_(*clauses)
323 filters[column].append(clause)
325 return filters, joins
328def _parse_value(schema_cls: SchemaType, column_name: str, value: str) -> Any:
329 """Parse and validate a value on which will be filtered."""
331 # Support nested fields, e.g. "blog.user.name"
332 if "." in column_name:
333 relation, _, column_part = column_name.partition(".")
334 field = schema_cls.model_fields.get(relation)
335 schema = _get_nested_schema(field)
336 if not schema: 336 ↛ 337line 336 didn't jump to line 337 because the condition on line 336 was never true
337 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}")
338 return _parse_value(schema, column_part, value)
340 # Hacky stuff to validate (i.e. parse) a single field.
341 # https://github.com/pydantic/pydantic/discussions/7367
342 obj = schema_cls.__pydantic_validator__.validate_assignment(
343 schema_cls.model_construct(), column_name, value
344 )
345 return getattr(obj, column_name)
348def _get_nested_schema(field: FieldInfo | None) -> SchemaType | None:
349 if field is None: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true
350 return None
352 annotation = _unwrap_optional_annotation(field.annotation)
354 if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel):
355 return annotation
357 return None
360def _make_where_clause(
361 column: InstrumentedAttribute, filter_value: str, parser: Callable
362) -> ColumnElement:
363 if filter_value.startswith(">="):
364 value = parser(filter_value[2:])
365 return column >= value
366 elif filter_value.startswith("<="):
367 value = parser(filter_value[2:])
368 return column <= value
369 elif filter_value.startswith(">"):
370 value = parser(filter_value[1:])
371 return column > value
372 elif filter_value.startswith("<"):
373 value = parser(filter_value[1:])
374 return column < value
375 elif filter_value.startswith("!"):
376 value = filter_value[1:]
377 if value == "null":
378 return column.isnot(None)
379 value = parser(value)
380 return column != value
381 else:
382 if filter_value == "null":
383 return column.is_(None)
384 value = parser(filter_value)
385 return column == value