Coverage for fastapi_restly / db / _session.py: 93%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-02 09:54 +0000

1from collections.abc import AsyncIterator, Callable, Iterator 

2from typing import Annotated, Any, cast 

3 

4from fastapi import Depends 

5from sqlalchemy import Engine, create_engine 

6from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine 

7from sqlalchemy.ext.asyncio import AsyncSession as SA_AsyncSession 

8from sqlalchemy.orm import Session as SA_Session 

9from sqlalchemy.orm import sessionmaker 

10 

11from ._globals import fr_globals 

12 

13try: 

14 import orjson 

15except ImportError: 

16 json_deserializer = None 

17 json_serializer = None 

18else: 

19 

20 def orjson_serializer(obj): 

21 return orjson.dumps( 

22 obj, option=orjson.OPT_NAIVE_UTC | orjson.OPT_NON_STR_KEYS 

23 ).decode() 

24 

25 json_deserializer = orjson.loads 

26 json_serializer = orjson_serializer 

27 

28 

29def _setup_async_database_connection( 

30 async_database_url: str | None = None, 

31 *, 

32 async_engine: AsyncEngine | None = None, 

33 async_make_session: async_sessionmaker | None = None, 

34) -> async_sessionmaker: 

35 if not async_make_session: 

36 if not async_engine: 

37 async_engine = create_async_engine( 

38 async_database_url, # type: ignore[arg-type] 

39 json_serializer=json_serializer, 

40 json_deserializer=json_deserializer, 

41 ) 

42 async_make_session = async_sessionmaker( 

43 bind=async_engine, autoflush=False, expire_on_commit=False 

44 ) 

45 

46 fr_globals.async_database_url = async_database_url 

47 fr_globals.async_make_session = async_make_session 

48 return async_make_session 

49 

50 

51def _setup_database_connection( 

52 database_url: str | None = None, 

53 *, 

54 engine: Engine | None = None, 

55 make_session: sessionmaker | None = None, 

56) -> sessionmaker: 

57 if make_session is None: 

58 if engine is None: 

59 engine = create_engine( 

60 database_url, # type: ignore[arg-type] 

61 json_serializer=json_serializer, 

62 json_deserializer=json_deserializer, 

63 ) 

64 make_session = sessionmaker(bind=engine, expire_on_commit=False) 

65 

66 fr_globals.database_url = database_url 

67 fr_globals.make_session = make_session 

68 return make_session 

69 

70 

71def configure( 

72 *, 

73 async_database_url: str | None = None, 

74 async_engine: AsyncEngine | None = None, 

75 async_make_session: async_sessionmaker | None = None, 

76 database_url: str | None = None, 

77 engine: Engine | None = None, 

78 make_session: sessionmaker | None = None, 

79 session_generator: Callable[[], AsyncIterator[SA_AsyncSession]] | None = None, 

80 sync_session_generator: Callable[[], Iterator[SA_Session]] | None = None, 

81) -> None: 

82 """Configure FastAPI-Restly. Call once at startup. 

83 

84 Pass async parameters (``async_database_url``, ``async_engine``, or 

85 ``async_make_session``) to enable async support, sync parameters 

86 (``database_url``, ``engine``, or ``make_session``) for sync support, 

87 or both if your application uses both. 

88 

89 Use ``session_generator`` / ``sync_session_generator`` to plug in a 

90 custom session factory instead of the built-in one. 

91 """ 

92 if async_database_url is not None or async_engine is not None or async_make_session is not None: 

93 _setup_async_database_connection( 

94 async_database_url=async_database_url, 

95 async_engine=async_engine, 

96 async_make_session=async_make_session, 

97 ) 

98 if database_url is not None or engine is not None or make_session is not None: 

99 _setup_database_connection( 

100 database_url=database_url, 

101 engine=engine, 

102 make_session=make_session, 

103 ) 

104 if session_generator is not None: 

105 fr_globals.session_generator = session_generator 

106 if sync_session_generator is not None: 

107 fr_globals.sync_session_generator = sync_session_generator 

108 

109 

110def activate_savepoint_only_mode( 

111 make_session: async_sessionmaker | sessionmaker, 

112) -> None: 

