diff --git a/.gitignore b/.gitignore index 9d6f21d..e1a3186 100644 --- a/.gitignore +++ b/.gitignore @@ -962,4 +962,3 @@ FodyWeavers.xsd # Additional files built by Visual Studio # End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode -database.db diff --git a/database.db b/database.db new file mode 100644 index 0000000..293c794 Binary files /dev/null and b/database.db differ diff --git a/pyproject.toml b/pyproject.toml index 9393705..e567efa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,15 +24,12 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx'] +dependencies = [ 'rich', 'sqlmodel', 'typer' ] [project.urls] -Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" -Issues = "https://github.com/waylonwalker/sqlmodel-base/issues" -Source = "https://github.com/waylonwalker/sqlmodel-base" - -[project.scripts] -sqlmodel-base = "sqlmodel_base.cli:app" +Documentation = "https://github.com/unknown/sqlmodel-base#readme" +Issues = "https://github.com/unknown/sqlmodel-base/issues" +Source = "https://github.com/unknown/sqlmodel-base" [tool.hatch.version] path = "sqlmodel_base/__about__.py" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index c3d1ee6..f12571a 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -1,21 +1,18 @@ -from typing import List, Optional +from typing import Optional -import httpx -import typer -import uvicorn -from fastapi import APIRouter, Depends, FastAPI -from iterfzf import iterfzf -from pydantic import BaseModel -from pydantic_core._pydantic_core import PydanticUndefinedType +from pydantic import BaseModel, validator from rich.console import Console from sqlalchemy import func -from sqlmodel import Session, SQLModel, select - -from sqlmodel_base.database import get_engine, get_session +from sqlmodel import Field, Session, SQLModel, create_engine, select console = Console() +def get_session(): + with Session(engine) as session: + yield session + + class PagedResult(BaseModel): items: list total: int @@ -25,134 +22,51 @@ class PagedResult(BaseModel): class Base(SQLModel): - @classmethod - @property - def engine(self): - engine = get_engine() - return engine - - def create(self, session: Optional[Session] = Depends(get_session)): - if isinstance(session, Session): + def create(self): + with Session(engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self - else: - response = httpx.post( - "http://localhost:8000/create/", json=self.model_dump_json() - ) - breakpoint() - return response @classmethod - def interactive_create(cls, id: Optional[int] = None): - data = {} - for name, field in cls.__fields__.items(): - default = field.default - if ( - default is None or isinstance(default, PydanticUndefinedType) - ) and not field.is_required(): - default = "None" - if (isinstance(default, PydanticUndefinedType)) and field.is_required(): - default = None - value = typer.prompt(f"{name}: ", default=default) - if value and value != "" and value != "None": - data[name] = value - item = cls(**data).create() - console.print(item) + def get(cls, id): + with Session(engine) as session: + return session.get(cls, id) @classmethod - def pick(cls): - all = cls.all() - item = iterfzf([item.model_dump_json() for item in all]) - if not item: - console.print("No item selected") - return - return cls.get(cls.parse_raw(item).id) + def get_all(cls): + with Session(engine) as session: + return session.exec(select(cls)).all() @classmethod - def get(cls, id: int): - with Session(cls.engine) as session: - if hasattr(cls, "__table_class__"): - return session.get(cls.__table_class__, id) - return cls.model_validate(session.get(cls, id)) + def get_count(cls): + with Session(engine) as session: + return session.exec(func.count(Hero.id)).scalar() @classmethod - def get_or_pick(cls, id: Optional[int] = None): - if id is None: - return cls.pick() - return cls.get(id=id) + def get_first(cls): + with Session(engine) as session: + return session.exec(select(cls).limit(1)).first() @classmethod - def all(cls) -> List: - with Session(cls.engine) as session: - if hasattr(cls, "__table_class__"): - return session.exec(select(cls.__table_class__)).all() - return [cls.model_validate(i) for i in session.exec(select(cls)).all()] + def get_last(cls): + with Session(engine) as session: + return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() @classmethod - def count(cls) -> int: - with Session(cls.engine) as session: - if hasattr(cls, "__table_class__"): - return session.exec(func.count(cls.__table_class__.id)).scalar() - return session.exec(func.count(cls.id)).scalar() + def get_random(cls): + with Session(engine) as session: + return session.exec(select(cls).order_by(cls.id).limit(1)).first() @classmethod - def first(cls): - with Session(cls.engine) as session: - if hasattr(cls, "__table_class__"): - table = cls.__table_class__ - else: - table = cls - return cls.model_validate( - session.exec(select(table).order_by(table.id.asc()).limit(1)).first() - ) - - @classmethod - def last(cls): - with Session(cls.engine) as session: - if hasattr(cls, "__table_class__"): - table = cls.__table_class__ - else: - table = cls - return session.exec( - select(table).order_by(table.id.desc()).limit(1) - ).first() - - @classmethod - def get_page( - cls, - page: int = 1, - page_size: int = 20, - all: bool = False, - reverse: bool = False, - ): - with Session(cls.engine) as session: - if hasattr(cls, "__table_class__"): - table = cls.__table_class__ - else: - table = cls - if all: - items = session.exec(select(table)).all() - page_size = len(items) - else: - if reverse: - items = session.exec( - select(table) - .offset((page - 1) * page_size) - .limit(page_size) - .order_by(table.id.desc()) - ).all() - else: - items = session.exec( - select(table) - .offset((page - 1) * page_size) - .limit(page_size) - .order_by(table.id) - ).all() - - total = table.count() + def get_page(cls, page: int = 1, page_size: int = 20): + with Session(engine) as session: + items = session.exec( + select(cls).offset((page - 1) * page_size).limit(page_size) + ).all() + total = cls.get_count() # determine if there is a next page if page * page_size < total: next_page = page + 1 @@ -168,126 +82,72 @@ class Base(SQLModel): ) def delete(self): - with Session(self.engine) as session: + with Session(engine) as session: session.delete(self) session.commit() return self def update(self): - with Session(self.engine) as session: + with Session(engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self - @classmethod - def interactive_update(cls, id: Optional[int] = None): - item = cls.get_or_pick(id=id) - if not item: - console.print("No item selected") - return - for field in item.__fields__.keys(): - if field == "id": - continue - value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") - if ( - value - and value != "" - and value != "None" - and value != getattr(item, field) - ): - setattr(item, field, value) - item.update() - console.print(item) - @classmethod - def api(cls): - api = FastAPI( - title="FastAPI", - version="0.1.0", - # docs_url=None, - # redoc_url=None, - # openapi_url=None, - # openapi_tags=tags_metadata, - # dependencies=[Depends(set_user), Depends(set_prefers)], - ) +class Hero(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None - api.include_router(cls.router()) + @validator("age") + def validate_age(cls, v): + if v is None: + return v + if v > 0: + return v + return abs(v) - return api - @classmethod - def router(cls): - router = APIRouter() - # router.add_api_route("/get/", cls.get, methods=["GET"]) - # router.add_api_route("/list/", cls.all, methods=["GET"]) - # router.add_api_route("/create/", cls.create, methods=["POST"]) - # router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" - @router.get("/") - def get(id: int) -> cls: - return cls.get(id=id) +engine = create_engine(sqlite_url) # , echo=True) - @router.get("/list", include_in_schema=False) - @router.get("/list/") - def get_page( - page: int = 1, - page_size: int = 20, - all: bool = False, - reverse: bool = False, - ) -> PagedResult: - return cls.get_page() - @router.post("/create") - def create(cls: cls) -> cls: - return cls.create() +# replace with alembic commands +def create_db_and_tables(): + SQLModel.metadata.create_all(engine) - @router.put("/update") - def update() -> cls: - return cls.update() - return router +def create_heroes(): + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson").create() + hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() + hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() - @classmethod - @property - def cli(cls): - app = typer.Typer() + # with Session(engine) as session: + # session.add(hero_1) + # session.add(hero_2) + # session.add(hero_3) + # + # session.commit() - @app.command() - def get(id: int = typer.Option(None, help="Hero ID")): - console.print(cls.get_or_pick(id=id)) - @app.command() - def create(): - console.print(cls.interactive_create()) +def page_heroes(): + next_page = 1 + while next_page: + page = Hero.get_page(page=next_page, page_size=2) + console.print(page) + next_page = page.next_page - @app.command() - def list( - page: int = typer.Option(1, help="Page number"), - page_size: int = typer.Option(20, help="Page size"), - all: bool = typer.Option(False, help="Show all heroes"), - reverse: bool = typer.Option(False, help="Reverse order"), - ): - console.print( - cls.get_page( - page=page, - page_size=page_size, - all=all, - reverse=reverse, - ) - ) - @app.command() - def api(): - cls.run_api() +def main(): + create_db_and_tables() + create_heroes() + page_heroes() - @app.command() - def update(): - console.print(cls.interactive_update()) - return app - - @classmethod - def run_api(cls): - uvicorn.run(cls.api(), host="127.0.0.1", port=8000) +if __name__ == "__main__": + main() diff --git a/sqlmodel_base/cli.py b/sqlmodel_base/cli.py deleted file mode 100644 index 13e682b..0000000 --- a/sqlmodel_base/cli.py +++ /dev/null @@ -1,11 +0,0 @@ -import typer - -from sqlmodel_base.hero.cli import hero_app - -app = typer.Typer() - -app.add_typer(hero_app, name="hero") - - -if __name__ == "__main__": - app() diff --git a/sqlmodel_base/database.py b/sqlmodel_base/database.py deleted file mode 100644 index 13f8dbc..0000000 --- a/sqlmodel_base/database.py +++ /dev/null @@ -1,18 +0,0 @@ -from functools import lru_cache - -from sqlmodel import Session, SQLModel, create_engine - -sqlite_file_name = "database.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" - - -@lru_cache -def get_engine(): - engine = create_engine(sqlite_url) - SQLModel.metadata.create_all(engine) - return engine - - -def get_session(): - with Session(get_engine()) as session: - yield session diff --git a/sqlmodel_base/hero/cli.py b/sqlmodel_base/hero/cli.py deleted file mode 100644 index 784edef..0000000 --- a/sqlmodel_base/hero/cli.py +++ /dev/null @@ -1,65 +0,0 @@ -from rich.console import Console - -from sqlmodel_base.database import get_engine -from sqlmodel_base.hero.models import Hero - -engine = get_engine() - -hero_app = Hero.cli -console = Console() -# @hero_app.callback() -# def hero(): -# "model cli" - - -# @hero_app.command() -# def get(id: int = typer.Option(None, help="Hero ID")): -# console.print(Hero.get_or_pick(id=id)) - - -# @hero_app.command() -# def list( -# page: int = typer.Option(1, help="Page number"), -# page_size: int = typer.Option(20, help="Page size"), -# all: bool = typer.Option(False, help="Show all heroes"), -# reverse: bool = typer.Option(False, help="Reverse order"), -# ): -# console.print( -# Hero.get_page(page=page, page_size=page_size, all=all, reverse=reverse) -# ) - - -# @hero_app.command() -# def create( -# name: str = typer.Option(..., help="Hero name", prompt=True), -# secret_name: str = typer.Option(..., help="Hero secret name", prompt=True), -# age: int = typer.Option(None, help="Hero age", prompt=True), -# ): -# hero = Hero( -# name=name, -# secret_name=secret_name, -# age=age, -# ).create() -# console.print(hero) - - -# @hero_app.command() -# def update( -# id: int = typer.Option(None, help="Hero ID"), -# name: str = typer.Option(None, help="Hero name"), -# secret_name: str = typer.Option(None, help="Hero secret name"), -# age: int = typer.Option(None, help="Hero age"), -# ): -# hero = Hero.interactive_update(id=id) -# console.print(hero) - - -# @hero_app.command() -# def create_heroes(): -# team_1 = Team.get(id=1) -# if not team_1: -# team_1 = Team(name="Team 1", headquarters="Headquarters 1").create() -# for _ in range(50): -# Hero(name="Deadpond", secret_name="Dive Wilson", team_id=team_1.id).create() -# Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() -# Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() diff --git a/sqlmodel_base/hero/models.py b/sqlmodel_base/hero/models.py deleted file mode 100644 index 1a34f17..0000000 --- a/sqlmodel_base/hero/models.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Optional - -from pydantic import validator -from rich.console import Console -from sqlmodel import Field - -from sqlmodel_base.base import Base - -console = Console() - - -class HeroBase(Base): - name: str - secret_name: str - age: Optional[int] = None - team_id: Optional[int] = Field(default=None, foreign_key="team.id") - - @validator("age") - def validate_age(cls, v): - if v is None: - return v - if v > 0: - return v - return abs(v) - - -class Hero(HeroBase, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - - -class HeroCreate(HeroBase): - __table_class__ = Hero - pass - - -class HeroRead(HeroBase): - __table_class__ = Hero - id: int - - -class HeroUpdate(Base, table=False): - __table_class__ = Hero - name: Optional[str] - secret_name: Optional[str] - age: Optional[int] - team_id: Optional[int] - - -if __name__ == "__main__": - Hero.cli() diff --git a/sqlmodel_base/team/models.py b/sqlmodel_base/team/models.py deleted file mode 100644 index 791abe3..0000000 --- a/sqlmodel_base/team/models.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Optional - -from sqlmodel import Field - -from sqlmodel_base.base import Base - - -class Team(Base, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str = Field(index=True) - headquarters: str