Coverage for fastapi_restly / query / _v1.py: 95%

189 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-02 09:54 +0000

1import functools 

2from collections import defaultdict 

3from typing import Any, Callable, Iterator, Optional, cast 

4 

5import pydantic 

6import sqlalchemy 

7from fastapi import HTTPException 

8from pydantic.fields import FieldInfo 

9from sqlalchemy import ColumnElement, Select 

10from sqlalchemy.orm import DeclarativeBase 

11from sqlalchemy.orm.attributes import InstrumentedAttribute 

12from sqlalchemy.orm.properties import ColumnProperty 

13from starlette.datastructures import QueryParams 

14 

15from ._shared import _escape_like_value, _unwrap_optional_annotation 

16 

17SchemaType = type[pydantic.BaseModel] 

18 

19 

20def create_query_param_schema(schema_cls: SchemaType) -> SchemaType: 

21 """ 

22 Create a pydantic model class that describes and validates all possible query parameters. 

23 """ 

24 fields = { 

25 "limit": (int | None, None), 

26 "offset": (int | None, None), 

27 "sort": (str | None, None), 

28 } 

29 for name, field in _iter_fields_including_nested(schema_cls): 

30 filter = f"filter[{name}]" 

31 fields[filter] = (Optional[field.annotation], None) 

32 

33 # Add contains parameter for string fields 

34 if _is_string_field(field): 

35 contains = f"contains[{name}]" 

36 fields[contains] = (Optional[str], None) 

37 

38 # TODO: Implement matching as OR-filters 

39 # match = f"match[{name}]" 

40 # fields[match] = (Optional[field.annotation], None) 

41 

42 schema_name = "QueryParam" + schema_cls.__name__ 

43 query_param_schema = pydantic.create_model(schema_name, **fields) 

44 return query_param_schema 

45 

46 

47def apply_query_modifiers( 

48 query_params: QueryParams, 

49 select_query: Select, 

50 model: type[DeclarativeBase], 

51 schema_cls: SchemaType, 

52) -> Select: 

53 """ 

54 Apply pagination, sorting, and filtering through URL query parameters on a SQL query. 

55 

56 Roughly follows JSONAPI query parameter families: 

57 https://jsonapi.org/format/#query-parameters-families 

58 

59 Common examples:: 

60 

61 # Pagination 

62 limit=100&offset=200 

63 

64 # Sorting 

65 sort=field1,-field2 

66 

67 # Equality filter 

68 filter[foo_id]=1&filter[name]=Bob 

69 

70 # OR-values 

71 filter[id]=1,2,3 

72 

73 # Range filters 

74 filter[created_at]=>=2024-01-01&filter[created_at]=<2025-01-01 

75 

76 # NULL checks 

77 filter[foo_id]=!null 

78 

79 # Case-insensitive contains 

80 contains[name]=john&contains[email]=example 

81 """ 

82 select_query = apply_pagination(query_params, select_query) 

83 select_query = apply_sorting(query_params, select_query, model) 

84 select_query = apply_filtering(query_params, select_query, model, schema_cls) 

85 return select_query 

86 

87 

88def _iter_fields_including_nested( 

89 schema_cls: SchemaType, prefix: str = "" 

90) -> Iterator[tuple[str, FieldInfo]]: 

91 for name, field in schema_cls.model_fields.items(): 

92 full_name = f"{prefix}.{name}" if prefix else name 

93 nested = _get_nested_schema(field) 

94 if nested: 

95 yield from _iter_fields_including_nested(nested, full_name) 

96 else: 

97 yield full_name, field 

98 

99 

100def _is_string_field(field: FieldInfo) -> bool: 

101 """Check if a field is a string type.""" 

102 annotation = _unwrap_optional_annotation(field.annotation) 

103 return annotation is str 

104 

105 

106def apply_pagination(query_params: QueryParams, select_query: Select) -> Select: 

107 limit = _get_int(query_params, "limit") 

108 if limit is not None: 

109 if limit < 0: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 raise HTTPException(400, "Invalid value for URL query parameter limit: must be non-negative") 

111 select_query = select_query.limit(limit) 

112 offset = _get_int(query_params, "offset") 

113 if offset is not None: 

114 if offset < 0: 114 ↛ 115line 114 didn't jump to line 115 because the condition on line 114 was never true

115 raise HTTPException(400, "Invalid value for URL query parameter offset: must be non-negative") 

116 select_query = select_query.offset(offset) 

117 return select_query 

118 

119 

120def _get_int(query_params: QueryParams, param_name: str) -> int | None: 

121 value = query_params.get(param_name) 

122 if not value: 

123 return None 

124 try: 

125 return int(value) 

126 except ValueError: 

127 raise HTTPException( 

128 400, 

129 f"Invalid value for URL query parameter {param_name}: {value} is not an integer", 

130 ) 

131 

132 

