add cli
This commit is contained in:
parent
85554e2169
commit
a21dbb08d4
8 changed files with 285 additions and 43 deletions
|
|
@ -1,18 +1,19 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from iterfzf import iterfzf
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic_core._pydantic_core import PydanticUndefinedType
|
||||
from rich.console import Console
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
from sqlmodel_base.database import get_engine
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def get_session():
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
class PagedResult(BaseModel):
|
||||
items: list
|
||||
total: int
|
||||
|
|
@ -22,8 +23,14 @@ class PagedResult(BaseModel):
|
|||
|
||||
|
||||
class Base(SQLModel):
|
||||
@classmethod
|
||||
@property
|
||||
def engine(self):
|
||||
engine = get_engine()
|
||||
return engine
|
||||
|
||||
def create(self):
|
||||
with Session(engine) as session:
|
||||
with Session(self.engine) as session:
|
||||
validated = self.model_validate(self)
|
||||
session.add(self.sqlmodel_update(validated))
|
||||
session.commit()
|
||||
|
|
@ -31,42 +38,99 @@ class Base(SQLModel):
|
|||
return self
|
||||
|
||||
@classmethod
|
||||
def get(cls, id):
|
||||
with Session(engine) as session:
|
||||
def interactive_create(cls, id: Optional[int] = None):
|
||||
data = {}
|
||||
for name, field in cls.__fields__.items():
|
||||
default = field.default
|
||||
if (
|
||||
default is None or isinstance(default, PydanticUndefinedType)
|
||||
) and not field.is_required():
|
||||
default = "None"
|
||||
if (isinstance(default, PydanticUndefinedType)) and field.is_required():
|
||||
default = None
|
||||
value = typer.prompt(f"{name}: ", default=default)
|
||||
if value and value != "" and value != "None":
|
||||
data[name] = value
|
||||
item = cls(**data).create()
|
||||
console.print(item)
|
||||
|
||||
@classmethod
|
||||
def pick(cls):
|
||||
all = cls.all()
|
||||
item = iterfzf([item.model_dump_json() for item in all])
|
||||
if not item:
|
||||
console.print("No item selected")
|
||||
return
|
||||
return cls.get(cls.parse_raw(item).id)
|
||||
|
||||
@classmethod
|
||||
def get(cls, id: int):
|
||||
with Session(cls.engine) as session:
|
||||
return session.get(cls, id)
|
||||
|
||||
@classmethod
|
||||
def get_all(cls):
|
||||
with Session(engine) as session:
|
||||
def get_or_pick(cls, id: Optional[int] = None):
|
||||
if id is None:
|
||||
return cls.pick()
|
||||
return cls.get(id=id)
|
||||
|
||||
@classmethod
|
||||
def all(cls):
|
||||
with Session(cls.engine) as session:
|
||||
return session.exec(select(cls)).all()
|
||||
|
||||
@classmethod
|
||||
def get_count(cls):
|
||||
with Session(engine) as session:
|
||||
return session.exec(func.count(Hero.id)).scalar()
|
||||
def count(cls):
|
||||
with Session(cls.engine) as session:
|
||||
return session.exec(func.count(cls.id)).scalar()
|
||||
|
||||
@classmethod
|
||||
def get_first(cls):
|
||||
with Session(engine) as session:
|
||||
def first(cls):
|
||||
with Session(cls.engine) as session:
|
||||
return session.exec(select(cls).limit(1)).first()
|
||||
|
||||
@classmethod
|
||||
def get_last(cls):
|
||||
with Session(engine) as session:
|
||||
def last(cls):
|
||||
with Session(cls.engine) as session:
|
||||
return session.exec(select(cls).order_by(cls.id.desc()).limit(1)).first()
|
||||
|
||||
@classmethod
|
||||
def get_random(cls):
|
||||
with Session(engine) as session:
|
||||
def random(cls):
|
||||
with Session(cls.engine) as session:
|
||||
return session.exec(select(cls).order_by(cls.id).limit(1)).first()
|
||||
|
||||
@classmethod
|
||||
def get_page(cls, page: int = 1, page_size: int = 20):
|
||||
with Session(engine) as session:
|
||||
items = session.exec(
|
||||
select(cls).offset((page - 1) * page_size).limit(page_size)
|
||||
).all()
|
||||
total = cls.get_count()
|
||||
def get_page(
|
||||
cls,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
all: bool = False,
|
||||
reverse: bool = False,
|
||||
):
|
||||
with Session(cls.engine) as session:
|
||||
if all:
|
||||
items = session.exec(select(cls)).all()
|
||||
page_size = len(items)
|
||||
else:
|
||||
if reverse:
|
||||
items = session.exec(
|
||||
select(cls)
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
.order_by(cls.id.desc())
|
||||
).all()
|
||||
else:
|
||||
items = session.exec(
|
||||
select(cls)
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
.order_by(cls.id)
|
||||
).all()
|
||||
# items = session.exec(
|
||||
# select(cls).offset((page - 1) * page_size).limit(page_size)
|
||||
# ).all()
|
||||
|
||||
total = cls.count()
|
||||
# determine if there is a next page
|
||||
if page * page_size < total:
|
||||
next_page = page + 1
|
||||
|
|
@ -82,39 +146,71 @@ class Base(SQLModel):
|
|||
)
|
||||
|
||||
def delete(self):
|
||||
with Session(engine) as session:
|
||||
with Session(self.engine) as session:
|
||||
session.delete(self)
|
||||
session.commit()
|
||||
return self
|
||||
|
||||
def update(self):
|
||||
with Session(engine) as session:
|
||||
with Session(self.engine) as session:
|
||||
validated = self.model_validate(self)
|
||||
session.add(self.sqlmodel_update(validated))
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def interactive_update(cls, id: Optional[int] = None):
|
||||
item = cls.get_or_pick(id=id)
|
||||
if not item:
|
||||
console.print("No item selected")
|
||||
return
|
||||
for field in item.__fields__.keys():
|
||||
value = typer.prompt(f"{field}: ", default=getattr(item, field) or "None")
|
||||
if (
|
||||
value
|
||||
and value != ""
|
||||
and value != "None"
|
||||
and value != getattr(item, field)
|
||||
):
|
||||
setattr(item, field, value)
|
||||
item.update()
|
||||
console.print(item)
|
||||
|
||||
class Hero(Base, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
secret_name: str
|
||||
age: Optional[int] = None
|
||||
@classmethod
|
||||
@property
|
||||
def cli(cls):
|
||||
app = typer.Typer()
|
||||
|
||||
@validator("age")
|
||||
def validate_age(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
if v > 0:
|
||||
return v
|
||||
return abs(v)
|
||||
@app.command()
|
||||
def get(id: int = typer.Option(None, help="Hero ID")):
|
||||
console.print(cls.get_or_pick(id=id))
|
||||
|
||||
@app.command()
|
||||
def create():
|
||||
console.print(cls.interactive_create())
|
||||
|
||||
sqlite_file_name = "database.db"
|
||||
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
||||
@app.command()
|
||||
def list(
|
||||
page: int = typer.Option(1, help="Page number"),
|
||||
page_size: int = typer.Option(20, help="Page size"),
|
||||
all: bool = typer.Option(False, help="Show all heroes"),
|
||||
reverse: bool = typer.Option(False, help="Reverse order"),
|
||||
):
|
||||
console.print(
|
||||
cls.get_page(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
all=all,
|
||||
reverse=reverse,
|
||||
)
|
||||
)
|
||||
|
||||
engine = create_engine(sqlite_url) # , echo=True)
|
||||
@app.command()
|
||||
def update():
|
||||
console.print(cls.interactive_update())
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# replace with alembic commands
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue