This commit is contained in:
Waylon Walker 2023-05-23 08:55:35 -05:00
parent daf81343bf
commit a2b33b25f8
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
11 changed files with 479 additions and 55 deletions

View file

@ -4,15 +4,15 @@ from fastapi import APIRouter, Depends
from sqlmodel import SQLModel
from learn_sql_model.api.user import oauth2_scheme
from learn_sql_model.config import get_config
from learn_sql_model.config import Config, get_config
from learn_sql_model.models.hero import Hero
hero_router = APIRouter()
@hero_router.on_event("startup")
def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine)
def on_startup(config: Config = Depends(get_config)) -> None:
SQLModel.metadata.create_all(config.database.engine)
@hero_router.get("/items/")
@ -21,21 +21,22 @@ async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
@hero_router.get("/hero/{id}")
def get_hero(id: int) -> Hero:
def get_hero(id: int, config: Config = Depends(get_config)) -> Hero:
"get one hero"
return Hero().get(id=id)
return Hero().get(id=id, config=config)
@hero_router.post("/hero/")
def post_hero(hero: Hero) -> Hero:
def post_hero(hero: Hero, config: Config = Depends(get_config)) -> Hero:
"read all the heros"
return hero.post()
hero.post(config=config)
return hero
@hero_router.get("/heros/")
def get_heros() -> list[Hero]:
def get_heros(config: Config = Depends(get_config)) -> list[Hero]:
"get all heros"
return Hero().get()
return Hero().get(config=config)
# Alternatively
# with get_config().database.session as session:
# statement = select(Hero)

View file

@ -10,12 +10,12 @@ from learn_sql_model.cli.tui import tui_app
app = typer.Typer(
name="learn_sql_model",
help="A rich terminal report for coveragepy.",
help="learn-sql-model cli for managing the project",
)
app.add_typer(config_app)
app.add_typer(tui_app)
app.add_typer(model_app)
app.add_typer(api_app)
app.add_typer(config_app, name="config")
app.add_typer(tui_app, name="tui")
app.add_typer(model_app, name="model")
app.add_typer(api_app, name="api")
app.add_typer(hero_app, name="hero")
@ -38,6 +38,17 @@ def version_callback(value: bool) -> None:
raise typer.Exit()
@app.callback()
def main(
version: bool = typer.Option(
False,
callback=version_callback,
help="show the version of the learn-sql-model package.",
),
):
"configuration cli"
@app.command()
def tui(ctx: typer.Context) -> None:
Trogon(get_group(app), click_context=ctx).run()

View file

@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Optional, Union
from pydantic_typer import expand_pydantic_args
from rich.console import Console
@ -9,6 +9,7 @@ from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.hero import Hero
from learn_sql_model.models.pet import Pet
import sys
hero_app = typer.Typer()
@ -21,7 +22,7 @@ def hero():
@hero_app.command()
@expand_pydantic_args(typer=True)
def get(
id: int = None,
id: Optional[int] = None,
config: Config = None,
) -> Union[Hero, List[Hero]]:
"get one hero"
@ -52,12 +53,11 @@ def populate(
config: Config = None,
) -> Hero:
"read all the heros"
config.init()
if config is None:
config = Config()
if config.env == "prod":
Console().print("populate is not supported in production")
return
sys.exit(1)
for hero in HeroFactory().batch(n):
pet = PetFactory().build()

View file

@ -1,5 +1,7 @@
from contextvars import ContextVar
from typing import TYPE_CHECKING
from fastapi import Depends
from pydantic import BaseModel, BaseSettings
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session
@ -24,6 +26,13 @@ class Database:
self.config = get_config()
else:
self.config = config
self.db_state_default = {
"closed": None,
"conn": None,
"ctx": None,
"transactions": None,
}
self.db_state = ContextVar("db_state", default=self.db_state_default.copy())
@property
def engine(self) -> "Engine":
@ -60,6 +69,24 @@ def get_database(config: Config = None) -> Database:
return Database(config)
async def reset_db_state(config: Config = None) -> None:
if config is None:
config = get_config()
config.database.db._state._state.set(db_state_default.copy())
config.database.db._state.reset()
def get_db(config: Config = None, reset_db_state=Depends(reset_db_state)):
if config is None:
config = get_config()
try:
config.database.db.connect()
yield
finally:
if not config.database.db.is_closed():
config.database.db.close()
def get_config(overrides: dict = {}) -> Config:
raw_config = load("learn_sql_model")
config = Config(**raw_config, **overrides)

View file

@ -1,15 +0,0 @@
from __future__ import annotations
from typing import Optional
from learn_sql_model.models.fast_model import FastModel
from sqlmodel import Field
class Hero(FastModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
# new_attribute: Optional[str] = None
# pets: List["Pet"] = Relationship(back_populates="hero")

View file

@ -29,6 +29,8 @@ class FastModel(SQLModel):
with config.database.session as session:
session.add(self)
session.commit()
session.refresh(self)
return
def get(
self, id: int = None, config: "Config" = None, where=None
@ -52,6 +54,16 @@ class FastModel(SQLModel):
results = session.exec(statement).one()
return results
def flags(self, config: "Config" = None) -> None:
if config is None:
config = get_config()
flags = []
for k, v in self.dict().items():
if v:
flags.append(f"--{k.replace('_', '-').lower()}")
flags.append(v)
return flags
# TODO
# update
# delete