diff --git a/.gitignore b/.gitignore index e1a3186..9d6f21d 100644 --- a/.gitignore +++ b/.gitignore @@ -962,3 +962,4 @@ FodyWeavers.xsd # Additional files built by Visual Studio # End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode +database.db diff --git a/database.db b/database.db deleted file mode 100644 index 293c794..0000000 Binary files a/database.db and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml index e567efa..9393705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,15 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ 'rich', 'sqlmodel', 'typer' ] +dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx'] [project.urls] -Documentation = "https://github.com/unknown/sqlmodel-base#readme" -Issues = "https://github.com/unknown/sqlmodel-base/issues" -Source = "https://github.com/unknown/sqlmodel-base" +Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" +Issues = "https://github.com/waylonwalker/sqlmodel-base/issues" +Source = "https://github.com/waylonwalker/sqlmodel-base" + +[project.scripts] +sqlmodel-base = "sqlmodel_base.cli:app" [tool.hatch.version] path = "sqlmodel_base/__about__.py" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index f12571a..c3d1ee6 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -1,18 +1,21 @@ -from typing import Optional +from typing import List, Optional -from pydantic import BaseModel, validator +import httpx +import typer +import uvicorn +from fastapi import APIRouter, Depends, FastAPI +from iterfzf import iterfzf +from pydantic import BaseModel +from pydantic_core._pydantic_core import PydanticUndefinedType from rich.console import Console from sqlalchemy import func -from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel import Session, SQLModel, select + +from sqlmodel_base.database import get_engine, get_session console = Console() -def get_session(): - with Session(engine) as session: - yield session - - class PagedResult(BaseModel): items: list total: int @@ -22,51 +25,134 @@ class PagedResult(BaseModel): class Base(SQLModel): - def create(self): - with Session(engine) as session: + @classmethod + @property + def engine(self): + engine = get_engine() + return engine + + def create(self, session: Optional[Session] = Depends(get_session)): + if isinstance(session, Session): validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self + else: + response = httpx.post( + "http://localhost:8000/create/", json=self.model_dump_json() + ) + breakpoint() + return response @classmethod - def get(cls, id): - with Session(engine) as session: - return session.get(cls, id) + def interactive_create(cls, id: Optional[int] = None): + data = {} + for name, field in cls.__fields__.items(): + default = field.default + if ( + default is None or isinstance(default, PydanticUndefinedType) + ) and not field.is_required(): + default = "None" + if (isinstance(default, PydanticUndefinedType)) and field.is_required(): + default = None + value = typer.prompt(f"{name}: ", default=default) + if value and value != "" and value != "None": + data[name] = value + item = cls(**data).create() + console.print(item) @classmethod - def get_all(cls): - with Session(engine) as session: - return session.exec(select(cls)).all() + def pick(cls): + all = cls.all() + item = iterfzf([item.model_dump_json() for item in all]) + if not item: + console.print("No item selected") + return + return cls.get(cls.parse_raw(item).id) @classmethod - def get_count(cls): - with Session(engine) as session: - return session.exec(func.count(Hero.id)).scalar() + def get(cls, id: int): + with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + return session.get(cls.__table_class__, id) + return cls.model_validate(session.get(cls, id)) @classmethod - def get_first(cls): - with Session(engine) as session: - return session.exec(select(cls).limit(1)).first() + def get_or_pick(cls, id: Optional[int] = None): + if id is None: + return cls.pick() + return cls.get(id=id) @classmethod - def get_last(cls): - with Session(engine) as session: - return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() + def all(cls) -> List: + with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + return session.exec(select(cls.__table_class__)).all() + return [cls.model_validate(i) for i in session.exec(select(cls)).all()] @classmethod - def get_random(cls): - with Session(engine) as session: - return session.exec(select(cls).order_by(cls.id).limit(1)).first() + def count(cls) -> int: + with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + return session.exec(func.count(cls.__table_class__.id)).scalar() + return session.exec(func.count(cls.id)).scalar() @classmethod - def get_page(cls, page: int = 1, page_size: int = 20): - with Session(engine) as session: - items = session.exec( - select(cls).offset((page - 1) * page_size).limit(page_size) - ).all() - total = cls.get_count() + def first(cls): + with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls + return cls.model_validate( + session.exec(select(table).order_by(table.id.asc()).limit(1)).first() + ) + + @classmethod + def last(cls): + with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls + return session.exec( + select(table).order_by(table.id.desc()).limit(1) + ).first() + + @classmethod + def get_page( + cls, + page: int = 1, + page_size: int = 20, + all: bool = False, + reverse: bool = False, + ): + with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls + if all: + items = session.exec(select(table)).all() + page_size = len(items) + else: + if reverse: + items = session.exec( + select(table) + .offset((page - 1) * page_size) + .limit(page_size) + .order_by(table.id.desc()) + ).all() + else: + items = session.exec( + select(table) + .offset((page - 1) * page_size) + .limit(page_size) + .order_by(table.id) + ).all() + + total = table.count() # determine if there is a next page if page * page_size < total: next_page = page + 1 @@ -82,72 +168,126 @@ class Base(SQLModel): ) def delete(self): - with Session(engine) as session: + with Session(self.engine) as session: session.delete(self) session.commit() return self def update(self): - with Session(engine) as session: + with Session(self.engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self + @classmethod + def interactive_update(cls, id: Optional[int] = None): + item = cls.get_or_pick(id=id) + if not item: + console.print("No item selected") + return + for field in item.__fields__.keys(): + if field == "id": + continue + value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") + if ( + value + and value != "" + and value != "None" + and value != getattr(item, field) + ): + setattr(item, field, value) + item.update() + console.print(item) -class Hero(Base, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str - secret_name: str - age: Optional[int] = None + @classmethod + def api(cls): + api = FastAPI( + title="FastAPI", + version="0.1.0", + # docs_url=None, + # redoc_url=None, + # openapi_url=None, + # openapi_tags=tags_metadata, + # dependencies=[Depends(set_user), Depends(set_prefers)], + ) - @validator("age") - def validate_age(cls, v): - if v is None: - return v - if v > 0: - return v - return abs(v) + api.include_router(cls.router()) + return api -sqlite_file_name = "database.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" + @classmethod + def router(cls): + router = APIRouter() + # router.add_api_route("/get/", cls.get, methods=["GET"]) + # router.add_api_route("/list/", cls.all, methods=["GET"]) + # router.add_api_route("/create/", cls.create, methods=["POST"]) + # router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) -engine = create_engine(sqlite_url) # , echo=True) + @router.get("/") + def get(id: int) -> cls: + return cls.get(id=id) + @router.get("/list", include_in_schema=False) + @router.get("/list/") + def get_page( + page: int = 1, + page_size: int = 20, + all: bool = False, + reverse: bool = False, + ) -> PagedResult: + return cls.get_page() -# replace with alembic commands -def create_db_and_tables(): - SQLModel.metadata.create_all(engine) + @router.post("/create") + def create(cls: cls) -> cls: + return cls.create() + @router.put("/update") + def update() -> cls: + return cls.update() -def create_heroes(): - hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson").create() - hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() - hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() + return router - # with Session(engine) as session: - # session.add(hero_1) - # session.add(hero_2) - # session.add(hero_3) - # - # session.commit() + @classmethod + @property + def cli(cls): + app = typer.Typer() + @app.command() + def get(id: int = typer.Option(None, help="Hero ID")): + console.print(cls.get_or_pick(id=id)) -def page_heroes(): - next_page = 1 - while next_page: - page = Hero.get_page(page=next_page, page_size=2) - console.print(page) - next_page = page.next_page + @app.command() + def create(): + console.print(cls.interactive_create()) + @app.command() + def list( + page: int = typer.Option(1, help="Page number"), + page_size: int = typer.Option(20, help="Page size"), + all: bool = typer.Option(False, help="Show all heroes"), + reverse: bool = typer.Option(False, help="Reverse order"), + ): + console.print( + cls.get_page( + page=page, + page_size=page_size, + all=all, + reverse=reverse, + ) + ) -def main(): - create_db_and_tables() - create_heroes() - page_heroes() + @app.command() + def api(): + cls.run_api() + @app.command() + def update(): + console.print(cls.interactive_update()) -if __name__ == "__main__": - main() + return app + + @classmethod + def run_api(cls): + uvicorn.run(cls.api(), host="127.0.0.1", port=8000) diff --git a/sqlmodel_base/cli.py b/sqlmodel_base/cli.py new file mode 100644 index 0000000..13e682b --- /dev/null +++ b/sqlmodel_base/cli.py @@ -0,0 +1,11 @@ +import typer + +from sqlmodel_base.hero.cli import hero_app + +app = typer.Typer() + +app.add_typer(hero_app, name="hero") + + +if __name__ == "__main__": + app() diff --git a/sqlmodel_base/database.py b/sqlmodel_base/database.py new file mode 100644 index 0000000..13f8dbc --- /dev/null +++ b/sqlmodel_base/database.py @@ -0,0 +1,18 @@ +from functools import lru_cache + +from sqlmodel import Session, SQLModel, create_engine + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + + +@lru_cache +def get_engine(): + engine = create_engine(sqlite_url) + SQLModel.metadata.create_all(engine) + return engine + + +def get_session(): + with Session(get_engine()) as session: + yield session diff --git a/sqlmodel_base/hero/cli.py b/sqlmodel_base/hero/cli.py new file mode 100644 index 0000000..784edef --- /dev/null +++ b/sqlmodel_base/hero/cli.py @@ -0,0 +1,65 @@ +from rich.console import Console + +from sqlmodel_base.database import get_engine +from sqlmodel_base.hero.models import Hero + +engine = get_engine() + +hero_app = Hero.cli +console = Console() +# @hero_app.callback() +# def hero(): +# "model cli" + + +# @hero_app.command() +# def get(id: int = typer.Option(None, help="Hero ID")): +# console.print(Hero.get_or_pick(id=id)) + + +# @hero_app.command() +# def list( +# page: int = typer.Option(1, help="Page number"), +# page_size: int = typer.Option(20, help="Page size"), +# all: bool = typer.Option(False, help="Show all heroes"), +# reverse: bool = typer.Option(False, help="Reverse order"), +# ): +# console.print( +# Hero.get_page(page=page, page_size=page_size, all=all, reverse=reverse) +# ) + + +# @hero_app.command() +# def create( +# name: str = typer.Option(..., help="Hero name", prompt=True), +# secret_name: str = typer.Option(..., help="Hero secret name", prompt=True), +# age: int = typer.Option(None, help="Hero age", prompt=True), +# ): +# hero = Hero( +# name=name, +# secret_name=secret_name, +# age=age, +# ).create() +# console.print(hero) + + +# @hero_app.command() +# def update( +# id: int = typer.Option(None, help="Hero ID"), +# name: str = typer.Option(None, help="Hero name"), +# secret_name: str = typer.Option(None, help="Hero secret name"), +# age: int = typer.Option(None, help="Hero age"), +# ): +# hero = Hero.interactive_update(id=id) +# console.print(hero) + + +# @hero_app.command() +# def create_heroes(): +# team_1 = Team.get(id=1) +# if not team_1: +# team_1 = Team(name="Team 1", headquarters="Headquarters 1").create() +# for _ in range(50): +# Hero(name="Deadpond", secret_name="Dive Wilson", team_id=team_1.id).create() +# Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() +# Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() diff --git a/sqlmodel_base/hero/models.py b/sqlmodel_base/hero/models.py new file mode 100644 index 0000000..1a34f17 --- /dev/null +++ b/sqlmodel_base/hero/models.py @@ -0,0 +1,50 @@ +from typing import Optional + +from pydantic import validator +from rich.console import Console +from sqlmodel import Field + +from sqlmodel_base.base import Base + +console = Console() + + +class HeroBase(Base): + name: str + secret_name: str + age: Optional[int] = None + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + + @validator("age") + def validate_age(cls, v): + if v is None: + return v + if v > 0: + return v + return abs(v) + + +class Hero(HeroBase, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class HeroCreate(HeroBase): + __table_class__ = Hero + pass + + +class HeroRead(HeroBase): + __table_class__ = Hero + id: int + + +class HeroUpdate(Base, table=False): + __table_class__ = Hero + name: Optional[str] + secret_name: Optional[str] + age: Optional[int] + team_id: Optional[int] + + +if __name__ == "__main__": + Hero.cli() diff --git a/sqlmodel_base/team/models.py b/sqlmodel_base/team/models.py new file mode 100644 index 0000000..791abe3 --- /dev/null +++ b/sqlmodel_base/team/models.py @@ -0,0 +1,11 @@ +from typing import Optional + +from sqlmodel import Field + +from sqlmodel_base.base import Base + + +class Team(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str