This commit is contained in:
Waylon Walker 2023-05-23 08:55:35 -05:00
parent daf81343bf
commit a2b33b25f8
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
11 changed files with 479 additions and 55 deletions

View file

@ -1,5 +1,7 @@
from contextvars import ContextVar
from typing import TYPE_CHECKING
from fastapi import Depends
from pydantic import BaseModel, BaseSettings
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session
@ -24,6 +26,13 @@ class Database:
self.config = get_config()
else:
self.config = config
self.db_state_default = {
"closed": None,
"conn": None,
"ctx": None,
"transactions": None,
}
self.db_state = ContextVar("db_state", default=self.db_state_default.copy())
@property
def engine(self) -> "Engine":
@ -60,6 +69,24 @@ def get_database(config: Config = None) -> Database:
return Database(config)
async def reset_db_state(config: Config = None) -> None:
if config is None:
config = get_config()
config.database.db._state._state.set(db_state_default.copy())
config.database.db._state.reset()
def get_db(config: Config = None, reset_db_state=Depends(reset_db_state)):
if config is None:
config = get_config()
try:
config.database.db.connect()
yield
finally:
if not config.database.db.is_closed():
config.database.db.close()
def get_config(overrides: dict = {}) -> Config:
raw_config = load("learn_sql_model")
config = Config(**raw_config, **overrides)