From 85554e21697dd3e8ec6c0fddf92ea643030fec0a Mon Sep 17 00:00:00 2001 From: "Waylon S. Walker" Date: Thu, 29 Feb 2024 15:05:57 -0600 Subject: [PATCH 1/6] wip --- pyproject.toml | 6 +++--- sqlmodel_base/base.py | 7 ------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e567efa..ebb2e6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,9 @@ classifiers = [ dependencies = [ 'rich', 'sqlmodel', 'typer' ] [project.urls] -Documentation = "https://github.com/unknown/sqlmodel-base#readme" -Issues = "https://github.com/unknown/sqlmodel-base/issues" -Source = "https://github.com/unknown/sqlmodel-base" +Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" +Issues = "https://github.com/waylonwalker/sqlmodel-base/issues" +Source = "https://github.com/waylonwalker/sqlmodel-base" [tool.hatch.version] path = "sqlmodel_base/__about__.py" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index f12571a..054ee62 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -127,13 +127,6 @@ def create_heroes(): hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() - # with Session(engine) as session: - # session.add(hero_1) - # session.add(hero_2) - # session.add(hero_3) - # - # session.commit() - def page_heroes(): next_page = 1 From a21dbb08d4933a61f606e908500a25a611e8b0c5 Mon Sep 17 00:00:00 2001 From: "Waylon S. Walker" Date: Fri, 1 Mar 2024 07:10:57 -0600 Subject: [PATCH 2/6] add cli --- .gitignore | 1 + pyproject.toml | 5 +- sqlmodel_base/base.py | 180 +++++++++++++++++++++++++++-------- sqlmodel_base/cli.py | 11 +++ sqlmodel_base/database.py | 21 ++++ sqlmodel_base/hero/cli.py | 72 ++++++++++++++ sqlmodel_base/hero/models.py | 26 +++++ sqlmodel_base/team/models.py | 12 +++ 8 files changed, 285 insertions(+), 43 deletions(-) create mode 100644 sqlmodel_base/cli.py create mode 100644 sqlmodel_base/database.py create mode 100644 sqlmodel_base/hero/cli.py create mode 100644 sqlmodel_base/hero/models.py create mode 100644 sqlmodel_base/team/models.py diff --git a/.gitignore b/.gitignore index e1a3186..9d6f21d 100644 --- a/.gitignore +++ b/.gitignore @@ -962,3 +962,4 @@ FodyWeavers.xsd # Additional files built by Visual Studio # End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode +database.db diff --git a/pyproject.toml b/pyproject.toml index ebb2e6e..cecd388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,16 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ 'rich', 'sqlmodel', 'typer' ] +dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf'] [project.urls] Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" Issues = "https://github.com/waylonwalker/sqlmodel-base/issues" Source = "https://github.com/waylonwalker/sqlmodel-base" +[project.scripts] +sqlmodel-base = "sqlmodel_base.cli:app" + [tool.hatch.version] path = "sqlmodel_base/__about__.py" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index 054ee62..80f91fb 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -1,18 +1,19 @@ +import json from typing import Optional +import typer +from iterfzf import iterfzf from pydantic import BaseModel, validator +from pydantic_core._pydantic_core import PydanticUndefinedType from rich.console import Console from sqlalchemy import func from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel_base.database import get_engine + console = Console() -def get_session(): - with Session(engine) as session: - yield session - - class PagedResult(BaseModel): items: list total: int @@ -22,8 +23,14 @@ class PagedResult(BaseModel): class Base(SQLModel): + @classmethod + @property + def engine(self): + engine = get_engine() + return engine + def create(self): - with Session(engine) as session: + with Session(self.engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() @@ -31,42 +38,99 @@ class Base(SQLModel): return self @classmethod - def get(cls, id): - with Session(engine) as session: + def interactive_create(cls, id: Optional[int] = None): + data = {} + for name, field in cls.__fields__.items(): + default = field.default + if ( + default is None or isinstance(default, PydanticUndefinedType) + ) and not field.is_required(): + default = "None" + if (isinstance(default, PydanticUndefinedType)) and field.is_required(): + default = None + value = typer.prompt(f"{name}: ", default=default) + if value and value != "" and value != "None": + data[name] = value + item = cls(**data).create() + console.print(item) + + @classmethod + def pick(cls): + all = cls.all() + item = iterfzf([item.model_dump_json() for item in all]) + if not item: + console.print("No item selected") + return + return cls.get(cls.parse_raw(item).id) + + @classmethod + def get(cls, id: int): + with Session(cls.engine) as session: return session.get(cls, id) @classmethod - def get_all(cls): - with Session(engine) as session: + def get_or_pick(cls, id: Optional[int] = None): + if id is None: + return cls.pick() + return cls.get(id=id) + + @classmethod + def all(cls): + with Session(cls.engine) as session: return session.exec(select(cls)).all() @classmethod - def get_count(cls): - with Session(engine) as session: - return session.exec(func.count(Hero.id)).scalar() + def count(cls): + with Session(cls.engine) as session: + return session.exec(func.count(cls.id)).scalar() @classmethod - def get_first(cls): - with Session(engine) as session: + def first(cls): + with Session(cls.engine) as session: return session.exec(select(cls).limit(1)).first() @classmethod - def get_last(cls): - with Session(engine) as session: + def last(cls): + with Session(cls.engine) as session: return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() @classmethod - def get_random(cls): - with Session(engine) as session: + def random(cls): + with Session(cls.engine) as session: return session.exec(select(cls).order_by(cls.id).limit(1)).first() @classmethod - def get_page(cls, page: int = 1, page_size: int = 20): - with Session(engine) as session: - items = session.exec( - select(cls).offset((page - 1) * page_size).limit(page_size) - ).all() - total = cls.get_count() + def get_page( + cls, + page: int = 1, + page_size: int = 20, + all: bool = False, + reverse: bool = False, + ): + with Session(cls.engine) as session: + if all: + items = session.exec(select(cls)).all() + page_size = len(items) + else: + if reverse: + items = session.exec( + select(cls) + .offset((page - 1) * page_size) + .limit(page_size) + .order_by(cls.id.desc()) + ).all() + else: + items = session.exec( + select(cls) + .offset((page - 1) * page_size) + .limit(page_size) + .order_by(cls.id) + ).all() + # items = session.exec( + # select(cls).offset((page - 1) * page_size).limit(page_size) + # ).all() + + total = cls.count() # determine if there is a next page if page * page_size < total: next_page = page + 1 @@ -82,39 +146,71 @@ class Base(SQLModel): ) def delete(self): - with Session(engine) as session: + with Session(self.engine) as session: session.delete(self) session.commit() return self def update(self): - with Session(engine) as session: + with Session(self.engine) as session: validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self + @classmethod + def interactive_update(cls, id: Optional[int] = None): + item = cls.get_or_pick(id=id) + if not item: + console.print("No item selected") + return + for field in item.__fields__.keys(): + value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") + if ( + value + and value != "" + and value != "None" + and value != getattr(item, field) + ): + setattr(item, field, value) + item.update() + console.print(item) -class Hero(Base, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str - secret_name: str - age: Optional[int] = None + @classmethod + @property + def cli(cls): + app = typer.Typer() - @validator("age") - def validate_age(cls, v): - if v is None: - return v - if v > 0: - return v - return abs(v) + @app.command() + def get(id: int = typer.Option(None, help="Hero ID")): + console.print(cls.get_or_pick(id=id)) + @app.command() + def create(): + console.print(cls.interactive_create()) -sqlite_file_name = "database.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" + @app.command() + def list( + page: int = typer.Option(1, help="Page number"), + page_size: int = typer.Option(20, help="Page size"), + all: bool = typer.Option(False, help="Show all heroes"), + reverse: bool = typer.Option(False, help="Reverse order"), + ): + console.print( + cls.get_page( + page=page, + page_size=page_size, + all=all, + reverse=reverse, + ) + ) -engine = create_engine(sqlite_url) # , echo=True) + @app.command() + def update(): + console.print(cls.interactive_update()) + + return app # replace with alembic commands diff --git a/sqlmodel_base/cli.py b/sqlmodel_base/cli.py new file mode 100644 index 0000000..13e682b --- /dev/null +++ b/sqlmodel_base/cli.py @@ -0,0 +1,11 @@ +import typer + +from sqlmodel_base.hero.cli import hero_app + +app = typer.Typer() + +app.add_typer(hero_app, name="hero") + + +if __name__ == "__main__": + app() diff --git a/sqlmodel_base/database.py b/sqlmodel_base/database.py new file mode 100644 index 0000000..90ceedb --- /dev/null +++ b/sqlmodel_base/database.py @@ -0,0 +1,21 @@ +from functools import lru_cache + +from sqlmodel import Field, Session, SQLModel, create_engine, select + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + + +@lru_cache +def get_engine(): + from sqlmodel_base.hero.models import Hero + from sqlmodel_base.team.models import Team + + engine = create_engine(sqlite_url) + SQLModel.metadata.create_all(engine) + return engine + + +def get_session(): + with Session(get_engine()) as session: + yield session diff --git a/sqlmodel_base/hero/cli.py b/sqlmodel_base/hero/cli.py new file mode 100644 index 0000000..458d319 --- /dev/null +++ b/sqlmodel_base/hero/cli.py @@ -0,0 +1,72 @@ +import json + +import typer +from iterfzf import iterfzf +from rich.console import Console + +from sqlmodel_base.database import get_engine +from sqlmodel_base.hero.models import Hero +from sqlmodel_base.team.models import Team + +engine = get_engine() + +hero_app = Hero.cli +console = Console() + + +# @hero_app.callback() +# def hero(): +# "model cli" + + +# @hero_app.command() +# def get(id: int = typer.Option(None, help="Hero ID")): +# console.print(Hero.get_or_pick(id=id)) + + +# @hero_app.command() +# def list( +# page: int = typer.Option(1, help="Page number"), +# page_size: int = typer.Option(20, help="Page size"), +# all: bool = typer.Option(False, help="Show all heroes"), +# reverse: bool = typer.Option(False, help="Reverse order"), +# ): +# console.print( +# Hero.get_page(page=page, page_size=page_size, all=all, reverse=reverse) +# ) + + +# @hero_app.command() +# def create( +# name: str = typer.Option(..., help="Hero name", prompt=True), +# secret_name: str = typer.Option(..., help="Hero secret name", prompt=True), +# age: int = typer.Option(None, help="Hero age", prompt=True), +# ): +# hero = Hero( +# name=name, +# secret_name=secret_name, +# age=age, +# ).create() +# console.print(hero) + + +# @hero_app.command() +# def update( +# id: int = typer.Option(None, help="Hero ID"), +# name: str = typer.Option(None, help="Hero name"), +# secret_name: str = typer.Option(None, help="Hero secret name"), +# age: int = typer.Option(None, help="Hero age"), +# ): +# hero = Hero.interactive_update(id=id) +# console.print(hero) + + +# @hero_app.command() +# def create_heroes(): +# team_1 = Team.get(id=1) +# if not team_1: +# team_1 = Team(name="Team 1", headquarters="Headquarters 1").create() +# for _ in range(50): +# Hero(name="Deadpond", secret_name="Dive Wilson", team_id=team_1.id).create() +# Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() +# Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() diff --git a/sqlmodel_base/hero/models.py b/sqlmodel_base/hero/models.py new file mode 100644 index 0000000..fa90c87 --- /dev/null +++ b/sqlmodel_base/hero/models.py @@ -0,0 +1,26 @@ +from typing import Optional + +from pydantic import BaseModel, validator +from rich.console import Console +from sqlalchemy import func +from sqlmodel import Field, Session, SQLModel, create_engine, select + +from sqlmodel_base.base import Base + +console = Console() + + +class Hero(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + + @validator("age") + def validate_age(cls, v): + if v is None: + return v + if v > 0: + return v + return abs(v) diff --git a/sqlmodel_base/team/models.py b/sqlmodel_base/team/models.py new file mode 100644 index 0000000..e57384e --- /dev/null +++ b/sqlmodel_base/team/models.py @@ -0,0 +1,12 @@ +from typing import Optional + +from rich.console import Console +from sqlmodel import Field + +from sqlmodel_base.base import Base + + +class Team(Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + headquarters: str From f42e8e4807de356c66fa3cf27648c8a784529d0e Mon Sep 17 00:00:00 2001 From: "Waylon S. Walker" Date: Fri, 1 Mar 2024 10:23:13 -0600 Subject: [PATCH 3/6] create fastapi --- pyproject.toml | 2 +- sqlmodel_base/base.py | 67 +++++++++++++++++++----------------- sqlmodel_base/database.py | 4 +-- sqlmodel_base/hero/cli.py | 4 --- sqlmodel_base/hero/models.py | 5 ++- sqlmodel_base/team/models.py | 1 - 6 files changed, 40 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cecd388..5a88aa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf'] +dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn'] [project.urls] Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index 80f91fb..75d9ecb 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -1,13 +1,14 @@ -import json from typing import Optional import typer +import uvicorn +from fastapi import APIRouter, FastAPI from iterfzf import iterfzf -from pydantic import BaseModel, validator +from pydantic import BaseModel from pydantic_core._pydantic_core import PydanticUndefinedType from rich.console import Console from sqlalchemy import func -from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel import Session, SQLModel, select from sqlmodel_base.database import get_engine @@ -177,6 +178,31 @@ class Base(SQLModel): item.update() console.print(item) + @classmethod + def api(cls): + api = FastAPI( + title="FastAPI", + version="0.1.0", + docs_url=None, + redoc_url=None, + openapi_url=None, + # openapi_tags=tags_metadata, + # dependencies=[Depends(set_user), Depends(set_prefers)], + ) + + api.include_router(cls.router()) + + return api + + @classmethod + def router(cls): + router = APIRouter() + router.add_api_route("/get/", cls.get, methods=["GET"]) + router.add_api_route("/list/", cls.all, methods=["GET"]) + router.add_api_route("/create/", cls.interactive_create, methods=["POST"]) + router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) + return router + @classmethod @property def cli(cls): @@ -206,37 +232,16 @@ class Base(SQLModel): ) ) + @app.command() + def api(): + cls.run_api() + @app.command() def update(): console.print(cls.interactive_update()) return app - -# replace with alembic commands -def create_db_and_tables(): - SQLModel.metadata.create_all(engine) - - -def create_heroes(): - hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson").create() - hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create() - hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create() - - -def page_heroes(): - next_page = 1 - while next_page: - page = Hero.get_page(page=next_page, page_size=2) - console.print(page) - next_page = page.next_page - - -def main(): - create_db_and_tables() - create_heroes() - page_heroes() - - -if __name__ == "__main__": - main() + @classmethod + def run_api(cls): + uvicorn.run(cls.api(), host="127.0.0.1", port=8000) diff --git a/sqlmodel_base/database.py b/sqlmodel_base/database.py index 90ceedb..96e1617 100644 --- a/sqlmodel_base/database.py +++ b/sqlmodel_base/database.py @@ -1,6 +1,6 @@ from functools import lru_cache -from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel import Session, SQLModel, create_engine sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" @@ -8,8 +8,6 @@ sqlite_url = f"sqlite:///{sqlite_file_name}" @lru_cache def get_engine(): - from sqlmodel_base.hero.models import Hero - from sqlmodel_base.team.models import Team engine = create_engine(sqlite_url) SQLModel.metadata.create_all(engine) diff --git a/sqlmodel_base/hero/cli.py b/sqlmodel_base/hero/cli.py index 458d319..123b337 100644 --- a/sqlmodel_base/hero/cli.py +++ b/sqlmodel_base/hero/cli.py @@ -1,12 +1,8 @@ -import json -import typer -from iterfzf import iterfzf from rich.console import Console from sqlmodel_base.database import get_engine from sqlmodel_base.hero.models import Hero -from sqlmodel_base.team.models import Team engine = get_engine() diff --git a/sqlmodel_base/hero/models.py b/sqlmodel_base/hero/models.py index fa90c87..2f399eb 100644 --- a/sqlmodel_base/hero/models.py +++ b/sqlmodel_base/hero/models.py @@ -1,9 +1,8 @@ from typing import Optional -from pydantic import BaseModel, validator +from pydantic import validator from rich.console import Console -from sqlalchemy import func -from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel import Field from sqlmodel_base.base import Base diff --git a/sqlmodel_base/team/models.py b/sqlmodel_base/team/models.py index e57384e..791abe3 100644 --- a/sqlmodel_base/team/models.py +++ b/sqlmodel_base/team/models.py @@ -1,6 +1,5 @@ from typing import Optional -from rich.console import Console from sqlmodel import Field from sqlmodel_base.base import Base From 841723103cfcd8e8353f1beb653c5a2193f53cfb Mon Sep 17 00:00:00 2001 From: "Waylon S. Walker" Date: Fri, 1 Mar 2024 10:25:48 -0600 Subject: [PATCH 4/6] format --- sqlmodel_base/database.py | 1 - sqlmodel_base/hero/cli.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/sqlmodel_base/database.py b/sqlmodel_base/database.py index 96e1617..13f8dbc 100644 --- a/sqlmodel_base/database.py +++ b/sqlmodel_base/database.py @@ -8,7 +8,6 @@ sqlite_url = f"sqlite:///{sqlite_file_name}" @lru_cache def get_engine(): - engine = create_engine(sqlite_url) SQLModel.metadata.create_all(engine) return engine diff --git a/sqlmodel_base/hero/cli.py b/sqlmodel_base/hero/cli.py index 123b337..784edef 100644 --- a/sqlmodel_base/hero/cli.py +++ b/sqlmodel_base/hero/cli.py @@ -1,4 +1,3 @@ - from rich.console import Console from sqlmodel_base.database import get_engine @@ -8,8 +7,6 @@ engine = get_engine() hero_app = Hero.cli console = Console() - - # @hero_app.callback() # def hero(): # "model cli" From cd3398298541df4f8a3ff404054a0ebd0ad0b6ac Mon Sep 17 00:00:00 2001 From: "Waylon S. Walker" Date: Fri, 1 Mar 2024 10:26:10 -0600 Subject: [PATCH 5/6] remove the database --- database.db | Bin 8192 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 database.db diff --git a/database.db b/database.db deleted file mode 100644 index 293c794720b4e76d6c7ad0904e61698c8aea302e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8192 zcmeI#TWAzl7{Ku}*|}d+O`2pi=EP`C+9Vp&+IY_&#$RbV9oio+dlUWwIPVO&J-sn^gzt=x@cMMBL2s~}7c3LGCOEI1pPf19w3BsQr_EVr|7}0B z@7Y)F@9b&2-)^<5tasKE>qqO7b<#Ru?Xa4yGV@RKk$Ky^U>-Hc&A8cMI>ztDuf|Q| zoN?IjjCNy%q3f^ohx+&W8U2vHSKp}D=#utayRTi>PH6|VVQrlj)tLHJ{Yjluzg4GH zS8Y+tmA{n7${pp3azdFWbS|6fN+0t1 zaC)kc%SOLJ*fhmXg6sssj#I1;vMGe!PqAYl>mlsd6gvvCUm@&1irovcUn1-m6gvX4 zdk}V*VuwI>H^L?-AS=KK(+^AyD4@%$ZkW}IK_5> zY$w8YP;3lj+Yxpv#cl!F%?P`RVmE^9283Ntv27sxIl``^*v~-rQ-p1$*tHUU`Hh^qB!mgs&k3e=M!mgm$I*?tCu(cG+LAC~Amr-mr$Sy_L zB@`P4*(!vsq}U3O{SaZxDYgt`7bEN Date: Sat, 22 Nov 2025 21:59:30 -0600 Subject: [PATCH 6/6] wip --- pyproject.toml | 2 +- sqlmodel_base/base.py | 110 +++++++++++++++++++++++++---------- sqlmodel_base/hero/models.py | 29 ++++++++- 3 files changed, 106 insertions(+), 35 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a88aa7..9393705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn'] +dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx'] [project.urls] Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" diff --git a/sqlmodel_base/base.py b/sqlmodel_base/base.py index 75d9ecb..c3d1ee6 100644 --- a/sqlmodel_base/base.py +++ b/sqlmodel_base/base.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import List, Optional +import httpx import typer import uvicorn -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter, Depends, FastAPI from iterfzf import iterfzf from pydantic import BaseModel from pydantic_core._pydantic_core import PydanticUndefinedType @@ -10,7 +11,7 @@ from rich.console import Console from sqlalchemy import func from sqlmodel import Session, SQLModel, select -from sqlmodel_base.database import get_engine +from sqlmodel_base.database import get_engine, get_session console = Console() @@ -30,13 +31,19 @@ class Base(SQLModel): engine = get_engine() return engine - def create(self): - with Session(self.engine) as session: + def create(self, session: Optional[Session] = Depends(get_session)): + if isinstance(session, Session): validated = self.model_validate(self) session.add(self.sqlmodel_update(validated)) session.commit() session.refresh(self) return self + else: + response = httpx.post( + "http://localhost:8000/create/", json=self.model_dump_json() + ) + breakpoint() + return response @classmethod def interactive_create(cls, id: Optional[int] = None): @@ -67,7 +74,9 @@ class Base(SQLModel): @classmethod def get(cls, id: int): with Session(cls.engine) as session: - return session.get(cls, id) + if hasattr(cls, "__table_class__"): + return session.get(cls.__table_class__, id) + return cls.model_validate(session.get(cls, id)) @classmethod def get_or_pick(cls, id: Optional[int] = None): @@ -76,29 +85,40 @@ class Base(SQLModel): return cls.get(id=id) @classmethod - def all(cls): + def all(cls) -> List: with Session(cls.engine) as session: - return session.exec(select(cls)).all() + if hasattr(cls, "__table_class__"): + return session.exec(select(cls.__table_class__)).all() + return [cls.model_validate(i) for i in session.exec(select(cls)).all()] @classmethod - def count(cls): + def count(cls) -> int: with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + return session.exec(func.count(cls.__table_class__.id)).scalar() return session.exec(func.count(cls.id)).scalar() @classmethod def first(cls): with Session(cls.engine) as session: - return session.exec(select(cls).limit(1)).first() + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls + return cls.model_validate( + session.exec(select(table).order_by(table.id.asc()).limit(1)).first() + ) @classmethod def last(cls): with Session(cls.engine) as session: - return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() - - @classmethod - def random(cls): - with Session(cls.engine) as session: - return session.exec(select(cls).order_by(cls.id).limit(1)).first() + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls + return session.exec( + select(table).order_by(table.id.desc()).limit(1) + ).first() @classmethod def get_page( @@ -109,29 +129,30 @@ class Base(SQLModel): reverse: bool = False, ): with Session(cls.engine) as session: + if hasattr(cls, "__table_class__"): + table = cls.__table_class__ + else: + table = cls if all: - items = session.exec(select(cls)).all() + items = session.exec(select(table)).all() page_size = len(items) else: if reverse: items = session.exec( - select(cls) + select(table) .offset((page - 1) * page_size) .limit(page_size) - .order_by(cls.id.desc()) + .order_by(table.id.desc()) ).all() else: items = session.exec( - select(cls) + select(table) .offset((page - 1) * page_size) .limit(page_size) - .order_by(cls.id) + .order_by(table.id) ).all() - # items = session.exec( - # select(cls).offset((page - 1) * page_size).limit(page_size) - # ).all() - total = cls.count() + total = table.count() # determine if there is a next page if page * page_size < total: next_page = page + 1 @@ -167,6 +188,8 @@ class Base(SQLModel): console.print("No item selected") return for field in item.__fields__.keys(): + if field == "id": + continue value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") if ( value @@ -183,9 +206,9 @@ class Base(SQLModel): api = FastAPI( title="FastAPI", version="0.1.0", - docs_url=None, - redoc_url=None, - openapi_url=None, + # docs_url=None, + # redoc_url=None, + # openapi_url=None, # openapi_tags=tags_metadata, # dependencies=[Depends(set_user), Depends(set_prefers)], ) @@ -197,10 +220,33 @@ class Base(SQLModel): @classmethod def router(cls): router = APIRouter() - router.add_api_route("/get/", cls.get, methods=["GET"]) - router.add_api_route("/list/", cls.all, methods=["GET"]) - router.add_api_route("/create/", cls.interactive_create, methods=["POST"]) - router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) + # router.add_api_route("/get/", cls.get, methods=["GET"]) + # router.add_api_route("/list/", cls.all, methods=["GET"]) + # router.add_api_route("/create/", cls.create, methods=["POST"]) + # router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) + + @router.get("/") + def get(id: int) -> cls: + return cls.get(id=id) + + @router.get("/list", include_in_schema=False) + @router.get("/list/") + def get_page( + page: int = 1, + page_size: int = 20, + all: bool = False, + reverse: bool = False, + ) -> PagedResult: + return cls.get_page() + + @router.post("/create") + def create(cls: cls) -> cls: + return cls.create() + + @router.put("/update") + def update() -> cls: + return cls.update() + return router @classmethod diff --git a/sqlmodel_base/hero/models.py b/sqlmodel_base/hero/models.py index 2f399eb..1a34f17 100644 --- a/sqlmodel_base/hero/models.py +++ b/sqlmodel_base/hero/models.py @@ -9,8 +9,7 @@ from sqlmodel_base.base import Base console = Console() -class Hero(Base, table=True): - id: Optional[int] = Field(default=None, primary_key=True) +class HeroBase(Base): name: str secret_name: str age: Optional[int] = None @@ -23,3 +22,29 @@ class Hero(Base, table=True): if v > 0: return v return abs(v) + + +class Hero(HeroBase, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class HeroCreate(HeroBase): + __table_class__ = Hero + pass + + +class HeroRead(HeroBase): + __table_class__ = Hero + id: int + + +class HeroUpdate(Base, table=False): + __table_class__ = Hero + name: Optional[str] + secret_name: Optional[str] + age: Optional[int] + team_id: Optional[int] + + +if __name__ == "__main__": + Hero.cli()