133def apply_sorting( 

134 query_params: QueryParams, select_query: Select, model: type[DeclarativeBase] 

135) -> Select: 

136 sort_string = query_params.get("sort") 

137 if not sort_string: 

138 # Try to apply a default ordering 

139 id_column = getattr(model, "id", None) 

140 # TODO: Maybe check if this is a UUID and dont sort in that case? 

141 if id_column: 141 ↛ 144line 141 didn't jump to line 144 because the condition on line 141 was always true

142 return select_query.order_by(id_column) 

143 else: 

144 return select_query # Unordered 

145 

146 for column_name in sort_string.split(","): 

147 order = sqlalchemy.asc 

148 if column_name.startswith("-"): 

149 order = sqlalchemy.desc 

150 column_name = column_name[1:] 

151 joins, column = _get_sqlalchemy_column(model, column_name) 

152 for join in joins: 

153 select_query = select_query.join(join) 

154 select_query = select_query.order_by(order(column)) 

155 return select_query 

156 

157 

158def _get_sqlalchemy_column( 

159 model: type[DeclarativeBase], column_path: str 

160) -> tuple[list[InstrumentedAttribute], InstrumentedAttribute]: 

161 """Get SQLAlchemy column and joins needed for a column path.""" 

162 joins, column = _resolve_sqlalchemy_column(model, column_path) 

163 return joins, column 

164 

165 

166def _resolve_sqlalchemy_column( 

167 model: type[DeclarativeBase], column_name: str 

168) -> tuple[list[InstrumentedAttribute], InstrumentedAttribute]: 

169 """ 

170 Recursively resolve a dot-separated column path to its SQLAlchemy column. 

171 

172 Returns a tuple of (joins, column) where joins is a list of relationship attributes 

173 that need to be joined, and column is the final InstrumentedAttribute representing the column. 

174 

175 Example: 

176 For column_name="upload.created_by.email", returns: 

177 - joins: [upload.created_by] (relationship attribute) 

178 - column: CreatedBy.email (InstrumentedAttribute) 

179 """ 

180 joins = [] 

181 current_model = model 

182 

183 while "." in column_name: 

184 relation, _, column_part = column_name.partition(".") 

185 rel = getattr(current_model, relation, None) 

186 if not isinstance(rel, InstrumentedAttribute) or not hasattr( 186 ↛ 190line 186 didn't jump to line 190 because the condition on line 186 was never true

187 rel.property, "mapper" 

188 ): 

189 # Fail if it is not a relation 

190 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

191 

192 # Add the relationship attribute to joins 

193 joins.append(rel) 

194 

195 # Move to the related model 

196 related_model = rel.property.mapper.class_ 

197 current_model = related_model 

198 column_name = column_part 

199 

200 # Get the final column 

201 column = getattr(current_model, column_name, None) 

202 if ( 

203 column is None 

204 or not isinstance(column, InstrumentedAttribute) 

205 or not isinstance(column.property, ColumnProperty) 

206 ): 

207 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

208 

209 return joins, cast(InstrumentedAttribute, column) 

210 

211 

212def apply_filtering( 

213 query_params: QueryParams, 

214 select_query: Select, 

215 model: type[DeclarativeBase], 

216 schema_cls: SchemaType, 

217) -> Select: 

218 """ 

219 Apply filtering through URL query parameters on a SQL query. 

220 """ 

221 filters: dict[InstrumentedAttribute, list[ColumnElement]] = defaultdict(list) 

222 all_joins: set[InstrumentedAttribute] = set() 

223 

224 # Handle filter parameters 

225 filter_filters, filter_joins = _apply_filter_parameters( 

226 query_params, select_query, model, schema_cls 

227 ) 

228 

229 # Handle contains parameters 

230 contains_filters, contains_joins = _apply_contains_parameters( 

231 query_params, select_query, model, schema_cls 

232 ) 

233 

234 # Collect all joins 

235 all_joins.update(filter_joins) 

236 all_joins.update(contains_joins) 

237 

238 # Apply all joins 

239 for join in all_joins: 

240 select_query = select_query.join(join) 

241 

242 # Merge all filters 

243 for column, clauses in filter_filters.items(): 

244 filters[column].extend(clauses) 

245 for column, clauses in contains_filters.items(): 

246 filters[column].extend(clauses) 

247 

248 # Apply all filters 

249 for column, or_clauses in filters.items(): 

250 if len(or_clauses) == 1: 250 ↛ 253line 250 didn't jump to line 253 because the condition on line 250 was always true

251 and_clause = or_clauses[0] 

252 else: 

253 and_clause = sqlalchemy.and_(*or_clauses) # CNF > DNF 

254 select_query = select_query.where(and_clause) 

255 

256 return select_query 

257 

258 

259def _apply_filter_parameters( 

260 query_params: QueryParams, 

261 select_query: Select, 

262 model: type[DeclarativeBase], 

263 schema_cls: SchemaType, 

264) -> tuple[ 

265 dict[InstrumentedAttribute, list[ColumnElement]], set[InstrumentedAttribute] 

266]: 

