This commit is contained in:
Waylon Walker 2023-06-09 16:04:58 -05:00
parent 1a0bf1adb9
commit c3db85a209
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
21 changed files with 647 additions and 658 deletions

View file

@ -2,6 +2,7 @@
from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import Config
from learn_sql_model.config import get_config
from learn_sql_model.config import get_session
from learn_sql_model.console import console
from learn_sql_model.database import get_database
from learn_sql_model.factories.hero import HeroFactory

View file

@ -12,8 +12,6 @@ RUN groupadd -f -g ${SMOKE_GID} smoke && \
RUN mkdir /home/smoke && chown -R smoke:smoke /home/smoke && mkdir /src && chown smoke:smoke /src
WORKDIR /home/smoke
RUN apt update && \
apt upgrade -y && \
apt install -y \

View file

@ -1,18 +1,9 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import SQLModel, Session
from fastapi import APIRouter, Depends
from sqlmodel import SQLModel
from learn_sql_model.api.user import oauth2_scheme
from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import Config, get_config
from learn_sql_model.models.hero import (
Hero,
HeroCreate,
HeroDelete,
HeroRead,
HeroUpdate,
)
from learn_sql_model.config import get_config, get_session
from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, HeroUpdate
hero_router = APIRouter()
@ -22,52 +13,73 @@ def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine)
@hero_router.get("/items/")
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
return {"token": token}
@hero_router.get("/hero/{id}")
async def get_hero(id: int, config: Config = Depends(get_config)) -> Hero:
@hero_router.get("/hero/{hero_id}")
async def get_hero(
*,
session: Session = Depends(get_session),
hero_id: int,
) -> HeroRead:
"get one hero"
return Hero().get(id=id, config=config)
@hero_router.get("/h/{id}")
async def get_h(id: int, config: Config = Depends(get_config)) -> Hero:
"get one hero"
return Hero().get(id=id, config=config)
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
return hero
@hero_router.post("/hero/")
async def post_hero(hero: HeroCreate) -> HeroRead:
async def post_hero(
*,
session: Session = Depends(get_session),
hero: HeroCreate,
) -> HeroRead:
"read all the heros"
config = get_config()
hero = hero.post(config=config)
db_hero = Hero.from_orm(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
await manager.broadcast({hero.json()}, id=1)
return hero
return db_hero
@hero_router.patch("/hero/")
async def patch_hero(hero: HeroUpdate) -> HeroRead:
async def patch_hero(
*,
session: Session = Depends(get_session),
hero: HeroUpdate,
) -> HeroRead:
"read all the heros"
config = get_config()
hero = hero.update(config=config)
db_hero = session.get(Hero, hero.id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
for key, value in hero.dict(exclude_unset=True).items():
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
await manager.broadcast({hero.json()}, id=1)
return hero
return db_hero
@hero_router.delete("/hero/{hero_id}")
async def delete_hero(hero_id: int):
async def delete_hero(
*,
session: Session = Depends(get_session),
hero_id: int,
):
"read all the heros"
hero = HeroDelete(id=hero_id)
config = get_config()
hero = hero.delete(config=config)
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
await manager.broadcast(f"deleted hero {hero_id}", id=1)
return hero
return {"ok": True}
@hero_router.get("/heros/")
async def get_heros(config: Config = Depends(get_config)) -> list[Hero]:
async def get_heros(
*,
session: Session = Depends(get_session),
) -> list[Hero]:
"get all heros"
return Hero().get(config=config)
return HeroRead.list(session=session)

View file

@ -1,45 +0,0 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlmodel import SQLModel
from learn_sql_model.api.user import oauth2_scheme
from learn_sql_model.config import Config, get_config
from learn_sql_model.models.new import new
new_router = APIRouter()
@new_router.on_event("startup")
def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine)
@new_router.get("/items/")
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
return {"token": token}
@new_router.get("/new/{id}")
def get_new(id: int, config: Config = Depends(get_config)) -> new:
"get one new"
return new().get(id=id, config=config)
@new_router.get("/h/{id}")
def get_h(id: int, config: Config = Depends(get_config)) -> new:
"get one new"
return new().get(id=id, config=config)
@new_router.post("/new/")
def post_new(new: new, config: Config = Depends(get_config)) -> new:
"read all the news"
new.post(config=config)
return new
@new_router.get("/news/")
def get_news(config: Config = Depends(get_config)) -> list[new]:
"get all news"
return new().get(config=config)

View file

@ -69,4 +69,4 @@ async def websocket_endpoint(websocket: WebSocket):
except WebSocketDisconnect:
manager.disconnect(websocket, id)
await manager.broadcast(f"Client #{client_id} left the chat", id)
await manager.broadcast(f"Client #{id} left the chat", id)

View file

@ -1,3 +1,4 @@
import httpx
from rich.console import Console
import typer
import uvicorn
@ -38,15 +39,11 @@ def status(
help="show the log messages",
),
):
import httpx
config = get_config()
host = config.api_server.host
port = config.api_server.port
url = f"http://{host}:{port}/docs"
url = config.api_client.url
try:
r = httpx.get(url)
r = httpx.get(url + "/docs")
if r.status_code == 200:
Console().print(f"[green]API: ([gold1]{url}[green]) is running")
else:
@ -59,7 +56,7 @@ def status(
Console().print(
f"[green]database: ([gold1]{config.database.engine}[green]) is running"
)
except Exception as e:
except Exception:
Console().print(
f"[red]database: ([gold1]{config.database.engine}[red]) is not running"
)

View file

@ -2,13 +2,11 @@ import sys
from typing import List, Optional, Union
from engorgio import engorgio
import httpx
from rich.console import Console
import typer
from learn_sql_model.config import Config, get_config
from learn_sql_model.config import get_config
from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.hero import (
Hero,
HeroCreate,
@ -19,6 +17,8 @@ from learn_sql_model.models.hero import (
hero_app = typer.Typer()
config = get_config()
@hero_app.callback()
def hero():
@ -28,12 +28,10 @@ def hero():
@hero_app.command()
@engorgio(typer=True)
def get(
id: Optional[int] = typer.Argument(default=None),
config: Config = None,
hero_id: Optional[int] = typer.Argument(default=None),
) -> Union[Hero, List[Hero]]:
"get one hero"
config.init()
hero = HeroRead.get(id=id, config=config)
hero = HeroRead.get(id=hero_id)
Console().print(hero)
return hero
@ -42,12 +40,11 @@ def get(
@engorgio(typer=True)
def list(
where: Optional[str] = None,
config: Config = None,
offset: int = 0,
limit: Optional[int] = None,
) -> Union[Hero, List[Hero]]:
"get one hero"
hero = HeroRead.list(config=config, where=where, offset=offset, limit=limit)
"list many heros"
hero = HeroRead.list(where=where, offset=offset, limit=limit)
Console().print(hero)
return hero
@ -56,65 +53,39 @@ def list(
@engorgio(typer=True)
def create(
hero: HeroCreate,
config: Config = None,
) -> Hero:
"read all the heros"
r = httpx.post(
f"{config.api_client.url}/hero/",
json=hero.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
# hero = hero.post(config=config)
# Console().print(hero)
# return hero
"create one hero"
hero.post()
@hero_app.command()
@engorgio(typer=True)
def update(
hero: HeroUpdate,
config: Config = None,
) -> Hero:
"read all the heros"
r = httpx.patch(
f"{config.api_client.url}/hero/",
json=hero.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
"update one hero"
hero.update()
@hero_app.command()
@engorgio(typer=True)
def delete(
hero: HeroDelete,
config: Config = None,
) -> Hero:
"read all the heros"
r = httpx.delete(
f"{config.api_client.url}/hero/{hero.id}",
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
"delete a hero by id"
hero.delete()
@hero_app.command()
@engorgio(typer=True)
def populate(
hero: Hero,
n: int = 10,
) -> Hero:
"read all the heros"
config = get_config()
"Create n number of heros"
if config.env == "prod":
Console().print("populate is not supported in production")
sys.exit(1)
for hero in HeroFactory().batch(n):
pet = PetFactory().build()
hero.pet = pet
Console().print(hero)
hero.post(config=config)
hero = HeroCreate(**hero.dict())
hero.post()

View file

@ -44,7 +44,6 @@ def create_revision(
prompt=True,
),
):
alembic_cfg = Config("alembic.ini")
alembic.command.revision(
config=alembic_cfg,
@ -63,7 +62,6 @@ def checkout(
),
revision: str = typer.Option("head"),
):
alembic_cfg = Config("alembic.ini")
alembic.command.upgrade(config=alembic_cfg, revision="head")

View file

@ -1,107 +0,0 @@
import sys
from typing import List, Optional, Union
from engorgio import engorgio
from rich.console import Console
import typer
from learn_sql_model.config import Config, get_config
from learn_sql_model.factories.new import newFactory
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.new import (
new,
newCreate,
newDelete,
newRead,
newUpdate,
)
new_app = typer.Typer()
@new_app.callback()
def new():
"model cli"
@new_app.command()
@engorgio(typer=True)
def get(
id: Optional[int] = typer.Argument(default=None),
config: Config = None,
) -> Union[new, List[new]]:
"get one new"
config.init()
new = newRead.get(id=id, config=config)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def list(
where: Optional[str] = None,
config: Config = None,
offset: int = 0,
limit: Optional[int] = None,
) -> Union[new, List[new]]:
"get one new"
new = newRead.list(config=config, where=where, offset=offset, limit=limit)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def create(
new: newCreate,
config: Config = None,
) -> new:
"read all the news"
# config.init()
new = new.post(config=config)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def update(
new: newUpdate,
config: Config = None,
) -> new:
"read all the news"
new = new.update(config=config)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def delete(
new: newDelete,
config: Config = None,
) -> new:
"read all the news"
# config.init()
new = new.delete(config=config)
return new
@new_app.command()
@engorgio(typer=True)
def populate(
new: new,
n: int = 10,
) -> new:
"read all the news"
config = get_config()
if config.env == "prod":
Console().print("populate is not supported in production")
sys.exit(1)
for new in newFactory().batch(n):
pet = PetFactory().build()
new.pet = pet
Console().print(new)
new.post(config=config)

View file

@ -30,7 +30,6 @@ class ApiClient(BaseModel):
class Database:
def __init__(self, config: "Config" = None) -> None:
if config is None:
self.config = get_config()
else:
self.config = config
@ -71,17 +70,26 @@ class Config(BaseSettings):
def get_database(config: Config = None) -> Database:
if config is None:
config = get_config()
return Database(config)
def get_config(overrides: dict = {}) -> Config:
raw_config = load("learn_sql_model")
config = Config(**raw_config, **overrides)
return config
def get_session(config: Config = None) -> "Session":
with Session(config.database.engine) as session:
yield session
async def reset_db_state(config: Config = None) -> None:
if config is None:
config = get_config()
config.database.db._state._state.set(db_state_default.copy())
config.database.db._state._state.set(config.database.db_state_default.copy())
config.database.db._state.reset()
@ -96,7 +104,4 @@ def get_db(config: Config = None, reset_db_state=Depends(reset_db_state)):
config.database.db.close()
def get_config(overrides: dict = {}) -> Config:
raw_config = load("learn_sql_model")
config = Config(**raw_config, **overrides)
return config
config = get_config()

View file

@ -1,14 +0,0 @@
from faker import Faker
from polyfactory.factories.pydantic_factory import ModelFactory
from learn_sql_model.models.new import new
class newFactory(ModelFactory[new]):
__model__ = new
__faker__ = Faker(locale="en_US")
__set_as_default_factory_for_type__ = True
id = None
__random_seed__ = 10

View file

@ -21,7 +21,6 @@ class FastModel(SQLModel):
def post(self, config: "Config" = None) -> None:
if config is None:
config = get_config()
self.pre_post()
@ -36,7 +35,6 @@ class FastModel(SQLModel):
self, id: int = None, config: "Config" = None, where=None
) -> Optional["FastModel"]:
if config is None:
config = get_config()
self.pre_get()

View file

@ -1,10 +1,11 @@
from typing import Optional
from fastapi import HTTPException
import httpx
from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select
from learn_sql_model.config import Config
from learn_sql_model.config import config, get_session
from learn_sql_model.models.pet import Pet
@ -25,14 +26,13 @@ class Hero(HeroBase, table=True):
class HeroCreate(HeroBase):
...
def post(self, config: Config) -> Hero:
config.init()
with Session(config.database.engine) as session:
db_hero = Hero.from_orm(self)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
def post(self) -> Hero:
r = httpx.post(
f"{config.api_client.url}/hero/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
class HeroRead(HeroBase):
@ -41,10 +41,8 @@ class HeroRead(HeroBase):
@classmethod
def get(
cls,
config: Config,
id: int,
) -> Hero:
with config.database.session as session:
hero = session.get(Hero, id)
if not hero:
@ -54,25 +52,25 @@ class HeroRead(HeroBase):
@classmethod
def list(
self,
config: Config,
where=None,
offset=0,
limit=None,
session: Session = get_session,
) -> Hero:
# with config.database.session as session:
with config.database.session as session:
statement = select(Hero)
if where != "None":
from sqlmodel import text
statement = select(Hero)
if where != "None" and where is not None:
from sqlmodel import text
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
heroes = session.exec(statement).all()
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
heroes = session.exec(statement).all()
return heroes
class HeroUpdate(SQLModel):
# id is required to get the hero
# id is required to update the hero
id: int
# all other fields, must match the model, but with Optional default None
@ -84,30 +82,22 @@ class HeroUpdate(SQLModel):
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="hero")
def update(self, config: Config) -> Hero:
with Session(config.database.engine) as session:
db_hero = session.get(Hero, self.id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = self.dict(exclude_unset=True)
for key, value in hero_data.items():
if value is not None:
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
def update(self) -> Hero:
r = httpx.patch(
f"{config.api_client.url}/hero/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
class HeroDelete(BaseModel):
id: int
def delete(self, config: Config) -> Hero:
config.init()
with Session(config.database.engine) as session:
hero = session.get(Hero, self.id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
return {"ok": True}
def delete(self) -> Hero:
r = httpx.delete(
f"{config.api_client.url}/hero/{self.id}",
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {"ok": True}

View file

@ -1,99 +0,0 @@
from typing import Optional
from fastapi import HTTPException
from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select
from learn_sql_model.config import Config
from learn_sql_model.models.pet import Pet
class newBase(SQLModel, table=False):
class new(newBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class newCreate(newBase):
...
def post(self, config: Config) -> new:
config.init()
with Session(config.database.engine) as session:
db_new = new.from_orm(self)
session.add(db_new)
session.commit()
session.refresh(db_new)
return db_new
class newRead(newBase):
id: int
@classmethod
def get(
cls,
config: Config,
id: int,
) -> new:
with config.database.session as session:
new = session.get(new, id)
if not new:
raise HTTPException(status_code=404, detail="new not found")
return new
@classmethod
def list(
self,
config: Config,
where=None,
offset=0,
limit=None,
) -> new:
with config.database.session as session:
statement = select(new)
if where != "None":
from sqlmodel import text
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
newes = session.exec(statement).all()
return newes
class newUpdate(SQLModel):
# id is required to get the new
id: int
# all other fields, must match the model, but with Optional default None
def update(self, config: Config) -> new:
with Session(config.database.engine) as session:
db_new = session.get(new, self.id)
if not db_new:
raise HTTPException(status_code=404, detail="new not found")
new_data = self.dict(exclude_unset=True)
for key, value in new_data.items():
if value is not None:
setattr(db_new, key, value)
session.add(db_new)
session.commit()
session.refresh(db_new)
return db_new
class newDelete(BaseModel):
id: int
def delete(self, config: Config) -> new:
config.init()
with Session(config.database.engine) as session:
new = session.get(new, self.id)
if not new:
raise HTTPException(status_code=404, detail="new not found")
session.delete(new)
session.commit()
return {"ok": True}

View file

@ -77,14 +77,17 @@ dependencies = [
[tool.hatch.envs.default.scripts]
test = "coverage run -m pytest"
cov = "coverage-rich report"
test-cov = ['test', 'cov']
cov-erase = "coverage erase"
lint = "ruff learn_sql_model"
format = "black learn_sql_model"
format-check = "black --check learn_sql_model"
fix_ruff = "ruff --fix learn_sql_model"
fix = ['format', 'fix_ruff']
build-docs = "markata build"
lint-test = [
"lint",
"format-check",
"cov-erase",
"test",
"cov",
]

View file

@ -18,3 +18,9 @@ def read_heroes(hero: Hero) -> list[Hero]:
def read_heros() -> list[Hero]:
"read all the heros"
return Hero.get()
@app.patch("/heros/")
def update_heros() -> list[Hero]:
"read all the heros"
return Hero.get()

View file

@ -1,11 +1,9 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import SQLModel, Session
from fastapi import APIRouter, Depends
from sqlmodel import SQLModel
from learn_sql_model.api.user import oauth2_scheme
from learn_sql_model.config import Config, get_config
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import get_config, get_session
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}, {{modelname}}Create, {{modelname}}Read, {{modelname}}Update
{{modelname.lower()}}_router = APIRouter()
@ -15,31 +13,74 @@ def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine)
@{{modelname.lower()}}_router.get("/items/")
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
return {"token": token}
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{id}")
def get_{{modelname.lower()}}(id: int, config: Config = Depends(get_config)) -> {{modelname}}:
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def get_{{modelname.lower()}}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
) -> {{modelname}}Read:
"get one {{modelname.lower()}}"
return {{modelname}}().get(id=id, config=config)
@{{modelname.lower()}}_router.get("/h/{id}")
def get_h(id: int, config: Config = Depends(get_config)) -> {{modelname}}:
"get one {{modelname.lower()}}"
return {{modelname}}().get(id=id, config=config)
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
def post_{{modelname.lower()}}({{modelname.lower()}}: {{modelname}}, config: Config = Depends(get_config)) -> {{modelname.lower()}}:
"read all the {{modelname.lower()}}s"
{{modelname.lower()}}.post(config=config)
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
return {{modelname.lower()}}
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
async def post_{{modelname.lower()}}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Create,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = {{modelname}}.from_orm({{modelname.lower()}})
session.add(db_{{modelname.lower()}})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
@{{modelname.lower()}}_router.patch("/{{modelname.lower()}}/")
async def patch_{{modelname.lower()}}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Update,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
if not db_{{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
for key, value in {{modelname.lower()}}.dict(exclude_unset=True).items():
setattr(db_{{modelname.lower()}}, key, value)
session.add(db_{{modelname.lower()}})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
@{{modelname.lower()}}_router.delete("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def delete_{{modelname.lower()}}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
):
"read all the {{modelname.lower()}}s"
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
session.delete({{modelname.lower()}})
session.commit()
await manager.broadcast(f"deleted {{modelname.lower()}} {{{modelname.lower()}}_id}", id=1)
return {"ok": True}
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/")
def get_{{modelname.lower()}}s(config: Config = Depends(get_config)) -> list[{{modelname}}]:
async def get_{{modelname.lower()}}s(
*,
session: Session = Depends(get_session),
) -> list[{{modelname}}]:
"get all {{modelname.lower()}}s"
return {{modelname}}().get(config=config)
return {{modelname}}Read.list(session=session)

View file

@ -5,9 +5,8 @@ from engorgio import engorgio
from rich.console import Console
import typer
from learn_sql_model.config import Config, get_config
from learn_sql_model.config import get_config
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.{{modelname.lower()}} import (
{{modelname}},
{{modelname}}Create,
@ -18,6 +17,8 @@ from learn_sql_model.models.{{modelname.lower()}} import (
{{modelname.lower()}}_app = typer.Typer()
config = get_config()
@{{modelname.lower()}}_app.callback()
def {{modelname.lower()}}():
@ -27,12 +28,10 @@ def {{modelname.lower()}}():
@{{modelname.lower()}}_app.command()
@engorgio(typer=True)
def get(
id: Optional[int] = typer.Argument(default=None),
config: Config = None,
) -> Union[{{modelname}}, List[{{modelname.lower()}}]]:
{{modelname.lower()}}_id: Optional[int] = typer.Argument(default=None),
) -> Union[{{modelname}}, List[{{modelname}}]]:
"get one {{modelname.lower()}}"
config.init()
{{modelname.lower()}} = {{modelname}}Read.get(id=id, config=config)
{{modelname.lower()}} = {{modelname}}Read.get(id={{modelname.lower()}}_id)
Console().print({{modelname.lower()}})
return {{modelname.lower()}}
@ -41,12 +40,11 @@ def get(
@engorgio(typer=True)
def list(
where: Optional[str] = None,
config: Config = None,
offset: int = 0,
limit: Optional[int] = None,
) -> Union[{{modelname}}, List[{{modelname.lower()}}]]:
"get one {{modelname.lower()}}"
{{modelname.lower()}} = {{modelname}}Read.list(config=config, where=where, offset=offset, limit=limit)
) -> Union[{{modelname}}, List[{{modelname}}]]:
"list many {{modelname.lower()}}s"
{{modelname.lower()}} = {{modelname}}Read.list(where=where, offset=offset, limit=limit)
Console().print({{modelname.lower()}})
return {{modelname.lower()}}
@ -55,53 +53,40 @@ def list(
@engorgio(typer=True)
def create(
{{modelname.lower()}}: {{modelname}}Create,
config: Config = None,
) -> {{modelname}}:
"read all the {{modelname.lower()}}s"
# config.init()
{{modelname.lower()}} = {{modelname.lower()}}.post(config=config)
Console().print({{modelname.lower()}})
return {{modelname.lower()}}
"create one {{modelname.lower()}}"
{{modelname.lower()}}.post()
@{{modelname.lower()}}_app.command()
@engorgio(typer=True)
def update(
{{modelname.lower()}}: {{modelname}}Update,
config: Config = None,
) -> {{modelname}}:
"read all the {{modelname.lower()}}s"
{{modelname.lower()}} = {{modelname.lower()}}.update(config=config)
Console().print({{modelname.lower()}})
return {{modelname.lower()}}
"update one {{modelname.lower()}}"
{{modelname.lower()}}.update()
@{{modelname.lower()}}_app.command()
@engorgio(typer=True)
def delete(
{{modelname.lower()}}: {{modelname}}Delete,
config: Config = None,
) -> {{modelname}}:
"read all the {{modelname.lower()}}s"
# config.init()
{{modelname.lower()}} = {{modelname.lower()}}.delete(config=config)
return {{modelname.lower()}}
"delete a {{modelname.lower()}} by id"
{{modelname.lower()}}.delete()
@{{modelname.lower()}}_app.command()
@engorgio(typer=True)
def populate(
{{modelname.lower()}}: {{modelname}},
n: int = 10,
) -> {{modelname}}:
"read all the {{modelname.lower()}}s"
config = get_config()
"Create n number of {{modelname.lower()}}s"
if config.env == "prod":
Console().print("populate is not supported in production")
sys.exit(1)
for {{modelname.lower()}} in {{modelname}}Factory().batch(n):
pet = PetFactory().build()
{{modelname.lower()}}.pet = pet
Console().print({{modelname.lower()}})
{{modelname.lower()}}.post(config=config)
{{modelname.lower()}} = {{modelname}}Create(**{{modelname.lower()}}.dict())
{{modelname.lower()}}.post()

View file

@ -1,43 +1,41 @@
from typing import Optional
from fastapi import HTTPException
from fastapi import Depends, HTTPException
import httpx
from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select
from learn_sql_model.config import Config
from learn_sql_model.config import config, get_session
from learn_sql_model.models.pet import Pet
class {{modelname}}Base(SQLModel, table=False):
class {{modelname}}({{modelname.lower()}}Base, table=True):
class {{modelname}}({{modelname}}Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class {{modelname}}Create({{modelname.lower()}}Base):
class {{modelname}}Create({{modelname}}Base):
...
def post(self, config: Config) -> {{modelname}}:
config.init()
with Session(config.database.engine) as session:
db_{{modelname.lower()}} = {{modelname}}.from_orm(self)
session.add(db_{{modelname.lower()}})
session.commit()
session.refresh(db_{{modelname.lower()}})
return db_{{modelname.lower()}}
def post(self) -> {{modelname}}:
r = httpx.post(
f"{config.api_client.url}/{{modelname.lower()}}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
class {{modelname}}Read({{modelname.lower()}}Base):
class {{modelname}}Read({{modelname}}Base):
id: int
@classmethod
def get(
cls,
config: Config,
id: int,
) -> {{modelname}}:
with config.database.session as session:
{{modelname.lower()}} = session.get({{modelname}}, id)
if not {{modelname.lower()}}:
@ -47,53 +45,49 @@ class {{modelname}}Read({{modelname.lower()}}Base):
@classmethod
def list(
self,
config: Config,
where=None,
offset=0,
limit=None,
session: Session = get_session,
) -> {{modelname}}:
# with config.database.session as session:
with config.database.session as session:
statement = select({{modelname}})
if where != "None":
from sqlmodel import text
statement = select({{modelname}})
if where != "None" and where is not None:
from sqlmodel import text
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
{{modelname.lower()}}es = session.exec(statement).all()
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
{{modelname.lower()}}es = session.exec(statement).all()
return {{modelname.lower()}}es
class {{modelname}}Update(SQLModel):
# id is required to get the {{modelname.lower()}}
# id is required to update the {{modelname.lower()}}
id: int
# all other fields, must match the model, but with Optional default None
def update(self, config: Config) -> {{modelname}}:
with Session(config.database.engine) as session:
db_{{modelname.lower()}} = session.get({{modelname}}, self.id)
if not db_{{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
{{modelname.lower()}}_data = self.dict(exclude_unset=True)
for key, value in {{modelname.lower()}}_data.items():
if value is not None:
setattr(db_{{modelname.lower()}}, key, value)
session.add(db_{{modelname.lower()}})
session.commit()
session.refresh(db_{{modelname.lower()}})
return db_{{modelname.lower()}}
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{modelname.lower()}}")
def update(self) -> {{modelname}}:
r = httpx.patch(
f"{config.api_client.url}/{{modelname.lower()}}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
class {{modelname}}Delete(BaseModel):
id: int
def delete(self, config: Config) -> {{modelname}}:
config.init()
with Session(config.database.engine) as session:
{{modelname.lower()}} = session.get({{modelname}}, self.id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
session.delete({{modelname.lower()}})
session.commit()
return {"ok": True}
def delete(self) -> {{modelname}}:
r = httpx.delete(
f"{config.api_client.url}/{{modelname.lower()}}/{self.id}",
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {"ok": True}

View file

@ -0,0 +1,207 @@
from fastapi.testclient import TestClient
import pytest
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from learn_sql_model.api.app import app
from learn_sql_model.config import get_config, get_session
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
runner = CliRunner()
client = TestClient(app)
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_api_post(client: TestClient):
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
{{modelname.lower()}}_dict = {{modelname.lower()}}.dict()
response = client.post("/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {{modelname.lower()}}_dict})
response_{{modelname.lower()}} = {{modelname}}.parse_obj(response.json())
assert response.status_code == 200
assert response_{{modelname.lower()}}.name == "Steelman"
assert response_{{modelname.lower()}}.age == 25
def test_api_read_{{modelname.lower()}}es(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
{{modelname.lower()}}_2 = {{modelname}}(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
session.add({{modelname.lower()}}_1)
session.add({{modelname.lower()}}_2)
session.commit()
response = client.get("/{{modelname.lower()}}s/")
data = response.json()
assert response.status_code == 200
assert len(data) == 2
assert data[0]["name"] == {{modelname.lower()}}_1.name
assert data[0]["secret_name"] == {{modelname.lower()}}_1.secret_name
assert data[0]["age"] == {{modelname.lower()}}_1.age
assert data[0]["id"] == {{modelname.lower()}}_1.id
assert data[1]["name"] == {{modelname.lower()}}_2.name
assert data[1]["secret_name"] == {{modelname.lower()}}_2.secret_name
assert data[1]["age"] == {{modelname.lower()}}_2.age
assert data[1]["id"] == {{modelname.lower()}}_2.id
def test_api_read_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.get(f"/{{modelname.lower()}}/999")
assert response.status_code == 404
def test_api_read_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.get(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
data = response.json()
assert response.status_code == 200
assert data["name"] == {{modelname.lower()}}_1.name
assert data["secret_name"] == {{modelname.lower()}}_1.secret_name
assert data["age"] == {{modelname.lower()}}_1.age
assert data["id"] == {{modelname.lower()}}_1.id
def test_api_update_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.patch(
f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": {{modelname.lower()}}_1.id}}
)
data = response.json()
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson"
assert data["age"] is None
assert data["id"] == {{modelname.lower()}}_1.id
def test_api_update_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.patch(f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": 999}})
assert response.status_code == 404
def test_delete_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
{{modelname.lower()}}_in_db = session.get({{modelname}}, {{modelname.lower()}}_1.id)
assert response.status_code == 200
assert {{modelname.lower()}}_in_db is None
def test_delete_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/999")
assert response.status_code == 404
def test_config_memory(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
{{modelname.lower()}}es = session.exec(select({{modelname}})).all()
assert {{modelname.lower()}}.name == "Steelman"
assert {{modelname.lower()}}.age == 25
assert len({{modelname.lower()}}es) == 1
def test_cli_get(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "1"])
assert result.exit_code == 0
assert f"name='{{{modelname.lower()}}.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}.secret_name}'" in result.stdout
def test_cli_get_404(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "{{modelname}} not found"

View file

@ -1,15 +1,12 @@
import tempfile
from fastapi.testclient import TestClient
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from sqlmodel import SQLModel, Session, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from learn_sql_model.api.app import app
from learn_sql_model.cli.hero import hero_app
from learn_sql_model.config import Config, get_config
from learn_sql_model.config import get_config, get_session
from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.models.hero import Hero
@ -17,142 +14,193 @@ runner = CliRunner()
client = TestClient(app)
@pytest.fixture
def config() -> Config:
tmp_db = tempfile.NamedTemporaryFile(suffix=".db")
config = get_config({"database_url": f"sqlite:///{tmp_db.name}"})
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
config.database_url, connect_args={"check_same_thread": False}
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# breakpoint()
SQLModel.metadata.create_all(config.database.engine)
# def override_get_db():
# try:
# db = TestingSessionLocal()
# yield db
# finally:
# db.close()
def override_get_config():
return config
app.dependency_overrides[get_config] = override_get_config
yield config
# tmp_db automatically deletes here
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
def test_post_hero(config: Config) -> None:
hero = HeroFactory().build(name="Batman", age=50, id=1)
hero = hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
assert db_hero.age == 50
assert db_hero.name == "Batman"
@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_update_hero(config: Config) -> None:
hero = HeroFactory().build(name="Batman", age=50, id=1)
hero = hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
db_hero.name = "Superbman"
hero = db_hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
assert db_hero.age == 50
assert db_hero.name == "Superbman"
def test_cli_get(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
hero.post(config=config)
result = runner.invoke(
hero_app,
["get", "--id", 99, "--database-url", config.database_url],
)
assert result.exit_code == 0
db_hero = Hero().get(id=99, config=config)
assert db_hero.age == 25
assert db_hero.name == "Steelman"
def test_cli_create(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
result = runner.invoke(
hero_app,
[
"create",
*hero.flags(config=config),
"--database-url",
config.database_url,
],
)
assert result.exit_code == 0
db_hero = Hero().get(id=99, config=config)
assert db_hero.age == 25
assert db_hero.name == "Steelman"
def test_cli_populate(config):
result = runner.invoke(
hero_app,
[
"populate",
"--n",
10,
"--database-url",
config.database_url,
],
)
assert result.exit_code == 0
db_hero = Hero().get(config=config)
assert len(db_hero) == 10
def test_cli_populate_fails_prod(config):
result = runner.invoke(
hero_app,
["populate", "--n", 10, "--database-url", config.database_url, "--env", "prod"],
)
assert result.exit_code == 1
assert result.output.strip() == "populate is not supported in production"
def test_api_read(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
hero_id = hero.id
hero = hero.post(config=config)
response = client.get(f"/hero/{hero_id}")
assert response.status_code == 200
reponse_hero = Hero.parse_obj(response.json())
assert reponse_hero.id == hero_id
assert reponse_hero.name == "Steelman"
assert reponse_hero.age == 25
def test_api_post(config):
def test_api_post(client: TestClient):
hero = HeroFactory().build(name="Steelman", age=25)
hero_dict = hero.dict()
response = client.post("/hero/", json={"hero": hero_dict})
assert response.status_code == 200
response_hero = Hero.parse_obj(response.json())
db_hero = Hero().get(id=response_hero.id, config=config)
assert db_hero.name == "Steelman"
assert db_hero.age == 25
def test_api_read_all(config):
hero = HeroFactory().build(name="Mothman", age=25, id=99)
hero_id = hero.id
hero = hero.post(config=config)
response = client.get("/heros/")
assert response.status_code == 200
heros = response.json()
response_hero_json = [hero for hero in heros if hero["id"] == hero_id][0]
response_hero = Hero.parse_obj(response_hero_json)
assert response_hero.id == hero_id
assert response_hero.name == "Mothman"
assert response_hero.name == "Steelman"
assert response_hero.age == 25
def test_api_read_heroes(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
session.add(hero_1)
session.add(hero_2)
session.commit()
response = client.get("/heros/")
data = response.json()
assert response.status_code == 200
assert len(data) == 2
assert data[0]["name"] == hero_1.name
assert data[0]["secret_name"] == hero_1.secret_name
assert data[0]["age"] == hero_1.age
assert data[0]["id"] == hero_1.id
assert data[1]["name"] == hero_2.name
assert data[1]["secret_name"] == hero_2.secret_name
assert data[1]["age"] == hero_2.age
assert data[1]["id"] == hero_2.id
def test_api_read_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.get(f"/hero/999")
assert response.status_code == 404
def test_api_read_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.get(f"/hero/{hero_1.id}")
data = response.json()
assert response.status_code == 200
assert data["name"] == hero_1.name
assert data["secret_name"] == hero_1.secret_name
assert data["age"] == hero_1.age
assert data["id"] == hero_1.id
def test_api_update_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.patch(
f"/hero/", json={"hero": {"name": "Deadpuddle", "id": hero_1.id}}
)
data = response.json()
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson"
assert data["age"] is None
assert data["id"] == hero_1.id
def test_api_update_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.patch(f"/hero/", json={"hero": {"name": "Deadpuddle", "id": 999}})
assert response.status_code == 404
def test_delete_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.delete(f"/hero/{hero_1.id}")
hero_in_db = session.get(Hero, hero_1.id)
assert response.status_code == 200
assert hero_in_db is None
def test_delete_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.delete(f"/hero/999")
assert response.status_code == 404
def test_config_memory(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
heroes = session.exec(select(Hero)).all()
assert hero.name == "Steelman"
assert hero.age == 25
assert len(heroes) == 1
def test_cli_get(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "1"])
assert result.exit_code == 0
assert f"name='{hero.name}'" in result.stdout
assert f"secret_name='{hero.secret_name}'" in result.stdout
def test_cli_get_404(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "Hero not found"