This commit is contained in:
Waylon Walker 2023-06-09 16:04:58 -05:00
parent 1a0bf1adb9
commit c3db85a209
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
21 changed files with 647 additions and 658 deletions

View file

@ -2,6 +2,7 @@
from learn_sql_model.api.websocket_connection_manager import manager from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import Config from learn_sql_model.config import Config
from learn_sql_model.config import get_config from learn_sql_model.config import get_config
from learn_sql_model.config import get_session
from learn_sql_model.console import console from learn_sql_model.console import console
from learn_sql_model.database import get_database from learn_sql_model.database import get_database
from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.hero import HeroFactory

View file

@ -12,8 +12,6 @@ RUN groupadd -f -g ${SMOKE_GID} smoke && \
RUN mkdir /home/smoke && chown -R smoke:smoke /home/smoke && mkdir /src && chown smoke:smoke /src RUN mkdir /home/smoke && chown -R smoke:smoke /home/smoke && mkdir /src && chown smoke:smoke /src
WORKDIR /home/smoke WORKDIR /home/smoke
RUN apt update && \ RUN apt update && \
apt upgrade -y && \ apt upgrade -y && \
apt install -y \ apt install -y \

View file

@ -1,18 +1,9 @@
from typing import Annotated from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import SQLModel, Session
from fastapi import APIRouter, Depends
from sqlmodel import SQLModel
from learn_sql_model.api.user import oauth2_scheme
from learn_sql_model.api.websocket_connection_manager import manager from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import Config, get_config from learn_sql_model.config import get_config, get_session
from learn_sql_model.models.hero import ( from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, HeroUpdate
Hero,
HeroCreate,
HeroDelete,
HeroRead,
HeroUpdate,
)
hero_router = APIRouter() hero_router = APIRouter()
@ -22,52 +13,73 @@ def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine) SQLModel.metadata.create_all(get_config().database.engine)
@hero_router.get("/items/") @hero_router.get("/hero/{hero_id}")
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): async def get_hero(
return {"token": token} *,
session: Session = Depends(get_session),
hero_id: int,
@hero_router.get("/hero/{id}") ) -> HeroRead:
async def get_hero(id: int, config: Config = Depends(get_config)) -> Hero:
"get one hero" "get one hero"
return Hero().get(id=id, config=config) hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
@hero_router.get("/h/{id}") return hero
async def get_h(id: int, config: Config = Depends(get_config)) -> Hero:
"get one hero"
return Hero().get(id=id, config=config)
@hero_router.post("/hero/") @hero_router.post("/hero/")
async def post_hero(hero: HeroCreate) -> HeroRead: async def post_hero(
*,
session: Session = Depends(get_session),
hero: HeroCreate,
) -> HeroRead:
"read all the heros" "read all the heros"
config = get_config() db_hero = Hero.from_orm(hero)
hero = hero.post(config=config) session.add(db_hero)
session.commit()
session.refresh(db_hero)
await manager.broadcast({hero.json()}, id=1) await manager.broadcast({hero.json()}, id=1)
return hero return db_hero
@hero_router.patch("/hero/") @hero_router.patch("/hero/")
async def patch_hero(hero: HeroUpdate) -> HeroRead: async def patch_hero(
*,
session: Session = Depends(get_session),
hero: HeroUpdate,
) -> HeroRead:
"read all the heros" "read all the heros"
config = get_config() db_hero = session.get(Hero, hero.id)
hero = hero.update(config=config) if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
for key, value in hero.dict(exclude_unset=True).items():
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
await manager.broadcast({hero.json()}, id=1) await manager.broadcast({hero.json()}, id=1)
return hero return db_hero
@hero_router.delete("/hero/{hero_id}") @hero_router.delete("/hero/{hero_id}")
async def delete_hero(hero_id: int): async def delete_hero(
*,
session: Session = Depends(get_session),
hero_id: int,
):
"read all the heros" "read all the heros"
hero = HeroDelete(id=hero_id) hero = session.get(Hero, hero_id)
config = get_config() if not hero:
hero = hero.delete(config=config) raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
await manager.broadcast(f"deleted hero {hero_id}", id=1) await manager.broadcast(f"deleted hero {hero_id}", id=1)
return hero return {"ok": True}
@hero_router.get("/heros/") @hero_router.get("/heros/")
async def get_heros(config: Config = Depends(get_config)) -> list[Hero]: async def get_heros(
*,
session: Session = Depends(get_session),
) -> list[Hero]:
"get all heros" "get all heros"
return Hero().get(config=config) return HeroRead.list(session=session)

View file

@ -1,45 +0,0 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlmodel import SQLModel
from learn_sql_model.api.user import oauth2_scheme
from learn_sql_model.config import Config, get_config
from learn_sql_model.models.new import new
new_router = APIRouter()
@new_router.on_event("startup")
def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine)
@new_router.get("/items/")
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
return {"token": token}
@new_router.get("/new/{id}")
def get_new(id: int, config: Config = Depends(get_config)) -> new:
"get one new"
return new().get(id=id, config=config)
@new_router.get("/h/{id}")
def get_h(id: int, config: Config = Depends(get_config)) -> new:
"get one new"
return new().get(id=id, config=config)
@new_router.post("/new/")
def post_new(new: new, config: Config = Depends(get_config)) -> new:
"read all the news"
new.post(config=config)
return new
@new_router.get("/news/")
def get_news(config: Config = Depends(get_config)) -> list[new]:
"get all news"
return new().get(config=config)

View file

