This commit is contained in:
Waylon Walker 2024-03-01 07:10:57 -06:00
parent 85554e2169
commit a21dbb08d4
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
8 changed files with 285 additions and 43 deletions

1
.gitignore vendored
View file

@ -962,3 +962,4 @@ FodyWeavers.xsd
# Additional files built by Visual Studio # Additional files built by Visual Studio
# End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode # End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode
database.db

View file

@ -24,13 +24,16 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: Implementation :: PyPy",
] ]
dependencies = [ 'rich', 'sqlmodel', 'typer' ] dependencies = [ 'rich', 'sqlmodel', 'typer', 'iterfzf']
[project.urls] [project.urls]
Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme" Documentation = "https://github.com/waylonwalker/sqlmodel-base#readme"
Issues = "https://github.com/waylonwalker/sqlmodel-base/issues" Issues = "https://github.com/waylonwalker/sqlmodel-base/issues"
Source = "https://github.com/waylonwalker/sqlmodel-base" Source = "https://github.com/waylonwalker/sqlmodel-base"
[project.scripts]
sqlmodel-base = "sqlmodel_base.cli:app"
[tool.hatch.version] [tool.hatch.version]
path = "sqlmodel_base/__about__.py" path = "sqlmodel_base/__about__.py"

View file

@ -1,18 +1,19 @@
import json
from typing import Optional from typing import Optional
import typer
from iterfzf import iterfzf
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from pydantic_core._pydantic_core import PydanticUndefinedType
from rich.console import Console from rich.console import Console
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import Field, Session, SQLModel, create_engine, select from sqlmodel import Field, Session, SQLModel, create_engine, select
from sqlmodel_base.database import get_engine
console = Console() console = Console()
def get_session():
with Session(engine) as session:
yield session
class PagedResult(BaseModel): class PagedResult(BaseModel):
items: list items: list
total: int total: int
@ -22,8 +23,14 @@ class PagedResult(BaseModel):
class Base(SQLModel): class Base(SQLModel):
@classmethod
@property
def engine(self):
engine = get_engine()
return engine
def create(self): def create(self):
with Session(engine) as session: with Session(self.engine) as session:
validated = self.model_validate(self) validated = self.model_validate(self)
session.add(self.sqlmodel_update(validated)) session.add(self.sqlmodel_update(validated))
session.commit() session.commit()
@ -31,42 +38,99 @@ class Base(SQLModel):
return self return self
@classmethod @classmethod
def get(cls, id): def interactive_create(cls, id: Optional[int] = None):
with Session(engine) as session: data = {}
for name, field in cls.__fields__.items():
default = field.default
if (
default is None or isinstance(default, PydanticUndefinedType)
) and not field.is_required():
default = "None"
if (isinstance(default, PydanticUndefinedType)) and field.is_required():
default = None
value = typer.prompt(f"{name}: ", default=default)
if value and value != "" and value != "None":
data[name] = value
item = cls(**data).create()
console.print(item)
@classmethod
def pick(cls):
all = cls.all()
item = iterfzf([item.model_dump_json() for item in all])
if not item:
console.print("No item selected")
return
return cls.get(cls.parse_raw(item).id)
@classmethod
def get(cls, id: int):
with Session(cls.engine) as session:
return session.get(cls, id) return session.get(cls, id)
@classmethod @classmethod
def get_all(cls): def get_or_pick(cls, id: Optional[int] = None):
with Session(engine) as session: if id is None:
return cls.pick()
return cls.get(id=id)
@classmethod
def all(cls):
with Session(cls.engine) as session:
return session.exec(select(cls)).all() return session.exec(select(cls)).all()
@classmethod @classmethod
def get_count(cls): def count(cls):
with Session(engine) as session: with Session(cls.engine) as session:
return session.exec(func.count(Hero.id)).scalar() return session.exec(func.count(cls.id)).scalar()
@classmethod @classmethod
def get_first(cls): def first(cls):
with Session(engine) as session: with Session(cls.engine) as session:
return session.exec(select(cls).limit(1)).first() return session.exec(select(cls).limit(1)).first()
@classmethod @classmethod
def get_last(cls): def last(cls):
with Session(engine) as session: with Session(cls.engine) as session:
return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first() return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first()
@classmethod @classmethod
def get_random(cls): def random(cls):
with Session(engine) as session: with Session(cls.engine) as session:
return session.exec(select(cls).order_by(cls.id).limit(1)).first() return session.exec(select(cls).order_by(cls.id).limit(1)).first()
@classmethod @classmethod
def get_page(cls, page: int = 1, page_size: int = 20): def get_page(
with Session(engine) as session: cls,
items = session.exec( page: int = 1,
select(cls).offset((page - 1) * page_size).limit(page_size) page_size: int = 20,
).all() all: bool = False,
total = cls.get_count() reverse: bool = False,
):
with Session(cls.engine) as session:
if all:
items = session.exec(select(cls)).all()
page_size = len(items)
else:
if reverse:
items = session.exec(
select(cls)
.offset((page - 1) * page_size)
.limit(page_size)
.order_by(cls.id.desc())
).all()
else:
items = session.exec(
select(cls)
.offset((page - 1) * page_size)
.limit(page_size)
.order_by(cls.id)
).all()
# items = session.exec(
# select(cls).offset((page - 1) * page_size).limit(page_size)
# ).all()
total = cls.count()
# determine if there is a next page # determine if there is a next page
if page * page_size < total: if page * page_size < total:
next_page = page + 1 next_page = page + 1
@ -82,39 +146,71 @@ class Base(SQLModel):
) )
def delete(self): def delete(self):
with Session(engine) as session: with Session(self.engine) as session:
session.delete(self) session.delete(self)
session.commit() session.commit()
return self return self
def update(self): def update(self):
with Session(engine) as session: with Session(self.engine) as session:
validated = self.model_validate(self) validated = self.model_validate(self)
session.add(self.sqlmodel_update(validated)) session.add(self.sqlmodel_update(validated))
session.commit() session.commit()
session.refresh(self) session.refresh(self)
return self return self
@classmethod
def interactive_update(cls, id: Optional[int] = None):
item = cls.get_or_pick(id=id)
if not item:
console.print("No item selected")
return
for field in item.__fields__.keys():
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
if (
value
and value != ""
and value != "None"
and value != getattr(item, field)
):
setattr(item, field, value)
item.update()
console.print(item)
class Hero(Base, table=True): @classmethod
id: Optional[int] = Field(default=None, primary_key=True) @property
name: str def cli(cls):
secret_name: str app = typer.Typer()
age: Optional[int] = None
@validator("age") @app.command()
def validate_age(cls, v): def get(id: int = typer.Option(None, help="Hero ID")):
if v is None: console.print(cls.get_or_pick(id=id))
return v
if v > 0:
return v
return abs(v)
@app.command()
def create():
console.print(cls.interactive_create())
sqlite_file_name = "database.db" @app.command()
sqlite_url = f"sqlite:///{sqlite_file_name}" def list(
page: int = typer.Option(1, help="Page number"),
page_size: int = typer.Option(20, help="Page size"),
all: bool = typer.Option(False, help="Show all heroes"),
reverse: bool = typer.Option(False, help="Reverse order"),
):
console.print(
cls.get_page(
page=page,
page_size=page_size,
all=all,
reverse=reverse,
)
)
engine = create_engine(sqlite_url) # , echo=True) @app.command()
def update():
console.print(cls.interactive_update())
return app
# replace with alembic commands # replace with alembic commands

