Coverage for fastapi_restly / views / _async.py: 97%

109 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-02 09:54 +0000

1from typing import Any, Sequence 

2 

3import fastapi 

4import sqlalchemy 

5from sqlalchemy import func, select 

6from sqlalchemy.orm import DeclarativeBase 

7 

8from ..db import AsyncSessionDep 

9from ..query import apply_query_modifiers, use_query_modifier_version 

10from ..schemas import ( 

11 BaseSchema, 

12 IDSchema, 

13 async_resolve_ids_to_sqlalchemy_objects, 

14 get_writable_inputs, 

15 is_readonly_field, 

16) 

17from ._base import BaseRestView, _accepts_init_kwarg, delete, get, patch, post 

18 

19 

20class AsyncRestView(BaseRestView): 

21 """ 

22 AsyncRestView creates an async CRUD/REST interface for database objects. 

23 Basic usage:: 

24 

25 class FooView(AsyncRestView): 

26 prefix = "/foo" 

27 schema = FooSchema 

28 model = Foo 

29 

30 Where ``Foo`` is a SQLAlchemy model and ``FooSchema`` a Pydantic model. 

31 """ 

32 

33 session: AsyncSessionDep 

34 

35 @get("/") 

36 async def index(self, query_params: Any) -> Any: 

37 objs = await self.on_list(query_params) 

38 if not self.include_pagination_metadata: 38 ↛ 39line 38 didn't jump to line 39 because the condition on line 38 was never true

39 return [self.to_response_schema(obj) for obj in objs] 

40 

41 total = await self.count_index(query_params) 

42 return self._build_pagination_payload(query_params, objs, total) 

43 

44 async def on_list( 

45 self, 

46 query_params: Any, 

47 query: sqlalchemy.Select[Any] | None = None, 

48 ) -> Sequence[Any]: 

49 """ 

50 Handle a GET request on "/". This should return a list of objects. 

51 Accepts a query argument that can be used for narrowing down the selection. 

52 Feel free to override this method, e.g.: 

53 

54 async def on_list(self, query_params, query=None): 

55 query = make_my_query() 

56 objs = await super().on_list(query_params, query) 

57 return add_my_info(objs) 

58 """ 

59 if query is None: 

60 query = sqlalchemy.select(self.model) 

61 loader_options = self.get_relationship_loader_options() 

62 if loader_options: 

63 query = query.options(*loader_options) 

64 query_params = self._to_query_params(query_params) 

65 

66 with use_query_modifier_version(self.get_query_modifier_version()): 

67 query = apply_query_modifiers( 

68 query_params, query, self.model, self.schema 

69 ) 

70 scalar_result = await self.session.scalars(query) 

71 return scalar_result.all() 

72 

73 async def count_index(self, query_params: Any) -> int: 

74 query_params = self._to_query_params(query_params) 

75 with use_query_modifier_version(self.get_query_modifier_version()): 

76 filtered_query = apply_query_modifiers( 

77 query_params, sqlalchemy.select(self.model), self.model, self.schema 

78 ) 

79 filtered_query = filtered_query.order_by(None).limit(None).offset(None) 

80 count_query = select(func.count()).select_from(filtered_query.subquery()) 

81 return int(await self.session.scalar(count_query) or 0) 

82 

83 @get("/{id}") 

84 async def get(self, id: Any) -> Any: 

85 obj = await self.on_get(id) 

86 return self.to_response_schema(obj) 

87 

88 async def on_get(self, id: Any) -> Any: 

89 """ 

90 Handle a GET request on "/{id}". This should return a single object. 

91 Return a 404 if not found. 

92 Feel free to override this method. 

93 """ 

94 loader_options = self.get_relationship_loader_options() 

95 obj = await self.session.get(self.model, id, options=loader_options) 

96 if obj is None: 

97 raise fastapi.HTTPException(404) 

98 return obj 

99 

100 @post("/") 

101 async def post( 

102 self, schema_obj: BaseSchema 

103 ) -> Any: # schema_obj type is set in before_include_view 

104 obj = await self.on_create(schema_obj) 

