This commit is contained in:
Waylon Walker 2025-11-22 21:59:30 -06:00
parent cd33982985
commit 9d6d509618
3 changed files with 106 additions and 35 deletions

View file

@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: Implementation :: PyPy",
] ]
dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn'] dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx']
[project.urls] [project.urls]
Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme"

View file

@ -1,8 +1,9 @@
from typing import Optional from typing import List, Optional
import httpx
import typer import typer
import uvicorn import uvicorn
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, Depends, FastAPI
from iterfzf import iterfzf from iterfzf import iterfzf
from pydantic import BaseModel from pydantic import BaseModel
from pydantic_core._pydantic_core import PydanticUndefinedType from pydantic_core._pydantic_core import PydanticUndefinedType
@ -10,7 +11,7 @@ from rich.console import Console
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import Session, SQLModel, select from sqlmodel import Session, SQLModel, select
from sqlmodel_base.database import get_engine from sqlmodel_base.database import get_engine, get_session
console = Console() console = Console()
@ -30,13 +31,19 @@ class Base(SQLModel):
engine = get_engine() engine = get_engine()
return engine return engine
def create(self): def create(self, session: Optional[Session] = Depends(get_session)):
with Session(self.engine) as session: if isinstance(session, Session):
validated = self.model_validate(self) validated = self.model_validate(self)
session.add(self.sqlmodel_update(validated)) session.add(self.sqlmodel_update(validated))
session.commit() session.commit()
session.refresh(self) session.refresh(self)
return self return self
else:
response = httpx.post(
"http://localhost:8000/create/", json=self.model_dump_json()
)
breakpoint()
return response
@classmethod @classmethod
def interactive_create(cls, id: Optional[int] = None): def interactive_create(cls, id: Optional[int] = None):
@ -67,7 +74,9 @@ class Base(SQLModel):
@classmethod @classmethod
def get(cls, id: int): def get(cls, id: int):
with Session(cls.engine) as session: with Session(cls.engine) as session:
return session.get(cls, id) if hasattr(cls, "__table_class__"):
return session.get(cls.__table_class__, id)
return cls.model_validate(session.get(cls, id))
@classmethod @classmethod
def get_or_pick(cls, id: Optional[int] = None): def get_or_pick(cls, id: Optional[int] = None):
@ -76,29 +85,40 @@ class Base(SQLModel):
return cls.get(id=id) return cls.get(id=id)
@classmethod @classmethod
def all(cls): def all(cls) -> List:
with Session(cls.engine) as session: with Session(cls.engine) as session:
return session.exec(select(cls)).all() if hasattr(cls, "__table_class__"):
return session.exec(select(cls.__table_class__)).all()
return [cls.model_validate(i) for i in session.exec(select(cls)).all()]
@classmethod @classmethod
def count(cls): def count(cls) -> int:
with Session(cls.engine) as session: with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
return session.exec(func.count(cls.__table_class__.id)).scalar()
return session.exec(func.count(cls.id)).scalar() return session.exec(func.count(cls.id)).scalar()
@classmethod @classmethod
def first(cls): def first(cls):
with Session(cls.engine) as session: with Session(cls.engine) as session:
return session.exec(select(cls).limit(1)).first() if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
return cls.model_validate(
session.exec(select(table).order_by(table.id.asc()).limit(1)).first()
)
@classmethod @classmethod
def last(cls): def last(cls):
with Session(cls.engine) as session: with Session(cls.engine) as session:
return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() if hasattr(cls, "__table_class__"):
table = cls.__table_class__
@classmethod else:
def random(cls): table = cls
with Session(cls.engine) as session: return session.exec(
return session.exec(select(cls).order_by(cls.id).limit(1)).first() select(table).order_by(table.id.desc()).limit(1)
).first()
@classmethod @classmethod
def get_page( def get_page(
@ -109,29 +129,30 @@ class Base(SQLModel):
reverse: bool = False, reverse: bool = False,
): ):
with Session(cls.engine) as session: with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
if all: if all:
items = session.exec(select(cls)).all() items = session.exec(select(table)).all()
page_size = len(items) page_size = len(items)
else: else:
if reverse: if reverse:
items = session.exec( items = session.exec(
select(cls) select(table)
.offset((page - 1) * page_size) .offset((page - 1) * page_size)
.limit(page_size) .limit(page_size)
.order_by(cls.id.desc()) .order_by(table.id.desc())
).all() ).all()
else: else:
items = session.exec( items = session.exec(
select(cls) select(table)
.offset((page - 1) * page_size) .offset((page - 1) * page_size)
.limit(page_size) .limit(page_size)
.order_by(cls.id) .order_by(table.id)
).all() ).all()
# items = session.exec(
# select(cls).offset((page - 1) * page_size).limit(page_size)
# ).all()
total = cls.count() total = table.count()
# determine if there is a next page # determine if there is a next page
if page * page_size < total: if page * page_size < total:
next_page = page + 1 next_page = page + 1
@ -167,6 +188,8 @@ class Base(SQLModel):
console.print("No item selected") console.print("No item selected")
return return
for field in item.__fields__.keys(): for field in item.__fields__.keys():
if field == "id":
continue
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None") value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
if ( if (
value value
@ -183,9 +206,9 @@ class Base(SQLModel):
api = FastAPI( api = FastAPI(
title="FastAPI", title="FastAPI",
version="0.1.0", version="0.1.0",
docs_url=None, # docs_url=None,
redoc_url=None, # redoc_url=None,
openapi_url=None, # openapi_url=None,
# openapi_tags=tags_metadata, # openapi_tags=tags_metadata,
# dependencies=[Depends(set_user), Depends(set_prefers)], # dependencies=[Depends(set_user), Depends(set_prefers)],
) )
@ -197,10 +220,33 @@ class Base(SQLModel):
@classmethod @classmethod
def router(cls): def router(cls):
router = APIRouter() router = APIRouter()
router.add_api_route("/get/", cls.get, methods=["GET"]) # router.add_api_route("/get/", cls.get, methods=["GET"])
router.add_api_route("/list/", cls.all, methods=["GET"]) # router.add_api_route("/list/", cls.all, methods=["GET"])
router.add_api_route("/create/", cls.interactive_create, methods=["POST"]) # router.add_api_route("/create/", cls.create, methods=["POST"])
router.add_api_route("/update/", cls.interactive_update, methods=["PUT"]) # router.add_api_route("/update/", cls.interactive_update, methods=["PUT"])
@router.get("/")
def get(id: int) -> cls:
return cls.get(id=id)
@router.get("/list", include_in_schema=False)
@router.get("/list/")
def get_page(
page: int = 1,
page_size: int = 20,
all: bool = False,
reverse: bool = False,
) -> PagedResult:
return cls.get_page()
@router.post("/create")
def create(cls: cls) -> cls:
return cls.create()
@router.put("/update")
def update() -> cls:
return cls.update()
return router return router
@classmethod @classmethod

View file

@ -9,8 +9,7 @@ from sqlmodel_base.base import Base
console = Console() console = Console()
class Hero(Base, table=True): class HeroBase(Base):
id: Optional[int] = Field(default=None, primary_key=True)
name: str name: str
secret_name: str secret_name: str
age: Optional[int] = None age: Optional[int] = None
@ -23,3 +22,29 @@ class Hero(Base, table=True):
if v > 0: if v > 0:
return v return v
return abs(v) return abs(v)
class Hero(HeroBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class HeroCreate(HeroBase):
__table_class__ = Hero
pass
class HeroRead(HeroBase):
__table_class__ = Hero
id: int
class HeroUpdate(Base, table=False):
__table_class__ = Hero
name: Optional[str]
secret_name: Optional[str]
age: Optional[int]
team_id: Optional[int]
if __name__ == "__main__":
Hero.cli()