293 lines
8.6 KiB
Python
293 lines
8.6 KiB
Python
from typing import List, Optional
|
|
|
|
import httpx
|
|
import typer
|
|
import uvicorn
|
|
from fastapi import APIRouter, Depends, FastAPI
|
|
from iterfzf import iterfzf
|
|
from pydantic import BaseModel
|
|
from pydantic_core._pydantic_core import PydanticUndefinedType
|
|
from rich.console import Console
|
|
from sqlalchemy import func
|
|
from sqlmodel import Session, SQLModel, select
|
|
|
|
from sqlmodel_base.database import get_engine, get_session
|
|
|
|
console = Console()
|
|
|
|
|
|
class PagedResult(BaseModel):
|
|
items: list
|
|
total: int
|
|
page: int
|
|
next_page: Optional[int]
|
|
page_size: int
|
|
|
|
|
|
class Base(SQLModel):
|
|
@classmethod
|
|
@property
|
|
def engine(self):
|
|
engine = get_engine()
|
|
return engine
|
|
|
|
def create(self, session: Optional[Session] = Depends(get_session)):
|
|
if isinstance(session, Session):
|
|
validated = self.model_validate(self)
|
|
session.add(self.sqlmodel_update(validated))
|
|
session.commit()
|
|
session.refresh(self)
|
|
return self
|
|
else:
|
|
response = httpx.post(
|
|
"http://localhost:8000/create/", json=self.model_dump_json()
|
|
)
|
|
breakpoint()
|
|
return response
|
|
|
|
@classmethod
|
|
def interactive_create(cls, id: Optional[int] = None):
|
|
data = {}
|
|
for name, field in cls.__fields__.items():
|
|
default = field.default
|
|
if (
|
|
default is None or isinstance(default, PydanticUndefinedType)
|
|
) and not field.is_required():
|
|
default = "None"
|
|
if (isinstance(default, PydanticUndefinedType)) and field.is_required():
|
|
default = None
|
|
value = typer.prompt(f"{name}: ", default=default)
|
|
if value and value != "" and value != "None":
|
|
data[name] = value
|
|
item = cls(**data).create()
|
|
console.print(item)
|
|
|
|
@classmethod
|
|
def pick(cls):
|
|
all = cls.all()
|
|
item = iterfzf([item.model_dump_json() for item in all])
|
|
if not item:
|
|
console.print("No item selected")
|
|
return
|
|
return cls.get(cls.parse_raw(item).id)
|
|
|
|
@classmethod
|
|
def get(cls, id: int):
|
|
with Session(cls.engine) as session:
|
|
if hasattr(cls, "__table_class__"):
|
|
return session.get(cls.__table_class__, id)
|
|
return cls.model_validate(session.get(cls, id))
|
|
|
|
@classmethod
|
|
def get_or_pick(cls, id: Optional[int] = None):
|
|
if id is None:
|
|
return cls.pick()
|
|
return cls.get(id=id)
|
|
|
|
@classmethod
|
|
def all(cls) -> List:
|
|
with Session(cls.engine) as session:
|
|
if hasattr(cls, "__table_class__"):
|
|
return session.exec(select(cls.__table_class__)).all()
|
|
return [cls.model_validate(i) for i in session.exec(select(cls)).all()]
|
|
|
|
@classmethod
|
|
def count(cls) -> int:
|
|
with Session(cls.engine) as session:
|
|
if hasattr(cls, "__table_class__"):
|
|
return session.exec(func.count(cls.__table_class__.id)).scalar()
|
|
return session.exec(func.count(cls.id)).scalar()
|
|
|
|
@classmethod
|
|
def first(cls):
|
|
with Session(cls.engine) as session:
|
|
if hasattr(cls, "__table_class__"):
|
|
table = cls.__table_class__
|
|
else:
|
|
table = cls
|
|
return cls.model_validate(
|
|
session.exec(select(table).order_by(table.id.asc()).limit(1)).first()
|
|
)
|
|
|
|
@classmethod
|
|
def last(cls):
|
|
with Session(cls.engine) as session:
|
|
if hasattr(cls, "__table_class__"):
|
|
table = cls.__table_class__
|
|
else:
|
|
table = cls
|
|
return session.exec(
|
|
select(table).order_by(table.id.desc()).limit(1)
|
|
).first()
|
|
|
|
@classmethod
|
|
def get_page(
|
|
cls,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
all: bool = False,
|
|
reverse: bool = False,
|
|
):
|
|
with Session(cls.engine) as session:
|
|
if hasattr(cls, "__table_class__"):
|
|
table = cls.__table_class__
|
|
else:
|
|
table = cls
|
|
if all:
|
|
items = session.exec(select(table)).all()
|
|
page_size = len(items)
|
|
else:
|
|
if reverse:
|
|
items = session.exec(
|
|
select(table)
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
.order_by(table.id.desc())
|
|
).all()
|
|
else:
|
|
items = session.exec(
|
|
select(table)
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
.order_by(table.id)
|
|
).all()
|
|
|
|
total = table.count()
|
|
# determine if there is a next page
|
|
if page * page_size < total:
|
|
next_page = page + 1
|
|
else:
|
|
next_page = None
|
|
|
|
return PagedResult(
|
|
items=items,
|
|
total=total,
|
|
page=page,
|
|
page_size=page_size,
|
|
next_page=next_page,
|
|
)
|
|
|
|
def delete(self):
|
|
with Session(self.engine) as session:
|
|
session.delete(self)
|
|
session.commit()
|
|
return self
|
|
|
|
def update(self):
|
|
with Session(self.engine) as session:
|
|
validated = self.model_validate(self)
|
|
session.add(self.sqlmodel_update(validated))
|
|
session.commit()
|
|
session.refresh(self)
|
|
return self
|
|
|
|
@classmethod
|
|
def interactive_update(cls, id: Optional[int] = None):
|
|
item = cls.get_or_pick(id=id)
|
|
if not item:
|
|
console.print("No item selected")
|
|
return
|
|
for field in item.__fields__.keys():
|
|
if field == "id":
|
|
continue
|
|
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
|
|
if (
|
|
value
|
|
and value != ""
|
|
and value != "None"
|
|
and value != getattr(item, field)
|
|
):
|
|
setattr(item, field, value)
|
|
item.update()
|
|
console.print(item)
|
|
|
|
@classmethod
|
|
def api(cls):
|
|
api = FastAPI(
|
|
title="FastAPI",
|
|
version="0.1.0",
|
|
# docs_url=None,
|
|
# redoc_url=None,
|
|
# openapi_url=None,
|
|
# openapi_tags=tags_metadata,
|
|
# dependencies=[Depends(set_user), Depends(set_prefers)],
|
|
)
|
|
|
|
api.include_router(cls.router())
|
|
|
|
return api
|
|
|
|
@classmethod
|
|
def router(cls):
|
|
router = APIRouter()
|
|
# router.add_api_route("/get/", cls.get, methods=["GET"])
|
|
# router.add_api_route("/list/", cls.all, methods=["GET"])
|
|
# router.add_api_route("/create/", cls.create, methods=["POST"])
|
|
# router.add_api_route("/update/", cls.interactive_update, methods=["PUT"])
|
|
|
|
@router.get("/")
|
|
def get(id: int) -> cls:
|
|
return cls.get(id=id)
|
|
|
|
@router.get("/list", include_in_schema=False)
|
|
@router.get("/list/")
|
|
def get_page(
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
all: bool = False,
|
|
reverse: bool = False,
|
|
) -> PagedResult:
|
|
return cls.get_page()
|
|
|
|
@router.post("/create")
|
|
def create(cls: cls) -> cls:
|
|
return cls.create()
|
|
|
|
@router.put("/update")
|
|
def update() -> cls:
|
|
return cls.update()
|
|
|
|
return router
|
|
|
|
@classmethod
|
|
@property
|
|
def cli(cls):
|
|
app = typer.Typer()
|
|
|
|
@app.command()
|
|
def get(id: int = typer.Option(None, help="Hero ID")):
|
|
console.print(cls.get_or_pick(id=id))
|
|
|
|
@app.command()
|
|
def create():
|
|
console.print(cls.interactive_create())
|
|
|
|
@app.command()
|
|
def list(
|
|
page: int = typer.Option(1, help="Page number"),
|
|
page_size: int = typer.Option(20, help="Page size"),
|
|
all: bool = typer.Option(False, help="Show all heroes"),
|
|
reverse: bool = typer.Option(False, help="Reverse order"),
|
|
):
|
|
console.print(
|
|
cls.get_page(
|
|
page=page,
|
|
page_size=page_size,
|
|
all=all,
|
|
reverse=reverse,
|
|
)
|
|
)
|
|
|
|
@app.command()
|
|
def api():
|
|
cls.run_api()
|
|
|
|
@app.command()
|
|
def update():
|
|
console.print(cls.interactive_update())
|
|
|
|
return app
|
|
|
|
@classmethod
|
|
def run_api(cls):
|
|
uvicorn.run(cls.api(), host="127.0.0.1", port=8000)
|