wip
This commit is contained in:
parent
cd33982985
commit
9d6d509618
3 changed files with 106 additions and 35 deletions
|
|
@ -24,7 +24,7 @@ classifiers = [
|
||||||
"Programming Language :: Python :: Implementation :: CPython",
|
"Programming Language :: Python :: Implementation :: CPython",
|
||||||
"Programming Language :: Python :: Implementation :: PyPy",
|
"Programming Language :: Python :: Implementation :: PyPy",
|
||||||
]
|
]
|
||||||
dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn']
|
dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx']
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme"
|
Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme"
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import APIRouter, FastAPI
|
from fastapi import APIRouter, Depends, FastAPI
|
||||||
from iterfzf import iterfzf
|
from iterfzf import iterfzf
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic_core._pydantic_core import PydanticUndefinedType
|
from pydantic_core._pydantic_core import PydanticUndefinedType
|
||||||
|
|
@ -10,7 +11,7 @@ from rich.console import Console
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlmodel import Session, SQLModel, select
|
from sqlmodel import Session, SQLModel, select
|
||||||
|
|
||||||
from sqlmodel_base.database import get_engine
|
from sqlmodel_base.database import get_engine, get_session
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -30,13 +31,19 @@ class Base(SQLModel):
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
def create(self):
|
def create(self, session: Optional[Session] = Depends(get_session)):
|
||||||
with Session(self.engine) as session:
|
if isinstance(session, Session):
|
||||||
validated = self.model_validate(self)
|
validated = self.model_validate(self)
|
||||||
session.add(self.sqlmodel_update(validated))
|
session.add(self.sqlmodel_update(validated))
|
||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(self)
|
session.refresh(self)
|
||||||
return self
|
return self
|
||||||
|
else:
|
||||||
|
response = httpx.post(
|
||||||
|
"http://localhost:8000/create/", json=self.model_dump_json()
|
||||||
|
)
|
||||||
|
breakpoint()
|
||||||
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def interactive_create(cls, id: Optional[int] = None):
|
def interactive_create(cls, id: Optional[int] = None):
|
||||||
|
|
@ -67,7 +74,9 @@ class Base(SQLModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, id: int):
|
def get(cls, id: int):
|
||||||
with Session(cls.engine) as session:
|
with Session(cls.engine) as session:
|
||||||
return session.get(cls, id)
|
if hasattr(cls, "__table_class__"):
|
||||||
|
return session.get(cls.__table_class__, id)
|
||||||
|
return cls.model_validate(session.get(cls, id))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_or_pick(cls, id: Optional[int] = None):
|
def get_or_pick(cls, id: Optional[int] = None):
|
||||||
|
|
@ -76,29 +85,40 @@ class Base(SQLModel):
|
||||||
return cls.get(id=id)
|
return cls.get(id=id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def all(cls):
|
def all(cls) -> List:
|
||||||
with Session(cls.engine) as session:
|
with Session(cls.engine) as session:
|
||||||
return session.exec(select(cls)).all()
|
if hasattr(cls, "__table_class__"):
|
||||||
|
return session.exec(select(cls.__table_class__)).all()
|
||||||
|
return [cls.model_validate(i) for i in session.exec(select(cls)).all()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def count(cls):
|
def count(cls) -> int:
|
||||||
with Session(cls.engine) as session:
|
with Session(cls.engine) as session:
|
||||||
|
if hasattr(cls, "__table_class__"):
|
||||||
|
return session.exec(func.count(cls.__table_class__.id)).scalar()
|
||||||
return session.exec(func.count(cls.id)).scalar()
|
return session.exec(func.count(cls.id)).scalar()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def first(cls):
|
def first(cls):
|
||||||
with Session(cls.engine) as session:
|
with Session(cls.engine) as session:
|
||||||
return session.exec(select(cls).limit(1)).first()
|
if hasattr(cls, "__table_class__"):
|
||||||
|
table = cls.__table_class__
|
||||||
|
else:
|
||||||
|
table = cls
|
||||||
|
return cls.model_validate(
|
||||||
|
session.exec(select(table).order_by(table.id.asc()).limit(1)).first()
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def last(cls):
|
def last(cls):
|
||||||
with Session(cls.engine) as session:
|
with Session(cls.engine) as session:
|
||||||
return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first()
|
if hasattr(cls, "__table_class__"):
|
||||||
|
table = cls.__table_class__
|
||||||
@classmethod
|
else:
|
||||||
def random(cls):
|
table = cls
|
||||||
with Session(cls.engine) as session:
|
return session.exec(
|
||||||
return session.exec(select(cls).order_by(cls.id).limit(1)).first()
|
select(table).order_by(table.id.desc()).limit(1)
|
||||||
|
).first()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_page(
|
def get_page(
|
||||||
|
|
@ -109,29 +129,30 @@ class Base(SQLModel):
|
||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
):
|
):
|
||||||
with Session(cls.engine) as session:
|
with Session(cls.engine) as session:
|
||||||
|
if hasattr(cls, "__table_class__"):
|
||||||
|
table = cls.__table_class__
|
||||||
|
else:
|
||||||
|
table = cls
|
||||||
if all:
|
if all:
|
||||||
items = session.exec(select(cls)).all()
|
items = session.exec(select(table)).all()
|
||||||
page_size = len(items)
|
page_size = len(items)
|
||||||
else:
|
else:
|
||||||
if reverse:
|
if reverse:
|
||||||
items = session.exec(
|
items = session.exec(
|
||||||
select(cls)
|
select(table)
|
||||||
.offset((page - 1) * page_size)
|
.offset((page - 1) * page_size)
|
||||||
.limit(page_size)
|
.limit(page_size)
|
||||||
.order_by(cls.id.desc())
|
.order_by(table.id.desc())
|
||||||
).all()
|
).all()
|
||||||
else:
|
else:
|
||||||
items = session.exec(
|
items = session.exec(
|
||||||
select(cls)
|
select(table)
|
||||||
.offset((page - 1) * page_size)
|
.offset((page - 1) * page_size)
|
||||||
.limit(page_size)
|
.limit(page_size)
|
||||||
.order_by(cls.id)
|
.order_by(table.id)
|
||||||
).all()
|
).all()
|
||||||
# items = session.exec(
|
|
||||||
# select(cls).offset((page - 1) * page_size).limit(page_size)
|
|
||||||
# ).all()
|
|
||||||
|
|
||||||
total = cls.count()
|
total = table.count()
|
||||||
# determine if there is a next page
|
# determine if there is a next page
|
||||||
if page * page_size < total:
|
if page * page_size < total:
|
||||||
next_page = page + 1
|
next_page = page + 1
|
||||||
|
|
@ -167,6 +188,8 @@ class Base(SQLModel):
|
||||||
console.print("No item selected")
|
console.print("No item selected")
|
||||||
return
|
return
|
||||||
for field in item.__fields__.keys():
|
for field in item.__fields__.keys():
|
||||||
|
if field == "id":
|
||||||
|
continue
|
||||||
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
|
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
|
||||||
if (
|
if (
|
||||||
value
|
value
|
||||||
|
|
@ -183,9 +206,9 @@ class Base(SQLModel):
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
title="FastAPI",
|
title="FastAPI",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
docs_url=None,
|
# docs_url=None,
|
||||||
redoc_url=None,
|
# redoc_url=None,
|
||||||
openapi_url=None,
|
# openapi_url=None,
|
||||||
# openapi_tags=tags_metadata,
|
# openapi_tags=tags_metadata,
|
||||||
# dependencies=[Depends(set_user), Depends(set_prefers)],
|
# dependencies=[Depends(set_user), Depends(set_prefers)],
|
||||||
)
|
)
|
||||||
|
|
@ -197,10 +220,33 @@ class Base(SQLModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def router(cls):
|
def router(cls):
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
router.add_api_route("/get/", cls.get, methods=["GET"])
|
# router.add_api_route("/get/", cls.get, methods=["GET"])
|
||||||
router.add_api_route("/list/", cls.all, methods=["GET"])
|
# router.add_api_route("/list/", cls.all, methods=["GET"])
|
||||||
router.add_api_route("/create/", cls.interactive_create, methods=["POST"])
|
# router.add_api_route("/create/", cls.create, methods=["POST"])
|
||||||
router.add_api_route("/update/", cls.interactive_update, methods=["PUT"])
|
# router.add_api_route("/update/", cls.interactive_update, methods=["PUT"])
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
def get(id: int) -> cls:
|
||||||
|
return cls.get(id=id)
|
||||||
|
|
||||||
|
@router.get("/list", include_in_schema=False)
|
||||||
|
@router.get("/list/")
|
||||||
|
def get_page(
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
all: bool = False,
|
||||||
|
reverse: bool = False,
|
||||||
|
) -> PagedResult:
|
||||||
|
return cls.get_page()
|
||||||
|
|
||||||
|
@router.post("/create")
|
||||||
|
def create(cls: cls) -> cls:
|
||||||
|
return cls.create()
|
||||||
|
|
||||||
|
@router.put("/update")
|
||||||
|
def update() -> cls:
|
||||||
|
return cls.update()
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,7 @@ from sqlmodel_base.base import Base
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
class Hero(Base, table=True):
|
class HeroBase(Base):
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
name: str
|
name: str
|
||||||
secret_name: str
|
secret_name: str
|
||||||
age: Optional[int] = None
|
age: Optional[int] = None
|
||||||
|
|
@ -23,3 +22,29 @@ class Hero(Base, table=True):
|
||||||
if v > 0:
|
if v > 0:
|
||||||
return v
|
return v
|
||||||
return abs(v)
|
return abs(v)
|
||||||
|
|
||||||
|
|
||||||
|
class Hero(HeroBase, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
|
||||||
|
|
||||||
|
class HeroCreate(HeroBase):
|
||||||
|
__table_class__ = Hero
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HeroRead(HeroBase):
|
||||||
|
__table_class__ = Hero
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class HeroUpdate(Base, table=False):
|
||||||
|
__table_class__ = Hero
|
||||||
|
name: Optional[str]
|
||||||
|
secret_name: Optional[str]
|
||||||
|
age: Optional[int]
|
||||||
|
team_id: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Hero.cli()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue