Coverage for fastapi_restly / views / _sync.py: 98%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-05 09:15 +0000

1from typing import Any, Sequence, TypeVar 

2 

3import fastapi 

4import sqlalchemy 

5from sqlalchemy import func, select 

6from sqlalchemy.orm import DeclarativeBase, Session 

7 

8from ..db import SessionDep 

9from ..query import apply_query_modifiers, use_query_modifier_version 

10from ..schemas import ( 

11 BaseSchema, 

12 IDSchema, 

13 get_writable_inputs, 

14 is_readonly_field, 

15 resolve_ids_to_sqlalchemy_objects, 

16) 

17from ._base import BaseRestView, _accepts_init_kwarg, delete, get, patch, post 

18 

19T = TypeVar("T", bound=DeclarativeBase) 

20 

21 

22def make_new_object( 

23 session: Session, 

24 model_cls: type[T], 

25 schema_obj: BaseSchema, 

26 schema_cls: type[BaseSchema] | None = None, 

27) -> T: 

28 resolve_ids_to_sqlalchemy_objects(session, schema_obj) 

29 # Filter out read-only fields when creating the object 

30 data = {} 

31 for field_name, value in schema_obj: 

32 if schema_cls is not None and is_readonly_field(schema_cls, field_name): 

33 continue 

34 if isinstance(value, IDSchema) and field_name.endswith("_id"): 

35 data[field_name] = value.id 

36 continue 

37 if isinstance(value, DeclarativeBase) and field_name.endswith("_id"): 

38 data[field_name] = value.id 

39 relation_name = field_name[:-3] 

40 if hasattr(model_cls, relation_name) and _accepts_init_kwarg(model_cls, relation_name): 40 ↛ 42line 40 didn't jump to line 42 because the condition on line 40 was always true

41 data[relation_name] = value 

42 continue 

43 data[field_name] = value 

44 obj = model_cls(**data) 

45 session.add(obj) 

46 return obj 

47 

48 

49def update_object( 

50 session: Session, 

51 obj: DeclarativeBase, 

52 schema_obj: BaseSchema, 

53 schema_cls: type[BaseSchema] | None = None, 

54) -> DeclarativeBase: 

55 resolve_ids_to_sqlalchemy_objects(session, schema_obj) 

56 for field_name, value in get_writable_inputs(schema_obj, schema_cls).items(): 

57 if isinstance(value, IDSchema) and field_name.endswith("_id"): 

58 setattr(obj, field_name, value.id) 

59 continue 

60 if isinstance(value, DeclarativeBase) and field_name.endswith("_id"): 

61 setattr(obj, field_name, value.id) 

62 relation_name = field_name[:-3] 

63 if hasattr(obj, relation_name): 63 ↛ 65line 63 didn't jump to line 65 because the condition on line 63 was always true

64 setattr(obj, relation_name, value) 

65 continue 

66 setattr(obj, field_name, value) 

67 return obj 

68 

69 

70def save_object(session: Session, obj: DeclarativeBase) -> DeclarativeBase: 

71 session.flush() 

72 session.refresh(obj) 

73 return obj 

74 

75 

76class RestView(BaseRestView): 

77 """ 

78 RestView creates a synchronous CRUD/REST interface for database objects. 

79 Basic usage:: 

80 

81 class FooView(RestView): 

82 prefix = "/foo" 

83 schema = FooSchema 

84 model = Foo 

85 

86 Where ``Foo`` is a SQLAlchemy model and ``FooSchema`` a Pydantic model. 

87 """ 

88 

89 session: SessionDep # type: ignore[reportIncompatibleVariableOverride] 

90 

91 @get("/") 

92 def index(self, query_params: Any) -> Any: 

93 objs = self.on_list(query_params) 

94 if not self.include_pagination_metadata: 

95 return [self.to_response_schema(obj) for obj in objs] 

96 

97 total = self.count_index(query_params) 

98 return self._build_pagination_payload(query_params, objs, total) 

99 

100 def on_list( 

101 self, 

102 query_params: Any, 

103 query: sqlalchemy.Select[Any] | None = None, 

104 ) -> Sequence[Any]: 

105 """ 

106 Handle a GET request on "/". This should return a list of objects. 

107 Accepts a query argument that can be used for narrowing down the selection. 

108 Feel free to override this method, e.g.: 

109 

110 def on_list(self, query_params, query=None): 

111 query = make_my_query() 

112 objs = super().on_list(query_params, query) 

113 return add_my_info(objs) 

114 """ 

