diff --git a/learn_sql_model/cli/hero.py b/learn_sql_model/cli/hero.py index f9c0243..b019494 100644 --- a/learn_sql_model/cli/hero.py +++ b/learn_sql_model/cli/hero.py @@ -2,10 +2,9 @@ from typing import List, Union from pydantic_typer import expand_pydantic_args from rich.console import Console -from sqlmodel import SQLModel import typer -from learn_sql_model.config import Config, get_config +from learn_sql_model.config import Config from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.pet import PetFactory from learn_sql_model.models.hero import Hero @@ -17,12 +16,16 @@ hero_app = typer.Typer() @hero_app.callback() def hero(): "model cli" - SQLModel.metadata.create_all(get_config().database.engine) @hero_app.command() -def get(id: int = None) -> Union[Hero, List[Hero]]: +@expand_pydantic_args(typer=True) +def get( + id: int = None, + config: Config = None, +) -> Union[Hero, List[Hero]]: "get one hero" + config.init() hero = Hero().get(id=id) Console().print(hero) return hero @@ -36,6 +39,7 @@ def create( config: Config = None, ) -> Hero: "read all the heros" + config.init() hero.pet = pet hero = hero.post(config=config) Console().print(hero) @@ -48,6 +52,7 @@ def populate( config: Config = None, ) -> Hero: "read all the heros" + config.init() if config is None: config = Config() if config.env == "prod": diff --git a/learn_sql_model/config.py b/learn_sql_model/config.py index 27aabff..7b5c47b 100644 --- a/learn_sql_model/config.py +++ b/learn_sql_model/config.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from pydantic import BaseModel, BaseSettings from sqlalchemy import create_engine -from sqlmodel import Session +from sqlmodel import SQLModel, Session from learn_sql_model.standard_config import load @@ -24,7 +24,6 @@ class Database: self.config = get_config() else: self.config = config - self.create_db_and_tables() @property def engine(self) -> "Engine": @@ -49,6 +48,9 @@ class Config(BaseSettings): def database(self) -> Database: return get_database(config=self) + def init(self) -> None: + SQLModel.metadata.create_all(self.database.engine) + def get_database(config: Config = None) -> Database: diff --git a/tests/test_hero.py b/tests/test_hero.py index e36b28c..976e8be 100644 --- a/tests/test_hero.py +++ b/tests/test_hero.py @@ -1,12 +1,17 @@ import tempfile +from fastapi.testclient import TestClient import pytest +from typer.testing import CliRunner +from learn_sql_model.api.app import app +from learn_sql_model.cli.hero import hero_app from learn_sql_model.config import Config, get_config from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.models.hero import Hero -Hero +runner = CliRunner() +client = TestClient(app) @pytest.fixture @@ -17,18 +22,57 @@ def config() -> Config: def test_post_hero(config: Config) -> None: + config.init() # required for python api, and no existing db hero = HeroFactory().build(name="Batman", age=50, id=1) hero = hero.post(config=config) - db_hero = Hero().get(hero.id, config=config) - assert db_hero == hero + db_hero = Hero().get(id=1, config=config) + assert db_hero.age == 50 + assert db_hero.name == "Batman" def test_update_hero(config: Config) -> None: + config.init() # required for python api, and no existing db hero = HeroFactory().build(name="Batman", age=50, id=1) hero = hero.post(config=config) - db_hero = Hero().get(id=hero.id, config=config) - assert db_hero.dict() == hero.dict() + db_hero = Hero().get(id=1, config=config) db_hero.name = "Superman" hero = db_hero.post(config=config) - db_hero = Hero().get(id=hero.id, config=config) - assert db_hero.dict() == hero.dict() + db_hero = Hero().get(id=1, config=config) + assert db_hero.age == 50 + assert db_hero.name == "Superman" + + +def test_cli_create(config): + result = runner.invoke( + hero_app, + [ + "create", + "--name", + "Darth Vader", + "--secret-name", + "Anakin", + "--id", + "2", + "--age", + "100", + "--database-url", + config.database_url, + ], + ) + assert result.exit_code == 0 + db_hero = Hero().get(id=2, config=config) + assert db_hero.age == 100 + assert db_hero.name == "Darth Vader" + + +def test_read_main(config): + config.init() + hero = HeroFactory().build(name="Ironman", age=25, id=99) + hero_id = hero.id + hero = hero.post(config=config) + response = client.get(f"/hero/{hero_id}") + assert response.status_code == 200 + reponse_hero = Hero.parse_obj(response.json()) + assert reponse_hero.id == hero_id + assert reponse_hero.name == "Ironman" + assert reponse_hero.age == 25