diff --git a/.pyflyby b/.pyflyby index c61c116..5e453ff 100644 --- a/.pyflyby +++ b/.pyflyby @@ -2,6 +2,7 @@ from learn_sql_model.api.websocket_connection_manager import manager from learn_sql_model.config import Config from learn_sql_model.config import get_config +from learn_sql_model.config import get_session from learn_sql_model.console import console from learn_sql_model.database import get_database from learn_sql_model.factories.hero import HeroFactory diff --git a/Dockerfile.dev b/Dockerfile.dev index dcfba07..c636e7b 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -12,8 +12,6 @@ RUN groupadd -f -g ${SMOKE_GID} smoke && \ RUN mkdir /home/smoke && chown -R smoke:smoke /home/smoke && mkdir /src && chown smoke:smoke /src WORKDIR /home/smoke - - RUN apt update && \ apt upgrade -y && \ apt install -y \ diff --git a/learn_sql_model/api/hero.py b/learn_sql_model/api/hero.py index 5f115c9..0f14b4d 100644 --- a/learn_sql_model/api/hero.py +++ b/learn_sql_model/api/hero.py @@ -1,18 +1,9 @@ -from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import SQLModel, Session -from fastapi import APIRouter, Depends -from sqlmodel import SQLModel - -from learn_sql_model.api.user import oauth2_scheme from learn_sql_model.api.websocket_connection_manager import manager -from learn_sql_model.config import Config, get_config -from learn_sql_model.models.hero import ( - Hero, - HeroCreate, - HeroDelete, - HeroRead, - HeroUpdate, -) +from learn_sql_model.config import get_config, get_session +from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, HeroUpdate hero_router = APIRouter() @@ -22,52 +13,73 @@ def on_startup() -> None: SQLModel.metadata.create_all(get_config().database.engine) -@hero_router.get("/items/") -async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): - return {"token": token} - - -@hero_router.get("/hero/{id}") -async def get_hero(id: int, config: Config = Depends(get_config)) -> Hero: +@hero_router.get("/hero/{hero_id}") +async def get_hero( + *, + session: Session = Depends(get_session), + hero_id: int, +) -> HeroRead: "get one hero" - return Hero().get(id=id, config=config) - - -@hero_router.get("/h/{id}") -async def get_h(id: int, config: Config = Depends(get_config)) -> Hero: - "get one hero" - return Hero().get(id=id, config=config) + hero = session.get(Hero, hero_id) + if not hero: + raise HTTPException(status_code=404, detail="Hero not found") + return hero @hero_router.post("/hero/") -async def post_hero(hero: HeroCreate) -> HeroRead: +async def post_hero( + *, + session: Session = Depends(get_session), + hero: HeroCreate, +) -> HeroRead: "read all the heros" - config = get_config() - hero = hero.post(config=config) + db_hero = Hero.from_orm(hero) + session.add(db_hero) + session.commit() + session.refresh(db_hero) await manager.broadcast({hero.json()}, id=1) - return hero + return db_hero @hero_router.patch("/hero/") -async def patch_hero(hero: HeroUpdate) -> HeroRead: +async def patch_hero( + *, + session: Session = Depends(get_session), + hero: HeroUpdate, +) -> HeroRead: "read all the heros" - config = get_config() - hero = hero.update(config=config) + db_hero = session.get(Hero, hero.id) + if not db_hero: + raise HTTPException(status_code=404, detail="Hero not found") + for key, value in hero.dict(exclude_unset=True).items(): + setattr(db_hero, key, value) + session.add(db_hero) + session.commit() + session.refresh(db_hero) await manager.broadcast({hero.json()}, id=1) - return hero + return db_hero @hero_router.delete("/hero/{hero_id}") -async def delete_hero(hero_id: int): +async def delete_hero( + *, + session: Session = Depends(get_session), + hero_id: int, +): "read all the heros" - hero = HeroDelete(id=hero_id) - config = get_config() - hero = hero.delete(config=config) + hero = session.get(Hero, hero_id) + if not hero: + raise HTTPException(status_code=404, detail="Hero not found") + session.delete(hero) + session.commit() await manager.broadcast(f"deleted hero {hero_id}", id=1) - return hero + return {"ok": True} @hero_router.get("/heros/") -async def get_heros(config: Config = Depends(get_config)) -> list[Hero]: +async def get_heros( + *, + session: Session = Depends(get_session), +) -> list[Hero]: "get all heros" - return Hero().get(config=config) + return HeroRead.list(session=session) diff --git a/learn_sql_model/api/new.py b/learn_sql_model/api/new.py deleted file mode 100644 index 813ec54..0000000 --- a/learn_sql_model/api/new.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Annotated - -from fastapi import APIRouter, Depends -from sqlmodel import SQLModel - -from learn_sql_model.api.user import oauth2_scheme -from learn_sql_model.config import Config, get_config -from learn_sql_model.models.new import new - -new_router = APIRouter() - - -@new_router.on_event("startup") -def on_startup() -> None: - SQLModel.metadata.create_all(get_config().database.engine) - - -@new_router.get("/items/") -async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): - return {"token": token} - - -@new_router.get("/new/{id}") -def get_new(id: int, config: Config = Depends(get_config)) -> new: - "get one new" - return new().get(id=id, config=config) - - -@new_router.get("/h/{id}") -def get_h(id: int, config: Config = Depends(get_config)) -> new: - "get one new" - return new().get(id=id, config=config) - - -@new_router.post("/new/") -def post_new(new: new, config: Config = Depends(get_config)) -> new: - "read all the news" - new.post(config=config) - return new - - -@new_router.get("/news/") -def get_news(config: Config = Depends(get_config)) -> list[new]: - "get all news" - return new().get(config=config) diff --git a/learn_sql_model/api/websocket.py b/learn_sql_model/api/websocket.py index e0e8b5e..670f90f 100644 --- a/learn_sql_model/api/websocket.py +++ b/learn_sql_model/api/websocket.py @@ -69,4 +69,4 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: manager.disconnect(websocket, id) - await manager.broadcast(f"Client #{client_id} left the chat", id) + await manager.broadcast(f"Client #{id} left the chat", id) diff --git a/learn_sql_model/cli/api.py b/learn_sql_model/cli/api.py index 76657cc..24ccb13 100644 --- a/learn_sql_model/cli/api.py +++ b/learn_sql_model/cli/api.py @@ -1,3 +1,4 @@ +import httpx from rich.console import Console import typer import uvicorn @@ -38,15 +39,11 @@ def status( help="show the log messages", ), ): - import httpx - config = get_config() - host = config.api_server.host - port = config.api_server.port - url = f"http://{host}:{port}/docs" + url = config.api_client.url try: - r = httpx.get(url) + r = httpx.get(url + "/docs") if r.status_code == 200: Console().print(f"[green]API: ([gold1]{url}[green]) is running") else: @@ -59,7 +56,7 @@ def status( Console().print( f"[green]database: ([gold1]{config.database.engine}[green]) is running" ) - except Exception as e: + except Exception: Console().print( f"[red]database: ([gold1]{config.database.engine}[red]) is not running" ) diff --git a/learn_sql_model/cli/hero.py b/learn_sql_model/cli/hero.py index 267c413..b3e938a 100644 --- a/learn_sql_model/cli/hero.py +++ b/learn_sql_model/cli/hero.py @@ -2,13 +2,11 @@ import sys from typing import List, Optional, Union from engorgio import engorgio -import httpx from rich.console import Console import typer -from learn_sql_model.config import Config, get_config +from learn_sql_model.config import get_config from learn_sql_model.factories.hero import HeroFactory -from learn_sql_model.factories.pet import PetFactory from learn_sql_model.models.hero import ( Hero, HeroCreate, @@ -19,6 +17,8 @@ from learn_sql_model.models.hero import ( hero_app = typer.Typer() +config = get_config() + @hero_app.callback() def hero(): @@ -28,12 +28,10 @@ def hero(): @hero_app.command() @engorgio(typer=True) def get( - id: Optional[int] = typer.Argument(default=None), - config: Config = None, + hero_id: Optional[int] = typer.Argument(default=None), ) -> Union[Hero, List[Hero]]: "get one hero" - config.init() - hero = HeroRead.get(id=id, config=config) + hero = HeroRead.get(id=hero_id) Console().print(hero) return hero @@ -42,12 +40,11 @@ def get( @engorgio(typer=True) def list( where: Optional[str] = None, - config: Config = None, offset: int = 0, limit: Optional[int] = None, ) -> Union[Hero, List[Hero]]: - "get one hero" - hero = HeroRead.list(config=config, where=where, offset=offset, limit=limit) + "list many heros" + hero = HeroRead.list(where=where, offset=offset, limit=limit) Console().print(hero) return hero @@ -56,65 +53,39 @@ def list( @engorgio(typer=True) def create( hero: HeroCreate, - config: Config = None, ) -> Hero: - "read all the heros" - - r = httpx.post( - f"{config.api_client.url}/hero/", - json=hero.dict(), - ) - if r.status_code != 200: - raise RuntimeError(f"{r.status_code}:\n {r.text}") - - # hero = hero.post(config=config) - # Console().print(hero) - # return hero + "create one hero" + hero.post() @hero_app.command() @engorgio(typer=True) def update( hero: HeroUpdate, - config: Config = None, ) -> Hero: - "read all the heros" - r = httpx.patch( - f"{config.api_client.url}/hero/", - json=hero.dict(), - ) - if r.status_code != 200: - raise RuntimeError(f"{r.status_code}:\n {r.text}") + "update one hero" + hero.update() @hero_app.command() @engorgio(typer=True) def delete( hero: HeroDelete, - config: Config = None, ) -> Hero: - "read all the heros" - r = httpx.delete( - f"{config.api_client.url}/hero/{hero.id}", - ) - if r.status_code != 200: - raise RuntimeError(f"{r.status_code}:\n {r.text}") + "delete a hero by id" + hero.delete() @hero_app.command() @engorgio(typer=True) def populate( - hero: Hero, n: int = 10, ) -> Hero: - "read all the heros" - config = get_config() + "Create n number of heros" if config.env == "prod": Console().print("populate is not supported in production") sys.exit(1) for hero in HeroFactory().batch(n): - pet = PetFactory().build() - hero.pet = pet - Console().print(hero) - hero.post(config=config) + hero = HeroCreate(**hero.dict()) + hero.post() diff --git a/learn_sql_model/cli/model.py b/learn_sql_model/cli/model.py index df6a27d..53bb419 100644 --- a/learn_sql_model/cli/model.py +++ b/learn_sql_model/cli/model.py @@ -44,7 +44,6 @@ def create_revision( prompt=True, ), ): - alembic_cfg = Config("alembic.ini") alembic.command.revision( config=alembic_cfg, @@ -63,7 +62,6 @@ def checkout( ), revision: str = typer.Option("head"), ): - alembic_cfg = Config("alembic.ini") alembic.command.upgrade(config=alembic_cfg, revision="head") diff --git a/learn_sql_model/cli/new.py b/learn_sql_model/cli/new.py deleted file mode 100644 index 20c90cf..0000000 --- a/learn_sql_model/cli/new.py +++ /dev/null @@ -1,107 +0,0 @@ -import sys -from typing import List, Optional, Union - -from engorgio import engorgio -from rich.console import Console -import typer - -from learn_sql_model.config import Config, get_config -from learn_sql_model.factories.new import newFactory -from learn_sql_model.factories.pet import PetFactory -from learn_sql_model.models.new import ( - new, - newCreate, - newDelete, - newRead, - newUpdate, -) - -new_app = typer.Typer() - - -@new_app.callback() -def new(): - "model cli" - - -@new_app.command() -@engorgio(typer=True) -def get( - id: Optional[int] = typer.Argument(default=None), - config: Config = None, -) -> Union[new, List[new]]: - "get one new" - config.init() - new = newRead.get(id=id, config=config) - Console().print(new) - return new - - -@new_app.command() -@engorgio(typer=True) -def list( - where: Optional[str] = None, - config: Config = None, - offset: int = 0, - limit: Optional[int] = None, -) -> Union[new, List[new]]: - "get one new" - new = newRead.list(config=config, where=where, offset=offset, limit=limit) - Console().print(new) - return new - - -@new_app.command() -@engorgio(typer=True) -def create( - new: newCreate, - config: Config = None, -) -> new: - "read all the news" - # config.init() - new = new.post(config=config) - Console().print(new) - return new - - -@new_app.command() -@engorgio(typer=True) -def update( - new: newUpdate, - config: Config = None, -) -> new: - "read all the news" - new = new.update(config=config) - Console().print(new) - return new - - -@new_app.command() -@engorgio(typer=True) -def delete( - new: newDelete, - config: Config = None, -) -> new: - "read all the news" - # config.init() - new = new.delete(config=config) - return new - - -@new_app.command() -@engorgio(typer=True) -def populate( - new: new, - n: int = 10, -) -> new: - "read all the news" - config = get_config() - if config.env == "prod": - Console().print("populate is not supported in production") - sys.exit(1) - - for new in newFactory().batch(n): - pet = PetFactory().build() - new.pet = pet - Console().print(new) - new.post(config=config) diff --git a/learn_sql_model/config.py b/learn_sql_model/config.py index 62d28f6..500642b 100644 --- a/learn_sql_model/config.py +++ b/learn_sql_model/config.py @@ -30,7 +30,6 @@ class ApiClient(BaseModel): class Database: def __init__(self, config: "Config" = None) -> None: if config is None: - self.config = get_config() else: self.config = config @@ -71,17 +70,26 @@ class Config(BaseSettings): def get_database(config: Config = None) -> Database: - if config is None: config = get_config() - return Database(config) +def get_config(overrides: dict = {}) -> Config: + raw_config = load("learn_sql_model") + config = Config(**raw_config, **overrides) + return config + + +def get_session(config: Config = None) -> "Session": + with Session(config.database.engine) as session: + yield session + + async def reset_db_state(config: Config = None) -> None: if config is None: config = get_config() - config.database.db._state._state.set(db_state_default.copy()) + config.database.db._state._state.set(config.database.db_state_default.copy()) config.database.db._state.reset() @@ -96,7 +104,4 @@ def get_db(config: Config = None, reset_db_state=Depends(reset_db_state)): config.database.db.close() -def get_config(overrides: dict = {}) -> Config: - raw_config = load("learn_sql_model") - config = Config(**raw_config, **overrides) - return config +config = get_config() diff --git a/learn_sql_model/factories/new.py b/learn_sql_model/factories/new.py deleted file mode 100644 index 8432c9c..0000000 --- a/learn_sql_model/factories/new.py +++ /dev/null @@ -1,14 +0,0 @@ -from faker import Faker -from polyfactory.factories.pydantic_factory import ModelFactory - -from learn_sql_model.models.new import new - - -class newFactory(ModelFactory[new]): - __model__ = new - __faker__ = Faker(locale="en_US") - __set_as_default_factory_for_type__ = True - id = None - - __random_seed__ = 10 - diff --git a/learn_sql_model/models/fast_model.py b/learn_sql_model/models/fast_model.py index 1b75c91..14d2a9f 100644 --- a/learn_sql_model/models/fast_model.py +++ b/learn_sql_model/models/fast_model.py @@ -21,7 +21,6 @@ class FastModel(SQLModel): def post(self, config: "Config" = None) -> None: if config is None: - config = get_config() self.pre_post() @@ -36,7 +35,6 @@ class FastModel(SQLModel): self, id: int = None, config: "Config" = None, where=None ) -> Optional["FastModel"]: if config is None: - config = get_config() self.pre_get() diff --git a/learn_sql_model/models/hero.py b/learn_sql_model/models/hero.py index 2f7c791..f349905 100644 --- a/learn_sql_model/models/hero.py +++ b/learn_sql_model/models/hero.py @@ -1,10 +1,11 @@ from typing import Optional from fastapi import HTTPException +import httpx from pydantic import BaseModel from sqlmodel import Field, Relationship, SQLModel, Session, select -from learn_sql_model.config import Config +from learn_sql_model.config import config, get_session from learn_sql_model.models.pet import Pet @@ -25,14 +26,13 @@ class Hero(HeroBase, table=True): class HeroCreate(HeroBase): ... - def post(self, config: Config) -> Hero: - config.init() - with Session(config.database.engine) as session: - db_hero = Hero.from_orm(self) - session.add(db_hero) - session.commit() - session.refresh(db_hero) - return db_hero + def post(self) -> Hero: + r = httpx.post( + f"{config.api_client.url}/hero/", + json=self.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") class HeroRead(HeroBase): @@ -41,10 +41,8 @@ class HeroRead(HeroBase): @classmethod def get( cls, - config: Config, id: int, ) -> Hero: - with config.database.session as session: hero = session.get(Hero, id) if not hero: @@ -54,25 +52,25 @@ class HeroRead(HeroBase): @classmethod def list( self, - config: Config, where=None, offset=0, limit=None, + session: Session = get_session, ) -> Hero: + # with config.database.session as session: - with config.database.session as session: - statement = select(Hero) - if where != "None": - from sqlmodel import text + statement = select(Hero) + if where != "None" and where is not None: + from sqlmodel import text - statement = statement.where(text(where)) - statement = statement.offset(offset).limit(limit) - heroes = session.exec(statement).all() + statement = statement.where(text(where)) + statement = statement.offset(offset).limit(limit) + heroes = session.exec(statement).all() return heroes class HeroUpdate(SQLModel): - # id is required to get the hero + # id is required to update the hero id: int # all other fields, must match the model, but with Optional default None @@ -84,30 +82,22 @@ class HeroUpdate(SQLModel): pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") pet: Optional[Pet] = Relationship(back_populates="hero") - def update(self, config: Config) -> Hero: - with Session(config.database.engine) as session: - db_hero = session.get(Hero, self.id) - if not db_hero: - raise HTTPException(status_code=404, detail="Hero not found") - hero_data = self.dict(exclude_unset=True) - for key, value in hero_data.items(): - if value is not None: - setattr(db_hero, key, value) - session.add(db_hero) - session.commit() - session.refresh(db_hero) - return db_hero + def update(self) -> Hero: + r = httpx.patch( + f"{config.api_client.url}/hero/", + json=self.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") class HeroDelete(BaseModel): id: int - def delete(self, config: Config) -> Hero: - config.init() - with Session(config.database.engine) as session: - hero = session.get(Hero, self.id) - if not hero: - raise HTTPException(status_code=404, detail="Hero not found") - session.delete(hero) - session.commit() - return {"ok": True} + def delete(self) -> Hero: + r = httpx.delete( + f"{config.api_client.url}/hero/{self.id}", + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + return {"ok": True} diff --git a/learn_sql_model/models/new.py b/learn_sql_model/models/new.py deleted file mode 100644 index 5acca4f..0000000 --- a/learn_sql_model/models/new.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Optional - -from fastapi import HTTPException -from pydantic import BaseModel -from sqlmodel import Field, Relationship, SQLModel, Session, select - -from learn_sql_model.config import Config -from learn_sql_model.models.pet import Pet - - -class newBase(SQLModel, table=False): - - -class new(newBase, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - - -class newCreate(newBase): - ... - - def post(self, config: Config) -> new: - config.init() - with Session(config.database.engine) as session: - db_new = new.from_orm(self) - session.add(db_new) - session.commit() - session.refresh(db_new) - return db_new - - -class newRead(newBase): - id: int - - @classmethod - def get( - cls, - config: Config, - id: int, - ) -> new: - - with config.database.session as session: - new = session.get(new, id) - if not new: - raise HTTPException(status_code=404, detail="new not found") - return new - - @classmethod - def list( - self, - config: Config, - where=None, - offset=0, - limit=None, - ) -> new: - - with config.database.session as session: - statement = select(new) - if where != "None": - from sqlmodel import text - - statement = statement.where(text(where)) - statement = statement.offset(offset).limit(limit) - newes = session.exec(statement).all() - return newes - - -class newUpdate(SQLModel): - # id is required to get the new - id: int - - # all other fields, must match the model, but with Optional default None - - def update(self, config: Config) -> new: - with Session(config.database.engine) as session: - db_new = session.get(new, self.id) - if not db_new: - raise HTTPException(status_code=404, detail="new not found") - new_data = self.dict(exclude_unset=True) - for key, value in new_data.items(): - if value is not None: - setattr(db_new, key, value) - session.add(db_new) - session.commit() - session.refresh(db_new) - return db_new - - -class newDelete(BaseModel): - id: int - - def delete(self, config: Config) -> new: - config.init() - with Session(config.database.engine) as session: - new = session.get(new, self.id) - if not new: - raise HTTPException(status_code=404, detail="new not found") - session.delete(new) - session.commit() - return {"ok": True} diff --git a/pyproject.toml b/pyproject.toml index b0c9f41..ec33f87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,14 +77,17 @@ dependencies = [ [tool.hatch.envs.default.scripts] test = "coverage run -m pytest" cov = "coverage-rich report" -test-cov = ['test', 'cov'] +cov-erase = "coverage erase" lint = "ruff learn_sql_model" format = "black learn_sql_model" format-check = "black --check learn_sql_model" +fix_ruff = "ruff --fix learn_sql_model" +fix = ['format', 'fix_ruff'] build-docs = "markata build" lint-test = [ "lint", "format-check", + "cov-erase", "test", "cov", ] diff --git a/templates/api.py b/templates/api.py index 782b7da..7686d79 100644 --- a/templates/api.py +++ b/templates/api.py @@ -18,3 +18,9 @@ def read_heroes(hero: Hero) -> list[Hero]: def read_heros() -> list[Hero]: "read all the heros" return Hero.get() + + +@app.patch("/heros/") +def update_heros() -> list[Hero]: + "read all the heros" + return Hero.get() diff --git a/templates/model/learn_sql_model/api/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/api/{{modelname.lower()}}.py.jinja index 663dea6..e1b86b7 100644 --- a/templates/model/learn_sql_model/api/{{modelname.lower()}}.py.jinja +++ b/templates/model/learn_sql_model/api/{{modelname.lower()}}.py.jinja @@ -1,11 +1,9 @@ -from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import SQLModel, Session -from fastapi import APIRouter, Depends -from sqlmodel import SQLModel - -from learn_sql_model.api.user import oauth2_scheme -from learn_sql_model.config import Config, get_config -from learn_sql_model.models.{{modelname.lower()}} import {{modelname}} +from learn_sql_model.api.websocket_connection_manager import manager +from learn_sql_model.config import get_config, get_session +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}, {{modelname}}Create, {{modelname}}Read, {{modelname}}Update {{modelname.lower()}}_router = APIRouter() @@ -15,31 +13,74 @@ def on_startup() -> None: SQLModel.metadata.create_all(get_config().database.engine) -@{{modelname.lower()}}_router.get("/items/") -async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): - return {"token": token} - - -@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{id}") -def get_{{modelname.lower()}}(id: int, config: Config = Depends(get_config)) -> {{modelname}}: +@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{{{modelname.lower()}}_id}") +async def get_{{modelname.lower()}}( + *, + session: Session = Depends(get_session), + {{modelname.lower()}}_id: int, +) -> {{modelname}}Read: "get one {{modelname.lower()}}" - return {{modelname}}().get(id=id, config=config) - - -@{{modelname.lower()}}_router.get("/h/{id}") -def get_h(id: int, config: Config = Depends(get_config)) -> {{modelname}}: - "get one {{modelname.lower()}}" - return {{modelname}}().get(id=id, config=config) - - -@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/") -def post_{{modelname.lower()}}({{modelname.lower()}}: {{modelname}}, config: Config = Depends(get_config)) -> {{modelname.lower()}}: - "read all the {{modelname.lower()}}s" - {{modelname.lower()}}.post(config=config) + {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id) + if not {{modelname.lower()}}: + raise HTTPException(status_code=404, detail="{{modelname}} not found") return {{modelname.lower()}} +@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/") +async def post_{{modelname.lower()}}( + *, + session: Session = Depends(get_session), + {{modelname.lower()}}: {{modelname}}Create, +) -> {{modelname}}Read: + "read all the {{modelname.lower()}}s" + db_{{modelname.lower()}} = {{modelname}}.from_orm({{modelname.lower()}}) + session.add(db_{{modelname.lower()}}) + session.commit() + session.refresh(db_{{modelname.lower()}}) + await manager.broadcast({{{modelname.lower()}}.json()}, id=1) + return db_{{modelname.lower()}} + + +@{{modelname.lower()}}_router.patch("/{{modelname.lower()}}/") +async def patch_{{modelname.lower()}}( + *, + session: Session = Depends(get_session), + {{modelname.lower()}}: {{modelname}}Update, +) -> {{modelname}}Read: + "read all the {{modelname.lower()}}s" + db_{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id) + if not db_{{modelname.lower()}}: + raise HTTPException(status_code=404, detail="{{modelname}} not found") + for key, value in {{modelname.lower()}}.dict(exclude_unset=True).items(): + setattr(db_{{modelname.lower()}}, key, value) + session.add(db_{{modelname.lower()}}) + session.commit() + session.refresh(db_{{modelname.lower()}}) + await manager.broadcast({{{modelname.lower()}}.json()}, id=1) + return db_{{modelname.lower()}} + + +@{{modelname.lower()}}_router.delete("/{{modelname.lower()}}/{{{modelname.lower()}}_id}") +async def delete_{{modelname.lower()}}( + *, + session: Session = Depends(get_session), + {{modelname.lower()}}_id: int, +): + "read all the {{modelname.lower()}}s" + {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id) + if not {{modelname.lower()}}: + raise HTTPException(status_code=404, detail="{{modelname}} not found") + session.delete({{modelname.lower()}}) + session.commit() + await manager.broadcast(f"deleted {{modelname.lower()}} {{{modelname.lower()}}_id}", id=1) + return {"ok": True} + + @{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/") -def get_{{modelname.lower()}}s(config: Config = Depends(get_config)) -> list[{{modelname}}]: +async def get_{{modelname.lower()}}s( + *, + session: Session = Depends(get_session), +) -> list[{{modelname}}]: "get all {{modelname.lower()}}s" - return {{modelname}}().get(config=config) + return {{modelname}}Read.list(session=session) + diff --git a/templates/model/learn_sql_model/cli/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/cli/{{modelname.lower()}}.py.jinja index 80f2057..bb0391b 100644 --- a/templates/model/learn_sql_model/cli/{{modelname.lower()}}.py.jinja +++ b/templates/model/learn_sql_model/cli/{{modelname.lower()}}.py.jinja @@ -5,9 +5,8 @@ from engorgio import engorgio from rich.console import Console import typer -from learn_sql_model.config import Config, get_config +from learn_sql_model.config import get_config from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory -from learn_sql_model.factories.pet import PetFactory from learn_sql_model.models.{{modelname.lower()}} import ( {{modelname}}, {{modelname}}Create, @@ -18,6 +17,8 @@ from learn_sql_model.models.{{modelname.lower()}} import ( {{modelname.lower()}}_app = typer.Typer() +config = get_config() + @{{modelname.lower()}}_app.callback() def {{modelname.lower()}}(): @@ -27,12 +28,10 @@ def {{modelname.lower()}}(): @{{modelname.lower()}}_app.command() @engorgio(typer=True) def get( - id: Optional[int] = typer.Argument(default=None), - config: Config = None, -) -> Union[{{modelname}}, List[{{modelname.lower()}}]]: + {{modelname.lower()}}_id: Optional[int] = typer.Argument(default=None), +) -> Union[{{modelname}}, List[{{modelname}}]]: "get one {{modelname.lower()}}" - config.init() - {{modelname.lower()}} = {{modelname}}Read.get(id=id, config=config) + {{modelname.lower()}} = {{modelname}}Read.get(id={{modelname.lower()}}_id) Console().print({{modelname.lower()}}) return {{modelname.lower()}} @@ -41,12 +40,11 @@ def get( @engorgio(typer=True) def list( where: Optional[str] = None, - config: Config = None, offset: int = 0, limit: Optional[int] = None, -) -> Union[{{modelname}}, List[{{modelname.lower()}}]]: - "get one {{modelname.lower()}}" - {{modelname.lower()}} = {{modelname}}Read.list(config=config, where=where, offset=offset, limit=limit) +) -> Union[{{modelname}}, List[{{modelname}}]]: + "list many {{modelname.lower()}}s" + {{modelname.lower()}} = {{modelname}}Read.list(where=where, offset=offset, limit=limit) Console().print({{modelname.lower()}}) return {{modelname.lower()}} @@ -55,53 +53,40 @@ def list( @engorgio(typer=True) def create( {{modelname.lower()}}: {{modelname}}Create, - config: Config = None, ) -> {{modelname}}: - "read all the {{modelname.lower()}}s" - # config.init() - {{modelname.lower()}} = {{modelname.lower()}}.post(config=config) - Console().print({{modelname.lower()}}) - return {{modelname.lower()}} + "create one {{modelname.lower()}}" + {{modelname.lower()}}.post() @{{modelname.lower()}}_app.command() @engorgio(typer=True) def update( {{modelname.lower()}}: {{modelname}}Update, - config: Config = None, ) -> {{modelname}}: - "read all the {{modelname.lower()}}s" - {{modelname.lower()}} = {{modelname.lower()}}.update(config=config) - Console().print({{modelname.lower()}}) - return {{modelname.lower()}} + "update one {{modelname.lower()}}" + {{modelname.lower()}}.update() @{{modelname.lower()}}_app.command() @engorgio(typer=True) def delete( {{modelname.lower()}}: {{modelname}}Delete, - config: Config = None, ) -> {{modelname}}: - "read all the {{modelname.lower()}}s" - # config.init() - {{modelname.lower()}} = {{modelname.lower()}}.delete(config=config) - return {{modelname.lower()}} + "delete a {{modelname.lower()}} by id" + {{modelname.lower()}}.delete() @{{modelname.lower()}}_app.command() @engorgio(typer=True) def populate( - {{modelname.lower()}}: {{modelname}}, n: int = 10, ) -> {{modelname}}: - "read all the {{modelname.lower()}}s" - config = get_config() + "Create n number of {{modelname.lower()}}s" if config.env == "prod": Console().print("populate is not supported in production") sys.exit(1) for {{modelname.lower()}} in {{modelname}}Factory().batch(n): - pet = PetFactory().build() - {{modelname.lower()}}.pet = pet - Console().print({{modelname.lower()}}) - {{modelname.lower()}}.post(config=config) + {{modelname.lower()}} = {{modelname}}Create(**{{modelname.lower()}}.dict()) + {{modelname.lower()}}.post() + diff --git a/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja index 0aef84c..992c33c 100644 --- a/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja +++ b/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja @@ -1,43 +1,41 @@ from typing import Optional -from fastapi import HTTPException +from fastapi import Depends, HTTPException +import httpx from pydantic import BaseModel from sqlmodel import Field, Relationship, SQLModel, Session, select -from learn_sql_model.config import Config +from learn_sql_model.config import config, get_session from learn_sql_model.models.pet import Pet class {{modelname}}Base(SQLModel, table=False): -class {{modelname}}({{modelname.lower()}}Base, table=True): +class {{modelname}}({{modelname}}Base, table=True): id: Optional[int] = Field(default=None, primary_key=True) -class {{modelname}}Create({{modelname.lower()}}Base): +class {{modelname}}Create({{modelname}}Base): ... - def post(self, config: Config) -> {{modelname}}: - config.init() - with Session(config.database.engine) as session: - db_{{modelname.lower()}} = {{modelname}}.from_orm(self) - session.add(db_{{modelname.lower()}}) - session.commit() - session.refresh(db_{{modelname.lower()}}) - return db_{{modelname.lower()}} + def post(self) -> {{modelname}}: + r = httpx.post( + f"{config.api_client.url}/{{modelname.lower()}}/", + json=self.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") -class {{modelname}}Read({{modelname.lower()}}Base): +class {{modelname}}Read({{modelname}}Base): id: int @classmethod def get( cls, - config: Config, id: int, ) -> {{modelname}}: - with config.database.session as session: {{modelname.lower()}} = session.get({{modelname}}, id) if not {{modelname.lower()}}: @@ -47,53 +45,49 @@ class {{modelname}}Read({{modelname.lower()}}Base): @classmethod def list( self, - config: Config, where=None, offset=0, limit=None, + session: Session = get_session, ) -> {{modelname}}: + # with config.database.session as session: - with config.database.session as session: - statement = select({{modelname}}) - if where != "None": - from sqlmodel import text + statement = select({{modelname}}) + if where != "None" and where is not None: + from sqlmodel import text - statement = statement.where(text(where)) - statement = statement.offset(offset).limit(limit) - {{modelname.lower()}}es = session.exec(statement).all() + statement = statement.where(text(where)) + statement = statement.offset(offset).limit(limit) + {{modelname.lower()}}es = session.exec(statement).all() return {{modelname.lower()}}es class {{modelname}}Update(SQLModel): - # id is required to get the {{modelname.lower()}} + # id is required to update the {{modelname.lower()}} id: int # all other fields, must match the model, but with Optional default None - def update(self, config: Config) -> {{modelname}}: - with Session(config.database.engine) as session: - db_{{modelname.lower()}} = session.get({{modelname}}, self.id) - if not db_{{modelname.lower()}}: - raise HTTPException(status_code=404, detail="{{modelname}} not found") - {{modelname.lower()}}_data = self.dict(exclude_unset=True) - for key, value in {{modelname.lower()}}_data.items(): - if value is not None: - setattr(db_{{modelname.lower()}}, key, value) - session.add(db_{{modelname.lower()}}) - session.commit() - session.refresh(db_{{modelname.lower()}}) - return db_{{modelname.lower()}} + pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") + pet: Optional[Pet] = Relationship(back_populates="{{modelname.lower()}}") + + def update(self) -> {{modelname}}: + r = httpx.patch( + f"{config.api_client.url}/{{modelname.lower()}}/", + json=self.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") class {{modelname}}Delete(BaseModel): id: int - def delete(self, config: Config) -> {{modelname}}: - config.init() - with Session(config.database.engine) as session: - {{modelname.lower()}} = session.get({{modelname}}, self.id) - if not {{modelname.lower()}}: - raise HTTPException(status_code=404, detail="{{modelname}} not found") - session.delete({{modelname.lower()}}) - session.commit() - return {"ok": True} + def delete(self) -> {{modelname}}: + r = httpx.delete( + f"{config.api_client.url}/{{modelname.lower()}}/{self.id}", + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + return {"ok": True} + diff --git a/templates/model/tests/{{modelname.lower()}}.py.jinja b/templates/model/tests/{{modelname.lower()}}.py.jinja new file mode 100644 index 0000000..c7a4d1b --- /dev/null +++ b/templates/model/tests/{{modelname.lower()}}.py.jinja @@ -0,0 +1,207 @@ +from fastapi.testclient import TestClient +import pytest +from sqlalchemy import create_engine +from sqlmodel import SQLModel, Session, select +from sqlmodel.pool import StaticPool +from typer.testing import CliRunner + +from learn_sql_model.api.app import app +from learn_sql_model.config import get_config, get_session +from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}} + +runner = CliRunner() +client = TestClient(app) + + +@pytest.fixture(name="session") +def session_fixture(): + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + + +@pytest.fixture(name="client") +def client_fixture(session: Session): + def get_session_override(): + return session + + app.dependency_overrides[get_session] = get_session_override + + client = TestClient(app) + yield client + app.dependency_overrides.clear() + + +def test_api_post(client: TestClient): + {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) + {{modelname.lower()}}_dict = {{modelname.lower()}}.dict() + response = client.post("/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {{modelname.lower()}}_dict}) + response_{{modelname.lower()}} = {{modelname}}.parse_obj(response.json()) + + assert response.status_code == 200 + assert response_{{modelname.lower()}}.name == "Steelman" + assert response_{{modelname.lower()}}.age == 25 + + +def test_api_read_{{modelname.lower()}}es(session: Session, client: TestClient): + {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") + {{modelname.lower()}}_2 = {{modelname}}(name="Rusty-Man", secret_name="Tommy Sharp", age=48) + session.add({{modelname.lower()}}_1) + session.add({{modelname.lower()}}_2) + session.commit() + + response = client.get("/{{modelname.lower()}}s/") + data = response.json() + + assert response.status_code == 200 + + assert len(data) == 2 + assert data[0]["name"] == {{modelname.lower()}}_1.name + assert data[0]["secret_name"] == {{modelname.lower()}}_1.secret_name + assert data[0]["age"] == {{modelname.lower()}}_1.age + assert data[0]["id"] == {{modelname.lower()}}_1.id + assert data[1]["name"] == {{modelname.lower()}}_2.name + assert data[1]["secret_name"] == {{modelname.lower()}}_2.secret_name + assert data[1]["age"] == {{modelname.lower()}}_2.age + assert data[1]["id"] == {{modelname.lower()}}_2.id + + +def test_api_read_{{modelname.lower()}}(session: Session, client: TestClient): + {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") + session.add({{modelname.lower()}}_1) + session.commit() + + response = client.get(f"/{{modelname.lower()}}/999") + assert response.status_code == 404 + + +def test_api_read_{{modelname.lower()}}_404(session: Session, client: TestClient): + {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") + session.add({{modelname.lower()}}_1) + session.commit() + + response = client.get(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}") + data = response.json() + + assert response.status_code == 200 + assert data["name"] == {{modelname.lower()}}_1.name + assert data["secret_name"] == {{modelname.lower()}}_1.secret_name + assert data["age"] == {{modelname.lower()}}_1.age + assert data["id"] == {{modelname.lower()}}_1.id + + +def test_api_update_{{modelname.lower()}}(session: Session, client: TestClient): + {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") + session.add({{modelname.lower()}}_1) + session.commit() + + response = client.patch( + f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": {{modelname.lower()}}_1.id}} + ) + data = response.json() + + assert response.status_code == 200 + assert data["name"] == "Deadpuddle" + assert data["secret_name"] == "Dive Wilson" + assert data["age"] is None + assert data["id"] == {{modelname.lower()}}_1.id + + +def test_api_update_{{modelname.lower()}}_404(session: Session, client: TestClient): + {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") + session.add({{modelname.lower()}}_1) + session.commit() + + response = client.patch(f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": 999}}) + assert response.status_code == 404 + + +def test_delete_{{modelname.lower()}}(session: Session, client: TestClient): + {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") + session.add({{modelname.lower()}}_1) + session.commit() + + response = client.delete(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}") + + {{modelname.lower()}}_in_db = session.get({{modelname}}, {{modelname.lower()}}_1.id) + + assert response.status_code == 200 + + assert {{modelname.lower()}}_in_db is None + + +def test_delete_{{modelname.lower()}}_404(session: Session, client: TestClient): + {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") + session.add({{modelname.lower()}}_1) + session.commit() + + response = client.delete(f"/{{modelname.lower()}}/999") + assert response.status_code == 404 + + +def test_config_memory(mocker): + mocker.patch( + "learn_sql_model.config.Database.engine", + new_callable=lambda: create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ), + ) + config = get_config() + SQLModel.metadata.create_all(config.database.engine) + {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) + with config.database.session as session: + session.add({{modelname.lower()}}) + session.commit() + {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id) + {{modelname.lower()}}es = session.exec(select({{modelname}})).all() + assert {{modelname.lower()}}.name == "Steelman" + assert {{modelname.lower()}}.age == 25 + assert len({{modelname.lower()}}es) == 1 + + +def test_cli_get(mocker): + mocker.patch( + "learn_sql_model.config.Database.engine", + new_callable=lambda: create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ), + ) + + config = get_config() + SQLModel.metadata.create_all(config.database.engine) + + {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) + with config.database.session as session: + session.add({{modelname.lower()}}) + session.commit() + {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id) + result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "1"]) + assert result.exit_code == 0 + assert f"name='{{{modelname.lower()}}.name}'" in result.stdout + assert f"secret_name='{{{modelname.lower()}}.secret_name}'" in result.stdout + + +def test_cli_get_404(mocker): + mocker.patch( + "learn_sql_model.config.Database.engine", + new_callable=lambda: create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ), + ) + + config = get_config() + SQLModel.metadata.create_all(config.database.engine) + + {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) + with config.database.session as session: + session.add({{modelname.lower()}}) + session.commit() + {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id) + result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "999"]) + assert result.exception.status_code == 404 + assert result.exception.detail == "{{modelname}} not found" + diff --git a/tests/test_hero.py b/tests/test_hero.py index 7eb42af..261ac61 100644 --- a/tests/test_hero.py +++ b/tests/test_hero.py @@ -1,15 +1,12 @@ -import tempfile - from fastapi.testclient import TestClient import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlmodel import SQLModel +from sqlmodel import SQLModel, Session, select +from sqlmodel.pool import StaticPool from typer.testing import CliRunner from learn_sql_model.api.app import app -from learn_sql_model.cli.hero import hero_app -from learn_sql_model.config import Config, get_config +from learn_sql_model.config import get_config, get_session from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.models.hero import Hero @@ -17,142 +14,193 @@ runner = CliRunner() client = TestClient(app) -@pytest.fixture -def config() -> Config: - tmp_db = tempfile.NamedTemporaryFile(suffix=".db") - config = get_config({"database_url": f"sqlite:///{tmp_db.name}"}) - +@pytest.fixture(name="session") +def session_fixture(): engine = create_engine( - config.database_url, connect_args={"check_same_thread": False} + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool ) - TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - # breakpoint() - SQLModel.metadata.create_all(config.database.engine) - - # def override_get_db(): - # try: - # db = TestingSessionLocal() - # yield db - # finally: - # db.close() - def override_get_config(): - return config - - app.dependency_overrides[get_config] = override_get_config - - yield config - # tmp_db automatically deletes here + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session -def test_post_hero(config: Config) -> None: - hero = HeroFactory().build(name="Batman", age=50, id=1) - hero = hero.post(config=config) - db_hero = Hero().get(id=1, config=config) - assert db_hero.age == 50 - assert db_hero.name == "Batman" +@pytest.fixture(name="client") +def client_fixture(session: Session): + def get_session_override(): + return session + + app.dependency_overrides[get_session] = get_session_override + + client = TestClient(app) + yield client + app.dependency_overrides.clear() -def test_update_hero(config: Config) -> None: - hero = HeroFactory().build(name="Batman", age=50, id=1) - hero = hero.post(config=config) - db_hero = Hero().get(id=1, config=config) - db_hero.name = "Superbman" - hero = db_hero.post(config=config) - db_hero = Hero().get(id=1, config=config) - assert db_hero.age == 50 - assert db_hero.name == "Superbman" - - -def test_cli_get(config): - hero = HeroFactory().build(name="Steelman", age=25, id=99) - hero.post(config=config) - result = runner.invoke( - hero_app, - ["get", "--id", 99, "--database-url", config.database_url], - ) - assert result.exit_code == 0 - db_hero = Hero().get(id=99, config=config) - assert db_hero.age == 25 - assert db_hero.name == "Steelman" - - -def test_cli_create(config): - hero = HeroFactory().build(name="Steelman", age=25, id=99) - result = runner.invoke( - hero_app, - [ - "create", - *hero.flags(config=config), - "--database-url", - config.database_url, - ], - ) - assert result.exit_code == 0 - db_hero = Hero().get(id=99, config=config) - assert db_hero.age == 25 - assert db_hero.name == "Steelman" - - -def test_cli_populate(config): - result = runner.invoke( - hero_app, - [ - "populate", - "--n", - 10, - "--database-url", - config.database_url, - ], - ) - assert result.exit_code == 0 - db_hero = Hero().get(config=config) - assert len(db_hero) == 10 - - -def test_cli_populate_fails_prod(config): - result = runner.invoke( - hero_app, - ["populate", "--n", 10, "--database-url", config.database_url, "--env", "prod"], - ) - assert result.exit_code == 1 - assert result.output.strip() == "populate is not supported in production" - - -def test_api_read(config): - hero = HeroFactory().build(name="Steelman", age=25, id=99) - hero_id = hero.id - hero = hero.post(config=config) - response = client.get(f"/hero/{hero_id}") - assert response.status_code == 200 - reponse_hero = Hero.parse_obj(response.json()) - assert reponse_hero.id == hero_id - assert reponse_hero.name == "Steelman" - assert reponse_hero.age == 25 - - -def test_api_post(config): +def test_api_post(client: TestClient): hero = HeroFactory().build(name="Steelman", age=25) hero_dict = hero.dict() response = client.post("/hero/", json={"hero": hero_dict}) - assert response.status_code == 200 - response_hero = Hero.parse_obj(response.json()) - db_hero = Hero().get(id=response_hero.id, config=config) - assert db_hero.name == "Steelman" - assert db_hero.age == 25 - - -def test_api_read_all(config): - hero = HeroFactory().build(name="Mothman", age=25, id=99) - hero_id = hero.id - hero = hero.post(config=config) - response = client.get("/heros/") assert response.status_code == 200 - heros = response.json() - response_hero_json = [hero for hero in heros if hero["id"] == hero_id][0] - response_hero = Hero.parse_obj(response_hero_json) - assert response_hero.id == hero_id - assert response_hero.name == "Mothman" + assert response_hero.name == "Steelman" assert response_hero.age == 25 + + +def test_api_read_heroes(session: Session, client: TestClient): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48) + session.add(hero_1) + session.add(hero_2) + session.commit() + + response = client.get("/heros/") + data = response.json() + + assert response.status_code == 200 + + assert len(data) == 2 + assert data[0]["name"] == hero_1.name + assert data[0]["secret_name"] == hero_1.secret_name + assert data[0]["age"] == hero_1.age + assert data[0]["id"] == hero_1.id + assert data[1]["name"] == hero_2.name + assert data[1]["secret_name"] == hero_2.secret_name + assert data[1]["age"] == hero_2.age + assert data[1]["id"] == hero_2.id + + +def test_api_read_hero(session: Session, client: TestClient): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + session.add(hero_1) + session.commit() + + response = client.get(f"/hero/999") + assert response.status_code == 404 + + +def test_api_read_hero_404(session: Session, client: TestClient): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + session.add(hero_1) + session.commit() + + response = client.get(f"/hero/{hero_1.id}") + data = response.json() + + assert response.status_code == 200 + assert data["name"] == hero_1.name + assert data["secret_name"] == hero_1.secret_name + assert data["age"] == hero_1.age + assert data["id"] == hero_1.id + + +def test_api_update_hero(session: Session, client: TestClient): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + session.add(hero_1) + session.commit() + + response = client.patch( + f"/hero/", json={"hero": {"name": "Deadpuddle", "id": hero_1.id}} + ) + data = response.json() + + assert response.status_code == 200 + assert data["name"] == "Deadpuddle" + assert data["secret_name"] == "Dive Wilson" + assert data["age"] is None + assert data["id"] == hero_1.id + + +def test_api_update_hero_404(session: Session, client: TestClient): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + session.add(hero_1) + session.commit() + + response = client.patch(f"/hero/", json={"hero": {"name": "Deadpuddle", "id": 999}}) + assert response.status_code == 404 + + +def test_delete_hero(session: Session, client: TestClient): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + session.add(hero_1) + session.commit() + + response = client.delete(f"/hero/{hero_1.id}") + + hero_in_db = session.get(Hero, hero_1.id) + + assert response.status_code == 200 + + assert hero_in_db is None + + +def test_delete_hero_404(session: Session, client: TestClient): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + session.add(hero_1) + session.commit() + + response = client.delete(f"/hero/999") + assert response.status_code == 404 + + +def test_config_memory(mocker): + mocker.patch( + "learn_sql_model.config.Database.engine", + new_callable=lambda: create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ), + ) + config = get_config() + SQLModel.metadata.create_all(config.database.engine) + hero = HeroFactory().build(name="Steelman", age=25) + with config.database.session as session: + session.add(hero) + session.commit() + hero = session.get(Hero, hero.id) + heroes = session.exec(select(Hero)).all() + assert hero.name == "Steelman" + assert hero.age == 25 + assert len(heroes) == 1 + + +def test_cli_get(mocker): + mocker.patch( + "learn_sql_model.config.Database.engine", + new_callable=lambda: create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ), + ) + + config = get_config() + SQLModel.metadata.create_all(config.database.engine) + + hero = HeroFactory().build(name="Steelman", age=25) + with config.database.session as session: + session.add(hero) + session.commit() + hero = session.get(Hero, hero.id) + result = runner.invoke(hero_app, ["get", "--hero-id", "1"]) + assert result.exit_code == 0 + assert f"name='{hero.name}'" in result.stdout + assert f"secret_name='{hero.secret_name}'" in result.stdout + + +def test_cli_get_404(mocker): + mocker.patch( + "learn_sql_model.config.Database.engine", + new_callable=lambda: create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ), + ) + + config = get_config() + SQLModel.metadata.create_all(config.database.engine) + + hero = HeroFactory().build(name="Steelman", age=25) + with config.database.session as session: + session.add(hero) + session.commit() + hero = session.get(Hero, hero.id) + result = runner.invoke(hero_app, ["get", "--hero-id", "999"]) + assert result.exception.status_code == 404 + assert result.exception.detail == "Hero not found"