115 if query is None: 115 ↛ 117line 115 didn't jump to line 117 because the condition on line 115 was always true

116 query = sqlalchemy.select(self.model) 

117 loader_options = self.get_relationship_loader_options() 

118 if loader_options: 

119 query = query.options(*loader_options) 

120 query_params = self._to_query_params(query_params) 

121 with use_query_modifier_version(self.get_query_modifier_version()): 

122 query = apply_query_modifiers( 

123 query_params, 

124 query, 

125 self.model, 

126 self.schema, 

127 ) 

128 scalar_result = self.session.scalars(query) 

129 return scalar_result.all() 

130 

131 def count_index(self, query_params: Any) -> int: 

132 query_params = self._to_query_params(query_params) 

133 with use_query_modifier_version(self.get_query_modifier_version()): 

134 filtered_query = apply_query_modifiers( 

135 query_params, sqlalchemy.select(self.model), self.model, self.schema 

136 ) 

137 filtered_query = filtered_query.order_by(None).limit(None).offset(None) 

138 count_query = select(func.count()).select_from(filtered_query.subquery()) 

139 return int(self.session.scalar(count_query) or 0) 

140 

141 @get("/{id}") 

142 def get(self, id: Any) -> Any: 

143 obj = self.on_get(id) 

144 return self.to_response_schema(obj) 

145 

146 def on_get(self, id: Any) -> Any: 

147 """ 

148 Handle a GET request on "/{id}". This should return a single object. 

149 Return a 404 if not found. 

150 Feel free to override this method. 

151 """ 

152 loader_options = self.get_relationship_loader_options() 

153 obj = self.session.get(self.model, id, options=loader_options) 

154 if obj is None: 

155 raise fastapi.HTTPException(404) 

156 return obj 

157 

158 @post("/") 

159 def post( 

160 self, schema_obj: BaseSchema 

161 ) -> Any: # schema_obj type is set in before_include_view 

162 obj = self.on_create(schema_obj) 

163 return self.to_response_schema(obj) 

164 

165 def on_create(self, schema_obj: BaseSchema) -> Any: 

166 """ 

167 Handle a POST request on "/". This should create a new object. 

168 Feel free to override this method. 

169 """ 

170 obj = self.make_new_object(schema_obj) 

171 obj = self.save_object(obj) 

172 return obj 

173 

174 @patch("/{id}") 

175 def patch(self, id: Any, schema_obj: BaseSchema) -> Any: 

176 obj = self.on_update(id, schema_obj) 

177 return self.to_response_schema(obj) 

178 

179 def on_update(self, id: Any, schema_obj: BaseSchema) -> Any: 

180 """ 

181 Handle a PATCH request on "/{id}". This should partially update an existing 

182 object. 

183 Feel free to override this method. 

184 """ 

185 obj = self.on_get(id) 

186 obj = self.update_object(obj, schema_obj) 

187 return self.save_object(obj) 

188 

189 @delete("/{id}") 

190 def delete(self, id: Any) -> fastapi.Response: 

191 return self.on_delete(id) 

192 

193 def on_delete(self, id: Any) -> fastapi.Response: 

194 obj = self.on_get(id) 

195 self.delete_object(obj) 

196 return fastapi.Response(status_code=204) 

197 

198 def delete_object(self, obj: DeclarativeBase) -> None: 

199 """ 

200 Delete an object from the database. 

201 Feel free to override this method. 

202 """ 

203 self.session.delete(obj) 

204 self.session.flush() 

205 

206 def make_new_object(self, schema_obj: BaseSchema) -> DeclarativeBase: 

207 """ 

208 Create a new object from a schema object. 

209 Feel free to override this method. 

210 """ 

211 return make_new_object( 

212 self.session, self.model, schema_obj, self.schema 

213 ) 

214 

215 def update_object(self, obj: DeclarativeBase, schema_obj: BaseSchema) -> DeclarativeBase: 

216 """ 

217 Update an existing object with data from a schema object. 

218 Feel free to override this method. 

219 """ 

220 return update_object(self.session, obj, schema_obj, self.schema) 

221 

222 def save_object(self, obj: DeclarativeBase) -> DeclarativeBase: 

223 """ 

224 Save an object to the database. 

225 Feel free to override this method. 

226 """ 

227 return save_object(self.session, obj)