This commit is contained in:
Waylon Walker 2023-06-09 16:04:58 -05:00
parent 1a0bf1adb9
commit c3db85a209
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
21 changed files with 647 additions and 658 deletions

View file

@ -1,15 +1,12 @@
import tempfile
from fastapi.testclient import TestClient
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from sqlmodel import SQLModel, Session, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from learn_sql_model.api.app import app
from learn_sql_model.cli.hero import hero_app
from learn_sql_model.config import Config, get_config
from learn_sql_model.config import get_config, get_session
from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.models.hero import Hero
@ -17,142 +14,193 @@ runner = CliRunner()
client = TestClient(app)
@pytest.fixture
def config() -> Config:
tmp_db = tempfile.NamedTemporaryFile(suffix=".db")
config = get_config({"database_url": f"sqlite:///{tmp_db.name}"})
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
config.database_url, connect_args={"check_same_thread": False}
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# breakpoint()
SQLModel.metadata.create_all(config.database.engine)
# def override_get_db():
# try:
# db = TestingSessionLocal()
# yield db
# finally:
# db.close()
def override_get_config():
return config
app.dependency_overrides[get_config] = override_get_config
yield config
# tmp_db automatically deletes here
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
def test_post_hero(config: Config) -> None:
hero = HeroFactory().build(name="Batman", age=50, id=1)
hero = hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
assert db_hero.age == 50
assert db_hero.name == "Batman"
@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_update_hero(config: Config) -> None:
hero = HeroFactory().build(name="Batman", age=50, id=1)
hero = hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
db_hero.name = "Superbman"
hero = db_hero.post(config=config)
db_hero = Hero().get(id=1, config=config)
assert db_hero.age == 50
assert db_hero.name == "Superbman"
def test_cli_get(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
hero.post(config=config)
result = runner.invoke(
hero_app,
["get", "--id", 99, "--database-url", config.database_url],
)
assert result.exit_code == 0
db_hero = Hero().get(id=99, config=config)
assert db_hero.age == 25
assert db_hero.name == "Steelman"
def test_cli_create(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
result = runner.invoke(
hero_app,
[
"create",
*hero.flags(config=config),
"--database-url",
config.database_url,
],
)
assert result.exit_code == 0
db_hero = Hero().get(id=99, config=config)
assert db_hero.age == 25
assert db_hero.name == "Steelman"
def test_cli_populate(config):
result = runner.invoke(
hero_app,
[
"populate",
"--n",
10,
"--database-url",
config.database_url,
],
)
assert result.exit_code == 0
db_hero = Hero().get(config=config)
assert len(db_hero) == 10
def test_cli_populate_fails_prod(config):
result = runner.invoke(
hero_app,
["populate", "--n", 10, "--database-url", config.database_url, "--env", "prod"],
)
assert result.exit_code == 1
assert result.output.strip() == "populate is not supported in production"
def test_api_read(config):
hero = HeroFactory().build(name="Steelman", age=25, id=99)
hero_id = hero.id
hero = hero.post(config=config)
response = client.get(f"/hero/{hero_id}")
assert response.status_code == 200
reponse_hero = Hero.parse_obj(response.json())
assert reponse_hero.id == hero_id
assert reponse_hero.name == "Steelman"
assert reponse_hero.age == 25
def test_api_post(config):
def test_api_post(client: TestClient):
hero = HeroFactory().build(name="Steelman", age=25)
hero_dict = hero.dict()
response = client.post("/hero/", json={"hero": hero_dict})
assert response.status_code == 200
response_hero = Hero.parse_obj(response.json())
db_hero = Hero().get(id=response_hero.id, config=config)
assert db_hero.name == "Steelman"
assert db_hero.age == 25
def test_api_read_all(config):
hero = HeroFactory().build(name="Mothman", age=25, id=99)
hero_id = hero.id
hero = hero.post(config=config)
response = client.get("/heros/")
assert response.status_code == 200
heros = response.json()
response_hero_json = [hero for hero in heros if hero["id"] == hero_id][0]
response_hero = Hero.parse_obj(response_hero_json)
assert response_hero.id == hero_id
assert response_hero.name == "Mothman"
assert response_hero.name == "Steelman"
assert response_hero.age == 25
def test_api_read_heroes(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_2 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
session.add(hero_1)
session.add(hero_2)
session.commit()
response = client.get("/heros/")
data = response.json()
assert response.status_code == 200
assert len(data) == 2
assert data[0]["name"] == hero_1.name
assert data[0]["secret_name"] == hero_1.secret_name
assert data[0]["age"] == hero_1.age
assert data[0]["id"] == hero_1.id
assert data[1]["name"] == hero_2.name
assert data[1]["secret_name"] == hero_2.secret_name
assert data[1]["age"] == hero_2.age
assert data[1]["id"] == hero_2.id
def test_api_read_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.get(f"/hero/999")
assert response.status_code == 404
def test_api_read_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.get(f"/hero/{hero_1.id}")
data = response.json()
assert response.status_code == 200
assert data["name"] == hero_1.name
assert data["secret_name"] == hero_1.secret_name
assert data["age"] == hero_1.age
assert data["id"] == hero_1.id
def test_api_update_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.patch(
f"/hero/", json={"hero": {"name": "Deadpuddle", "id": hero_1.id}}
)
data = response.json()
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson"
assert data["age"] is None
assert data["id"] == hero_1.id
def test_api_update_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.patch(f"/hero/", json={"hero": {"name": "Deadpuddle", "id": 999}})
assert response.status_code == 404
def test_delete_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.delete(f"/hero/{hero_1.id}")
hero_in_db = session.get(Hero, hero_1.id)
assert response.status_code == 200
assert hero_in_db is None
def test_delete_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
session.add(hero_1)
session.commit()
response = client.delete(f"/hero/999")
assert response.status_code == 404
def test_config_memory(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
heroes = session.exec(select(Hero)).all()
assert hero.name == "Steelman"
assert hero.age == 25
assert len(heroes) == 1
def test_cli_get(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "1"])
assert result.exit_code == 0
assert f"name='{hero.name}'" in result.stdout
assert f"secret_name='{hero.secret_name}'" in result.stdout
def test_cli_get_404(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "Hero not found"