diff --git a/pyproject.toml b/pyproject.toml index 5a88aa7..9393705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn'] +dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx'] [project.urls] Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index 75d9ecb..c3d1ee6 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import List, Optional +import httpx import typer import uvicorn -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, Depends, FastAPI from iterfzf import iterfzf from pydantic import BaseModel from pydantic_core._pydantic_core import PydanticUndefinedType @@ -10,7 +11,7 @@ from rich.console import Console from sqlalchemy import func from sqlmodel import Session, SQLModel, select -from sqlmodel_base.database import get_engine +from sqlmodel_base.database import get_engine, get_session console = Console() @@ -30,13 +31,19 @@ class Base(SQLModel): engine = get_engine() return engine - def create(self): - with Session(self.engine) as session: + def create(self, session: Optional[Session] = Depends(get_session)): + if isinstance(session, Session): validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self + else: + response = httpx.post( + "http://localhost:8000/create/", json=self.model_dump_json() + ) + breakpoint() + return response @classmethod def interactive_create(cls, id: Optional[int] = None): @@ -67,7 +74,9 @@ class Base(SQLModel): @classmethod def get(cls, id: int): with Session(cls.engine) as session: - return session.get(cls, id) + if hasattr(cls, "__table_class__"): + return session.get(cls.__table_class__, id) + return cls.model_validate(session.get(cls, id)) @classmethod def get_or_pick(cls, id: Optional[int] = None): @@ -76,29 +85,40 @@ class Base(SQLModel): return cls.get(id=id) @classmethod - def all(cls): + def all(cls) -> List: with Session(cls.engine) as session: - return session.exec(select(cls)).all() + if hasattr(cls, "__table_class__"): + return session.exec(select(cls.__table_class__)).all() + return [cls.model_validate(i) for i in session.exec(select(cls)).all()] @classmethod - def count(cls): + def count(cls) -> int: with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + return session.exec(func.count(cls.__table_class__.id)).scalar() return session.exec(func.count(cls.id)).scalar() @classmethod def first(cls): with Session(cls.engine) as session: - return session.exec(select(cls).limit(1)).first() + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls + return cls.model_validate( + session.exec(select(table).order_by(table.id.asc()).limit(1)).first() + ) @classmethod def last(cls): with Session(cls.engine) as session: - return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() - - @classmethod - def random(cls): - with Session(cls.engine) as session: - return session.exec(select(cls).order_by(cls.id).limit(1)).first() + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls + return session.exec( + select(table).order_by(table.id.desc()).limit(1) + ).first() @classmethod def get_page( @@ -109,29 +129,30 @@ class Base(SQLModel): reverse: bool = False, ): with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls if all: - items = session.exec(select(cls)).all() + items = session.exec(select(table)).all() page_size = len(items) else: if reverse: items = session.exec( - select(cls) + select(table) .offset((page - 1) * page_size) .limit(page_size) - .order_by(cls.id.desc()) + .order_by(table.id.desc()) ).all() else: items = session.exec( - select(cls) + select(table) .offset((page - 1) * page_size) .limit(page_size) - .order_by(cls.id) + .order_by(table.id) ).all() - # items = session.exec( - # select(cls).offset((page - 1) * page_size).limit(page_size) - # ).all() - total = cls.count() + total = table.count() # determine if there is a next page if page * page_size < total: next_page = page + 1 @@ -167,6 +188,8 @@ class Base(SQLModel): console.print("No item selected") return for field in item.__fields__.keys(): + if field == "id": + continue value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") if ( value @@ -183,9 +206,9 @@ class Base(SQLModel): api = FastAPI( title="FastAPI", version="0.1.0", - docs_url=None, - redoc_url=None, - openapi_url=None, + # docs_url=None, + # redoc_url=None, + # openapi_url=None, # openapi_tags=tags_metadata, # dependencies=[Depends(set_user), Depends(set_prefers)], ) @@ -197,10 +220,33 @@ class Base(SQLModel): @classmethod def router(cls): router = APIRouter() - router.add_api_route("/get/", cls.get, methods=["GET"]) - router.add_api_route("/list/", cls.all, methods=["GET"]) - router.add_api_route("/create/", cls.interactive_create, methods=["POST"]) - router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) + # router.add_api_route("/get/", cls.get, methods=["GET"]) + # router.add_api_route("/list/", cls.all, methods=["GET"]) + # router.add_api_route("/create/", cls.create, methods=["POST"]) + # router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) + + @router.get("/") + def get(id: int) -> cls: + return cls.get(id=id) + + @router.get("/list", include_in_schema=False) + @router.get("/list/") + def get_page( + page: int = 1, + page_size: int = 20, + all: bool = False, + reverse: bool = False, + ) -> PagedResult: + return cls.get_page() + + @router.post("/create") + def create(cls: cls) -> cls: + return cls.create() + + @router.put("/update") + def update() -> cls: + return cls.update() + return router @classmethod diff --git a/sqlmodel_base/hero/models.py b/sqlmodel_base/hero/models.py index 2f399eb..1a34f17 100644 --- a/sqlmodel_base/hero/models.py +++ b/sqlmodel_base/hero/models.py @@ -9,8 +9,7 @@ from sqlmodel_base.base import Base console = Console() -class Hero(Base, table=True): - id: Optional[int] = Field(default=None, primary_key=True) +class HeroBase(Base): name: str secret_name: str age: Optional[int] = None @@ -23,3 +22,29 @@ class Hero(Base, table=True): if v > 0: return v return abs(v) + + +class Hero(HeroBase, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class HeroCreate(HeroBase): + __table_class__ = Hero + pass + + +class HeroRead(HeroBase): + __table_class__ = Hero + id: int + + +class HeroUpdate(Base, table=False): + __table_class__ = Hero + name: Optional[str] + secret_name: Optional[str] + age: Optional[int] + team_id: Optional[int] + + +if __name__ == "__main__": + Hero.cli()