This commit is contained in:
Waylon Walker 2025-11-22 21:59:30 -06:00
parent cd33982985
commit 9d6d509618
3 changed files with 106 additions and 35 deletions

View file

@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn']
dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx']
[project.urls]
Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme"

View file

@ -1,8 +1,9 @@
from typing import Optional
from typing import List, Optional
import httpx
import typer
import uvicorn
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter, Depends, FastAPI
from iterfzf import iterfzf
from pydantic import BaseModel
from pydantic_core._pydantic_core import PydanticUndefinedType
@ -10,7 +11,7 @@ from rich.console import Console
from sqlalchemy import func
from sqlmodel import Session, SQLModel, select
from sqlmodel_base.database import get_engine
from sqlmodel_base.database import get_engine, get_session
console = Console()
@ -30,13 +31,19 @@ class Base(SQLModel):
engine = get_engine()
return engine
def create(self):
with Session(self.engine) as session:
def create(self, session: Optional[Session] = Depends(get_session)):
if isinstance(session, Session):
validated = self.model_validate(self)
session.add(self.sqlmodel_update(validated))
session.commit()
session.refresh(self)
return self
else:
response = httpx.post(
"http://localhost:8000/create/", json=self.model_dump_json()
)
breakpoint()
return response
@classmethod
def interactive_create(cls, id: Optional[int] = None):
@ -67,7 +74,9 @@ class Base(SQLModel):
@classmethod
def get(cls, id: int):
with Session(cls.engine) as session:
return session.get(cls, id)
if hasattr(cls, "__table_class__"):
return session.get(cls.__table_class__, id)
return cls.model_validate(session.get(cls, id))
@classmethod
def get_or_pick(cls, id: Optional[int] = None):
@ -76,29 +85,40 @@ class Base(SQLModel):
return cls.get(id=id)
@classmethod
def all(cls):
def all(cls) -> List:
with Session(cls.engine) as session:
return session.exec(select(cls)).all()
if hasattr(cls, "__table_class__"):
return session.exec(select(cls.__table_class__)).all()
return [cls.model_validate(i) for i in session.exec(select(cls)).all()]
@classmethod
def count(cls):
def count(cls) -> int:
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
return session.exec(func.count(cls.__table_class__.id)).scalar()
return session.exec(func.count(cls.id)).scalar()
@classmethod
def first(cls):
with Session(cls.engine) as session:
return session.exec(select(cls).limit(1)).first()
if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
return cls.model_validate(
session.exec(select(table).order_by(table.id.asc()).limit(1)).first()
)
@classmethod
def last(cls):
with Session(cls.engine) as session:
return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first()
@classmethod
def random(cls):
with Session(cls.engine) as session:
return session.exec(select(cls).order_by(cls.id).limit(1)).first()
if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
return session.exec(
select(table).order_by(table.id.desc()).limit(1)
).first()
@classmethod
def get_page(
@ -109,29 +129,30 @@ class Base(SQLModel):
reverse: bool = False,
):
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
if all:
items = session.exec(select(cls)).all()
items = session.exec(select(table)).all()
page_size = len(items)
else:
if reverse:
items = session.exec(
select(cls)
select(table)
.offset((page - 1) * page_size)
.limit(page_size)
.order_by(cls.id.desc())
.order_by(table.id.desc())
).all()
else:
items = session.exec(
select(cls)
select(table)
.offset((page - 1) * page_size)
.limit(page_size)
.order_by(cls.id)
.order_by(table.id)
).all()
# items = session.exec(
# select(cls).offset((page - 1) * page_size).limit(page_size)
# ).all()
total = cls.count()
total = table.count()
# determine if there is a next page
if page * page_size < total:
next_page = page + 1
@ -167,6 +188,8 @@ class Base(SQLModel):
console.print("No item selected")
return
for field in item.__fields__.keys():
if field == "id":
continue
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
if (
value
@ -183,9 +206,9 @@ class Base(SQLModel):
api = FastAPI(
title="FastAPI",
version="0.1.0",
docs_url=None,
redoc_url=None,
openapi_url=None,
# docs_url=None,
# redoc_url=None,
# openapi_url=None,
# openapi_tags=tags_metadata,
# dependencies=[Depends(set_user), Depends(set_prefers)],
)
@ -197,10 +220,33 @@ class Base(SQLModel):
@classmethod
def router(cls):
router = APIRouter()
router.add_api_route("/get/", cls.get, methods=["GET"])
router.add_api_route("/list/", cls.all, methods=["GET"])
router.add_api_route("/create/", cls.interactive_create, methods=["POST"])
router.add_api_route("/update/", cls.interactive_update, methods=["PUT"])
# router.add_api_route("/get/", cls.get, methods=["GET"])
# router.add_api_route("/list/", cls.all, methods=["GET"])
# router.add_api_route("/create/", cls.create, methods=["POST"])
# router.add_api_route("/update/", cls.interactive_update, methods=["PUT"])
@router.get("/")
def get(id: int) -> cls:
return cls.get(id=id)
@router.get("/list", include_in_schema=False)
@router.get("/list/")
def get_page(
page: int = 1,
page_size: int = 20,
all: bool = False,
reverse: bool = False,
) -> PagedResult:
return cls.get_page()
@router.post("/create")
def create(cls: cls) -> cls:
return cls.create()
@router.put("/update")
def update() -> cls:
return cls.update()
return router
@classmethod

View file

@ -9,8 +9,7 @@ from sqlmodel_base.base import Base
console = Console()
class Hero(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class HeroBase(Base):
name: str
secret_name: str
age: Optional[int] = None
@ -23,3 +22,29 @@ class Hero(Base, table=True):
if v > 0:
return v
return abs(v)
class Hero(HeroBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class HeroCreate(HeroBase):
__table_class__ = Hero
pass
class HeroRead(HeroBase):
__table_class__ = Hero
id: int
class HeroUpdate(Base, table=False):
__table_class__ = Hero
name: Optional[str]
secret_name: Optional[str]
age: Optional[int]
team_id: Optional[int]
if __name__ == "__main__":
Hero.cli()