This commit is contained in:
Waylon Walker 2023-06-22 16:27:52 -05:00
parent e86e432102
commit 28eda9e899
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
5 changed files with 121 additions and 184 deletions

View file

@ -1,86 +1,89 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import SQLModel, Session
from sqlmodel import Session, select
from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import get_config, get_session
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}, {{modelname}}Create, {{modelname}}Read, {{modelname}}Update
from learn_sql_model.config import get_session
from learn_sql_model.models.{{ modelname }} import {{ modelname }}, {{ modelname }}Create, {{ modelname }}Read, {{ modelname }}Update, {{ modelname }}s
{{modelname.lower()}}_router = APIRouter()
{{ modelname }}_router = APIRouter()
@{{modelname.lower()}}_router.on_event("startup")
@{{ modelname }}_router.on_event("startup")
def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine)
# SQLModel.metadata.create_all(get_config().database.engine)
...
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def get_{{modelname.lower()}}(
@{{ modelname }}_router.get("/{{ modelname }}/{{{ modelname }}_id}")
def get_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
) -> {{modelname}}Read:
"get one {{modelname.lower()}}"
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
return {{modelname.lower()}}
{{ modelname }}_id: int,
) -> {{ modelname }}Read:
"get one {{ modelname }}"
{{ modelname }} = session.get({{ modelname }}, {{ modelname }}_id)
if not {{ modelname }}:
raise HTTPException(status_code=404, detail="{{ modelname }} not found")
return {{ modelname }}
@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/")
async def post_{{modelname.lower()}}(
@{{ modelname }}_router.post("/{{ modelname }}/")
def post_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Create,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = {{modelname}}.from_orm({{modelname.lower()}})
session.add(db_{{modelname.lower()}})
{{ modelname }}: {{ modelname }}Create,
) -> {{ modelname }}Read:
"create a {{ modelname }}"
db_{{ modelname }} = {{ modelname }}.from_orm({{ modelname }})
session.add(db_{{ modelname }})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
session.refresh(db_{{ modelname }})
await manager.broadcast({{{ modelname }}.json()}, id=1)
return db_{{ modelname }}
@{{modelname.lower()}}_router.patch("/{{modelname.lower()}}/")
async def patch_{{modelname.lower()}}(
@{{ modelname }}_router.patch("/{{ modelname }}/")
def patch_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}: {{modelname}}Update,
) -> {{modelname}}Read:
"read all the {{modelname.lower()}}s"
db_{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
if not db_{{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
for key, value in {{modelname.lower()}}.dict(exclude_unset=True).items():
setattr(db_{{modelname.lower()}}, key, value)
session.add(db_{{modelname.lower()}})
{{ modelname }}: {{ modelname }}Update,
) -> {{ modelname }}Read:
"update a {{ modelname }}"
db_{{ modelname }} = session.get({{ modelname }}, {{ modelname }}.id)
if not db_{{ modelname }}:
raise HTTPException(status_code=404, detail="{{ modelname }} not found")
for key, value in {{ modelname }}.dict(exclude_unset=True).items():
setattr(db_{{ modelname }}, key, value)
session.add(db_{{ modelname }})
session.commit()
session.refresh(db_{{modelname.lower()}})
await manager.broadcast({{{modelname.lower()}}.json()}, id=1)
return db_{{modelname.lower()}}
session.refresh(db_{{ modelname }})
await manager.broadcast({{{ modelname }}.json()}, id=1)
return db_{{ modelname }}
@{{modelname.lower()}}_router.delete("/{{modelname.lower()}}/{{{modelname.lower()}}_id}")
async def delete_{{modelname.lower()}}(
@{{ modelname }}_router.delete("/{{ modelname }}/{{{ modelname }}_id}")
def delete_{{ modelname }}(
*,
session: Session = Depends(get_session),
{{modelname.lower()}}_id: int,
{{ modelname }}_id: int,
):
"read all the {{modelname.lower()}}s"
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}_id)
if not {{modelname.lower()}}:
raise HTTPException(status_code=404, detail="{{modelname}} not found")
session.delete({{modelname.lower()}})
"delete a {{ modelname }}"
{{ modelname }} = session.get({{ modelname }}, {{ modelname }}_id)
if not {{ modelname }}:
raise HTTPException(status_code=404, detail="{{ modelname }} not found")
session.delete({{ modelname }})
session.commit()
await manager.broadcast(f"deleted {{modelname.lower()}} {{{modelname.lower()}}_id}", id=1)
await manager.broadcast(f"deleted {{ modelname }} {{{ modelname }}_id}", id=1)
return {"ok": True}
@{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/")
async def get_{{modelname.lower()}}s(
@{{ modelname }}_router.get("/{{ modelname }}s/")
def get_{{ modelname }}s(
*,
session: Session = Depends(get_session),
) -> list[{{modelname}}]:
"get all {{modelname.lower()}}s"
return {{modelname}}Read.list(session=session)
) -> {{ modelname }}s:
"get all {{ modelname }}s"
statement = select({{ modelname }})
{{ modelname }}s = session.exec(statement).all()
return {{ modelname }}s(__root__={{ modelname }}s)

View file

@ -1,14 +1,12 @@
from faker import Faker
from polyfactory.factories.pydantic_factory import ModelFactory
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
from learn_sql_model.factories.pet import PetFactory
from learn_sql_model.models.{{ modelname }} import {{ modelname }}
from learn_sql_model.models.pet import Pet
class {{modelname}}Factory(ModelFactory[{{modelname.lower()}}]):
__model__ = {{modelname}}
class {{ modelname }}Factory(ModelFactory[{{ modelname }}]):
__model__ = {{ modelname }}
__faker__ = Faker(locale="en_US")
__set_as_default_factory_for_type__ = True
id = None
__random_seed__ = 10

View file

@ -1,134 +1,81 @@
from typing import Optional
from typing import Dict, Optional
from fastapi import HTTPException
import httpx
from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select
from sqlmodel import Field, SQLModel
from learn_sql_model.config import config
from learn_sql_model.models.pet import Pet
class {{ model }}Base(SQLModel, table=False):
name: str
secret_name: str
x: int
y: int
size: int
age: Optional[int] = None
shoe_size: Optional[int] = None
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
class {{ modelname }}Base(SQLModel, table=False):
# put model attributes here
class {{ model }}({{ model }}Base, table=True):
class {{ modelname }}({{ modelname }}Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class {{ model }}Create({{ model }}Base):
class {{ modelname }}Create({{ modelname }}Base):
...
def post(self) -> {{ model }}:
def post(self) -> {{ modelname }}:
r = httpx.post(
f"{config.api_client.url}/{{ model.lower() }}/",
f"{config.api_client.url}/{{ modelname }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model }}.parse_obj(r.json())
return {{ modelname }}.parse_obj(r.json())
class {{ model }}Read({{ model }}Base):
class {{ modelname }}Read({{ modelname }}Base):
id: int
@classmethod
def get(
cls,
id: int,
) -> {{ model }}:
with config.database.session as session:
{{ model.lower() }} = session.get({{ model }}, id)
if not {{ model.lower() }}:
raise HTTPException(status_code=404, detail="{{ model }} not found")
return {{ model.lower() }}
) -> {{ modelname }}:
r = httpx.get(f"{config.api_client.url}/{{ modelname }}/{id}")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ modelname }}Read.parse_obj(r.json())
class {{ model }}s(BaseModel):
{{ model.lower() }}s: list[{{ model }}]
class {{ modelname }}s(BaseModel):
__root__: list[{{ modelname }}]
@classmethod
def list(
self,
where=None,
offset=0,
limit=None,
session: Session = None,
) -> {{ model }}:
# with config.database.session as session:
def get_{{ model.lower() }}s(session, where, offset, limit):
statement = select({{ model }})
if where != "None" and where is not None:
from sqlmodel import text
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
{{ model.lower() }}s = session.exec(statement).all()
return {{ model }}s({{ model.lower() }}s={{ model.lower() }}s)
if session is None:
r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}s/")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model }}s.parse_obj(r.json())
return get_{{ model.lower() }}s(session, where, offset, limit)
) -> {{ modelname }}:
r = httpx.get(f"{config.api_client.url}/{{ modelname }}s/")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ modelname }}s.parse_obj({"__root__": r.json()})
class {{ model }}Update(SQLModel):
# id is required to update the {{ model.lower() }}
class {{ modelname }}Update(SQLModel):
# id is required to update the {{ modelname }}
id: int
# all other fields, must match the model, but with Optional default None
name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
shoe_size: Optional[int] = None
x: int
y: int
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
def update(self, session: Session = None) -> {{ model }}:
if session is not None:
db_{{ model.lower() }} = session.get({{ model }}, self.id)
if not db_{{ model.lower() }}:
raise HTTPException(status_code=404, detail="{{ model }} not found")
for key, value in self.dict(exclude_unset=True).items():
setattr(db_{{ model.lower() }}, key, value)
session.add(db_{{ model.lower() }})
session.commit()
session.refresh(db_{{ model.lower() }})
return db_{{ model.lower() }}
def update(self) -> {{ modelname }}:
r = httpx.patch(
f"{config.api_client.url}/{{ model.lower() }}/",
f"{config.api_client.url}/{{ modelname }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
class {{ model }}Delete(BaseModel):
class {{ modelname }}Delete(BaseModel):
id: int
def delete(self) -> {{ model }}:
@classmethod
def delete(self, id: int) -> Dict[str, bool]:
r = httpx.delete(
f"{config.api_client.url}/{{ model.lower() }}/{self.id}",
f"{config.api_client.url}/{{ modelname }}/{id}",
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")