from typing import List, Optional import httpx import typer import uvicorn from fastapi import APIRouter, Depends, FastAPI from iterfzf import iterfzf from pydantic import BaseModel from pydantic_core._pydantic_core import PydanticUndefinedType from rich.console import Console from sqlalchemy import func from sqlmodel import Session, SQLModel, select from sqlmodel_base.database import get_engine, get_session console = Console() class PagedResult(BaseModel): items: list total: int page: int next_page: Optional[int] page_size: int class Base(SQLModel): @classmethod @property def engine(self): engine = get_engine() return engine def create(self, session: Optional[Session] = Depends(get_session)): if isinstance(session, Session): validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self else: response = httpx.post( "http://localhost:8000/create/", json=self.model_dump_json() ) breakpoint() return response @classmethod def interactive_create(cls, id: Optional[int] = None): data = {} for name, field in cls.__fields__.items(): default = field.default if ( default is None or isinstance(default, PydanticUndefinedType) ) and not field.is_required(): default = "None" if (isinstance(default, PydanticUndefinedType)) and field.is_required(): default = None value = typer.prompt(f"{name}: ", default=default) if value and value != "" and value != "None": data[name] = value item = cls(**data).create() console.print(item) @classmethod def pick(cls): all = cls.all() item = iterfzf([item.model_dump_json() for item in all]) if not item: console.print("No item selected") return return cls.get(cls.parse_raw(item).id) @classmethod def get(cls, id: int): with Session(cls.engine) as session: if hasattr(cls, "__table_class__"): return session.get(cls.__table_class__, id) return cls.model_validate(session.get(cls, id)) @classmethod def get_or_pick(cls, id: Optional[int] = None): if id is None: return cls.pick() return cls.get(id=id) @classmethod def all(cls) -> List: with Session(cls.engine) as session: if hasattr(cls, "__table_class__"): return session.exec(select(cls.__table_class__)).all() return [cls.model_validate(i) for i in session.exec(select(cls)).all()] @classmethod def count(cls) -> int: with Session(cls.engine) as session: if hasattr(cls, "__table_class__"): return session.exec(func.count(cls.__table_class__.id)).scalar() return session.exec(func.count(cls.id)).scalar() @classmethod def first(cls): with Session(cls.engine) as session: if hasattr(cls, "__table_class__"): table = cls.__table_class__ else: table = cls return cls.model_validate( session.exec(select(table).order_by(table.id.asc()).limit(1)).first() ) @classmethod def last(cls): with Session(cls.engine) as session: if hasattr(cls, "__table_class__"): table = cls.__table_class__ else: table = cls return session.exec( select(table).order_by(table.id.desc()).limit(1) ).first() @classmethod def get_page( cls, page: int = 1, page_size: int = 20, all: bool = False, reverse: bool = False, ): with Session(cls.engine) as session: if hasattr(cls, "__table_class__"): table = cls.__table_class__ else: table = cls if all: items = session.exec(select(table)).all() page_size = len(items) else: if reverse: items = session.exec( select(table) .offset((page - 1) * page_size) .limit(page_size) .order_by(table.id.desc()) ).all() else: items = session.exec( select(table) .offset((page - 1) * page_size) .limit(page_size) .order_by(table.id) ).all() total = table.count() # determine if there is a next page if page * page_size < total: next_page = page + 1 else: next_page = None return PagedResult( items=items, total=total, page=page, page_size=page_size, next_page=next_page, ) def delete(self): with Session(self.engine) as session: session.delete(self) session.commit() return self def update(self): with Session(self.engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self @classmethod def interactive_update(cls, id: Optional[int] = None): item = cls.get_or_pick(id=id) if not item: console.print("No item selected") return for field in item.__fields__.keys(): if field == "id": continue value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") if ( value and value != "" and value != "None" and value != getattr(item, field) ): setattr(item, field, value) item.update() console.print(item) @classmethod def api(cls): api = FastAPI( title="FastAPI", version="0.1.0", # docs_url=None, # redoc_url=None, # openapi_url=None, # openapi_tags=tags_metadata, # dependencies=[Depends(set_user), Depends(set_prefers)], ) api.include_router(cls.router()) return api @classmethod def router(cls): router = APIRouter() # router.add_api_route("/get/", cls.get, methods=["GET"]) # router.add_api_route("/list/", cls.all, methods=["GET"]) # router.add_api_route("/create/", cls.create, methods=["POST"]) # router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) @router.get("/") def get(id: int) -> cls: return cls.get(id=id) @router.get("/list", include_in_schema=False) @router.get("/list/") def get_page( page: int = 1, page_size: int = 20, all: bool = False, reverse: bool = False, ) -> PagedResult: return cls.get_page() @router.post("/create") def create(cls: cls) -> cls: return cls.create() @router.put("/update") def update() -> cls: return cls.update() return router @classmethod @property def cli(cls): app = typer.Typer() @app.command() def get(id: int = typer.Option(None, help="Hero ID")): console.print(cls.get_or_pick(id=id)) @app.command() def create(): console.print(cls.interactive_create()) @app.command() def list( page: int = typer.Option(1, help="Page number"), page_size: int = typer.Option(20, help="Page size"), all: bool = typer.Option(False, help="Show all heroes"), reverse: bool = typer.Option(False, help="Reverse order"), ): console.print( cls.get_page( page=page, page_size=page_size, all=all, reverse=reverse, ) ) @app.command() def api(): cls.run_api() @app.command() def update(): console.print(cls.interactive_update()) return app @classmethod def run_api(cls): uvicorn.run(cls.api(), host="127.0.0.1", port=8000)