diff --git a/.gitignore b/.gitignore index e1a3186..9d6f21d 100644 --- a/.gitignore +++ b/.gitignore @@ -962,3 +962,4 @@ FodyWeavers.xsd # Additional files built by Visual Studio # End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode +database.db diff --git a/pyproject.toml b/pyproject.toml index ebb2e6e..cecd388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,16 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ 'rich', 'sqlmodel', 'typer' ] +dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf'] [project.urls] Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" Issues = "https://github.com/waylonwalker/sqlmodel-base/issues" Source = "https://github.com/waylonwalker/sqlmodel-base" +[project.scripts] +sqlmodel-base = "sqlmodel_base.cli:app" + [tool.hatch.version] path = "sqlmodel_base/__about__.py" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index 054ee62..80f91fb 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -1,18 +1,19 @@ +import json from typing import Optional +import typer +from iterfzf import iterfzf from pydantic import BaseModel, validator +from pydantic_core._pydantic_core import PydanticUndefinedType from rich.console import Console from sqlalchemy import func from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel_base.database import get_engine + console = Console() -def get_session(): - with Session(engine) as session: - yield session - - class PagedResult(BaseModel): items: list total: int @@ -22,8 +23,14 @@ class PagedResult(BaseModel): class Base(SQLModel): + @classmethod + @property + def engine(self): + engine = get_engine() + return engine + def create(self): - with Session(engine) as session: + with Session(self.engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() @@ -31,42 +38,99 @@ class Base(SQLModel): return self @classmethod - def get(cls, id): - with Session(engine) as session: + def interactive_create(cls, id: Optional[int] = None): + data = {} + for name, field in cls.__fields__.items(): + default = field.default + if ( + default is None or isinstance(default, PydanticUndefinedType) + ) and not field.is_required(): + default = "None" + if (isinstance(default, PydanticUndefinedType)) and field.is_required(): + default = None + value = typer.prompt(f"{name}: ", default=default) + if value and value != "" and value != "None": + data[name] = value + item = cls(**data).create() + console.print(item) + + @classmethod + def pick(cls): + all = cls.all() + item = iterfzf([item.model_dump_json() for item in all]) + if not item: + console.print("No item selected") + return + return cls.get(cls.parse_raw(item).id) + + @classmethod + def get(cls, id: int): + with Session(cls.engine) as session: return session.get(cls, id) @classmethod - def get_all(cls): - with Session(engine) as session: + def get_or_pick(cls, id: Optional[int] = None): + if id is None: + return cls.pick() + return cls.get(id=id) + + @classmethod + def all(cls): + with Session(cls.engine) as session: return session.exec(select(cls)).all() @classmethod - def get_count(cls): - with Session(engine) as session: - return session.exec(func.count(Hero.id)).scalar() + def count(cls): + with Session(cls.engine) as session: + return session.exec(func.count(cls.id)).scalar() @classmethod - def get_first(cls): - with Session(engine) as session: + def first(cls): + with Session(cls.engine) as session: return session.exec(select(cls).limit(1)).first() @classmethod - def get_last(cls): - with Session(engine) as session: + def last(cls): + with Session(cls.engine) as session: return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() @classmethod - def get_random(cls): - with Session(engine) as session: + def random(cls): + with Session(cls.engine) as session: return session.exec(select(cls).order_by(cls.id).limit(1)).first() @classmethod - def get_page(cls, page: int = 1, page_size: int = 20): - with Session(engine) as session: - items = session.exec( - select(cls).offset((page - 1) * page_size).limit(page_size) - ).all() - total = cls.get_count() + def get_page( + cls, + page: int = 1, + page_size: int = 20, + all: bool = False, + reverse: bool = False, + ): + with Session(cls.engine) as session: + if all: + items = session.exec(select(cls)).all() + page_size = len(items) + else: + if reverse: + items = session.exec( + select(cls) + .offset((page - 1) * page_size) + .limit(page_size) + .order_by(cls.id.desc()) + ).all() + else: + items = session.exec( + select(cls) + .offset((page - 1) * page_size) + .limit(page_size) + .order_by(cls.id) + ).all() + # items = session.exec( + # select(cls).offset((page - 1) * page_size).limit(page_size) + # ).all() + + total = cls.count() # determine if there is a next page if page * page_size < total: next_page = page + 1 @@ -82,39 +146,71 @@ class Base(SQLModel): ) def delete(self): - with Session(engine) as session: + with Session(self.engine) as session: session.delete(self) session.commit() return self def update(self): - with Session(engine) as session: + with Session(self.engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self + @classmethod + def interactive_update(cls, id: Optional[int] = None): + item = cls.get_or_pick(id=id) + if not item: + console.print("No item selected") + return + for field in item.__fields__.keys(): + value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") + if ( + value + and value != "" + and value != "None" + and value != getattr(item, field) + ): + setattr(item, field, value) + item.update() + console.print(item) -class Hero(Base, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str - secret_name: str - age: Optional[int] = None + @classmethod + @property + def cli(cls): + app = typer.Typer() - @validator("age") - def validate_age(cls, v): - if v is None: - return v - if v > 0: - return v - return abs(v) + @app.command() + def get(id: int = typer.Option(None, help="Hero ID")): + console.print(cls.get_or_pick(id=id)) + @app.command() + def create(): + console.print(cls.interactive_create()) -sqlite_file_name = "database.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" + @app.command() + def list( + page: int = typer.Option(1, help="Page number"), + page_size: int = typer.Option(20, help="Page size"), + all: bool = typer.Option(False, help="Show all heroes"), + reverse: bool = typer.Option(False, help="Reverse order"), + ): + console.print( + cls.get_page( + page=page, + page_size=page_size, + all=all, + reverse=reverse, + ) + ) -engine = create_engine(sqlite_url) # , echo=True) + @app.command() + def update(): + console.print(cls.interactive_update()) + + return app # replace with alembic commands diff --git a/sqlmodel_base/cli.py b/sqlmodel_base/cli.py new file mode 100644 index 0000000..13e682b --- /dev/null +++ b/sqlmodel_base/cli.py @@ -0,0 +1,11 @@ +import typer + +from sqlmodel_base.hero.cli import hero_app + +app = typer.Typer() + +app.add_typer(hero_app, name="hero") + + +if __name__ == "__main__": + app() diff --git a/sqlmodel_base/database.py b/sqlmodel_base/database.py new file mode 100644 index 0000000..90ceedb --- /dev/null +++ b/sqlmodel_base/database.py @@ -0,0 +1,21 @@ +from functools import lru_cache + +from sqlmodel import Field, Session, SQLModel, create_engine, select + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + + +@lru_cache +def get_engine(): + from sqlmodel_base.hero.models import Hero + from sqlmodel_base.team.models import Team + + engine = create_engine(sqlite_url) + SQLModel.metadata.create_all(engine) + return engine + + +def get_session(): + with Session(get_engine()) as session: + yield session diff --git a/sqlmodel_base/hero/cli.py b/sqlmodel_base/hero/cli.py new file mode 100644 index 0000000..458d319 --- /dev/null +++ b/sqlmodel_base/hero/cli.py @@ -0,0 +1,72 @@ +import json + +import typer +from iterfzf import iterfzf +from rich.console import Console + +from sqlmodel_base.database import get_engine +from sqlmodel_base.hero.models import Hero +from sqlmodel_base.team.models import Team + +engine = get_engine() + +hero_app = Hero.cli +console = Console() + + +# @hero_app.callback() +# def hero(): +# "model cli" + + +# @hero_app.command() +# def get(id: int = typer.Option(None, help="Hero ID")): +# console.print(Hero.get_or_pick(id=id)) + + +# @hero_app.command() +# def list( +# page: int = typer.Option(1, help="Page number"), +# page_size: int = typer.Option(20, help="Page size"), +# all: bool = typer.Option(False, help="Show all heroes"), +# reverse: bool = typer.Option(False, help="Reverse order"), +# ): +# console.print( +# Hero.get_page(page=page, page_size=page_size, all=all, reverse=reverse) +# ) + + +# @hero_app.command() +# def create( +# name: str = typer.Option(..., help="Hero name", prompt=True), +# secret_name: str = typer.Option(..., help="Hero secret name", prompt=True), +# age: int = typer.Option(None, help="Hero age", prompt=True), +# ): +# hero = Hero( +# name=name, +# secret_name=secret_name, +# age=age, +# ).create() +# console.print(hero) + + +# @hero_app.command() +# def update( +# id: int = typer.Option(None, help="Hero ID"), +# name: str = typer.Option(None, help="Hero name"), +# secret_name: str = typer.Option(None, help="Hero secret name"), +# age: int = typer.Option(None, help="Hero age"), +# ): +# hero = Hero.interactive_update(id=id) +# console.print(hero) + + +# @hero_app.command() +# def create_heroes(): +# team_1 = Team.get(id=1) +# if not team_1: +# team_1 = Team(name="Team 1", headquarters="Headquarters 1").create() +# for _ in range(50): +# Hero(name="Deadpond", secret_name="Dive Wilson", team_id=team_1.id).create() +# Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() +# Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() diff --git a/sqlmodel_base/hero/models.py b/sqlmodel_base/hero/models.py new file mode 100644 index 0000000..fa90c87 --- /dev/null +++ b/sqlmodel_base/hero/models.py @@ -0,0 +1,26 @@ +from typing import Optional + +from pydantic import BaseModel, validator +from rich.console import Console +from sqlalchemy import func +from sqlmodel import Field, Session, SQLModel, create_engine, select + +from sqlmodel_base.base import Base + +console = Console() + + +class Hero(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + + @validator("age") + def validate_age(cls, v): + if v is None: + return v + if v > 0: + return v + return abs(v) diff --git a/sqlmodel_base/team/models.py b/sqlmodel_base/team/models.py new file mode 100644 index 0000000..e57384e --- /dev/null +++ b/sqlmodel_base/team/models.py @@ -0,0 +1,12 @@ +from typing import Optional + +from rich.console import Console +from sqlmodel import Field + +from sqlmodel_base.base import Base + + +class Team(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str