105 return self.to_response_schema(obj) 

106 

107 async def on_create(self, schema_obj: BaseSchema) -> Any: 

108 """ 

109 Handle a POST request on "/". This should create a new object. 

110 Feel free to override this method. 

111 """ 

112 obj = await self.make_new_object(schema_obj) 

113 return await self.save_object(obj) 

114 

115 @patch("/{id}") 

116 async def patch(self, id: Any, schema_obj: BaseSchema) -> Any: 

117 obj = await self.on_update(id, schema_obj) 

118 return self.to_response_schema(obj) 

119 

120 async def on_update(self, id: Any, schema_obj: BaseSchema) -> Any: 

121 """ 

122 Handle a PATCH request on "/{id}". This should partially update an existing 

123 object. 

124 Feel free to override this method. 

125 """ 

126 obj = await self.on_get(id) 

127 obj = await self.update_object(obj, schema_obj) 

128 return await self.save_object(obj) 

129 

130 @delete("/{id}") 

131 async def delete(self, id: Any) -> fastapi.Response: 

132 return await self.on_delete(id) 

133 

134 async def on_delete(self, id: Any) -> fastapi.Response: 

135 obj = await self.on_get(id) 

136 await self.delete_object(obj) 

137 return fastapi.Response(status_code=204) 

138 

139 async def delete_object(self, obj: DeclarativeBase) -> None: 

140 """ 

141 Handle a DELETE request on "/{id}". This should delete an object from the 

142 database. `on_get()` is called first to lookup the object. 

143 Feel free to override this method. 

144 """ 

145 await self.session.delete(obj) 

146 await self.session.flush() 

147 

148 async def make_new_object(self, schema_obj: BaseSchema) -> DeclarativeBase: 

149 """ 

150 Create a new object from a schema object. 

151 Feel free to override this method. 

152 """ 

153 await async_resolve_ids_to_sqlalchemy_objects(self.session, schema_obj) 

154 

155 # Filter out read-only fields when creating the object 

156 data = {} 

157 for field_name, value in schema_obj: 

158 is_readonly = is_readonly_field(self.schema, field_name) 

159 if is_readonly: 

160 continue 

161 if isinstance(value, IDSchema) and field_name.endswith("_id"): 

162 data[field_name] = value.id 

163 continue 

164 if isinstance(value, DeclarativeBase) and field_name.endswith("_id"): 

165 data[field_name] = value.id 

166 relation_name = field_name[:-3] 

167 if hasattr(self.model, relation_name) and _accepts_init_kwarg(self.model, relation_name): 167 ↛ 169line 167 didn't jump to line 169 because the condition on line 167 was always true

168 data[relation_name] = value 

169 continue 

170 data[field_name] = value 

171 

172 obj = self.model(**data) 

173 self.session.add(obj) 

174 return obj 

175 

176 async def update_object(self, obj: DeclarativeBase, schema_obj: BaseSchema) -> DeclarativeBase: 

177 """ 

178 Update an existing object with data from a schema object. 

179 Feel free to override this method. 

180 """ 

181 await async_resolve_ids_to_sqlalchemy_objects(self.session, schema_obj) 

182 for field_name, value in get_writable_inputs(schema_obj, self.schema).items(): 

183 if isinstance(value, IDSchema) and field_name.endswith("_id"): 

184 setattr(obj, field_name, value.id) 

185 continue 

186 if isinstance(value, DeclarativeBase) and field_name.endswith("_id"): 

187 setattr(obj, field_name, value.id) 

188 relation_name = field_name[:-3] 

189 if hasattr(obj, relation_name): 189 ↛ 191line 189 didn't jump to line 191 because the condition on line 189 was always true

190 setattr(obj, relation_name, value) 

191 continue 

192 setattr(obj, field_name, value) 

193 return obj 

194 

195 async def save_object(self, obj: DeclarativeBase) -> DeclarativeBase: 

196 """ 

197 Save an object to the database. 

198 Feel free to override this method. 

199 """ 

200 await self.session.flush() 

201 await self.session.refresh(obj) 

202 return obj