113 """ 

114 Intended for use in tests. Puts the session factory into savepoint-only mode so 

115 that no test data is ever committed to the database. Each test can roll back 

116 instantly by closing the session, leaving the database clean for the next test. 

117 

118 This is done with "create_savepoint" mode and a wrapper on engine.connect() that 

119 begins the outer transaction before the Session can use it. 

120 https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#session-external-transaction 

121 """ 

122 engine = _get_sync_engine(make_session) 

123 

124 # Check if already activated (look for the marker attribute we set) 

125 if hasattr(engine.connect, "_original_connect"): 

126 return # Already activated, skip 

127 

128 original_connect = engine.connect 

129 

130 def _begin_on_connect(): 

131 connection = original_connect() 

132 connection.begin() 

133 return connection 

134 

135 # Using setattr to silence pyright 

136 setattr(_begin_on_connect, "_original_connect", original_connect) 

137 

138 engine.connect = _begin_on_connect 

139 make_session.configure(join_transaction_mode="create_savepoint") 

140 

141 

142def deactivate_savepoint_only_mode( 

143 make_session: async_sessionmaker | sessionmaker, 

144) -> None: 

145 """ 

146 Reverts the effect of `activate_savepoint_only_mode`. 

147 Restores the original engine.connect and disables savepoint-only mode. 

148 """ 

149 engine = _get_sync_engine(make_session) 

150 _begin_on_connect = cast(Any, engine.connect) 

151 if hasattr(_begin_on_connect, "_original_connect"): 151 ↛ 156line 151 didn't jump to line 156 because the condition on line 151 was always true

152 # Restore the original connect that was saved by activate_savepoint_only_mode 

153 engine.connect = _begin_on_connect._original_connect 

154 # If engine was never activated, there is nothing to restore; this is safe to call 

155 

156 make_session.configure(join_transaction_mode=None) 

157 

158 

159def get_async_engine() -> AsyncEngine: 

160 """Return the async engine registered via configure().""" 

161 if fr_globals.async_make_session is None: 

162 raise RuntimeError( 

163 "Call fr.configure() before using get_async_engine()." 

164 ) 

165 return fr_globals.async_make_session.kw["bind"] 

166 

167 

168def get_engine() -> Engine: 

169 """Return the sync engine registered via configure().""" 

170 if fr_globals.make_session is None: 

171 raise RuntimeError( 

172 "Call fr.configure() before using get_engine()." 

173 ) 

174 return fr_globals.make_session.kw["bind"] 

175 

176 

177def _get_sync_engine(make_session: async_sessionmaker | sessionmaker) -> Engine: 

178 engine = make_session.kw["bind"] 

179 if isinstance(engine, AsyncEngine): 

180 return engine.sync_engine 

181 return engine 

182 

183 

184async def async_generate_session() -> AsyncIterator[SA_AsyncSession]: 

185 """FastAPI dependency for async database session.""" 

186 if fr_globals.session_generator is not None: 

187 async for session in fr_globals.session_generator(): 

188 yield session 

189 return 

190 

191 # FastAPI does not support contextmanagers as dependency directly, 

192 # but it does support generators. 

193 async with fr_globals.async_make_session() as session: 

194 yield session 

195 if session.is_active: 195 ↛ exitline 195 didn't jump to the function exit

196 try: 

197 await session.commit() 

198 except Exception: 

199 await session.rollback() 

200 raise 

201 

202 

203AsyncSessionDep = Annotated[SA_AsyncSession, Depends(async_generate_session)] 

204 

205 

206def generate_session() -> Iterator[SA_Session]: 

207 """FastAPI dependency for sync database session.""" 

208 if fr_globals.sync_session_generator is not None: 

209 yield from fr_globals.sync_session_generator() 

210 return 

211 

212 with fr_globals.make_session() as session: 

213 yield session 

214 if session.is_active: 214 ↛ exitline 214 didn't jump to the function exit

215 try: 

216 session.commit() 

217 except Exception: 

218 session.rollback() 

219 raise 

220 

221 

222SessionDep = Annotated[SA_Session, Depends(generate_session)]