267 """Handle filter[field] parameters.""" 

268 filters: dict[InstrumentedAttribute, list[ColumnElement]] = defaultdict(list) 

269 joins: set[InstrumentedAttribute] = set() 

270 

271 for key, raw_value in query_params.multi_items(): 

272 if not (key.startswith("filter[") and key.endswith("]")): 

273 continue 

274 column_name = key[7:-1] 

275 column_joins, column = _get_sqlalchemy_column(model, column_name) 

276 joins.update(column_joins) 

277 

278 # Create a parser/validator for the filter values. Which is user input after all. 

279 parser = functools.partial(_parse_value, schema_cls, column_name) 

280 split_values = raw_value.split(",") # Filtering on empty strings is also OK 

281 clauses = [_make_where_clause(column, v, parser) for v in split_values] 

282 if len(clauses) == 1: 

283 or_clause = clauses[0] 

284 else: 

285 or_clause = sqlalchemy.or_(*clauses) 

286 filters[column].append(or_clause) 

287 

288 return filters, joins 

289 

290 

291def _apply_contains_parameters( 

292 query_params: QueryParams, 

293 select_query: Select, 

294 model: type[DeclarativeBase], 

295 schema_cls: SchemaType, 

296) -> tuple[ 

297 dict[InstrumentedAttribute, list[ColumnElement]], set[InstrumentedAttribute] 

298]: 

299 """Handle contains[field] parameters for string fields.""" 

300 filters: dict[InstrumentedAttribute, list[ColumnElement]] = defaultdict(list) 

301 joins: set[InstrumentedAttribute] = set() 

302 

303 for key, raw_value in query_params.multi_items(): 

304 if not (key.startswith("contains[") and key.endswith("]")): 

305 continue 

306 column_name = key[9:-1] # Extract field name from contains[field] 

307 column_joins, column = _get_sqlalchemy_column(model, column_name) 

308 joins.update(column_joins) 

309 

310 # For contains, we don't need to parse the value since it's just a string 

311 # Split by space for multiple contains values (AND logic) 

312 split_values = raw_value.split() 

313 clauses = [ 

314 column.ilike(f"%{_escape_like_value(v)}%", escape="\\") 

315 for v in split_values 

316 if v.strip() 

317 ] 

318 if clauses: 

319 if len(clauses) == 1: 

320 clause = clauses[0] 

321 else: 

322 clause = sqlalchemy.and_(*clauses) 

323 filters[column].append(clause) 

324 

325 return filters, joins 

326 

327 

328def _parse_value(schema_cls: SchemaType, column_name: str, value: str) -> Any: 

329 """Parse and validate a value on which will be filtered.""" 

330 

331 # Support nested fields, e.g. "blog.user.name" 

332 if "." in column_name: 

333 relation, _, column_part = column_name.partition(".") 

334 field = schema_cls.model_fields.get(relation) 

335 schema = _get_nested_schema(field) 

336 if not schema: 336 ↛ 337line 336 didn't jump to line 337 because the condition on line 336 was never true

337 raise HTTPException(400, f"Invalid attribute in URL query: {column_name}") 

338 return _parse_value(schema, column_part, value) 

339 

340 # Hacky stuff to validate (i.e. parse) a single field. 

341 # https://github.com/pydantic/pydantic/discussions/7367 

342 obj = schema_cls.__pydantic_validator__.validate_assignment( 

343 schema_cls.model_construct(), column_name, value 

344 ) 

345 return getattr(obj, column_name) 

346 

347 

348def _get_nested_schema(field: FieldInfo | None) -> SchemaType | None: 

349 if field is None: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true

350 return None 

351 

352 annotation = _unwrap_optional_annotation(field.annotation) 

353 

354 if isinstance(annotation, type) and issubclass(annotation, pydantic.BaseModel): 

355 return annotation 

356 

357 return None 

358 

359 

360def _make_where_clause( 

361 column: InstrumentedAttribute, filter_value: str, parser: Callable 

362) -> ColumnElement: 

363 if filter_value.startswith(">="): 

364 value = parser(filter_value[2:]) 

365 return column >= value 

366 elif filter_value.startswith("<="): 

367 value = parser(filter_value[2:]) 

368 return column <= value 

369 elif filter_value.startswith(">"): 

370 value = parser(filter_value[1:]) 

371 return column > value 

372 elif filter_value.startswith("<"): 

373 value = parser(filter_value[1:]) 

374 return column < value 

375 elif filter_value.startswith("!"): 

376 value = filter_value[1:] 

377 if value == "null": 

378 return column.isnot(None) 

379 value = parser(value) 

380 return column != value 

381 else: 

382 if filter_value == "null": 

383 return column.is_(None) 

384 value = parser(filter_value) 

385 return column == value