Coverage for fastapi_restly / db / _session.py: 93%
101 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-02 09:54 +0000
1from collections.abc import AsyncIterator, Callable, Iterator
2from typing import Annotated, Any, cast
4from fastapi import Depends
5from sqlalchemy import Engine, create_engine
6from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
7from sqlalchemy.ext.asyncio import AsyncSession as SA_AsyncSession
8from sqlalchemy.orm import Session as SA_Session
9from sqlalchemy.orm import sessionmaker
11from ._globals import fr_globals
13try:
14 import orjson
15except ImportError:
16 json_deserializer = None
17 json_serializer = None
18else:
20 def orjson_serializer(obj):
21 return orjson.dumps(
22 obj, option=orjson.OPT_NAIVE_UTC | orjson.OPT_NON_STR_KEYS
23 ).decode()
25 json_deserializer = orjson.loads
26 json_serializer = orjson_serializer
29def _setup_async_database_connection(
30 async_database_url: str | None = None,
31 *,
32 async_engine: AsyncEngine | None = None,
33 async_make_session: async_sessionmaker | None = None,
34) -> async_sessionmaker:
35 if not async_make_session:
36 if not async_engine:
37 async_engine = create_async_engine(
38 async_database_url, # type: ignore[arg-type]
39 json_serializer=json_serializer,
40 json_deserializer=json_deserializer,
41 )
42 async_make_session = async_sessionmaker(
43 bind=async_engine, autoflush=False, expire_on_commit=False
44 )
46 fr_globals.async_database_url = async_database_url
47 fr_globals.async_make_session = async_make_session
48 return async_make_session
51def _setup_database_connection(
52 database_url: str | None = None,
53 *,
54 engine: Engine | None = None,
55 make_session: sessionmaker | None = None,
56) -> sessionmaker:
57 if make_session is None:
58 if engine is None:
59 engine = create_engine(
60 database_url, # type: ignore[arg-type]
61 json_serializer=json_serializer,
62 json_deserializer=json_deserializer,
63 )
64 make_session = sessionmaker(bind=engine, expire_on_commit=False)
66 fr_globals.database_url = database_url
67 fr_globals.make_session = make_session
68 return make_session
71def configure(
72 *,
73 async_database_url: str | None = None,
74 async_engine: AsyncEngine | None = None,
75 async_make_session: async_sessionmaker | None = None,
76 database_url: str | None = None,
77 engine: Engine | None = None,
78 make_session: sessionmaker | None = None,
79 session_generator: Callable[[], AsyncIterator[SA_AsyncSession]] | None = None,
80 sync_session_generator: Callable[[], Iterator[SA_Session]] | None = None,
81) -> None:
82 """Configure FastAPI-Restly. Call once at startup.
84 Pass async parameters (``async_database_url``, ``async_engine``, or
85 ``async_make_session``) to enable async support, sync parameters
86 (``database_url``, ``engine``, or ``make_session``) for sync support,
87 or both if your application uses both.
89 Use ``session_generator`` / ``sync_session_generator`` to plug in a
90 custom session factory instead of the built-in one.
91 """
92 if async_database_url is not None or async_engine is not None or async_make_session is not None:
93 _setup_async_database_connection(
94 async_database_url=async_database_url,
95 async_engine=async_engine,
96 async_make_session=async_make_session,
97 )
98 if database_url is not None or engine is not None or make_session is not None:
99 _setup_database_connection(
100 database_url=database_url,
101 engine=engine,
102 make_session=make_session,
103 )
104 if session_generator is not None:
105 fr_globals.session_generator = session_generator
106 if sync_session_generator is not None:
107 fr_globals.sync_session_generator = sync_session_generator
110def activate_savepoint_only_mode(
111 make_session: async_sessionmaker | sessionmaker,
112) -> None:
113 """
114 Intended for use in tests. Puts the session factory into savepoint-only mode so
115 that no test data is ever committed to the database. Each test can roll back
116 instantly by closing the session, leaving the database clean for the next test.
118 This is done with "create_savepoint" mode and a wrapper on engine.connect() that
119 begins the outer transaction before the Session can use it.
120 https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#session-external-transaction
121 """
122 engine = _get_sync_engine(make_session)
124 # Check if already activated (look for the marker attribute we set)
125 if hasattr(engine.connect, "_original_connect"):
126 return # Already activated, skip
128 original_connect = engine.connect
130 def _begin_on_connect():
131 connection = original_connect()
132 connection.begin()
133 return connection
135 # Using setattr to silence pyright
136 setattr(_begin_on_connect, "_original_connect", original_connect)
138 engine.connect = _begin_on_connect
139 make_session.configure(join_transaction_mode="create_savepoint")
142def deactivate_savepoint_only_mode(
143 make_session: async_sessionmaker | sessionmaker,
144) -> None:
145 """
146 Reverts the effect of `activate_savepoint_only_mode`.
147 Restores the original engine.connect and disables savepoint-only mode.
148 """
149 engine = _get_sync_engine(make_session)
150 _begin_on_connect = cast(Any, engine.connect)
151 if hasattr(_begin_on_connect, "_original_connect"): 151 ↛ 156line 151 didn't jump to line 156 because the condition on line 151 was always true
152 # Restore the original connect that was saved by activate_savepoint_only_mode
153 engine.connect = _begin_on_connect._original_connect
154 # If engine was never activated, there is nothing to restore; this is safe to call
156 make_session.configure(join_transaction_mode=None)
159def get_async_engine() -> AsyncEngine:
160 """Return the async engine registered via configure()."""
161 if fr_globals.async_make_session is None:
162 raise RuntimeError(
163 "Call fr.configure() before using get_async_engine()."
164 )
165 return fr_globals.async_make_session.kw["bind"]
168def get_engine() -> Engine:
169 """Return the sync engine registered via configure()."""
170 if fr_globals.make_session is None:
171 raise RuntimeError(
172 "Call fr.configure() before using get_engine()."
173 )
174 return fr_globals.make_session.kw["bind"]
177def _get_sync_engine(make_session: async_sessionmaker | sessionmaker) -> Engine:
178 engine = make_session.kw["bind"]
179 if isinstance(engine, AsyncEngine):
180 return engine.sync_engine
181 return engine
184async def async_generate_session() -> AsyncIterator[SA_AsyncSession]:
185 """FastAPI dependency for async database session."""
186 if fr_globals.session_generator is not None:
187 async for session in fr_globals.session_generator():
188 yield session
189 return
191 # FastAPI does not support contextmanagers as dependency directly,
192 # but it does support generators.
193 async with fr_globals.async_make_session() as session:
194 yield session
195 if session.is_active: 195 ↛ exitline 195 didn't jump to the function exit
196 try:
197 await session.commit()
198 except Exception:
199 await session.rollback()
200 raise
203AsyncSessionDep = Annotated[SA_AsyncSession, Depends(async_generate_session)]
206def generate_session() -> Iterator[SA_Session]:
207 """FastAPI dependency for sync database session."""
208 if fr_globals.sync_session_generator is not None:
209 yield from fr_globals.sync_session_generator()
210 return
212 with fr_globals.make_session() as session:
213 yield session
214 if session.is_active: 214 ↛ exitline 214 didn't jump to the function exit
215 try:
216 session.commit()
217 except Exception:
218 session.rollback()
219 raise
222SessionDep = Annotated[SA_Session, Depends(generate_session)]