This commit is contained in:
Waylon Walker 2023-06-22 16:27:52 -05:00
parent e86e432102
commit 28eda9e899
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
5 changed files with 121 additions and 184 deletions

View file

@ -1,7 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import get_session
from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, HeroUpdate, Heros
@ -38,7 +37,7 @@ def post_hero(
session.add(db_hero)
session.commit()
session.refresh(db_hero)
await manager.broadcast({hero.json()}, id=1)
# await manager.broadcast({hero.json()}, id=1)
return db_hero
@ -57,7 +56,7 @@ def patch_hero(
session.add(db_hero)
session.commit()
session.refresh(db_hero)
await manager.broadcast({hero.json()}, id=1)
# await manager.broadcast({hero.json()}, id=1)
return db_hero
@ -73,7 +72,7 @@ def delete_hero(
raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
await manager.broadcast(f"deleted hero {hero_id}", id=1)
# await manager.broadcast(f"deleted hero {hero_id}", id=1)
return {"ok": True}

View file

@ -1,86 +1,89 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import SQLModel, Session
from sqlmodel import Session, select
from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import get_config, get_session
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}, {{modelname}}Create, {{modelname}}Read, {{modelname}}Update
from learn_sql_model.config import get_session
from learn_sql_model.models.{{ modelname }} import {{ modelname }}, {{ modelname }}Create, {{ modelname }}Read, {{ modelname }}Update, {{ modelname }}s
{{modelname.lower()}}_router = APIRouter()
{{ modelname }}_router = APIRouter()
@{{modelname.lower()}}_router.on_event("startup")
@{{ modelname }}_router.on_event("startup")
def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine)
# SQLModel.metadata.create_all(get_config().database.engine)
...
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def get_{{modelname.lower()}}(
@{{ modelname }}_router.get("/{{ modelname }}/{{{ modelname }}_id}")
def get_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
) -> {{modelname}}Read:
"get one {{modelname.lower()}}"
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
return {{modelname.lower()}}
{{ modelname }}_id: int,
) -> {{ modelname }}Read:
"get one {{ modelname }}"
{{ modelname }} = session.get({{ modelname }}, {{ modelname }}_id)
if not {{ modelname }}:
raise HTTPException(status_code=404, detail="{{ modelname }} not found")
return {{ modelname }}
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
async def post_{{modelname.lower()}}(
@{{ modelname }}_router.post("/{{ modelname }}/")
def post_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Create,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = {{modelname}}.from_orm({{modelname.lower()}})
session.add(db_{{modelname.lower()}})
{{ modelname }}: {{ modelname }}Create,
) -> {{ modelname }}Read:
"create a {{ modelname }}"
db_{{ modelname }} = {{ modelname }}.from_orm({{ modelname }})
session.add(db_{{ modelname }})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
session.refresh(db_{{ modelname }})
await manager.broadcast({{{ modelname }}.json()}, id=1)
return db_{{ modelname }}
@{{modelname.lower()}}_router.patch("/{{modelname.lower()}}/")
async def patch_{{modelname.lower()}}(
@{{ modelname }}_router.patch("/{{ modelname }}/")
def patch_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Update,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
if not db_{{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
for key, value in {{modelname.lower()}}.dict(exclude_unset=True).items():
setattr(db_{{modelname.lower()}}, key, value)
session.add(db_{{modelname.lower()}})
{{ modelname }}: {{ modelname }}Update,
) -> {{ modelname }}Read:
"update a {{ modelname }}"
db_{{ modelname }} = session.get({{ modelname }}, {{ modelname }}.id)
if not db_{{ modelname }}:
raise HTTPException(status_code=404, detail="{{ modelname }} not found")
for key, value in {{ modelname }}.dict(exclude_unset=True).items():
setattr(db_{{ modelname }}, key, value)
session.add(db_{{ modelname }})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
session.refresh(db_{{ modelname }})
await manager.broadcast({{{ modelname }}.json()}, id=1)
return db_{{ modelname }}
@{{modelname.lower()}}_router.delete("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def delete_{{modelname.lower()}}(
@{{ modelname }}_router.delete("/{{ modelname }}/{{{ modelname }}_id}")
def delete_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
{{ modelname }}_id: int,
):
"read all the {{modelname.lower()}}s"
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
session.delete({{modelname.lower()}})
"delete a {{ modelname }}"
{{ modelname }} = session.get({{ modelname }}, {{ modelname }}_id)
if not {{ modelname }}:
raise HTTPException(status_code=404, detail="{{ modelname }} not found")
session.delete({{ modelname }})
session.commit()
await manager.broadcast(f"deleted {{modelname.lower()}} {{{modelname.lower()}}_id}", id=1)
await manager.broadcast(f"deleted {{ modelname }} {{{ modelname }}_id}", id=1)
return {"ok": True}
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/")
async def get_{{modelname.lower()}}s(
@{{ modelname }}_router.get("/{{ modelname }}s/")
def get_{{ modelname }}s(
*,
session: Session = Depends(get_session),
) -> list[{{modelname}}]:
"get all {{modelname.lower()}}s"
return {{modelname}}Read.list(session=session)
) -> {{ modelname }}s:
"get all {{ modelname }}s"
statement = select({{ modelname }})
{{ modelname }}s = session.exec(statement).all()
return {{ modelname }}s(__root__={{ modelname }}s)

View file

@ -1,14 +1,12 @@
from faker import Faker
from polyfactory.factories.pydantic_factory import ModelFactory
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.{{ modelname }} import {{ modelname }}
from learn_sql_model.models.pet import Pet
class {{modelname}}Factory(ModelFactory[{{modelname.lower()}}]):
__model__ = {{modelname}}
class {{ modelname }}Factory(ModelFactory[{{ modelname }}]):
__model__ = {{ modelname }}
__faker__ = Faker(locale="en_US")
__set_as_default_factory_for_type__ = True
id = None
__random_seed__ = 10

View file

@ -1,134 +1,81 @@
from typing import Optional
from typing import Dict, Optional
from fastapi import HTTPException
import httpx
from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select
from sqlmodel import Field, SQLModel
from learn_sql_model.config import config
from learn_sql_model.models.pet import Pet
class {{ model }}Base(SQLModel, table=False):
name: str
secret_name: str
x: int
y: int
size: int
age: Optional[int] = None
shoe_size: Optional[int] = None
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
class {{ modelname }}Base(SQLModel, table=False):
# put model attributes here
class {{ model }}({{ model }}Base, table=True):
class {{ modelname }}({{ modelname }}Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class {{ model }}Create({{ model }}Base):
class {{ modelname }}Create({{ modelname }}Base):
...
def post(self) -> {{ model }}:
def post(self) -> {{ modelname }}:
r = httpx.post(
f"{config.api_client.url}/{{ model.lower() }}/",
f"{config.api_client.url}/{{ modelname }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model }}.parse_obj(r.json())
return {{ modelname }}.parse_obj(r.json())
class {{ model }}Read({{ model }}Base):
class {{ modelname }}Read({{ modelname }}Base):
id: int
@classmethod
def get(
cls,
id: int,
) -> {{ model }}:
with config.database.session as session:
{{ model.lower() }} = session.get({{ model }}, id)
if not {{ model.lower() }}:
raise HTTPException(status_code=404, detail="{{ model }} not found")
return {{ model.lower() }}
) -> {{ modelname }}:
r = httpx.get(f"{config.api_client.url}/{{ modelname }}/{id}")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ modelname }}Read.parse_obj(r.json())
class {{ model }}s(BaseModel):
{{ model.lower() }}s: list[{{ model }}]
class {{ modelname }}s(BaseModel):
__root__: list[{{ modelname }}]
@classmethod
def list(
self,
where=None,
offset=0,
limit=None,
session: Session = None,
) -> {{ model }}:
# with config.database.session as session:
def get_{{ model.lower() }}s(session, where, offset, limit):
statement = select({{ model }})
if where != "None" and where is not None:
from sqlmodel import text
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
{{ model.lower() }}s = session.exec(statement).all()
return {{ model }}s({{ model.lower() }}s={{ model.lower() }}s)
if session is None:
r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}s/")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model }}s.parse_obj(r.json())
return get_{{ model.lower() }}s(session, where, offset, limit)
) -> {{ modelname }}:
r = httpx.get(f"{config.api_client.url}/{{ modelname }}s/")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ modelname }}s.parse_obj({"__root__": r.json()})
class {{ model }}Update(SQLModel):
# id is required to update the {{ model.lower() }}
class {{ modelname }}Update(SQLModel):
# id is required to update the {{ modelname }}
id: int
# all other fields, must match the model, but with Optional default None
name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
shoe_size: Optional[int] = None
x: int
y: int
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
def update(self, session: Session = None) -> {{ model }}:
if session is not None:
db_{{ model.lower() }} = session.get({{ model }}, self.id)
if not db_{{ model.lower() }}:
raise HTTPException(status_code=404, detail="{{ model }} not found")
for key, value in self.dict(exclude_unset=True).items():
setattr(db_{{ model.lower() }}, key, value)
session.add(db_{{ model.lower() }})
session.commit()
session.refresh(db_{{ model.lower() }})
return db_{{ model.lower() }}
def update(self) -> {{ modelname }}:
r = httpx.patch(
f"{config.api_client.url}/{{ model.lower() }}/",
f"{config.api_client.url}/{{ modelname }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
class {{ model }}Delete(BaseModel):
class {{ modelname }}Delete(BaseModel):
id: int
def delete(self) -> {{ model }}:
@classmethod
def delete(self, id: int) -> Dict[str, bool]:
r = httpx.delete(
f"{config.api_client.url}/{{ model.lower() }}/{self.id}",
f"{config.api_client.url}/{{ modelname }}/{id}",
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")

View file

@ -39,19 +39,18 @@ def client_fixture(session: Session):
def test_api_post(client: TestClient):
hero = HeroFactory().build(name="Steelman", age=25)
hero = HeroFactory().build()
hero_dict = hero.dict()
response = client.post("/hero/", json=hero_dict)
response_hero = Hero.parse_obj(response.json())
assert response.status_code == 200
assert response_hero.name == "Steelman"
assert response_hero.age == 25
assert response_hero.name == hero.name
def test_api_read_heroes(session: Session, client: TestClient):
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_2 = HeroFactory().build(name="Rusty-Man", age=48)
hero_1 = HeroFactory().build()
hero_2 = HeroFactory().build()
session.add(hero_1)
session.add(hero_2)
session.commit()
@ -64,16 +63,14 @@ def test_api_read_heroes(session: Session, client: TestClient):
assert len(data) == 2
assert data[0]["name"] == hero_1.name
assert data[0]["secret_name"] == hero_1.secret_name
assert data[0]["age"] == hero_1.age
assert data[0]["id"] == hero_1.id
assert data[1]["name"] == hero_2.name
assert data[1]["secret_name"] == hero_2.secret_name
assert data[1]["age"] == hero_2.age
assert data[1]["id"] == hero_2.id
def test_api_read_hero(session: Session, client: TestClient):
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_1 = HeroFactory().build()
session.add(hero_1)
session.commit()
@ -83,12 +80,11 @@ def test_api_read_hero(session: Session, client: TestClient):
assert response.status_code == 200
assert data["name"] == hero_1.name
assert data["secret_name"] == hero_1.secret_name
assert data["age"] == hero_1.age
assert data["id"] == hero_1.id
def test_api_read_hero_404(session: Session, client: TestClient):
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_1 = HeroFactory().build()
session.add(hero_1)
session.commit()
@ -97,7 +93,7 @@ def test_api_read_hero_404(session: Session, client: TestClient):
def test_api_update_hero(session: Session, client: TestClient):
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_1 = HeroFactory().build()
session.add(hero_1)
session.commit()
@ -107,12 +103,11 @@ def test_api_update_hero(session: Session, client: TestClient):
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == hero_1.secret_name
assert data["age"] is hero_1.age
assert data["id"] == hero_1.id
def test_api_update_hero_404(session: Session, client: TestClient):
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_1 = HeroFactory().build()
session.add(hero_1)
session.commit()
@ -121,7 +116,7 @@ def test_api_update_hero_404(session: Session, client: TestClient):
def test_delete_hero(session: Session, client: TestClient):
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_1 = HeroFactory().build()
session.add(hero_1)
session.commit()
@ -135,7 +130,7 @@ def test_delete_hero(session: Session, client: TestClient):
def test_delete_hero_404(session: Session, client: TestClient):
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_1 = HeroFactory().build()
session.add(hero_1)
session.commit()
@ -152,19 +147,18 @@ def test_config_memory(mocker):
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
hero = HeroFactory().build()
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
heroes = session.exec(select(Hero)).all()
assert hero.name == "Steelman"
assert hero.age == 25
assert len(heroes) == 1
db_hero = session.get(Hero, hero.id)
db_heroes = session.exec(select(Hero)).all()
assert db_hero.name == hero.name
assert len(db_heroes) == 1
def test_cli_get(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
hero = HeroRead(**hero.dict(exclude_none=True))
httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock()
@ -181,7 +175,7 @@ def test_cli_get(mocker):
def test_cli_get_404(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
hero = HeroRead(**hero.dict(exclude_none=True))
httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock()
@ -198,12 +192,8 @@ def test_cli_get_404(mocker):
def test_cli_list(mocker):
hero_1 = HeroRead(
**HeroFactory().build(name="Steelman", age=25, id=1).dict(exclude_none=True)
)
hero_2 = HeroRead(
**HeroFactory().build(name="Hunk", age=52, id=2).dict(exclude_none=True)
)
hero_1 = HeroRead(**HeroFactory().build().dict(exclude_none=True))
hero_2 = HeroRead(**HeroFactory().build().dict(exclude_none=True))
heros = Heros(__root__=[hero_1, hero_2])
httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock()
@ -219,7 +209,7 @@ def test_cli_list(mocker):
def test_model_post(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
hero_create = HeroCreate(**hero.dict())
httpx = mocker.patch.object(hero_models, "httpx")
@ -234,7 +224,7 @@ def test_model_post(mocker):
def test_model_post_500(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
hero_create = HeroCreate(**hero.dict())
httpx = mocker.patch.object(hero_models, "httpx")
@ -249,7 +239,7 @@ def test_model_post_500(mocker):
def test_model_read_hero(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock()
@ -265,7 +255,7 @@ def test_model_read_hero(mocker):
def test_model_read_hero_404(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock()
httpx.get.return_value.status_code = 404
@ -280,7 +270,7 @@ def test_model_read_hero_404(mocker):
def test_model_delete_hero(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()
@ -295,7 +285,7 @@ def test_model_delete_hero(mocker):
def test_model_delete_hero_404(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()
@ -311,7 +301,7 @@ def test_model_delete_hero_404(mocker):
def test_cli_delete_hero(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()
@ -327,7 +317,7 @@ def test_cli_delete_hero(mocker):
def test_cli_delete_hero_404(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
hero = HeroFactory().build()
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()