wip
This commit is contained in:
parent
1a0bf1adb9
commit
c3db85a209
21 changed files with 647 additions and 658 deletions
1
.pyflyby
1
.pyflyby
|
|
@ -2,6 +2,7 @@
|
||||||
from learn_sql_model.api.websocket_connection_manager import manager
|
from learn_sql_model.api.websocket_connection_manager import manager
|
||||||
from learn_sql_model.config import Config
|
from learn_sql_model.config import Config
|
||||||
from learn_sql_model.config import get_config
|
from learn_sql_model.config import get_config
|
||||||
|
from learn_sql_model.config import get_session
|
||||||
from learn_sql_model.console import console
|
from learn_sql_model.console import console
|
||||||
from learn_sql_model.database import get_database
|
from learn_sql_model.database import get_database
|
||||||
from learn_sql_model.factories.hero import HeroFactory
|
from learn_sql_model.factories.hero import HeroFactory
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,6 @@ RUN groupadd -f -g ${SMOKE_GID} smoke && \
|
||||||
RUN mkdir /home/smoke && chown -R smoke:smoke /home/smoke && mkdir /src && chown smoke:smoke /src
|
RUN mkdir /home/smoke && chown -R smoke:smoke /home/smoke && mkdir /src && chown smoke:smoke /src
|
||||||
WORKDIR /home/smoke
|
WORKDIR /home/smoke
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
RUN apt update && \
|
RUN apt update && \
|
||||||
apt upgrade -y && \
|
apt upgrade -y && \
|
||||||
apt install -y \
|
apt install -y \
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,9 @@
|
||||||
from typing import Annotated
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlmodel import SQLModel, Session
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from sqlmodel import SQLModel
|
|
||||||
|
|
||||||
from learn_sql_model.api.user import oauth2_scheme
|
|
||||||
from learn_sql_model.api.websocket_connection_manager import manager
|
from learn_sql_model.api.websocket_connection_manager import manager
|
||||||
from learn_sql_model.config import Config, get_config
|
from learn_sql_model.config import get_config, get_session
|
||||||
from learn_sql_model.models.hero import (
|
from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, HeroUpdate
|
||||||
Hero,
|
|
||||||
HeroCreate,
|
|
||||||
HeroDelete,
|
|
||||||
HeroRead,
|
|
||||||
HeroUpdate,
|
|
||||||
)
|
|
||||||
|
|
||||||
hero_router = APIRouter()
|
hero_router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -22,52 +13,73 @@ def on_startup() -> None:
|
||||||
SQLModel.metadata.create_all(get_config().database.engine)
|
SQLModel.metadata.create_all(get_config().database.engine)
|
||||||
|
|
||||||
|
|
||||||
@hero_router.get("/items/")
|
@hero_router.get("/hero/{hero_id}")
|
||||||
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
|
async def get_hero(
|
||||||
return {"token": token}
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
hero_id: int,
|
||||||
@hero_router.get("/hero/{id}")
|
) -> HeroRead:
|
||||||
async def get_hero(id: int, config: Config = Depends(get_config)) -> Hero:
|
|
||||||
"get one hero"
|
"get one hero"
|
||||||
return Hero().get(id=id, config=config)
|
hero = session.get(Hero, hero_id)
|
||||||
|
if not hero:
|
||||||
|
raise HTTPException(status_code=404, detail="Hero not found")
|
||||||
@hero_router.get("/h/{id}")
|
return hero
|
||||||
async def get_h(id: int, config: Config = Depends(get_config)) -> Hero:
|
|
||||||
"get one hero"
|
|
||||||
return Hero().get(id=id, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
@hero_router.post("/hero/")
|
@hero_router.post("/hero/")
|
||||||
async def post_hero(hero: HeroCreate) -> HeroRead:
|
async def post_hero(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
hero: HeroCreate,
|
||||||
|
) -> HeroRead:
|
||||||
"read all the heros"
|
"read all the heros"
|
||||||
config = get_config()
|
db_hero = Hero.from_orm(hero)
|
||||||
hero = hero.post(config=config)
|
session.add(db_hero)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(db_hero)
|
||||||
await manager.broadcast({hero.json()}, id=1)
|
await manager.broadcast({hero.json()}, id=1)
|
||||||
return hero
|
return db_hero
|
||||||
|
|
||||||
|
|
||||||
@hero_router.patch("/hero/")
|
@hero_router.patch("/hero/")
|
||||||
async def patch_hero(hero: HeroUpdate) -> HeroRead:
|
async def patch_hero(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
hero: HeroUpdate,
|
||||||
|
) -> HeroRead:
|
||||||
"read all the heros"
|
"read all the heros"
|
||||||
config = get_config()
|
db_hero = session.get(Hero, hero.id)
|
||||||
hero = hero.update(config=config)
|
if not db_hero:
|
||||||
|
raise HTTPException(status_code=404, detail="Hero not found")
|
||||||
|
for key, value in hero.dict(exclude_unset=True).items():
|
||||||
|
setattr(db_hero, key, value)
|
||||||
|
session.add(db_hero)
|
||||||
|
session.commit()
|
||||||
|
session.refresh(db_hero)
|
||||||
await manager.broadcast({hero.json()}, id=1)
|
await manager.broadcast({hero.json()}, id=1)
|
||||||
return hero
|
return db_hero
|
||||||
|
|
||||||
|
|
||||||
@hero_router.delete("/hero/{hero_id}")
|
@hero_router.delete("/hero/{hero_id}")
|
||||||
async def delete_hero(hero_id: int):
|
async def delete_hero(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
hero_id: int,
|
||||||
|
):
|
||||||
"read all the heros"
|
"read all the heros"
|
||||||
hero = HeroDelete(id=hero_id)
|
hero = session.get(Hero, hero_id)
|
||||||
config = get_config()
|
if not hero:
|
||||||
hero = hero.delete(config=config)
|
raise HTTPException(status_code=404, detail="Hero not found")
|
||||||
|
session.delete(hero)
|
||||||
|
session.commit()
|
||||||
await manager.broadcast(f"deleted hero {hero_id}", id=1)
|
await manager.broadcast(f"deleted hero {hero_id}", id=1)
|
||||||
return hero
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@hero_router.get("/heros/")
|
@hero_router.get("/heros/")
|
||||||
async def get_heros(config: Config = Depends(get_config)) -> list[Hero]:
|
async def get_heros(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
) -> list[Hero]:
|
||||||
"get all heros"
|
"get all heros"
|
||||||
return Hero().get(config=config)
|
return HeroRead.list(session=session)
|
||||||
|
|
|
||||||
|
|
@ -1,45 +0,0 @@
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from sqlmodel import SQLModel
|
|
||||||
|
|
||||||
from learn_sql_model.api.user import oauth2_scheme
|
|
||||||
from learn_sql_model.config import Config, get_config
|
|
||||||
from learn_sql_model.models.new import new
|
|
||||||
|
|
||||||
new_router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
@new_router.on_event("startup")
|
|
||||||
def on_startup() -> None:
|
|
||||||
SQLModel.metadata.create_all(get_config().database.engine)
|
|
||||||
|
|
||||||
|
|
||||||
@new_router.get("/items/")
|
|
||||||
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
|
|
||||||
return {"token": token}
|
|
||||||
|
|
||||||
|
|
||||||
@new_router.get("/new/{id}")
|
|
||||||
def get_new(id: int, config: Config = Depends(get_config)) -> new:
|
|
||||||
"get one new"
|
|
||||||
return new().get(id=id, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
@new_router.get("/h/{id}")
|
|
||||||
def get_h(id: int, config: Config = Depends(get_config)) -> new:
|
|
||||||
"get one new"
|
|
||||||
return new().get(id=id, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
@new_router.post("/new/")
|
|
||||||
def post_new(new: new, config: Config = Depends(get_config)) -> new:
|
|
||||||
"read all the news"
|
|
||||||
new.post(config=config)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
@new_router.get("/news/")
|
|
||||||
def get_news(config: Config = Depends(get_config)) -> list[new]:
|
|
||||||
"get all news"
|
|
||||||
return new().get(config=config)
|
|
||||||
|
|
@ -69,4 +69,4 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
manager.disconnect(websocket, id)
|
manager.disconnect(websocket, id)
|
||||||
await manager.broadcast(f"Client #{client_id} left the chat", id)
|
await manager.broadcast(f"Client #{id} left the chat", id)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import httpx
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
import typer
|
import typer
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
@ -38,15 +39,11 @@ def status(
|
||||||
help="show the log messages",
|
help="show the log messages",
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
import httpx
|
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
host = config.api_server.host
|
url = config.api_client.url
|
||||||
port = config.api_server.port
|
|
||||||
url = f"http://{host}:{port}/docs"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = httpx.get(url)
|
r = httpx.get(url + "/docs")
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
Console().print(f"[green]API: ([gold1]{url}[green]) is running")
|
Console().print(f"[green]API: ([gold1]{url}[green]) is running")
|
||||||
else:
|
else:
|
||||||
|
|
@ -59,7 +56,7 @@ def status(
|
||||||
Console().print(
|
Console().print(
|
||||||
f"[green]database: ([gold1]{config.database.engine}[green]) is running"
|
f"[green]database: ([gold1]{config.database.engine}[green]) is running"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
Console().print(
|
Console().print(
|
||||||
f"[red]database: ([gold1]{config.database.engine}[red]) is not running"
|
f"[red]database: ([gold1]{config.database.engine}[red]) is not running"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,11 @@ import sys
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from engorgio import engorgio
|
from engorgio import engorgio
|
||||||
import httpx
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from learn_sql_model.config import Config, get_config
|
from learn_sql_model.config import get_config
|
||||||
from learn_sql_model.factories.hero import HeroFactory
|
from learn_sql_model.factories.hero import HeroFactory
|
||||||
from learn_sql_model.factories.pet import PetFactory
|
|
||||||
from learn_sql_model.models.hero import (
|
from learn_sql_model.models.hero import (
|
||||||
Hero,
|
Hero,
|
||||||
HeroCreate,
|
HeroCreate,
|
||||||
|
|
@ -19,6 +17,8 @@ from learn_sql_model.models.hero import (
|
||||||
|
|
||||||
hero_app = typer.Typer()
|
hero_app = typer.Typer()
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
|
||||||
@hero_app.callback()
|
@hero_app.callback()
|
||||||
def hero():
|
def hero():
|
||||||
|
|
@ -28,12 +28,10 @@ def hero():
|
||||||
@hero_app.command()
|
@hero_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def get(
|
def get(
|
||||||
id: Optional[int] = typer.Argument(default=None),
|
hero_id: Optional[int] = typer.Argument(default=None),
|
||||||
config: Config = None,
|
|
||||||
) -> Union[Hero, List[Hero]]:
|
) -> Union[Hero, List[Hero]]:
|
||||||
"get one hero"
|
"get one hero"
|
||||||
config.init()
|
hero = HeroRead.get(id=hero_id)
|
||||||
hero = HeroRead.get(id=id, config=config)
|
|
||||||
Console().print(hero)
|
Console().print(hero)
|
||||||
return hero
|
return hero
|
||||||
|
|
||||||
|
|
@ -42,12 +40,11 @@ def get(
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def list(
|
def list(
|
||||||
where: Optional[str] = None,
|
where: Optional[str] = None,
|
||||||
config: Config = None,
|
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> Union[Hero, List[Hero]]:
|
) -> Union[Hero, List[Hero]]:
|
||||||
"get one hero"
|
"list many heros"
|
||||||
hero = HeroRead.list(config=config, where=where, offset=offset, limit=limit)
|
hero = HeroRead.list(where=where, offset=offset, limit=limit)
|
||||||
Console().print(hero)
|
Console().print(hero)
|
||||||
return hero
|
return hero
|
||||||
|
|
||||||
|
|
@ -56,65 +53,39 @@ def list(
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def create(
|
def create(
|
||||||
hero: HeroCreate,
|
hero: HeroCreate,
|
||||||
config: Config = None,
|
|
||||||
) -> Hero:
|
) -> Hero:
|
||||||
"read all the heros"
|
"create one hero"
|
||||||
|
hero.post()
|
||||||
r = httpx.post(
|
|
||||||
f"{config.api_client.url}/hero/",
|
|
||||||
json=hero.dict(),
|
|
||||||
)
|
|
||||||
if r.status_code != 200:
|
|
||||||
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
|
||||||
|
|
||||||
# hero = hero.post(config=config)
|
|
||||||
# Console().print(hero)
|
|
||||||
# return hero
|
|
||||||
|
|
||||||
|
|
||||||
@hero_app.command()
|
@hero_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def update(
|
def update(
|
||||||
hero: HeroUpdate,
|
hero: HeroUpdate,
|
||||||
config: Config = None,
|
|
||||||
) -> Hero:
|
) -> Hero:
|
||||||
"read all the heros"
|
"update one hero"
|
||||||
r = httpx.patch(
|
hero.update()
|
||||||
f"{config.api_client.url}/hero/",
|
|
||||||
json=hero.dict(),
|
|
||||||
)
|
|
||||||
if r.status_code != 200:
|
|
||||||
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
|
||||||
|
|
||||||
|
|
||||||
@hero_app.command()
|
@hero_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def delete(
|
def delete(
|
||||||
hero: HeroDelete,
|
hero: HeroDelete,
|
||||||
config: Config = None,
|
|
||||||
) -> Hero:
|
) -> Hero:
|
||||||
"read all the heros"
|
"delete a hero by id"
|
||||||
r = httpx.delete(
|
hero.delete()
|
||||||
f"{config.api_client.url}/hero/{hero.id}",
|
|
||||||
)
|
|
||||||
if r.status_code != 200:
|
|
||||||
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
|
||||||
|
|
||||||
|
|
||||||
@hero_app.command()
|
@hero_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def populate(
|
def populate(
|
||||||
hero: Hero,
|
|
||||||
n: int = 10,
|
n: int = 10,
|
||||||
) -> Hero:
|
) -> Hero:
|
||||||
"read all the heros"
|
"Create n number of heros"
|
||||||
config = get_config()
|
|
||||||
if config.env == "prod":
|
if config.env == "prod":
|
||||||
Console().print("populate is not supported in production")
|
Console().print("populate is not supported in production")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
for hero in HeroFactory().batch(n):
|
for hero in HeroFactory().batch(n):
|
||||||
pet = PetFactory().build()
|
hero = HeroCreate(**hero.dict())
|
||||||
hero.pet = pet
|
hero.post()
|
||||||
Console().print(hero)
|
|
||||||
hero.post(config=config)
|
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,6 @@ def create_revision(
|
||||||
prompt=True,
|
prompt=True,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|
||||||
alembic_cfg = Config("alembic.ini")
|
alembic_cfg = Config("alembic.ini")
|
||||||
alembic.command.revision(
|
alembic.command.revision(
|
||||||
config=alembic_cfg,
|
config=alembic_cfg,
|
||||||
|
|
@ -63,7 +62,6 @@ def checkout(
|
||||||
),
|
),
|
||||||
revision: str = typer.Option("head"),
|
revision: str = typer.Option("head"),
|
||||||
):
|
):
|
||||||
|
|
||||||
alembic_cfg = Config("alembic.ini")
|
alembic_cfg = Config("alembic.ini")
|
||||||
alembic.command.upgrade(config=alembic_cfg, revision="head")
|
alembic.command.upgrade(config=alembic_cfg, revision="head")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,107 +0,0 @@
|
||||||
import sys
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from engorgio import engorgio
|
|
||||||
from rich.console import Console
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from learn_sql_model.config import Config, get_config
|
|
||||||
from learn_sql_model.factories.new import newFactory
|
|
||||||
from learn_sql_model.factories.pet import PetFactory
|
|
||||||
from learn_sql_model.models.new import (
|
|
||||||
new,
|
|
||||||
newCreate,
|
|
||||||
newDelete,
|
|
||||||
newRead,
|
|
||||||
newUpdate,
|
|
||||||
)
|
|
||||||
|
|
||||||
new_app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@new_app.callback()
|
|
||||||
def new():
|
|
||||||
"model cli"
|
|
||||||
|
|
||||||
|
|
||||||
@new_app.command()
|
|
||||||
@engorgio(typer=True)
|
|
||||||
def get(
|
|
||||||
id: Optional[int] = typer.Argument(default=None),
|
|
||||||
config: Config = None,
|
|
||||||
) -> Union[new, List[new]]:
|
|
||||||
"get one new"
|
|
||||||
config.init()
|
|
||||||
new = newRead.get(id=id, config=config)
|
|
||||||
Console().print(new)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
@new_app.command()
|
|
||||||
@engorgio(typer=True)
|
|
||||||
def list(
|
|
||||||
where: Optional[str] = None,
|
|
||||||
config: Config = None,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
) -> Union[new, List[new]]:
|
|
||||||
"get one new"
|
|
||||||
new = newRead.list(config=config, where=where, offset=offset, limit=limit)
|
|
||||||
Console().print(new)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
@new_app.command()
|
|
||||||
@engorgio(typer=True)
|
|
||||||
def create(
|
|
||||||
new: newCreate,
|
|
||||||
config: Config = None,
|
|
||||||
) -> new:
|
|
||||||
"read all the news"
|
|
||||||
# config.init()
|
|
||||||
new = new.post(config=config)
|
|
||||||
Console().print(new)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
@new_app.command()
|
|
||||||
@engorgio(typer=True)
|
|
||||||
def update(
|
|
||||||
new: newUpdate,
|
|
||||||
config: Config = None,
|
|
||||||
) -> new:
|
|
||||||
"read all the news"
|
|
||||||
new = new.update(config=config)
|
|
||||||
Console().print(new)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
@new_app.command()
|
|
||||||
@engorgio(typer=True)
|
|
||||||
def delete(
|
|
||||||
new: newDelete,
|
|
||||||
config: Config = None,
|
|
||||||
) -> new:
|
|
||||||
"read all the news"
|
|
||||||
# config.init()
|
|
||||||
new = new.delete(config=config)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
@new_app.command()
|
|
||||||
@engorgio(typer=True)
|
|
||||||
def populate(
|
|
||||||
new: new,
|
|
||||||
n: int = 10,
|
|
||||||
) -> new:
|
|
||||||
"read all the news"
|
|
||||||
config = get_config()
|
|
||||||
if config.env == "prod":
|
|
||||||
Console().print("populate is not supported in production")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
for new in newFactory().batch(n):
|
|
||||||
pet = PetFactory().build()
|
|
||||||
new.pet = pet
|
|
||||||
Console().print(new)
|
|
||||||
new.post(config=config)
|
|
||||||
|
|
@ -30,7 +30,6 @@ class ApiClient(BaseModel):
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self, config: "Config" = None) -> None:
|
def __init__(self, config: "Config" = None) -> None:
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
||||||
self.config = get_config()
|
self.config = get_config()
|
||||||
else:
|
else:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
@ -71,17 +70,26 @@ class Config(BaseSettings):
|
||||||
|
|
||||||
|
|
||||||
def get_database(config: Config = None) -> Database:
|
def get_database(config: Config = None) -> Database:
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
return Database(config)
|
return Database(config)
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(overrides: dict = {}) -> Config:
|
||||||
|
raw_config = load("learn_sql_model")
|
||||||
|
config = Config(**raw_config, **overrides)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(config: Config = None) -> "Session":
|
||||||
|
with Session(config.database.engine) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
async def reset_db_state(config: Config = None) -> None:
|
async def reset_db_state(config: Config = None) -> None:
|
||||||
if config is None:
|
if config is None:
|
||||||
config = get_config()
|
config = get_config()
|
||||||
config.database.db._state._state.set(db_state_default.copy())
|
config.database.db._state._state.set(config.database.db_state_default.copy())
|
||||||
config.database.db._state.reset()
|
config.database.db._state.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -96,7 +104,4 @@ def get_db(config: Config = None, reset_db_state=Depends(reset_db_state)):
|
||||||
config.database.db.close()
|
config.database.db.close()
|
||||||
|
|
||||||
|
|
||||||
def get_config(overrides: dict = {}) -> Config:
|
config = get_config()
|
||||||
raw_config = load("learn_sql_model")
|
|
||||||
config = Config(**raw_config, **overrides)
|
|
||||||
return config
|
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
from faker import Faker
|
|
||||||
from polyfactory.factories.pydantic_factory import ModelFactory
|
|
||||||
|
|
||||||
from learn_sql_model.models.new import new
|
|
||||||
|
|
||||||
|
|
||||||
class newFactory(ModelFactory[new]):
|
|
||||||
__model__ = new
|
|
||||||
__faker__ = Faker(locale="en_US")
|
|
||||||
__set_as_default_factory_for_type__ = True
|
|
||||||
id = None
|
|
||||||
|
|
||||||
__random_seed__ = 10
|
|
||||||
|
|
||||||
|
|
@ -21,7 +21,6 @@ class FastModel(SQLModel):
|
||||||
|
|
||||||
def post(self, config: "Config" = None) -> None:
|
def post(self, config: "Config" = None) -> None:
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
self.pre_post()
|
self.pre_post()
|
||||||
|
|
@ -36,7 +35,6 @@ class FastModel(SQLModel):
|
||||||
self, id: int = None, config: "Config" = None, where=None
|
self, id: int = None, config: "Config" = None, where=None
|
||||||
) -> Optional["FastModel"]:
|
) -> Optional["FastModel"]:
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
self.pre_get()
|
self.pre_get()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import Field, Relationship, SQLModel, Session, select
|
from sqlmodel import Field, Relationship, SQLModel, Session, select
|
||||||
|
|
||||||
from learn_sql_model.config import Config
|
from learn_sql_model.config import config, get_session
|
||||||
from learn_sql_model.models.pet import Pet
|
from learn_sql_model.models.pet import Pet
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,14 +26,13 @@ class Hero(HeroBase, table=True):
|
||||||
class HeroCreate(HeroBase):
|
class HeroCreate(HeroBase):
|
||||||
...
|
...
|
||||||
|
|
||||||
def post(self, config: Config) -> Hero:
|
def post(self) -> Hero:
|
||||||
config.init()
|
r = httpx.post(
|
||||||
with Session(config.database.engine) as session:
|
f"{config.api_client.url}/hero/",
|
||||||
db_hero = Hero.from_orm(self)
|
json=self.dict(),
|
||||||
session.add(db_hero)
|
)
|
||||||
session.commit()
|
if r.status_code != 200:
|
||||||
session.refresh(db_hero)
|
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
||||||
return db_hero
|
|
||||||
|
|
||||||
|
|
||||||
class HeroRead(HeroBase):
|
class HeroRead(HeroBase):
|
||||||
|
|
@ -41,10 +41,8 @@ class HeroRead(HeroBase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(
|
def get(
|
||||||
cls,
|
cls,
|
||||||
config: Config,
|
|
||||||
id: int,
|
id: int,
|
||||||
) -> Hero:
|
) -> Hero:
|
||||||
|
|
||||||
with config.database.session as session:
|
with config.database.session as session:
|
||||||
hero = session.get(Hero, id)
|
hero = session.get(Hero, id)
|
||||||
if not hero:
|
if not hero:
|
||||||
|
|
@ -54,25 +52,25 @@ class HeroRead(HeroBase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def list(
|
def list(
|
||||||
self,
|
self,
|
||||||
config: Config,
|
|
||||||
where=None,
|
where=None,
|
||||||
offset=0,
|
offset=0,
|
||||||
limit=None,
|
limit=None,
|
||||||
|
session: Session = get_session,
|
||||||
) -> Hero:
|
) -> Hero:
|
||||||
|
# with config.database.session as session:
|
||||||
|
|
||||||
with config.database.session as session:
|
statement = select(Hero)
|
||||||
statement = select(Hero)
|
if where != "None" and where is not None:
|
||||||
if where != "None":
|
from sqlmodel import text
|
||||||
from sqlmodel import text
|
|
||||||
|
|
||||||
statement = statement.where(text(where))
|
statement = statement.where(text(where))
|
||||||
statement = statement.offset(offset).limit(limit)
|
statement = statement.offset(offset).limit(limit)
|
||||||
heroes = session.exec(statement).all()
|
heroes = session.exec(statement).all()
|
||||||
return heroes
|
return heroes
|
||||||
|
|
||||||
|
|
||||||
class HeroUpdate(SQLModel):
|
class HeroUpdate(SQLModel):
|
||||||
# id is required to get the hero
|
# id is required to update the hero
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
# all other fields, must match the model, but with Optional default None
|
# all other fields, must match the model, but with Optional default None
|
||||||
|
|
@ -84,30 +82,22 @@ class HeroUpdate(SQLModel):
|
||||||
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
|
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
|
||||||
pet: Optional[Pet] = Relationship(back_populates="hero")
|
pet: Optional[Pet] = Relationship(back_populates="hero")
|
||||||
|
|
||||||
def update(self, config: Config) -> Hero:
|
def update(self) -> Hero:
|
||||||
with Session(config.database.engine) as session:
|
r = httpx.patch(
|
||||||
db_hero = session.get(Hero, self.id)
|
f"{config.api_client.url}/hero/",
|
||||||
if not db_hero:
|
json=self.dict(),
|
||||||
raise HTTPException(status_code=404, detail="Hero not found")
|
)
|
||||||
hero_data = self.dict(exclude_unset=True)
|
if r.status_code != 200:
|
||||||
for key, value in hero_data.items():
|
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
||||||
if value is not None:
|
|
||||||
setattr(db_hero, key, value)
|
|
||||||
session.add(db_hero)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(db_hero)
|
|
||||||
return db_hero
|
|
||||||
|
|
||||||
|
|
||||||
class HeroDelete(BaseModel):
|
class HeroDelete(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
def delete(self, config: Config) -> Hero:
|
def delete(self) -> Hero:
|
||||||
config.init()
|
r = httpx.delete(
|
||||||
with Session(config.database.engine) as session:
|
f"{config.api_client.url}/hero/{self.id}",
|
||||||
hero = session.get(Hero, self.id)
|
)
|
||||||
if not hero:
|
if r.status_code != 200:
|
||||||
raise HTTPException(status_code=404, detail="Hero not found")
|
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
||||||
session.delete(hero)
|
return {"ok": True}
|
||||||
session.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
|
||||||
|
|
@ -1,99 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlmodel import Field, Relationship, SQLModel, Session, select
|
|
||||||
|
|
||||||
from learn_sql_model.config import Config
|
|
||||||
from learn_sql_model.models.pet import Pet
|
|
||||||
|
|
||||||
|
|
||||||
class newBase(SQLModel, table=False):
|
|
||||||
|
|
||||||
|
|
||||||
class new(newBase, table=True):
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
|
|
||||||
|
|
||||||
class newCreate(newBase):
|
|
||||||
...
|
|
||||||
|
|
||||||
def post(self, config: Config) -> new:
|
|
||||||
config.init()
|
|
||||||
with Session(config.database.engine) as session:
|
|
||||||
db_new = new.from_orm(self)
|
|
||||||
session.add(db_new)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(db_new)
|
|
||||||
return db_new
|
|
||||||
|
|
||||||
|
|
||||||
class newRead(newBase):
|
|
||||||
id: int
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(
|
|
||||||
cls,
|
|
||||||
config: Config,
|
|
||||||
id: int,
|
|
||||||
) -> new:
|
|
||||||
|
|
||||||
with config.database.session as session:
|
|
||||||
new = session.get(new, id)
|
|
||||||
if not new:
|
|
||||||
raise HTTPException(status_code=404, detail="new not found")
|
|
||||||
return new
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list(
|
|
||||||
self,
|
|
||||||
config: Config,
|
|
||||||
where=None,
|
|
||||||
offset=0,
|
|
||||||
limit=None,
|
|
||||||
) -> new:
|
|
||||||
|
|
||||||
with config.database.session as session:
|
|
||||||
statement = select(new)
|
|
||||||
if where != "None":
|
|
||||||
from sqlmodel import text
|
|
||||||
|
|
||||||
statement = statement.where(text(where))
|
|
||||||
statement = statement.offset(offset).limit(limit)
|
|
||||||
newes = session.exec(statement).all()
|
|
||||||
return newes
|
|
||||||
|
|
||||||
|
|
||||||
class newUpdate(SQLModel):
|
|
||||||
# id is required to get the new
|
|
||||||
id: int
|
|
||||||
|
|
||||||
# all other fields, must match the model, but with Optional default None
|
|
||||||
|
|
||||||
def update(self, config: Config) -> new:
|
|
||||||
with Session(config.database.engine) as session:
|
|
||||||
db_new = session.get(new, self.id)
|
|
||||||
if not db_new:
|
|
||||||
raise HTTPException(status_code=404, detail="new not found")
|
|
||||||
new_data = self.dict(exclude_unset=True)
|
|
||||||
for key, value in new_data.items():
|
|
||||||
if value is not None:
|
|
||||||
setattr(db_new, key, value)
|
|
||||||
session.add(db_new)
|
|
||||||
session.commit()
|
|
||||||
session.refresh(db_new)
|
|
||||||
return db_new
|
|
||||||
|
|
||||||
|
|
||||||
class newDelete(BaseModel):
|
|
||||||
id: int
|
|
||||||
|
|
||||||
def delete(self, config: Config) -> new:
|
|
||||||
config.init()
|
|
||||||
with Session(config.database.engine) as session:
|
|
||||||
new = session.get(new, self.id)
|
|
||||||
if not new:
|
|
||||||
raise HTTPException(status_code=404, detail="new not found")
|
|
||||||
session.delete(new)
|
|
||||||
session.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
@ -77,14 +77,17 @@ dependencies = [
|
||||||
[tool.hatch.envs.default.scripts]
|
[tool.hatch.envs.default.scripts]
|
||||||
test = "coverage run -m pytest"
|
test = "coverage run -m pytest"
|
||||||
cov = "coverage-rich report"
|
cov = "coverage-rich report"
|
||||||
test-cov = ['test', 'cov']
|
cov-erase = "coverage erase"
|
||||||
lint = "ruff learn_sql_model"
|
lint = "ruff learn_sql_model"
|
||||||
format = "black learn_sql_model"
|
format = "black learn_sql_model"
|
||||||
format-check = "black --check learn_sql_model"
|
format-check = "black --check learn_sql_model"
|
||||||
|
fix_ruff = "ruff --fix learn_sql_model"
|
||||||
|
fix = ['format', 'fix_ruff']
|
||||||
build-docs = "markata build"
|
build-docs = "markata build"
|
||||||
lint-test = [
|
lint-test = [
|
||||||
"lint",
|
"lint",
|
||||||
"format-check",
|
"format-check",
|
||||||
|
"cov-erase",
|
||||||
"test",
|
"test",
|
||||||
"cov",
|
"cov",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -18,3 +18,9 @@ def read_heroes(hero: Hero) -> list[Hero]:
|
||||||
def read_heros() -> list[Hero]:
|
def read_heros() -> list[Hero]:
|
||||||
"read all the heros"
|
"read all the heros"
|
||||||
return Hero.get()
|
return Hero.get()
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch("/heros/")
|
||||||
|
def update_heros() -> list[Hero]:
|
||||||
|
"read all the heros"
|
||||||
|
return Hero.get()
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
from typing import Annotated
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlmodel import SQLModel, Session
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from learn_sql_model.api.websocket_connection_manager import manager
|
||||||
from sqlmodel import SQLModel
|
from learn_sql_model.config import get_config, get_session
|
||||||
|
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}, {{modelname}}Create, {{modelname}}Read, {{modelname}}Update
|
||||||
from learn_sql_model.api.user import oauth2_scheme
|
|
||||||
from learn_sql_model.config import Config, get_config
|
|
||||||
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
|
|
||||||
|
|
||||||
{{modelname.lower()}}_router = APIRouter()
|
{{modelname.lower()}}_router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -15,31 +13,74 @@ def on_startup() -> None:
|
||||||
SQLModel.metadata.create_all(get_config().database.engine)
|
SQLModel.metadata.create_all(get_config().database.engine)
|
||||||
|
|
||||||
|
|
||||||
@{{modelname.lower()}}_router.get("/items/")
|
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
|
||||||
async def read_items(token: Annotated[str, Depends(oauth2_scheme)]):
|
async def get_{{modelname.lower()}}(
|
||||||
return {"token": token}
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
{{modelname.lower()}}_id: int,
|
||||||
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{id}")
|
) -> {{modelname}}Read:
|
||||||
def get_{{modelname.lower()}}(id: int, config: Config = Depends(get_config)) -> {{modelname}}:
|
|
||||||
"get one {{modelname.lower()}}"
|
"get one {{modelname.lower()}}"
|
||||||
return {{modelname}}().get(id=id, config=config)
|
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
|
||||||
|
if not {{modelname.lower()}}:
|
||||||
|
raise HTTPException(status_code=404, detail="{{modelname}} not found")
|
||||||
@{{modelname.lower()}}_router.get("/h/{id}")
|
|
||||||
def get_h(id: int, config: Config = Depends(get_config)) -> {{modelname}}:
|
|
||||||
"get one {{modelname.lower()}}"
|
|
||||||
return {{modelname}}().get(id=id, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
|
|
||||||
def post_{{modelname.lower()}}({{modelname.lower()}}: {{modelname}}, config: Config = Depends(get_config)) -> {{modelname.lower()}}:
|
|
||||||
"read all the {{modelname.lower()}}s"
|
|
||||||
{{modelname.lower()}}.post(config=config)
|
|
||||||
return {{modelname.lower()}}
|
return {{modelname.lower()}}
|
||||||
|
|
||||||
|
|
||||||
|
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
|
||||||
|
async def post_{{modelname.lower()}}(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
{{modelname.lower()}}: {{modelname}}Create,
|
||||||
|
) -> {{modelname}}Read:
|
||||||
|
"read all the {{modelname.lower()}}s"
|
||||||
|
db_{{modelname.lower()}} = {{modelname}}.from_orm({{modelname.lower()}})
|
||||||
|
session.add(db_{{modelname.lower()}})
|
||||||
|
session.commit()
|
||||||
|
session.refresh(db_{{modelname.lower()}})
|
||||||
|
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
|
||||||
|
return db_{{modelname.lower()}}
|
||||||
|
|
||||||
|
|
||||||
|
@{{modelname.lower()}}_router.patch("/{{modelname.lower()}}/")
|
||||||
|
async def patch_{{modelname.lower()}}(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
{{modelname.lower()}}: {{modelname}}Update,
|
||||||
|
) -> {{modelname}}Read:
|
||||||
|
"read all the {{modelname.lower()}}s"
|
||||||
|
db_{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
|
||||||
|
if not db_{{modelname.lower()}}:
|
||||||
|
raise HTTPException(status_code=404, detail="{{modelname}} not found")
|
||||||
|
for key, value in {{modelname.lower()}}.dict(exclude_unset=True).items():
|
||||||
|
setattr(db_{{modelname.lower()}}, key, value)
|
||||||
|
session.add(db_{{modelname.lower()}})
|
||||||
|
session.commit()
|
||||||
|
session.refresh(db_{{modelname.lower()}})
|
||||||
|
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
|
||||||
|
return db_{{modelname.lower()}}
|
||||||
|
|
||||||
|
|
||||||
|
@{{modelname.lower()}}_router.delete("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
|
||||||
|
async def delete_{{modelname.lower()}}(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
{{modelname.lower()}}_id: int,
|
||||||
|
):
|
||||||
|
"read all the {{modelname.lower()}}s"
|
||||||
|
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
|
||||||
|
if not {{modelname.lower()}}:
|
||||||
|
raise HTTPException(status_code=404, detail="{{modelname}} not found")
|
||||||
|
session.delete({{modelname.lower()}})
|
||||||
|
session.commit()
|
||||||
|
await manager.broadcast(f"deleted {{modelname.lower()}} {{{modelname.lower()}}_id}", id=1)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/")
|
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/")
|
||||||
def get_{{modelname.lower()}}s(config: Config = Depends(get_config)) -> list[{{modelname}}]:
|
async def get_{{modelname.lower()}}s(
|
||||||
|
*,
|
||||||
|
session: Session = Depends(get_session),
|
||||||
|
) -> list[{{modelname}}]:
|
||||||
"get all {{modelname.lower()}}s"
|
"get all {{modelname.lower()}}s"
|
||||||
return {{modelname}}().get(config=config)
|
return {{modelname}}Read.list(session=session)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,8 @@ from engorgio import engorgio
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from learn_sql_model.config import Config, get_config
|
from learn_sql_model.config import get_config
|
||||||
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
|
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
|
||||||
from learn_sql_model.factories.pet import PetFactory
|
|
||||||
from learn_sql_model.models.{{modelname.lower()}} import (
|
from learn_sql_model.models.{{modelname.lower()}} import (
|
||||||
{{modelname}},
|
{{modelname}},
|
||||||
{{modelname}}Create,
|
{{modelname}}Create,
|
||||||
|
|
@ -18,6 +17,8 @@ from learn_sql_model.models.{{modelname.lower()}} import (
|
||||||
|
|
||||||
{{modelname.lower()}}_app = typer.Typer()
|
{{modelname.lower()}}_app = typer.Typer()
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
|
||||||
@{{modelname.lower()}}_app.callback()
|
@{{modelname.lower()}}_app.callback()
|
||||||
def {{modelname.lower()}}():
|
def {{modelname.lower()}}():
|
||||||
|
|
@ -27,12 +28,10 @@ def {{modelname.lower()}}():
|
||||||
@{{modelname.lower()}}_app.command()
|
@{{modelname.lower()}}_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def get(
|
def get(
|
||||||
id: Optional[int] = typer.Argument(default=None),
|
{{modelname.lower()}}_id: Optional[int] = typer.Argument(default=None),
|
||||||
config: Config = None,
|
) -> Union[{{modelname}}, List[{{modelname}}]]:
|
||||||
) -> Union[{{modelname}}, List[{{modelname.lower()}}]]:
|
|
||||||
"get one {{modelname.lower()}}"
|
"get one {{modelname.lower()}}"
|
||||||
config.init()
|
{{modelname.lower()}} = {{modelname}}Read.get(id={{modelname.lower()}}_id)
|
||||||
{{modelname.lower()}} = {{modelname}}Read.get(id=id, config=config)
|
|
||||||
Console().print({{modelname.lower()}})
|
Console().print({{modelname.lower()}})
|
||||||
return {{modelname.lower()}}
|
return {{modelname.lower()}}
|
||||||
|
|
||||||
|
|
@ -41,12 +40,11 @@ def get(
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def list(
|
def list(
|
||||||
where: Optional[str] = None,
|
where: Optional[str] = None,
|
||||||
config: Config = None,
|
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
) -> Union[{{modelname}}, List[{{modelname.lower()}}]]:
|
) -> Union[{{modelname}}, List[{{modelname}}]]:
|
||||||
"get one {{modelname.lower()}}"
|
"list many {{modelname.lower()}}s"
|
||||||
{{modelname.lower()}} = {{modelname}}Read.list(config=config, where=where, offset=offset, limit=limit)
|
{{modelname.lower()}} = {{modelname}}Read.list(where=where, offset=offset, limit=limit)
|
||||||
Console().print({{modelname.lower()}})
|
Console().print({{modelname.lower()}})
|
||||||
return {{modelname.lower()}}
|
return {{modelname.lower()}}
|
||||||
|
|
||||||
|
|
@ -55,53 +53,40 @@ def list(
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def create(
|
def create(
|
||||||
{{modelname.lower()}}: {{modelname}}Create,
|
{{modelname.lower()}}: {{modelname}}Create,
|
||||||
config: Config = None,
|
|
||||||
) -> {{modelname}}:
|
) -> {{modelname}}:
|
||||||
"read all the {{modelname.lower()}}s"
|
"create one {{modelname.lower()}}"
|
||||||
# config.init()
|
{{modelname.lower()}}.post()
|
||||||
{{modelname.lower()}} = {{modelname.lower()}}.post(config=config)
|
|
||||||
Console().print({{modelname.lower()}})
|
|
||||||
return {{modelname.lower()}}
|
|
||||||
|
|
||||||
|
|
||||||
@{{modelname.lower()}}_app.command()
|
@{{modelname.lower()}}_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def update(
|
def update(
|
||||||
{{modelname.lower()}}: {{modelname}}Update,
|
{{modelname.lower()}}: {{modelname}}Update,
|
||||||
config: Config = None,
|
|
||||||
) -> {{modelname}}:
|
) -> {{modelname}}:
|
||||||
"read all the {{modelname.lower()}}s"
|
"update one {{modelname.lower()}}"
|
||||||
{{modelname.lower()}} = {{modelname.lower()}}.update(config=config)
|
{{modelname.lower()}}.update()
|
||||||
Console().print({{modelname.lower()}})
|
|
||||||
return {{modelname.lower()}}
|
|
||||||
|
|
||||||
|
|
||||||
@{{modelname.lower()}}_app.command()
|
@{{modelname.lower()}}_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def delete(
|
def delete(
|
||||||
{{modelname.lower()}}: {{modelname}}Delete,
|
{{modelname.lower()}}: {{modelname}}Delete,
|
||||||
config: Config = None,
|
|
||||||
) -> {{modelname}}:
|
) -> {{modelname}}:
|
||||||
"read all the {{modelname.lower()}}s"
|
"delete a {{modelname.lower()}} by id"
|
||||||
# config.init()
|
{{modelname.lower()}}.delete()
|
||||||
{{modelname.lower()}} = {{modelname.lower()}}.delete(config=config)
|
|
||||||
return {{modelname.lower()}}
|
|
||||||
|
|
||||||
|
|
||||||
@{{modelname.lower()}}_app.command()
|
@{{modelname.lower()}}_app.command()
|
||||||
@engorgio(typer=True)
|
@engorgio(typer=True)
|
||||||
def populate(
|
def populate(
|
||||||
{{modelname.lower()}}: {{modelname}},
|
|
||||||
n: int = 10,
|
n: int = 10,
|
||||||
) -> {{modelname}}:
|
) -> {{modelname}}:
|
||||||
"read all the {{modelname.lower()}}s"
|
"Create n number of {{modelname.lower()}}s"
|
||||||
config = get_config()
|
|
||||||
if config.env == "prod":
|
if config.env == "prod":
|
||||||
Console().print("populate is not supported in production")
|
Console().print("populate is not supported in production")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
for {{modelname.lower()}} in {{modelname}}Factory().batch(n):
|
for {{modelname.lower()}} in {{modelname}}Factory().batch(n):
|
||||||
pet = PetFactory().build()
|
{{modelname.lower()}} = {{modelname}}Create(**{{modelname.lower()}}.dict())
|
||||||
{{modelname.lower()}}.pet = pet
|
{{modelname.lower()}}.post()
|
||||||
Console().print({{modelname.lower()}})
|
|
||||||
{{modelname.lower()}}.post(config=config)
|
|
||||||
|
|
|
||||||
|
|
@ -1,43 +1,41 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import Field, Relationship, SQLModel, Session, select
|
from sqlmodel import Field, Relationship, SQLModel, Session, select
|
||||||
|
|
||||||
from learn_sql_model.config import Config
|
from learn_sql_model.config import config, get_session
|
||||||
from learn_sql_model.models.pet import Pet
|
from learn_sql_model.models.pet import Pet
|
||||||
|
|
||||||
|
|
||||||
class {{modelname}}Base(SQLModel, table=False):
|
class {{modelname}}Base(SQLModel, table=False):
|
||||||
|
|
||||||
|
|
||||||
class {{modelname}}({{modelname.lower()}}Base, table=True):
|
class {{modelname}}({{modelname}}Base, table=True):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
class {{modelname}}Create({{modelname.lower()}}Base):
|
class {{modelname}}Create({{modelname}}Base):
|
||||||
...
|
...
|
||||||
|
|
||||||
def post(self, config: Config) -> {{modelname}}:
|
def post(self) -> {{modelname}}:
|
||||||
config.init()
|
r = httpx.post(
|
||||||
with Session(config.database.engine) as session:
|
f"{config.api_client.url}/{{modelname.lower()}}/",
|
||||||
db_{{modelname.lower()}} = {{modelname}}.from_orm(self)
|
json=self.dict(),
|
||||||
session.add(db_{{modelname.lower()}})
|
)
|
||||||
session.commit()
|
if r.status_code != 200:
|
||||||
session.refresh(db_{{modelname.lower()}})
|
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
||||||
return db_{{modelname.lower()}}
|
|
||||||
|
|
||||||
|
|
||||||
class {{modelname}}Read({{modelname.lower()}}Base):
|
class {{modelname}}Read({{modelname}}Base):
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(
|
def get(
|
||||||
cls,
|
cls,
|
||||||
config: Config,
|
|
||||||
id: int,
|
id: int,
|
||||||
) -> {{modelname}}:
|
) -> {{modelname}}:
|
||||||
|
|
||||||
with config.database.session as session:
|
with config.database.session as session:
|
||||||
{{modelname.lower()}} = session.get({{modelname}}, id)
|
{{modelname.lower()}} = session.get({{modelname}}, id)
|
||||||
if not {{modelname.lower()}}:
|
if not {{modelname.lower()}}:
|
||||||
|
|
@ -47,53 +45,49 @@ class {{modelname}}Read({{modelname.lower()}}Base):
|
||||||
@classmethod
|
@classmethod
|
||||||
def list(
|
def list(
|
||||||
self,
|
self,
|
||||||
config: Config,
|
|
||||||
where=None,
|
where=None,
|
||||||
offset=0,
|
offset=0,
|
||||||
limit=None,
|
limit=None,
|
||||||
|
session: Session = get_session,
|
||||||
) -> {{modelname}}:
|
) -> {{modelname}}:
|
||||||
|
# with config.database.session as session:
|
||||||
|
|
||||||
with config.database.session as session:
|
statement = select({{modelname}})
|
||||||
statement = select({{modelname}})
|
if where != "None" and where is not None:
|
||||||
if where != "None":
|
from sqlmodel import text
|
||||||
from sqlmodel import text
|
|
||||||
|
|
||||||
statement = statement.where(text(where))
|
statement = statement.where(text(where))
|
||||||
statement = statement.offset(offset).limit(limit)
|
statement = statement.offset(offset).limit(limit)
|
||||||
{{modelname.lower()}}es = session.exec(statement).all()
|
{{modelname.lower()}}es = session.exec(statement).all()
|
||||||
return {{modelname.lower()}}es
|
return {{modelname.lower()}}es
|
||||||
|
|
||||||
|
|
||||||
class {{modelname}}Update(SQLModel):
|
class {{modelname}}Update(SQLModel):
|
||||||
# id is required to get the {{modelname.lower()}}
|
# id is required to update the {{modelname.lower()}}
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
# all other fields, must match the model, but with Optional default None
|
# all other fields, must match the model, but with Optional default None
|
||||||
|
|
||||||
def update(self, config: Config) -> {{modelname}}:
|
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
|
||||||
with Session(config.database.engine) as session:
|
pet: Optional[Pet] = Relationship(back_populates="{{modelname.lower()}}")
|
||||||
db_{{modelname.lower()}} = session.get({{modelname}}, self.id)
|
|
||||||
if not db_{{modelname.lower()}}:
|
def update(self) -> {{modelname}}:
|
||||||
raise HTTPException(status_code=404, detail="{{modelname}} not found")
|
r = httpx.patch(
|
||||||
{{modelname.lower()}}_data = self.dict(exclude_unset=True)
|
f"{config.api_client.url}/{{modelname.lower()}}/",
|
||||||
for key, value in {{modelname.lower()}}_data.items():
|
json=self.dict(),
|
||||||
if value is not None:
|
)
|
||||||
setattr(db_{{modelname.lower()}}, key, value)
|
if r.status_code != 200:
|
||||||
session.add(db_{{modelname.lower()}})
|
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
||||||
session.commit()
|
|
||||||
session.refresh(db_{{modelname.lower()}})
|
|
||||||
return db_{{modelname.lower()}}
|
|
||||||
|
|
||||||
|
|
||||||
class {{modelname}}Delete(BaseModel):
|
class {{modelname}}Delete(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
def delete(self, config: Config) -> {{modelname}}:
|
def delete(self) -> {{modelname}}:
|
||||||
config.init()
|
r = httpx.delete(
|
||||||
with Session(config.database.engine) as session:
|
f"{config.api_client.url}/{{modelname.lower()}}/{self.id}",
|
||||||
{{modelname.lower()}} = session.get({{modelname}}, self.id)
|
)
|
||||||
if not {{modelname.lower()}}:
|
if r.status_code != 200:
|
||||||
raise HTTPException(status_code=404, detail="{{modelname}} not found")
|
raise RuntimeError(f"{r.status_code}:\n {r.text}")
|
||||||
session.delete({{modelname.lower()}})
|
return {"ok": True}
|
||||||
session.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
|
||||||
207
templates/model/tests/{{modelname.lower()}}.py.jinja
Normal file
207
templates/model/tests/{{modelname.lower()}}.py.jinja
Normal file
|
|
@ -0,0 +1,207 @@
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlmodel import SQLModel, Session, select
|
||||||
|
from sqlmodel.pool import StaticPool
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from learn_sql_model.api.app import app
|
||||||
|
from learn_sql_model.config import get_config, get_session
|
||||||
|
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
|
||||||
|
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="session")
|
||||||
|
def session_fixture():
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
|
)
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
with Session(engine) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="client")
|
||||||
|
def client_fixture(session: Session):
|
||||||
|
def get_session_override():
|
||||||
|
return session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = get_session_override
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
yield client
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_post(client: TestClient):
|
||||||
|
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
|
||||||
|
{{modelname.lower()}}_dict = {{modelname.lower()}}.dict()
|
||||||
|
response = client.post("/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {{modelname.lower()}}_dict})
|
||||||
|
response_{{modelname.lower()}} = {{modelname}}.parse_obj(response.json())
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response_{{modelname.lower()}}.name == "Steelman"
|
||||||
|
assert response_{{modelname.lower()}}.age == 25
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_read_{{modelname.lower()}}es(session: Session, client: TestClient):
|
||||||
|
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
{{modelname.lower()}}_2 = {{modelname}}(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
|
||||||
|
session.add({{modelname.lower()}}_1)
|
||||||
|
session.add({{modelname.lower()}}_2)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get("/{{modelname.lower()}}s/")
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
assert len(data) == 2
|
||||||
|
assert data[0]["name"] == {{modelname.lower()}}_1.name
|
||||||
|
assert data[0]["secret_name"] == {{modelname.lower()}}_1.secret_name
|
||||||
|
assert data[0]["age"] == {{modelname.lower()}}_1.age
|
||||||
|
assert data[0]["id"] == {{modelname.lower()}}_1.id
|
||||||
|
assert data[1]["name"] == {{modelname.lower()}}_2.name
|
||||||
|
assert data[1]["secret_name"] == {{modelname.lower()}}_2.secret_name
|
||||||
|
assert data[1]["age"] == {{modelname.lower()}}_2.age
|
||||||
|
assert data[1]["id"] == {{modelname.lower()}}_2.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_read_{{modelname.lower()}}(session: Session, client: TestClient):
|
||||||
|
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add({{modelname.lower()}}_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get(f"/{{modelname.lower()}}/999")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_read_{{modelname.lower()}}_404(session: Session, client: TestClient):
|
||||||
|
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add({{modelname.lower()}}_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert data["name"] == {{modelname.lower()}}_1.name
|
||||||
|
assert data["secret_name"] == {{modelname.lower()}}_1.secret_name
|
||||||
|
assert data["age"] == {{modelname.lower()}}_1.age
|
||||||
|
assert data["id"] == {{modelname.lower()}}_1.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_update_{{modelname.lower()}}(session: Session, client: TestClient):
|
||||||
|
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add({{modelname.lower()}}_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": {{modelname.lower()}}_1.id}}
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert data["name"] == "Deadpuddle"
|
||||||
|
assert data["secret_name"] == "Dive Wilson"
|
||||||
|
assert data["age"] is None
|
||||||
|
assert data["id"] == {{modelname.lower()}}_1.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_update_{{modelname.lower()}}_404(session: Session, client: TestClient):
|
||||||
|
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add({{modelname.lower()}}_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.patch(f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": 999}})
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_{{modelname.lower()}}(session: Session, client: TestClient):
|
||||||
|
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add({{modelname.lower()}}_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.delete(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
|
||||||
|
|
||||||
|
{{modelname.lower()}}_in_db = session.get({{modelname}}, {{modelname.lower()}}_1.id)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
assert {{modelname.lower()}}_in_db is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_{{modelname.lower()}}_404(session: Session, client: TestClient):
|
||||||
|
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add({{modelname.lower()}}_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.delete(f"/{{modelname.lower()}}/999")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_memory(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"learn_sql_model.config.Database.engine",
|
||||||
|
new_callable=lambda: create_engine(
|
||||||
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
|
),
|
||||||
|
)
|
||||||
|
config = get_config()
|
||||||
|
SQLModel.metadata.create_all(config.database.engine)
|
||||||
|
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
|
||||||
|
with config.database.session as session:
|
||||||
|
session.add({{modelname.lower()}})
|
||||||
|
session.commit()
|
||||||
|
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
|
||||||
|
{{modelname.lower()}}es = session.exec(select({{modelname}})).all()
|
||||||
|
assert {{modelname.lower()}}.name == "Steelman"
|
||||||
|
assert {{modelname.lower()}}.age == 25
|
||||||
|
assert len({{modelname.lower()}}es) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_get(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"learn_sql_model.config.Database.engine",
|
||||||
|
new_callable=lambda: create_engine(
|
||||||
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
SQLModel.metadata.create_all(config.database.engine)
|
||||||
|
|
||||||
|
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
|
||||||
|
with config.database.session as session:
|
||||||
|
session.add({{modelname.lower()}})
|
||||||
|
session.commit()
|
||||||
|
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
|
||||||
|
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "1"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert f"name='{{{modelname.lower()}}.name}'" in result.stdout
|
||||||
|
assert f"secret_name='{{{modelname.lower()}}.secret_name}'" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_get_404(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"learn_sql_model.config.Database.engine",
|
||||||
|
new_callable=lambda: create_engine(
|
||||||
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
SQLModel.metadata.create_all(config.database.engine)
|
||||||
|
|
||||||
|
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
|
||||||
|
with config.database.session as session:
|
||||||
|
session.add({{modelname.lower()}})
|
||||||
|
session.commit()
|
||||||
|
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
|
||||||
|
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "999"])
|
||||||
|
assert result.exception.status_code == 404
|
||||||
|
assert result.exception.detail == "{{modelname}} not found"
|
||||||
|
|
||||||
|
|
@ -1,15 +1,12 @@
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlmodel import SQLModel, Session, select
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel.pool import StaticPool
|
||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
from learn_sql_model.api.app import app
|
from learn_sql_model.api.app import app
|
||||||
from learn_sql_model.cli.hero import hero_app
|
from learn_sql_model.config import get_config, get_session
|
||||||
from learn_sql_model.config import Config, get_config
|
|
||||||
from learn_sql_model.factories.hero import HeroFactory
|
from learn_sql_model.factories.hero import HeroFactory
|
||||||
from learn_sql_model.models.hero import Hero
|
from learn_sql_model.models.hero import Hero
|
||||||
|
|
||||||
|
|
@ -17,142 +14,193 @@ runner = CliRunner()
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(name="session")
|
||||||
def config() -> Config:
|
def session_fixture():
|
||||||
tmp_db = tempfile.NamedTemporaryFile(suffix=".db")
|
|
||||||
config = get_config({"database_url": f"sqlite:///{tmp_db.name}"})
|
|
||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
config.database_url, connect_args={"check_same_thread": False}
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
)
|
)
|
||||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
with Session(engine) as session:
|
||||||
# breakpoint()
|
yield session
|
||||||
SQLModel.metadata.create_all(config.database.engine)
|
|
||||||
|
|
||||||
# def override_get_db():
|
|
||||||
# try:
|
|
||||||
# db = TestingSessionLocal()
|
|
||||||
# yield db
|
|
||||||
# finally:
|
|
||||||
# db.close()
|
|
||||||
def override_get_config():
|
|
||||||
return config
|
|
||||||
|
|
||||||
app.dependency_overrides[get_config] = override_get_config
|
|
||||||
|
|
||||||
yield config
|
|
||||||
# tmp_db automatically deletes here
|
|
||||||
|
|
||||||
|
|
||||||
def test_post_hero(config: Config) -> None:
|
@pytest.fixture(name="client")
|
||||||
hero = HeroFactory().build(name="Batman", age=50, id=1)
|
def client_fixture(session: Session):
|
||||||
hero = hero.post(config=config)
|
def get_session_override():
|
||||||
db_hero = Hero().get(id=1, config=config)
|
return session
|
||||||
assert db_hero.age == 50
|
|
||||||
assert db_hero.name == "Batman"
|
app.dependency_overrides[get_session] = get_session_override
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
yield client
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
def test_update_hero(config: Config) -> None:
|
def test_api_post(client: TestClient):
|
||||||
hero = HeroFactory().build(name="Batman", age=50, id=1)
|
|
||||||
hero = hero.post(config=config)
|
|
||||||
db_hero = Hero().get(id=1, config=config)
|
|
||||||
db_hero.name = "Superbman"
|
|
||||||
hero = db_hero.post(config=config)
|
|
||||||
db_hero = Hero().get(id=1, config=config)
|
|
||||||
assert db_hero.age == 50
|
|
||||||
assert db_hero.name == "Superbman"
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_get(config):
|
|
||||||
hero = HeroFactory().build(name="Steelman", age=25, id=99)
|
|
||||||
hero.post(config=config)
|
|
||||||
result = runner.invoke(
|
|
||||||
hero_app,
|
|
||||||
["get", "--id", 99, "--database-url", config.database_url],
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
db_hero = Hero().get(id=99, config=config)
|
|
||||||
assert db_hero.age == 25
|
|
||||||
assert db_hero.name == "Steelman"
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_create(config):
|
|
||||||
hero = HeroFactory().build(name="Steelman", age=25, id=99)
|
|
||||||
result = runner.invoke(
|
|
||||||
hero_app,
|
|
||||||
[
|
|
||||||
"create",
|
|
||||||
*hero.flags(config=config),
|
|
||||||
"--database-url",
|
|
||||||
config.database_url,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
db_hero = Hero().get(id=99, config=config)
|
|
||||||
assert db_hero.age == 25
|
|
||||||
assert db_hero.name == "Steelman"
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_populate(config):
|
|
||||||
result = runner.invoke(
|
|
||||||
hero_app,
|
|
||||||
[
|
|
||||||
"populate",
|
|
||||||
"--n",
|
|
||||||
10,
|
|
||||||
"--database-url",
|
|
||||||
config.database_url,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
db_hero = Hero().get(config=config)
|
|
||||||
assert len(db_hero) == 10
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_populate_fails_prod(config):
|
|
||||||
result = runner.invoke(
|
|
||||||
hero_app,
|
|
||||||
["populate", "--n", 10, "--database-url", config.database_url, "--env", "prod"],
|
|
||||||
)
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert result.output.strip() == "populate is not supported in production"
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_read(config):
|
|
||||||
hero = HeroFactory().build(name="Steelman", age=25, id=99)
|
|
||||||
hero_id = hero.id
|
|
||||||
hero = hero.post(config=config)
|
|
||||||
response = client.get(f"/hero/{hero_id}")
|
|
||||||
assert response.status_code == 200
|
|
||||||
reponse_hero = Hero.parse_obj(response.json())
|
|
||||||
assert reponse_hero.id == hero_id
|
|
||||||
assert reponse_hero.name == "Steelman"
|
|
||||||
assert reponse_hero.age == 25
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_post(config):
|
|
||||||
hero = HeroFactory().build(name="Steelman", age=25)
|
hero = HeroFactory().build(name="Steelman", age=25)
|
||||||
hero_dict = hero.dict()
|
hero_dict = hero.dict()
|
||||||
response = client.post("/hero/", json={"hero": hero_dict})
|
response = client.post("/hero/", json={"hero": hero_dict})
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
response_hero = Hero.parse_obj(response.json())
|
response_hero = Hero.parse_obj(response.json())
|
||||||
db_hero = Hero().get(id=response_hero.id, config=config)
|
|
||||||
|
|
||||||
assert db_hero.name == "Steelman"
|
|
||||||
assert db_hero.age == 25
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_read_all(config):
|
|
||||||
hero = HeroFactory().build(name="Mothman", age=25, id=99)
|
|
||||||
hero_id = hero.id
|
|
||||||
hero = hero.post(config=config)
|
|
||||||
response = client.get("/heros/")
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
heros = response.json()
|
assert response_hero.name == "Steelman"
|
||||||
response_hero_json = [hero for hero in heros if hero["id"] == hero_id][0]
|
|
||||||
response_hero = Hero.parse_obj(response_hero_json)
|
|
||||||
assert response_hero.id == hero_id
|
|
||||||
assert response_hero.name == "Mothman"
|
|
||||||
assert response_hero.age == 25
|
assert response_hero.age == 25
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_read_heroes(session: Session, client: TestClient):
|
||||||
|
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
hero_2 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
|
||||||
|
session.add(hero_1)
|
||||||
|
session.add(hero_2)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get("/heros/")
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
assert len(data) == 2
|
||||||
|
assert data[0]["name"] == hero_1.name
|
||||||
|
assert data[0]["secret_name"] == hero_1.secret_name
|
||||||
|
assert data[0]["age"] == hero_1.age
|
||||||
|
assert data[0]["id"] == hero_1.id
|
||||||
|
assert data[1]["name"] == hero_2.name
|
||||||
|
assert data[1]["secret_name"] == hero_2.secret_name
|
||||||
|
assert data[1]["age"] == hero_2.age
|
||||||
|
assert data[1]["id"] == hero_2.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_read_hero(session: Session, client: TestClient):
|
||||||
|
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add(hero_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get(f"/hero/999")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_read_hero_404(session: Session, client: TestClient):
|
||||||
|
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add(hero_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.get(f"/hero/{hero_1.id}")
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert data["name"] == hero_1.name
|
||||||
|
assert data["secret_name"] == hero_1.secret_name
|
||||||
|
assert data["age"] == hero_1.age
|
||||||
|
assert data["id"] == hero_1.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_update_hero(session: Session, client: TestClient):
|
||||||
|
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add(hero_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.patch(
|
||||||
|
f"/hero/", json={"hero": {"name": "Deadpuddle", "id": hero_1.id}}
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert data["name"] == "Deadpuddle"
|
||||||
|
assert data["secret_name"] == "Dive Wilson"
|
||||||
|
assert data["age"] is None
|
||||||
|
assert data["id"] == hero_1.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_update_hero_404(session: Session, client: TestClient):
|
||||||
|
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add(hero_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.patch(f"/hero/", json={"hero": {"name": "Deadpuddle", "id": 999}})
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_hero(session: Session, client: TestClient):
|
||||||
|
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add(hero_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.delete(f"/hero/{hero_1.id}")
|
||||||
|
|
||||||
|
hero_in_db = session.get(Hero, hero_1.id)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
assert hero_in_db is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_hero_404(session: Session, client: TestClient):
|
||||||
|
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||||
|
session.add(hero_1)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = client.delete(f"/hero/999")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_memory(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"learn_sql_model.config.Database.engine",
|
||||||
|
new_callable=lambda: create_engine(
|
||||||
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
|
),
|
||||||
|
)
|
||||||
|
config = get_config()
|
||||||
|
SQLModel.metadata.create_all(config.database.engine)
|
||||||
|
hero = HeroFactory().build(name="Steelman", age=25)
|
||||||
|
with config.database.session as session:
|
||||||
|
session.add(hero)
|
||||||
|
session.commit()
|
||||||
|
hero = session.get(Hero, hero.id)
|
||||||
|
heroes = session.exec(select(Hero)).all()
|
||||||
|
assert hero.name == "Steelman"
|
||||||
|
assert hero.age == 25
|
||||||
|
assert len(heroes) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_get(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"learn_sql_model.config.Database.engine",
|
||||||
|
new_callable=lambda: create_engine(
|
||||||
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
SQLModel.metadata.create_all(config.database.engine)
|
||||||
|
|
||||||
|
hero = HeroFactory().build(name="Steelman", age=25)
|
||||||
|
with config.database.session as session:
|
||||||
|
session.add(hero)
|
||||||
|
session.commit()
|
||||||
|
hero = session.get(Hero, hero.id)
|
||||||
|
result = runner.invoke(hero_app, ["get", "--hero-id", "1"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert f"name='{hero.name}'" in result.stdout
|
||||||
|
assert f"secret_name='{hero.secret_name}'" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_get_404(mocker):
|
||||||
|
mocker.patch(
|
||||||
|
"learn_sql_model.config.Database.engine",
|
||||||
|
new_callable=lambda: create_engine(
|
||||||
|
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
SQLModel.metadata.create_all(config.database.engine)
|
||||||
|
|
||||||
|
hero = HeroFactory().build(name="Steelman", age=25)
|
||||||
|
with config.database.session as session:
|
||||||
|
session.add(hero)
|
||||||
|
session.commit()
|
||||||
|
hero = session.get(Hero, hero.id)
|
||||||
|
result = runner.invoke(hero_app, ["get", "--hero-id", "999"])
|
||||||
|
assert result.exception.status_code == 404
|
||||||
|
assert result.exception.detail == "Hero not found"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue