diff --git a/.gitignore b/.gitignore index 9d6f21d..243fb3a 100644 --- a/.gitignore +++ b/.gitignore @@ -963,3 +963,5 @@ FodyWeavers.xsd # End of https://www.toptal.com/developers/gitignore/api/vim,node,data,emacs,python,pycharm,executable,sublimetext,visualstudio,visualstudiocode database.db +database.db +database.db diff --git a/README.md b/README.md index 969cb55..6044f26 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ learn-sql-model hero get --id 0 ## Use python to manage Heros ```python -from learn_sql_model.models import Hero +from learn_sql_model.models.hero import Hero # create a hero bruce = Hero(name="Batman", secret_name="Bruce Wayne") bruce.post() diff --git a/learn_sql_model/api.py b/learn_sql_model/api.py index b110c33..fbedc68 100644 --- a/learn_sql_model/api.py +++ b/learn_sql_model/api.py @@ -1,9 +1,7 @@ from typing import Union from fastapi import FastAPI -import httpx -from learn_sql_model.console import console from learn_sql_model.models.hero import Hero from learn_sql_model.models.pet import Pet @@ -13,53 +11,3 @@ models = Union[Hero, Pet] # from learn_sql_model.models import Hero app = FastAPI() - - -app.post("/heroes/") - - -def post(self: models) -> None: - - try: - httpx.post("http://localhost:5000/heroes/", json=self.dict()) - except httpx.ConnectError: - console.log("local failover") - post_local(self) - - -def post_local(self: models) -> None: - from learn_sql_model.config import config - - with config.session as session: - session.add(self) - session.commit() - - -def get(self: models, instance: models = None) -> list[models]: - "read all the heros" - from learn_sql_model.config import config - - with config.session as session: - if instance is None: - heroes = session.exec(select(self)).all() - return heroes - else: - hero = session.exec(select(self).where(self.id == instance.id)).all().one() - return hero - - -@app.post("/heroes/") -def create_hero(hero: Hero): - post(hero) - - -@app.get("/heroes/") -def read_heroes() -> list[Hero]: - "read all the heros" - return get(Hero) - - -@app.get("/hero/") -def read_heroes(hero: Hero) -> list[Hero]: - "read all the heros" - return get(Hero, hero) diff --git a/learn_sql_model/api/hero.py b/learn_sql_model/api/hero.py index 01fb8bd..3560421 100644 --- a/learn_sql_model/api/hero.py +++ b/learn_sql_model/api/hero.py @@ -28,4 +28,4 @@ def post_hero(hero: Hero) -> Hero: @hero_router.get("/heros/") def get_heros() -> list[Hero]: "get all heros" - return Hero.get() + return Hero().get() diff --git a/learn_sql_model/cli/api.py b/learn_sql_model/cli/api.py index 209bbc6..7c9af84 100644 --- a/learn_sql_model/cli/api.py +++ b/learn_sql_model/cli/api.py @@ -2,7 +2,7 @@ import typer import uvicorn from learn_sql_model.cli.common import verbose_callback -from learn_sql_model.config import config +from learn_sql_model.config import get_config api_app = typer.Typer() @@ -26,4 +26,4 @@ def run( help="show the log messages", ), ): - uvicorn.run("learn_sql_model.api.app:app", port=config.port, log_level="info") + uvicorn.run("learn_sql_model.api.app:app", port=get_config().port, log_level="info") diff --git a/learn_sql_model/cli/config.py b/learn_sql_model/cli/config.py index a42a450..3e6d622 100644 --- a/learn_sql_model/cli/config.py +++ b/learn_sql_model/cli/config.py @@ -2,7 +2,7 @@ from rich.console import Console import typer from learn_sql_model.cli.common import verbose_callback -from learn_sql_model.config import config as configuration +from learn_sql_model.config import get_config config_app = typer.Typer() @@ -26,4 +26,4 @@ def show( help="show the log messages", ), ): - Console().print(configuration) + Console().print(get_config()) diff --git a/learn_sql_model/cli/hero.py b/learn_sql_model/cli/hero.py index e15dfd1..11c2e99 100644 --- a/learn_sql_model/cli/hero.py +++ b/learn_sql_model/cli/hero.py @@ -17,7 +17,7 @@ def hero(): @hero_app.command() def get(id: int = None) -> Union[Hero, List[Hero]]: "get one hero" - hero = Hero.get(item_id=id) + hero = Hero().get(item_id=id) Console().print(hero) return hero diff --git a/learn_sql_model/config.py b/learn_sql_model/config.py index 16b37f5..c1b6a41 100644 --- a/learn_sql_model/config.py +++ b/learn_sql_model/config.py @@ -1,18 +1,40 @@ from typing import TYPE_CHECKING from pydantic import BaseSettings -from sqlmodel import SQLModel, Session, create_engine +from sqlalchemy import create_engine +from sqlmodel import SQLModel, Session -from learn_sql_model.models.hero import Hero -from learn_sql_model.models.pet import Pet from learn_sql_model.standard_config import load -models = [Hero, Pet] - if TYPE_CHECKING: from sqlalchemy import Engine +class Database: + def __init__(self, config: "Config" = None) -> None: + if config is None: + + self.config = get_config() + else: + self.config = config + self.create_db_and_tables() + + @property + def engine(self) -> "Engine": + return create_engine(self.config.database_url) + + def session(self) -> "Session": + return Session(self.engine) + + def create_db_and_tables(self) -> None: + from learn_sql_model.models.hero import Hero + from learn_sql_model.models.pet import Pet + + __all__ = [Hero, Pet] + + SQLModel.metadata.create_all(self.engine) + + class Config(BaseSettings): database_url: str = "sqlite:///database.db" port: int = 5000 @@ -21,27 +43,19 @@ class Config(BaseSettings): env_prefix = "LEARN_SQL_MODEL_" @property - def engine(self) -> "Engine": - return create_engine(self.database_url) + def database(self) -> Database: + return get_database(config=self) - @property - def session(self) -> "Session": - return Session(self.engine) - def create_db_and_tables(self) -> None: - SQLModel.metadata.create_all(self.engine) +def get_database(config: Config = None) -> Database: - # def create_endpoints(self) -> None: - # for model in models: - # app.post("/heroes/")(Hero.post_local) - # app.get("/heroes/")(Hero.read_heroes) + if config is None: + config = get_config() + + return Database(config) def get_config(overrides: dict = {}) -> Config: raw_config = load("learn_sql_model") config = Config(**raw_config, **overrides) - config.create_db_and_tables() return config - - -config = get_config() diff --git a/learn_sql_model/database.py b/learn_sql_model/database.py new file mode 100644 index 0000000..e69de29 diff --git a/learn_sql_model/models.py b/learn_sql_model/models.py new file mode 100644 index 0000000..a919bb7 --- /dev/null +++ b/learn_sql_model/models.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import Optional + +from learn_sql_model.models.fast_model import FastModel +from sqlmodel import Field + + +class Hero(FastModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + # new_attribute: Optional[str] = None + # pets: List["Pet"] = Relationship(back_populates="hero") diff --git a/learn_sql_model/models/fast_model.py b/learn_sql_model/models/fast_model.py index c5cc569..0c40bd5 100644 --- a/learn_sql_model/models/fast_model.py +++ b/learn_sql_model/models/fast_model.py @@ -2,6 +2,8 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import SQLModel, select +from learn_sql_model.config import get_config + if TYPE_CHECKING: from learn_sql_model.config import Config @@ -19,28 +21,36 @@ class FastModel(SQLModel): def post(self, config: "Config" = None) -> None: if config is None: - from learn_sql_model.config import get_config config = get_config() self.pre_post() - with config.session as session: + instance = self.__class__(**self.dict()) + + with config.database.session() as session: session.add(self) session.commit() + return instance - @classmethod - def get( - self, item_id: int = None, config: "Config" = None - ) -> Optional["FastModel"]: + def get(self, id: int = None, config: "Config" = None) -> Optional["FastModel"]: if config is None: - from learn_sql_model.config import get_config config = get_config() self.pre_get() - with config.session as session: - if item_id is None: - return session.exec(select(self)).all() - return session.exec(select(self).where(self.id == item_id)).one() + with config.database.session() as session: + if id is None: + print("get all") + statement = select(self.__class__) + results = session.exec(statement).all() + else: + print("get by id") + statement = select(self.__class__).where(self.__class__.id == id) + results = session.exec(statement).one() + return results + + # TODO + # update + # delete diff --git a/tests/test_hero.py b/tests/test_hero.py index 40b8be4..e36b28c 100644 --- a/tests/test_hero.py +++ b/tests/test_hero.py @@ -1,7 +1,6 @@ import tempfile import pytest -from sqlmodel import Session from learn_sql_model.config import Config, get_config from learn_sql_model.factories.hero import HeroFactory @@ -11,15 +10,25 @@ Hero @pytest.fixture -def config() -> Session: +def config() -> Config: tmp_db = tempfile.NamedTemporaryFile(suffix=".db") config = get_config({"database_url": f"sqlite:///{tmp_db.name}"}) - config.create_db_and_tables() return config def test_post_hero(config: Config) -> None: - hero = HeroFactory().build(name="Batman", age=50) - hero.post(config=config) - assert hero.get(hero.id) == hero - breakpoint() + hero = HeroFactory().build(name="Batman", age=50, id=1) + hero = hero.post(config=config) + db_hero = Hero().get(hero.id, config=config) + assert db_hero == hero + + +def test_update_hero(config: Config) -> None: + hero = HeroFactory().build(name="Batman", age=50, id=1) + hero = hero.post(config=config) + db_hero = Hero().get(id=hero.id, config=config) + assert db_hero.dict() == hero.dict() + db_hero.name = "Superman" + hero = db_hero.post(config=config) + db_hero = Hero().get(id=hero.id, config=config) + assert db_hero.dict() == hero.dict()