Compare commits

..

No commits in common. "main" and "0.0.1" have entirely different histories.
main ... 0.0.1

9 changed files with 80 additions and 379 deletions

1
.gitignore vendored
View file

@ -962,4 +962,3 @@ FodyWeavers.xsd
# Additional files built by Visual Studio
# End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode
database.db

BIN
database.db Normal file

Binary file not shown.

View file

@ -24,15 +24,12 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf', 'fastapi', 'uvicorn', 'httpx']
dependencies = [ 'rich', 'sqlmodel', 'typer' ]
[project.urls]
Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme"
Issues = "https://github.com/waylonwalker/sqlmodel-base/issues"
Source = "https://github.com/waylonwalker/sqlmodel-base"
[project.scripts]
sqlmodel-base = "sqlmodel_base.cli:app"
Documentation = "https://github.com/unknown/sqlmodel-base#readme"
Issues = "https://github.com/unknown/sqlmodel-base/issues"
Source = "https://github.com/unknown/sqlmodel-base"
[tool.hatch.version]
path = "sqlmodel_base/__about__.py"

View file

@ -1,21 +1,18 @@
from typing import List, Optional
from typing import Optional
import httpx
import typer
import uvicorn
from fastapi import APIRouter, Depends, FastAPI
from iterfzf import iterfzf
from pydantic import BaseModel
from pydantic_core._pydantic_core import PydanticUndefinedType
from pydantic import BaseModel, validator
from rich.console import Console
from sqlalchemy import func
from sqlmodel import Session, SQLModel, select
from sqlmodel_base.database import get_engine, get_session
from sqlmodel import Field, Session, SQLModel, create_engine, select
console = Console()
def get_session():
with Session(engine) as session:
yield session
class PagedResult(BaseModel):
items: list
total: int
@ -25,134 +22,51 @@ class PagedResult(BaseModel):
class Base(SQLModel):
@classmethod
@property
def engine(self):
engine = get_engine()
return engine
def create(self, session: Optional[Session] = Depends(get_session)):
if isinstance(session, Session):
def create(self):
with Session(engine) as session:
validated = self.model_validate(self)
session.add(self.sqlmodel_update(validated))
session.commit()
session.refresh(self)
return self
else:
response = httpx.post(
"http://localhost:8000/create/", json=self.model_dump_json()
)
breakpoint()
return response
@classmethod
def interactive_create(cls, id: Optional[int] = None):
data = {}
for name, field in cls.__fields__.items():
default = field.default
if (
default is None or isinstance(default, PydanticUndefinedType)
) and not field.is_required():
default = "None"
if (isinstance(default, PydanticUndefinedType)) and field.is_required():
default = None
value = typer.prompt(f"{name}: ", default=default)
if value and value != "" and value != "None":
data[name] = value
item = cls(**data).create()
console.print(item)
def get(cls, id):
with Session(engine) as session:
return session.get(cls, id)
@classmethod
def pick(cls):
all = cls.all()
item = iterfzf([item.model_dump_json() for item in all])
if not item:
console.print("No item selected")
return
return cls.get(cls.parse_raw(item).id)
def get_all(cls):
with Session(engine) as session:
return session.exec(select(cls)).all()
@classmethod
def get(cls, id: int):
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
return session.get(cls.__table_class__, id)
return cls.model_validate(session.get(cls, id))
def get_count(cls):
with Session(engine) as session:
return session.exec(func.count(Hero.id)).scalar()
@classmethod
def get_or_pick(cls, id: Optional[int] = None):
if id is None:
return cls.pick()
return cls.get(id=id)
def get_first(cls):
with Session(engine) as session:
return session.exec(select(cls).limit(1)).first()
@classmethod
def all(cls) -> List:
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
return session.exec(select(cls.__table_class__)).all()
return [cls.model_validate(i) for i in session.exec(select(cls)).all()]
def get_last(cls):
with Session(engine) as session:
return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first()
@classmethod
def count(cls) -> int:
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
return session.exec(func.count(cls.__table_class__.id)).scalar()
return session.exec(func.count(cls.id)).scalar()
def get_random(cls):
with Session(engine) as session:
return session.exec(select(cls).order_by(cls.id).limit(1)).first()
@classmethod
def first(cls):
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
return cls.model_validate(
session.exec(select(table).order_by(table.id.asc()).limit(1)).first()
)
@classmethod
def last(cls):
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
return session.exec(
select(table).order_by(table.id.desc()).limit(1)
).first()
@classmethod
def get_page(
cls,
page: int = 1,
page_size: int = 20,
all: bool = False,
reverse: bool = False,
):
with Session(cls.engine) as session:
if hasattr(cls, "__table_class__"):
table = cls.__table_class__
else:
table = cls
if all:
items = session.exec(select(table)).all()
page_size = len(items)
else:
if reverse:
def get_page(cls, page: int = 1, page_size: int = 20):
with Session(engine) as session:
items = session.exec(
select(table)
.offset((page - 1) * page_size)
.limit(page_size)
.order_by(table.id.desc())
select(cls).offset((page - 1) * page_size).limit(page_size)
).all()
else:
items = session.exec(
select(table)
.offset((page - 1) * page_size)
.limit(page_size)
.order_by(table.id)
).all()
total = table.count()
total = cls.get_count()
# determine if there is a next page
if page * page_size < total:
next_page = page + 1
@ -168,126 +82,72 @@ class Base(SQLModel):
)
def delete(self):
with Session(self.engine) as session:
with Session(engine) as session:
session.delete(self)
session.commit()
return self
def update(self):
with Session(self.engine) as session:
with Session(engine) as session:
validated = self.model_validate(self)
session.add(self.sqlmodel_update(validated))
session.commit()
session.refresh(self)
return self
@classmethod
def interactive_update(cls, id: Optional[int] = None):
item = cls.get_or_pick(id=id)
if not item:
console.print("No item selected")
return
for field in item.__fields__.keys():
if field == "id":
continue
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
if (
value
and value != ""
and value != "None"
and value != getattr(item, field)
):
setattr(item, field, value)
item.update()
console.print(item)
@classmethod
def api(cls):
api = FastAPI(
title="FastAPI",
version="0.1.0",
# docs_url=None,
# redoc_url=None,
# openapi_url=None,
# openapi_tags=tags_metadata,
# dependencies=[Depends(set_user), Depends(set_prefers)],
)
class Hero(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
api.include_router(cls.router())
@validator("age")
def validate_age(cls, v):
if v is None:
return v
if v > 0:
return v
return abs(v)
return api
@classmethod
def router(cls):
router = APIRouter()
# router.add_api_route("/get/", cls.get, methods=["GET"])
# router.add_api_route("/list/", cls.all, methods=["GET"])
# router.add_api_route("/create/", cls.create, methods=["POST"])
# router.add_api_route("/update/", cls.interactive_update, methods=["PUT"])
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
@router.get("/")
def get(id: int) -> cls:
return cls.get(id=id)
engine = create_engine(sqlite_url) # , echo=True)
@router.get("/list", include_in_schema=False)
@router.get("/list/")
def get_page(
page: int = 1,
page_size: int = 20,
all: bool = False,
reverse: bool = False,
) -> PagedResult:
return cls.get_page()
@router.post("/create")
def create(cls: cls) -> cls:
return cls.create()
# replace with alembic commands
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
@router.put("/update")
def update() -> cls:
return cls.update()
return router
def create_heroes():
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson").create()
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create()
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create()
@classmethod
@property
def cli(cls):
app = typer.Typer()
# with Session(engine) as session:
# session.add(hero_1)
# session.add(hero_2)
# session.add(hero_3)
#
# session.commit()
@app.command()
def get(id: int = typer.Option(None, help="Hero ID")):
console.print(cls.get_or_pick(id=id))
@app.command()
def create():
console.print(cls.interactive_create())
def page_heroes():
next_page = 1
while next_page:
page = Hero.get_page(page=next_page, page_size=2)
console.print(page)
next_page = page.next_page
@app.command()
def list(
page: int = typer.Option(1, help="Page number"),
page_size: int = typer.Option(20, help="Page size"),
all: bool = typer.Option(False, help="Show all heroes"),
reverse: bool = typer.Option(False, help="Reverse order"),
):
console.print(
cls.get_page(
page=page,
page_size=page_size,
all=all,
reverse=reverse,
)
)
@app.command()
def api():
cls.run_api()
def main():
create_db_and_tables()
create_heroes()
page_heroes()
@app.command()
def update():
console.print(cls.interactive_update())
return app
@classmethod
def run_api(cls):
uvicorn.run(cls.api(), host="127.0.0.1", port=8000)
if __name__ == "__main__":
main()

View file

@ -1,11 +0,0 @@
import typer
from sqlmodel_base.hero.cli import hero_app
app = typer.Typer()
app.add_typer(hero_app, name="hero")
if __name__ == "__main__":
app()

View file

@ -1,18 +0,0 @@
from functools import lru_cache
from sqlmodel import Session, SQLModel, create_engine
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
@lru_cache
def get_engine():
engine = create_engine(sqlite_url)
SQLModel.metadata.create_all(engine)
return engine
def get_session():
with Session(get_engine()) as session:
yield session

View file

@ -1,65 +0,0 @@
from rich.console import Console
from sqlmodel_base.database import get_engine
from sqlmodel_base.hero.models import Hero
engine = get_engine()
hero_app = Hero.cli
console = Console()
# @hero_app.callback()
# def hero():
# "model cli"
# @hero_app.command()
# def get(id: int = typer.Option(None, help="Hero ID")):
# console.print(Hero.get_or_pick(id=id))
# @hero_app.command()
# def list(
# page: int = typer.Option(1, help="Page number"),
# page_size: int = typer.Option(20, help="Page size"),
# all: bool = typer.Option(False, help="Show all heroes"),
# reverse: bool = typer.Option(False, help="Reverse order"),
# ):
# console.print(
# Hero.get_page(page=page, page_size=page_size, all=all, reverse=reverse)
# )
# @hero_app.command()
# def create(
# name: str = typer.Option(..., help="Hero name", prompt=True),
# secret_name: str = typer.Option(..., help="Hero secret name", prompt=True),
# age: int = typer.Option(None, help="Hero age", prompt=True),
# ):
# hero = Hero(
# name=name,
# secret_name=secret_name,
# age=age,
# ).create()
# console.print(hero)
# @hero_app.command()
# def update(
# id: int = typer.Option(None, help="Hero ID"),
# name: str = typer.Option(None, help="Hero name"),
# secret_name: str = typer.Option(None, help="Hero secret name"),
# age: int = typer.Option(None, help="Hero age"),
# ):
# hero = Hero.interactive_update(id=id)
# console.print(hero)
# @hero_app.command()
# def create_heroes():
# team_1 = Team.get(id=1)
# if not team_1:
# team_1 = Team(name="Team 1", headquarters="Headquarters 1").create()
# for _ in range(50):
# Hero(name="Deadpond", secret_name="Dive Wilson", team_id=team_1.id).create()
# Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create()
# Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create()

View file

@ -1,50 +0,0 @@
from typing import Optional
from pydantic import validator
from rich.console import Console
from sqlmodel import Field
from sqlmodel_base.base import Base
console = Console()
class HeroBase(Base):
name: str
secret_name: str
age: Optional[int] = None
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
@validator("age")
def validate_age(cls, v):
if v is None:
return v
if v > 0:
return v
return abs(v)
class Hero(HeroBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
class HeroCreate(HeroBase):
__table_class__ = Hero
pass
class HeroRead(HeroBase):
__table_class__ = Hero
id: int
class HeroUpdate(Base, table=False):
__table_class__ = Hero
name: Optional[str]
secret_name: Optional[str]
age: Optional[int]
team_id: Optional[int]
if __name__ == "__main__":
Hero.cli()

View file

@ -1,11 +0,0 @@
from typing import Optional
from sqlmodel import Field
from sqlmodel_base.base import Base
class Team(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
headquarters: str