@ -69,4 +69,4 @@ async def websocket_endpoint(websocket: WebSocket):
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(websocket, id) manager.disconnect(websocket, id)
await manager.broadcast(f"Client #{client_id} left the chat", id) await manager.broadcast(f"Client #{id} left the chat", id)

View file

@ -1,3 +1,4 @@
import httpx
from rich.console import Console from rich.console import Console
import typer import typer
import uvicorn import uvicorn
@ -38,15 +39,11 @@ def status(
help="show the log messages", help="show the log messages",
), ),
): ):
import httpx
config = get_config() config = get_config()
host = config.api_server.host url = config.api_client.url
port = config.api_server.port
url = f"http://{host}:{port}/docs"
try: try:
r = httpx.get(url) r = httpx.get(url + "/docs")
if r.status_code == 200: if r.status_code == 200:
Console().print(f"[green]API: ([gold1]{url}[green]) is running") Console().print(f"[green]API: ([gold1]{url}[green]) is running")
else: else:
@ -59,7 +56,7 @@ def status(
Console().print( Console().print(
f"[green]database: ([gold1]{config.database.engine}[green]) is running" f"[green]database: ([gold1]{config.database.engine}[green]) is running"
) )
except Exception as e: except Exception:
Console().print( Console().print(
f"[red]database: ([gold1]{config.database.engine}[red]) is not running" f"[red]database: ([gold1]{config.database.engine}[red]) is not running"
) )

View file

@ -2,13 +2,11 @@ import sys
from typing import List, Optional, Union from typing import List, Optional, Union
from engorgio import engorgio from engorgio import engorgio
import httpx
from rich.console import Console from rich.console import Console
import typer import typer
from learn_sql_model.config import Config, get_config from learn_sql_model.config import get_config
from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.hero import ( from learn_sql_model.models.hero import (
Hero, Hero,
HeroCreate, HeroCreate,
@ -19,6 +17,8 @@ from learn_sql_model.models.hero import (
hero_app = typer.Typer() hero_app = typer.Typer()
config = get_config()
@hero_app.callback() @hero_app.callback()
def hero(): def hero():
@ -28,12 +28,10 @@ def hero():
@hero_app.command() @hero_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def get( def get(
id: Optional[int] = typer.Argument(default=None), hero_id: Optional[int] = typer.Argument(default=None),
config: Config = None,
) -> Union[Hero, List[Hero]]: ) -> Union[Hero, List[Hero]]:
"get one hero" "get one hero"
config.init() hero = HeroRead.get(id=hero_id)
hero = HeroRead.get(id=id, config=config)
Console().print(hero) Console().print(hero)
return hero return hero
@ -42,12 +40,11 @@ def get(
@engorgio(typer=True) @engorgio(typer=True)
def list( def list(
where: Optional[str] = None, where: Optional[str] = None,
config: Config = None,
offset: int = 0, offset: int = 0,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> Union[Hero, List[Hero]]: ) -> Union[Hero, List[Hero]]:
"get one hero" "list many heros"
hero = HeroRead.list(config=config, where=where, offset=offset, limit=limit) hero = HeroRead.list(where=where, offset=offset, limit=limit)
Console().print(hero) Console().print(hero)
return hero return hero
@ -56,65 +53,39 @@ def list(
@engorgio(typer=True) @engorgio(typer=True)
def create( def create(
hero: HeroCreate, hero: HeroCreate,
config: Config = None,
) -> Hero: ) -> Hero:
"read all the heros" "create one hero"
hero.post()
r = httpx.post(
f"{config.api_client.url}/hero/",
json=hero.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
# hero = hero.post(config=config)
# Console().print(hero)
# return hero
@hero_app.command() @hero_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def update( def update(
hero: HeroUpdate, hero: HeroUpdate,
config: Config = None,
) -> Hero: ) -> Hero:
"read all the heros" "update one hero"
r = httpx.patch( hero.update()
f"{config.api_client.url}/hero/",
json=hero.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
@hero_app.command() @hero_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def delete( def delete(
hero: HeroDelete, hero: HeroDelete,
config: Config = None,
) -> Hero: ) -> Hero:
"read all the heros" "delete a hero by id"
r = httpx.delete( hero.delete()
f"{config.api_client.url}/hero/{hero.id}",
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
@hero_app.command() @hero_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def populate( def populate(
hero: Hero,
n: int = 10, n: int = 10,
) -> Hero: ) -> Hero:
"read all the heros" "Create n number of heros"
config = get_config()
if config.env == "prod": if config.env == "prod":
Console().print("populate is not supported in production") Console().print("populate is not supported in production")
sys.exit(1) sys.exit(1)
for hero in HeroFactory().batch(n): for hero in HeroFactory().batch(n):
pet = PetFactory().build() hero = HeroCreate(**hero.dict())
hero.pet = pet hero.post()
Console().print(hero)
hero.post(config=config)

View file

@ -44,7 +44,6 @@ def create_revision(
prompt=True, prompt=True,
), ),
): ):
alembic_cfg = Config("alembic.ini") alembic_cfg = Config("alembic.ini")
alembic.command.revision( alembic.command.revision(
config=alembic_cfg, config=alembic_cfg,
@ -63,7 +62,6 @@ def checkout(
), ),
revision: str = typer.Option("head"), revision: str = typer.Option("head"),
): ):
alembic_cfg = Config("alembic.ini") alembic_cfg = Config("alembic.ini")
alembic.command.upgrade(config=alembic_cfg, revision="head") alembic.command.upgrade(config=alembic_cfg, revision="head")

View file

@ -1,107 +0,0 @@
import sys
from typing import List, Optional, Union
from engorgio import engorgio
from rich.console import Console
import typer
from learn_sql_model.config import Config, get_config
from learn_sql_model.factories.new import newFactory
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.new import (
new,
newCreate,
newDelete,
newRead,
newUpdate,
)
new_app = typer.Typer()
@new_app.callback()
def new():
"model cli"
@new_app.command()
@engorgio(typer=True)
def get(
id: Optional[int] = typer.Argument(default=None),
config: Config = None,
) -> Union[new, List[new]]:
"get one new"
config.init()
new = newRead.get(id=id, config=config)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def list(
where: Optional[str] = None,
config: Config = None,
offset: int = 0,
limit: Optional[int] = None,
) -> Union[new, List[new]]:
"get one new"
new = newRead.list(config=config, where=where, offset=offset, limit=limit)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def create(
new: newCreate,
config: Config = None,
) -> new:
"read all the news"
# config.init()
new = new.post(config=config)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def update(
new: newUpdate,
config: Config = None,
) -> new:
"read all the news"
new = new.update(config=config)
Console().print(new)
return new
@new_app.command()
@engorgio(typer=True)
def delete(
new: newDelete,
config: Config = None,
) -> new:
"read all the news"
# config.init()
new = new.delete(config=config)
return new
@new_app.command()
@engorgio(typer=True)
def populate(
new: new,
n: int = 10,
) -> new:
"read all the news"
config = get_config()
if config.env == "prod":
Console().print("populate is not supported in production")
sys.exit(1)
for new in newFactory().batch(n):
pet = PetFactory().build()
new.pet = pet
Console().print(new)
new.post(config=config)

View file

@ -30,7 +30,6 @@ class ApiClient(BaseModel):
class Database: class Database:
def __init__(self, config: "Config" = None) -> None: def __init__(self, config: "Config" = None) -> None:
if config is None: if config is None:
self.config = get_config() self.config = get_config()
else: else:
self.config = config self.config = config
@ -71,17 +70,26 @@ class Config(BaseSettings):
def get_database(config: Config = None) -> Database: def get_database(config: Config = None) -> Database:
if config is None: if config is None:
config = get_config() config = get_config()
return Database(config) return Database(config)
def get_config(overrides: dict = {}) -> Config:
raw_config = load("learn_sql_model")
config = Config(**raw_config, **overrides)
return config
def get_session(config: Config = None) -> "Session":
with Session(config.database.engine) as session:
yield session
async def reset_db_state(config: Config = None) -> None: async def reset_db_state(config: Config = None) -> None:
if config is None: if config is None:
config = get_config() config = get_config()
config.database.db._state._state.set(db_state_default.copy()) config.database.db._state._state.set(config.database.db_state_default.copy())
config.database.db._state.reset() config.database.db._state.reset()
@ -96,7 +104,4 @@ def get_db(config: Config = None, reset_db_state=Depends(reset_db_state)):
config.database.db.close() config.database.db.close()
def get_config(overrides: dict = {}) -> Config: config = get_config()
raw_config = load("learn_sql_model")
config = Config(**raw_config, **overrides)
return config

View file

@ -1,14 +0,0 @@
from faker import Faker
from polyfactory.factories.pydantic_factory import ModelFactory
from learn_sql_model.models.new import new
class newFactory(ModelFactory[new]):
__model__ = new
__faker__ = Faker(locale="en_US")
__set_as_default_factory_for_type__ = True
id = None
__random_seed__ = 10

View file

@ -21,7 +21,6 @@ class FastModel(SQLModel):
def post(self, config: "Config" = None) -> None: def post(self, config: "Config" = None) -> None:
if config is None: if config is None:
config = get_config() config = get_config()
self.pre_post() self.pre_post()
@ -36,7 +35,6 @@ class FastModel(SQLModel):
self, id: int = None, config: "Config" = None, where=None self, id: int = None, config: "Config" = None, where=None
) -> Optional["FastModel"]: ) -> Optional["FastModel"]:
if config is None: if config is None:
config = get_config() config = get_config()
self.pre_get() self.pre_get()

View file

@ -1,10 +1,11 @@
from typing import Optional from typing import Optional
from fastapi import HTTPException from fastapi import HTTPException
import httpx
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select from sqlmodel import Field, Relationship, SQLModel, Session, select
from learn_sql_model.config import Config from learn_sql_model.config import config, get_session
from learn_sql_model.models.pet import Pet from learn_sql_model.models.pet import Pet
@ -25,14 +26,13 @@ class Hero(HeroBase, table=True):
class HeroCreate(HeroBase): class HeroCreate(HeroBase):
... ...
def post(self, config: Config) -> Hero: def post(self) -> Hero:
config.init() r = httpx.post(
with Session(config.database.engine) as session: f"{config.api_client.url}/hero/",
db_hero = Hero.from_orm(self) json=self.dict(),
session.add(db_hero) )
session.commit() if r.status_code != 200:
session.refresh(db_hero) raise RuntimeError(f"{r.status_code}:\n {r.text}")
return db_hero
class HeroRead(HeroBase): class HeroRead(HeroBase):
@ -41,10 +41,8 @@ class HeroRead(HeroBase):
@classmethod @classmethod
def get( def get(
cls, cls,
config: Config,
id: int, id: int,
) -> Hero: ) -> Hero:
with config.database.session as session: with config.database.session as session:
hero = session.get(Hero, id) hero = session.get(Hero, id)
if not hero: if not hero:
@ -54,25 +52,25 @@ class HeroRead(HeroBase):
@classmethod @classmethod
def list( def list(
self, self,
config: Config,
where=None, where=None,
offset=0, offset=0,
limit=None, limit=None,
session: Session = get_session,
) -> Hero: ) -> Hero:
# with config.database.session as session:
with config.database.session as session: statement = select(Hero)
statement = select(Hero) if where != "None" and where is not None:
if where != "None": from sqlmodel import text
from sqlmodel import text
statement = statement.where(text(where)) statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit) statement = statement.offset(offset).limit(limit)
heroes = session.exec(statement).all() heroes = session.exec(statement).all()
return heroes return heroes
class HeroUpdate(SQLModel): class HeroUpdate(SQLModel):
# id is required to get the hero # id is required to update the hero
id: int id: int
# all other fields, must match the model, but with Optional default None # all other fields, must match the model, but with Optional default None
@ -84,30 +82,22 @@ class HeroUpdate(SQLModel):
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="hero") pet: Optional[Pet] = Relationship(back_populates="hero")
def update(self, config: Config) -> Hero: def update(self) -> Hero:
with Session(config.database.engine) as session: r = httpx.patch(
db_hero = session.get(Hero, self.id) f"{config.api_client.url}/hero/",
if not db_hero: json=self.dict(),
raise HTTPException(status_code=404, detail="Hero not found") )
hero_data = self.dict(exclude_unset=True) if r.status_code != 200:
for key, value in hero_data.items(): raise RuntimeError(f"{r.status_code}:\n {r.text}")
if value is not None:
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
class HeroDelete(BaseModel): class HeroDelete(BaseModel):
id: int id: int
def delete(self, config: Config) -> Hero: def delete(self) -> Hero:
config.init() r = httpx.delete(
with Session(config.database.engine) as session: f"{config.api_client.url}/hero/{self.id}",
hero = session.get(Hero, self.id) )
if not hero: if r.status_code != 200:
raise HTTPException(status_code=404, detail="Hero not found") raise RuntimeError(f"{r.status_code}:\n {r.text}")
session.delete(hero) return {"ok": True}
session.commit()
return {"ok": True}

View file

@ -1,99 +0,0 @@
from typing import Optional
from fastapi import HTTPException
from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select
from learn_sql_model.config import Config
from learn_sql_model.models.pet import Pet
class newBase(SQLModel, table=False):
class new(newBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class newCreate(newBase):
...
def post(self, config: Config) -> new:
config.init()
with Session(config.database.engine) as session:
db_new = new.from_orm(self)
session.add(db_new)
session.commit()
session.refresh(db_new)
return db_new
class newRead(newBase):
id: int
@classmethod
def get(
cls,
config: Config,
id: int,
) -> new:
with config.database.session as session:
new = session.get(new, id)
if not new:
raise HTTPException(status_code=404, detail="new not found")
return new
@classmethod
def list(
self,
config: Config,
where=None,
offset=0,
limit=None,
) -> new:
with config.database.session as session:
statement = select(new)
if where != "None":
from sqlmodel import text
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
newes = session.exec(statement).all()
return newes
class newUpdate(SQLModel):
# id is required to get the new
id: int
# all other fields, must match the model, but with Optional default None
def update(self, config: Config) -> new:
with Session(config.database.engine) as session:
db_new = session.get(new, self.id)
if not db_new:
raise HTTPException(status_code=404, detail="new not found")
new_data = self.dict(exclude_unset=True)
for key, value in new_data.items():
if value is not None:
setattr(db_new, key, value)
session.add(db_new)
session.commit()
session.refresh(db_new)
return db_new
class newDelete(BaseModel):
id: int
def delete(self, config: Config) -> new:
config.init()
with Session(config.database.engine) as session:
new = session.get(new, self.id)
if not new:
raise HTTPException(status_code=404, detail="new not found")
session.delete(new)
session.commit()
return {"ok": True}

View file

@ -77,14 +77,17 @@ dependencies = [
[tool.hatch.envs.default.scripts] [tool.hatch.envs.default.scripts]
test = "coverage run -m pytest" test = "coverage run -m pytest"
cov = "coverage-rich report" cov = "coverage-rich report"
test-cov = ['test', 'cov'] cov-erase = "coverage erase"
lint = "ruff learn_sql_model" lint = "ruff learn_sql_model"
format = "black learn_sql_model" format = "black learn_sql_model"
format-check = "black --check learn_sql_model" format-check = "black --check learn_sql_model"
fix_ruff = "ruff --fix learn_sql_model"
fix = ['format', 'fix_ruff']
build-docs = "markata build" build-docs = "markata build"
lint-test = [ lint-test = [
"lint", "lint",
"format-check", "format-check",
"cov-erase",
"test", "test",
"cov", "cov",
] ]

View file

@ -18,3 +18,9 @@ def read_heroes(hero: Hero) -> list[Hero]:
def read_heros() -> list[Hero]: def read_heros() -> list[Hero]:
"read all the heros" "read all the heros"
return Hero.get() return Hero.get()
@app.patch("/heros/")
def update_heros() -> list[Hero]:
"read all the heros"
return Hero.get()

View file

@ -1,11 +1,9 @@
from typing import Annotated from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import SQLModel, Session
from fastapi import APIRouter, Depends from learn_sql_model.api.websocket_connection_manager import manager
from sqlmodel import SQLModel from learn_sql_model.config import get_config, get_session
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}, {{modelname}}Create, {{modelname}}Read, {{modelname}}Update
from learn_sql_model.api.user import oauth2_scheme
from learn_sql_model.config import Config, get_config
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
{{modelname.lower()}}_router = APIRouter() {{modelname.lower()}}_router = APIRouter()
@ -15,31 +13,74 @@ def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine) SQLModel.metadata.create_all(get_config().database.engine)
@{{modelname.lower()}}_router.get("/items/") @{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): async def get_{{modelname.lower()}}(
return {"token": token} *,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{id}") ) -> {{modelname}}Read:
def get_{{modelname.lower()}}(id: int, config: Config = Depends(get_config)) -> {{modelname}}:
"get one {{modelname.lower()}}" "get one {{modelname.lower()}}"
return {{modelname}}().get(id=id, config=config) {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
@{{modelname.lower()}}_router.get("/h/{id}")
def get_h(id: int, config: Config = Depends(get_config)) -> {{modelname}}:
"get one {{modelname.lower()}}"
return {{modelname}}().get(id=id, config=config)
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
def post_{{modelname.lower()}}({{modelname.lower()}}: {{modelname}}, config: Config = Depends(get_config)) -> {{modelname.lower()}}:
"read all the {{modelname.lower()}}s"
{{modelname.lower()}}.post(config=config)
return {{modelname.lower()}} return {{modelname.lower()}}
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
async def post_{{modelname.lower()}}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Create,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = {{modelname}}.from_orm({{modelname.lower()}})
session.add(db_{{modelname.lower()}})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
@{{modelname.lower()}}_router.patch("/{{modelname.lower()}}/")
async def patch_{{modelname.lower()}}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Update,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
if not db_{{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
for key, value in {{modelname.lower()}}.dict(exclude_unset=True).items():
setattr(db_{{modelname.lower()}}, key, value)
session.add(db_{{modelname.lower()}})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
@{{modelname.lower()}}_router.delete("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def delete_{{modelname.lower()}}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
):
"read all the {{modelname.lower()}}s"
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
session.delete({{modelname.lower()}})
session.commit()
await manager.broadcast(f"deleted {{modelname.lower()}} {{{modelname.lower()}}_id}", id=1)
return {"ok": True}
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/") @{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/")
def get_{{modelname.lower()}}s(config: Config = Depends(get_config)) -> list[{{modelname}}]: async def get_{{modelname.lower()}}s(
*,
session: Session = Depends(get_session),
) -> list[{{modelname}}]:
"get all {{modelname.lower()}}s" "get all {{modelname.lower()}}s"
return {{modelname}}().get(config=config) return {{modelname}}Read.list(session=session)

View file

@ -5,9 +5,8 @@ from engorgio import engorgio
from rich.console import Console from rich.console import Console
import typer import typer
from learn_sql_model.config import Config, get_config from learn_sql_model.config import get_config
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.{{modelname.lower()}} import ( from learn_sql_model.models.{{modelname.lower()}} import (
{{modelname}}, {{modelname}},
{{modelname}}Create, {{modelname}}Create,
@ -18,6 +17,8 @@ from learn_sql_model.models.{{modelname.lower()}} import (
{{modelname.lower()}}_app = typer.Typer() {{modelname.lower()}}_app = typer.Typer()
config = get_config()
@{{modelname.lower()}}_app.callback() @{{modelname.lower()}}_app.callback()
def {{modelname.lower()}}(): def {{modelname.lower()}}():
@ -27,12 +28,10 @@ def {{modelname.lower()}}():
@{{modelname.lower()}}_app.command() @{{modelname.lower()}}_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def get( def get(
id: Optional[int] = typer.Argument(default=None), {{modelname.lower()}}_id: Optional[int] = typer.Argument(default=None),
config: Config = None, ) -> Union[{{modelname}}, List[{{modelname}}]]:
) -> Union[{{modelname}}, List[{{modelname.lower()}}]]:
"get one {{modelname.lower()}}" "get one {{modelname.lower()}}"
config.init() {{modelname.lower()}} = {{modelname}}Read.get(id={{modelname.lower()}}_id)
{{modelname.lower()}} = {{modelname}}Read.get(id=id, config=config)
Console().print({{modelname.lower()}}) Console().print({{modelname.lower()}})
return {{modelname.lower()}} return {{modelname.lower()}}
@ -41,12 +40,11 @@ def get(
@engorgio(typer=True) @engorgio(typer=True)
def list( def list(
where: Optional[str] = None, where: Optional[str] = None,
config: Config = None,
offset: int = 0, offset: int = 0,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> Union[{{modelname}}, List[{{modelname.lower()}}]]: ) -> Union[{{modelname}}, List[{{modelname}}]]:
"get one {{modelname.lower()}}" "list many {{modelname.lower()}}s"
{{modelname.lower()}} = {{modelname}}Read.list(config=config, where=where, offset=offset, limit=limit) {{modelname.lower()}} = {{modelname}}Read.list(where=where, offset=offset, limit=limit)
Console().print({{modelname.lower()}}) Console().print({{modelname.lower()}})
return {{modelname.lower()}} return {{modelname.lower()}}
@ -55,53 +53,40 @@ def list(
@engorgio(typer=True) @engorgio(typer=True)
def create( def create(
{{modelname.lower()}}: {{modelname}}Create, {{modelname.lower()}}: {{modelname}}Create,
config: Config = None,
) -> {{modelname}}: ) -> {{modelname}}:
"read all the {{modelname.lower()}}s" "create one {{modelname.lower()}}"
# config.init() {{modelname.lower()}}.post()
{{modelname.lower()}} = {{modelname.lower()}}.post(config=config)
Console().print({{modelname.lower()}})
return {{modelname.lower()}}
@{{modelname.lower()}}_app.command() @{{modelname.lower()}}_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def update( def update(
{{modelname.lower()}}: {{modelname}}Update, {{modelname.lower()}}: {{modelname}}Update,
config: Config = None,
) -> {{modelname}}: ) -> {{modelname}}:
"read all the {{modelname.lower()}}s" "update one {{modelname.lower()}}"
{{modelname.lower()}} = {{modelname.lower()}}.update(config=config) {{modelname.lower()}}.update()
Console().print({{modelname.lower()}})
return {{modelname.lower()}}
@{{modelname.lower()}}_app.command() @{{modelname.lower()}}_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def delete( def delete(
{{modelname.lower()}}: {{modelname}}Delete, {{modelname.lower()}}: {{modelname}}Delete,
config: Config = None,
) -> {{modelname}}: ) -> {{modelname}}:
"read all the {{modelname.lower()}}s" "delete a {{modelname.lower()}} by id"
# config.init() {{modelname.lower()}}.delete()
{{modelname.lower()}} = {{modelname.lower()}}.delete(config=config)
return {{modelname.lower()}}
@{{modelname.lower()}}_app.command() @{{modelname.lower()}}_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def populate( def populate(
{{modelname.lower()}}: {{modelname}},
n: int = 10, n: int = 10,
) -> {{modelname}}: ) -> {{modelname}}:
"read all the {{modelname.lower()}}s" "Create n number of {{modelname.lower()}}s"
config = get_config()
if config.env == "prod": if config.env == "prod":
Console().print("populate is not supported in production") Console().print("populate is not supported in production")
sys.exit(1) sys.exit(1)
for {{modelname.lower()}} in {{modelname}}Factory().batch(n): for {{modelname.lower()}} in {{modelname}}Factory().batch(n):
pet = PetFactory().build() {{modelname.lower()}} = {{modelname}}Create(**{{modelname.lower()}}.dict())
{{modelname.lower()}}.pet = pet {{modelname.lower()}}.post()
Console().print({{modelname.lower()}})
{{modelname.lower()}}.post(config=config)

View file

@ -1,43 +1,41 @@
from typing import Optional from typing import Optional
from fastapi import HTTPException from fastapi import Depends, HTTPException
import httpx
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select from sqlmodel import Field, Relationship, SQLModel, Session, select
from learn_sql_model.config import Config from learn_sql_model.config import config, get_session
from learn_sql_model.models.pet import Pet from learn_sql_model.models.pet import Pet
class {{modelname}}Base(SQLModel, table=False): class {{modelname}}Base(SQLModel, table=False):
class {{modelname}}({{modelname.lower()}}Base, table=True): class {{modelname}}({{modelname}}Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
class {{modelname}}Create({{modelname.lower()}}Base): class {{modelname}}Create({{modelname}}Base):
... ...
def post(self, config: Config) -> {{modelname}}: def post(self) -> {{modelname}}:
config.init() r = httpx.post(
with Session(config.database.engine) as session: f"{config.api_client.url}/{{modelname.lower()}}/",
db_{{modelname.lower()}} = {{modelname}}.from_orm(self) json=self.dict(),
session.add(db_{{modelname.lower()}}) )
session.commit() if r.status_code != 200:
session.refresh(db_{{modelname.lower()}}) raise RuntimeError(f"{r.status_code}:\n {r.text}")
return db_{{modelname.lower()}}
class {{modelname}}Read({{modelname.lower()}}Base): class {{modelname}}Read({{modelname}}Base):
id: int id: int
@classmethod @classmethod
def get( def get(
cls, cls,
config: Config,
id: int, id: int,
) -> {{modelname}}: ) -> {{modelname}}:
with config.database.session as session: with config.database.session as session:
{{modelname.lower()}} = session.get({{modelname}}, id) {{modelname.lower()}} = session.get({{modelname}}, id)
if not {{modelname.lower()}}: if not {{modelname.lower()}}:
@ -47,53 +45,49 @@ class {{modelname}}Read({{modelname.lower()}}Base):
@classmethod @classmethod
def list( def list(
self, self,
config: Config,
where=None, where=None,
offset=0, offset=0,
limit=None, limit=None,
session: Session = get_session,
) -> {{modelname}}: ) -> {{modelname}}:
# with config.database.session as session:
with config.database.session as session: statement = select({{modelname}})
statement = select({{modelname}}) if where != "None" and where is not None:
if where != "None": from sqlmodel import text
from sqlmodel import text
statement = statement.where(text(where)) statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit) statement = statement.offset(offset).limit(limit)
{{modelname.lower()}}es = session.exec(statement).all() {{modelname.lower()}}es = session.exec(statement).all()
return {{modelname.lower()}}es return {{modelname.lower()}}es
class {{modelname}}Update(SQLModel): class {{modelname}}Update(SQLModel):
# id is required to get the {{modelname.lower()}} # id is required to update the {{modelname.lower()}}
id: int id: int
# all other fields, must match the model, but with Optional default None # all other fields, must match the model, but with Optional default None
def update(self, config: Config) -> {{modelname}}: pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
with Session(config.database.engine) as session: pet: Optional[Pet] = Relationship(back_populates="{{modelname.lower()}}")
db_{{modelname.lower()}} = session.get({{modelname}}, self.id)
if not db_{{modelname.lower()}}: def update(self) -> {{modelname}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found") r = httpx.patch(
{{modelname.lower()}}_data = self.dict(exclude_unset=True) f"{config.api_client.url}/{{modelname.lower()}}/",
for key, value in {{modelname.lower()}}_data.items(): json=self.dict(),
if value is not None: )
setattr(db_{{modelname.lower()}}, key, value) if r.status_code != 200:
session.add(db_{{modelname.lower()}}) raise RuntimeError(f"{r.status_code}:\n {r.text}")
session.commit()
session.refresh(db_{{modelname.lower()}})
return db_{{modelname.lower()}}
class {{modelname}}Delete(BaseModel): class {{modelname}}Delete(BaseModel):
id: int id: int
def delete(self, config: Config) -> {{modelname}}: def delete(self) -> {{modelname}}:
config.init() r = httpx.delete(
with Session(config.database.engine) as session: f"{config.api_client.url}/{{modelname.lower()}}/{self.id}",
{{modelname.lower()}} = session.get({{modelname}}, self.id) )
if not {{modelname.lower()}}: if r.status_code != 200:
raise HTTPException(status_code=404, detail="{{modelname}} not found") raise RuntimeError(f"{r.status_code}:\n {r.text}")
session.delete({{modelname.lower()}}) return {"ok": True}
session.commit()
return {"ok": True}

View file

@ -0,0 +1,207 @@
from fastapi.testclient import TestClient
import pytest
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from learn_sql_model.api.app import app
from learn_sql_model.config import get_config, get_session
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
runner = CliRunner()
client = TestClient(app)
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_api_post(client: TestClient):
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
{{modelname.lower()}}_dict = {{modelname.lower()}}.dict()
response = client.post("/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {{modelname.lower()}}_dict})
response_{{modelname.lower()}} = {{modelname}}.parse_obj(response.json())
assert response.status_code == 200
assert response_{{modelname.lower()}}.name == "Steelman"
assert response_{{modelname.lower()}}.age == 25
def test_api_read_{{modelname.lower()}}es(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
{{modelname.lower()}}_2 = {{modelname}}(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
session.add({{modelname.lower()}}_1)
session.add({{modelname.lower()}}_2)
session.commit()
response = client.get("/{{modelname.lower()}}s/")
data = response.json()
assert response.status_code == 200
assert len(data) == 2
assert data[0]["name"] == {{modelname.lower()}}_1.name
assert data[0]["secret_name"] == {{modelname.lower()}}_1.secret_name
assert data[0]["age"] == {{modelname.lower()}}_1.age
assert data[0]["id"] == {{modelname.lower()}}_1.id
assert data[1]["name"] == {{modelname.lower()}}_2.name
assert data[1]["secret_name"] == {{modelname.lower()}}_2.secret_name
assert data[1]["age"] == {{modelname.lower()}}_2.age
assert data[1]["id"] == {{modelname.lower()}}_2.id
def test_api_read_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.get(f"/{{modelname.lower()}}/999")
assert response.status_code == 404
def test_api_read_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.get(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
data = response.json()
assert response.status_code == 200
assert data["name"] == {{modelname.lower()}}_1.name
assert data["secret_name"] == {{modelname.lower()}}_1.secret_name
assert data["age"] == {{modelname.lower()}}_1.age
assert data["id"] == {{modelname.lower()}}_1.id
def test_api_update_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.patch(
f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": {{modelname.lower()}}_1.id}}
)
data = response.json()
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson"
assert data["age"] is None
assert data["id"] == {{modelname.lower()}}_1.id
def test_api_update_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.patch(f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": 999}})
assert response.status_code == 404
def test_delete_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
{{modelname.lower()}}_in_db = session.get({{modelname}}, {{modelname.lower()}}_1.id)
assert response.status_code == 200
assert {{modelname.lower()}}_in_db is None
def test_delete_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/999")
assert response.status_code == 404
def test_config_memory(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
{{modelname.lower()}}es = session.exec(select({{modelname}})).all()
assert {{modelname.lower()}}.name == "Steelman"
assert {{modelname.lower()}}.age == 25
assert len({{modelname.lower()}}es) == 1
def test_cli_get(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "1"])
assert result.exit_code == 0
assert f"name='{{{modelname.lower()}}.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}.secret_name}'" in result.stdout
def test_cli_get_404(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "{{modelname}} not found"

View file

@ -1,15 +1,12 @@
import tempfile
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, Session, select
from sqlmodel import SQLModel from sqlmodel.pool import StaticPool
from typer.testing import CliRunner from typer.testing import CliRunner
from learn_sql_model.api.app import app from learn_sql_model.api.app import app
from learn_sql_model.cli.hero import hero_app from learn_sql_model.config import get_config, get_session
from learn_sql_model.config import Config, get_config
from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.models.hero import Hero from learn_sql_model.models.hero import Hero
@ -17,142 +14,193 @@ runner = CliRunner()
client = TestClient(app) client = TestClient(app)
@pytest.fixture @pytest.fixture(name="session")
def config() -> Config: def session_fixture():
tmp_db = tempfile.NamedTemporaryFile(suffix=".db")
config = get_config({"database_url": f"sqlite:///{tmp_db.name}"})
engine = create_engine( engine = create_engine(
config.database_url, connect_args={"check_same_thread": False} "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
) )
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SQLModel.metadata.create_all(engine)
with Session(engine) as session:
# breakpoint() yield session
SQLModel.metadata.create_all(config.database.engine)
# def override_get_db():
# try:
# db = TestingSessionLocal()
# yield db
# finally:
# db.close()
def override_get_config():
return config
app.dependency_overrides[get_config] = override_get_config
yield config
# tmp_db automatically deletes here
def test_post_hero(config: Config) -> None: @pytest.fixture(name="client")
hero = HeroFactory().build(name="Batman", age=50, id=1) def client_fixture(session: Session):
hero = hero.post(config=config) def get_session_override():
db_hero = Hero().get(id=1, config=config) return session
assert db_hero.age == 50
assert db_hero.name == "Batman" app.dependency_overrides[get_session] = get_session_override
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_update_hero(config: Config) -> None: def test_api_post(client: TestClient):
hero = HeroFactory().build(name="Batman", age=50, id=1)
hero = hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
db_hero.name = "Superbman"
hero = db_hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
assert db_hero.age == 50
assert db_hero.name == "Superbman"
def test_cli_get(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
hero.post(config=config)
result = runner.invoke(
hero_app,
["get", "--id", 99, "--database-url", config.database_url],
)
assert result.exit_code == 0
db_hero = Hero().get(id=99, config=config)
assert db_hero.age == 25
assert db_hero.name == "Steelman"
def test_cli_create(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
result = runner.invoke(
hero_app,
[
"create",
*hero.flags(config=config),
"--database-url",
config.database_url,
],
)
assert result.exit_code == 0
db_hero = Hero().get(id=99, config=config)
assert db_hero.age == 25
assert db_hero.name == "Steelman"
def test_cli_populate(config):
result = runner.invoke(
hero_app,
[
"populate",
"--n",
10,
"--database-url",
config.database_url,
],
)
assert result.exit_code == 0
db_hero = Hero().get(config=config)
assert len(db_hero) == 10
def test_cli_populate_fails_prod(config):
result = runner.invoke(
hero_app,
["populate", "--n", 10, "--database-url", config.database_url, "--env", "prod"],
)
assert result.exit_code == 1
assert result.output.strip() == "populate is not supported in production"
def test_api_read(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
hero_id = hero.id
hero = hero.post(config=config)
response = client.get(f"/hero/{hero_id}")
assert response.status_code == 200
reponse_hero = Hero.parse_obj(response.json())
assert reponse_hero.id == hero_id
assert reponse_hero.name == "Steelman"
assert reponse_hero.age == 25
def test_api_post(config):
hero = HeroFactory().build(name="Steelman", age=25) hero = HeroFactory().build(name="Steelman", age=25)
hero_dict = hero.dict() hero_dict = hero.dict()
response = client.post("/hero/", json={"hero": hero_dict}) response = client.post("/hero/", json={"hero": hero_dict})
assert response.status_code == 200
response_hero = Hero.parse_obj(response.json()) response_hero = Hero.parse_obj(response.json())
db_hero = Hero().get(id=response_hero.id, config=config)
assert db_hero.name == "Steelman"
assert db_hero.age == 25
def test_api_read_all(config):
hero = HeroFactory().build(name="Mothman", age=25, id=99)
hero_id = hero.id
hero = hero.post(config=config)
response = client.get("/heros/")
assert response.status_code == 200 assert response.status_code == 200
heros = response.json() assert response_hero.name == "Steelman"
response_hero_json = [hero for hero in heros if hero["id"] == hero_id][0]
response_hero = Hero.parse_obj(response_hero_json)
assert response_hero.id == hero_id
assert response_hero.name == "Mothman"
assert response_hero.age == 25 assert response_hero.age == 25
def test_api_read_heroes(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
session.add(hero_1)
session.add(hero_2)
session.commit()
response = client.get("/heros/")
data = response.json()
assert response.status_code == 200
assert len(data) == 2
assert data[0]["name"] == hero_1.name
assert data[0]["secret_name"] == hero_1.secret_name
assert data[0]["age"] == hero_1.age
assert data[0]["id"] == hero_1.id
assert data[1]["name"] == hero_2.name
assert data[1]["secret_name"] == hero_2.secret_name
assert data[1]["age"] == hero_2.age
assert data[1]["id"] == hero_2.id
def test_api_read_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.get(f"/hero/999")
assert response.status_code == 404
def test_api_read_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.get(f"/hero/{hero_1.id}")
data = response.json()
assert response.status_code == 200
assert data["name"] == hero_1.name
assert data["secret_name"] == hero_1.secret_name
assert data["age"] == hero_1.age
assert data["id"] == hero_1.id
def test_api_update_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.patch(
f"/hero/", json={"hero": {"name": "Deadpuddle", "id": hero_1.id}}
)
data = response.json()
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson"
assert data["age"] is None
assert data["id"] == hero_1.id
def test_api_update_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.patch(f"/hero/", json={"hero": {"name": "Deadpuddle", "id": 999}})
assert response.status_code == 404
def test_delete_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.delete(f"/hero/{hero_1.id}")
hero_in_db = session.get(Hero, hero_1.id)
assert response.status_code == 200
assert hero_in_db is None
def test_delete_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.delete(f"/hero/999")
assert response.status_code == 404
def test_config_memory(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
heroes = session.exec(select(Hero)).all()
assert hero.name == "Steelman"
assert hero.age == 25
assert len(heroes) == 1
def test_cli_get(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "1"])
assert result.exit_code == 0
assert f"name='{hero.name}'" in result.stdout
assert f"secret_name='{hero.secret_name}'" in result.stdout
def test_cli_get_404(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "Hero not found"