11
sqlmodel_base/cli.py Normal file
View file

@ -0,0 +1,11 @@
import typer
from sqlmodel_base.hero.cli import hero_app
app = typer.Typer()
app.add_typer(hero_app, name="hero")
if __name__ == "__main__":
app()

21
sqlmodel_base/database.py Normal file
View file

@ -0,0 +1,21 @@
from functools import lru_cache
from sqlmodel import Field, Session, SQLModel, create_engine, select
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
@lru_cache
def get_engine():
from sqlmodel_base.hero.models import Hero
from sqlmodel_base.team.models import Team
engine = create_engine(sqlite_url)
SQLModel.metadata.create_all(engine)
return engine
def get_session():
with Session(get_engine()) as session:
yield session

72
sqlmodel_base/hero/cli.py Normal file
View file

@ -0,0 +1,72 @@
import json
import typer
from iterfzf import iterfzf
from rich.console import Console
from sqlmodel_base.database import get_engine
from sqlmodel_base.hero.models import Hero
from sqlmodel_base.team.models import Team
engine = get_engine()
hero_app = Hero.cli
console = Console()
# @hero_app.callback()
# def hero():
# "model cli"
# @hero_app.command()
# def get(id: int = typer.Option(None, help="Hero ID")):
# console.print(Hero.get_or_pick(id=id))
# @hero_app.command()
# def list(
# page: int = typer.Option(1, help="Page number"),
# page_size: int = typer.Option(20, help="Page size"),
# all: bool = typer.Option(False, help="Show all heroes"),
# reverse: bool = typer.Option(False, help="Reverse order"),
# ):
# console.print(
# Hero.get_page(page=page, page_size=page_size, all=all, reverse=reverse)
# )
# @hero_app.command()
# def create(
# name: str = typer.Option(..., help="Hero name", prompt=True),
# secret_name: str = typer.Option(..., help="Hero secret name", prompt=True),
# age: int = typer.Option(None, help="Hero age", prompt=True),
# ):
# hero = Hero(
# name=name,
# secret_name=secret_name,
# age=age,
# ).create()
# console.print(hero)
# @hero_app.command()
# def update(
# id: int = typer.Option(None, help="Hero ID"),
# name: str = typer.Option(None, help="Hero name"),
# secret_name: str = typer.Option(None, help="Hero secret name"),
# age: int = typer.Option(None, help="Hero age"),
# ):
# hero = Hero.interactive_update(id=id)
# console.print(hero)
# @hero_app.command()
# def create_heroes():
# team_1 = Team.get(id=1)
# if not team_1:
# team_1 = Team(name="Team 1", headquarters="Headquarters 1").create()
# for _ in range(50):
# Hero(name="Deadpond", secret_name="Dive Wilson", team_id=team_1.id).create()
# Hero(name="Spider-Boy", secret_name="Pedro Parqueador").create()
# Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48).create()

View file

@ -0,0 +1,26 @@
from typing import Optional
from pydantic import BaseModel, validator
from rich.console import Console
from sqlalchemy import func
from sqlmodel import Field, Session, SQLModel, create_engine, select
from sqlmodel_base.base import Base
console = Console()
class Hero(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
@validator("age")
def validate_age(cls, v):
if v is None:
return v
if v > 0:
return v
return abs(v)

View file

@ -0,0 +1,12 @@
from typing import Optional
from rich.console import Console
from sqlmodel import Field
from sqlmodel_base.base import Base
class Team(Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
headquarters: str