wip
This commit is contained in:
parent
1a0bf1adb9
commit
c3db85a209
21 changed files with 647 additions and 658 deletions
|
|
@ -1,15 +1,12 @@
|
|||
import tempfile
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel import SQLModel, Session, select
|
||||
from sqlmodel.pool import StaticPool
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from learn_sql_model.api.app import app
|
||||
from learn_sql_model.cli.hero import hero_app
|
||||
from learn_sql_model.config import Config, get_config
|
||||
from learn_sql_model.config import get_config, get_session
|
||||
from learn_sql_model.factories.hero import HeroFactory
|
||||
from learn_sql_model.models.hero import Hero
|
||||
|
||||
|
|
@ -17,142 +14,193 @@ runner = CliRunner()
|
|||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config() -> Config:
|
||||
tmp_db = tempfile.NamedTemporaryFile(suffix=".db")
|
||||
config = get_config({"database_url": f"sqlite:///{tmp_db.name}"})
|
||||
|
||||
@pytest.fixture(name="session")
|
||||
def session_fixture():
|
||||
engine = create_engine(
|
||||
config.database_url, connect_args={"check_same_thread": False}
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# breakpoint()
|
||||
SQLModel.metadata.create_all(config.database.engine)
|
||||
|
||||
# def override_get_db():
|
||||
# try:
|
||||
# db = TestingSessionLocal()
|
||||
# yield db
|
||||
# finally:
|
||||
# db.close()
|
||||
def override_get_config():
|
||||
return config
|
||||
|
||||
app.dependency_overrides[get_config] = override_get_config
|
||||
|
||||
yield config
|
||||
# tmp_db automatically deletes here
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def test_post_hero(config: Config) -> None:
|
||||
hero = HeroFactory().build(name="Batman", age=50, id=1)
|
||||
hero = hero.post(config=config)
|
||||
db_hero = Hero().get(id=1, config=config)
|
||||
assert db_hero.age == 50
|
||||
assert db_hero.name == "Batman"
|
||||
@pytest.fixture(name="client")
|
||||
def client_fixture(session: Session):
|
||||
def get_session_override():
|
||||
return session
|
||||
|
||||
app.dependency_overrides[get_session] = get_session_override
|
||||
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_update_hero(config: Config) -> None:
|
||||
hero = HeroFactory().build(name="Batman", age=50, id=1)
|
||||
hero = hero.post(config=config)
|
||||
db_hero = Hero().get(id=1, config=config)
|
||||
db_hero.name = "Superbman"
|
||||
hero = db_hero.post(config=config)
|
||||
db_hero = Hero().get(id=1, config=config)
|
||||
assert db_hero.age == 50
|
||||
assert db_hero.name == "Superbman"
|
||||
|
||||
|
||||
def test_cli_get(config):
|
||||
hero = HeroFactory().build(name="Steelman", age=25, id=99)
|
||||
hero.post(config=config)
|
||||
result = runner.invoke(
|
||||
hero_app,
|
||||
["get", "--id", 99, "--database-url", config.database_url],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
db_hero = Hero().get(id=99, config=config)
|
||||
assert db_hero.age == 25
|
||||
assert db_hero.name == "Steelman"
|
||||
|
||||
|
||||
def test_cli_create(config):
|
||||
hero = HeroFactory().build(name="Steelman", age=25, id=99)
|
||||
result = runner.invoke(
|
||||
hero_app,
|
||||
[
|
||||
"create",
|
||||
*hero.flags(config=config),
|
||||
"--database-url",
|
||||
config.database_url,
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
db_hero = Hero().get(id=99, config=config)
|
||||
assert db_hero.age == 25
|
||||
assert db_hero.name == "Steelman"
|
||||
|
||||
|
||||
def test_cli_populate(config):
|
||||
result = runner.invoke(
|
||||
hero_app,
|
||||
[
|
||||
"populate",
|
||||
"--n",
|
||||
10,
|
||||
"--database-url",
|
||||
config.database_url,
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
db_hero = Hero().get(config=config)
|
||||
assert len(db_hero) == 10
|
||||
|
||||
|
||||
def test_cli_populate_fails_prod(config):
|
||||
result = runner.invoke(
|
||||
hero_app,
|
||||
["populate", "--n", 10, "--database-url", config.database_url, "--env", "prod"],
|
||||
)
|
||||
assert result.exit_code == 1
|
||||
assert result.output.strip() == "populate is not supported in production"
|
||||
|
||||
|
||||
def test_api_read(config):
|
||||
hero = HeroFactory().build(name="Steelman", age=25, id=99)
|
||||
hero_id = hero.id
|
||||
hero = hero.post(config=config)
|
||||
response = client.get(f"/hero/{hero_id}")
|
||||
assert response.status_code == 200
|
||||
reponse_hero = Hero.parse_obj(response.json())
|
||||
assert reponse_hero.id == hero_id
|
||||
assert reponse_hero.name == "Steelman"
|
||||
assert reponse_hero.age == 25
|
||||
|
||||
|
||||
def test_api_post(config):
|
||||
def test_api_post(client: TestClient):
|
||||
hero = HeroFactory().build(name="Steelman", age=25)
|
||||
hero_dict = hero.dict()
|
||||
response = client.post("/hero/", json={"hero": hero_dict})
|
||||
assert response.status_code == 200
|
||||
|
||||
response_hero = Hero.parse_obj(response.json())
|
||||
db_hero = Hero().get(id=response_hero.id, config=config)
|
||||
|
||||
assert db_hero.name == "Steelman"
|
||||
assert db_hero.age == 25
|
||||
|
||||
|
||||
def test_api_read_all(config):
|
||||
hero = HeroFactory().build(name="Mothman", age=25, id=99)
|
||||
hero_id = hero.id
|
||||
hero = hero.post(config=config)
|
||||
response = client.get("/heros/")
|
||||
assert response.status_code == 200
|
||||
heros = response.json()
|
||||
response_hero_json = [hero for hero in heros if hero["id"] == hero_id][0]
|
||||
response_hero = Hero.parse_obj(response_hero_json)
|
||||
assert response_hero.id == hero_id
|
||||
assert response_hero.name == "Mothman"
|
||||
assert response_hero.name == "Steelman"
|
||||
assert response_hero.age == 25
|
||||
|
||||
|
||||
def test_api_read_heroes(session: Session, client: TestClient):
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||
hero_2 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
|
||||
session.add(hero_1)
|
||||
session.add(hero_2)
|
||||
session.commit()
|
||||
|
||||
response = client.get("/heros/")
|
||||
data = response.json()
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
assert len(data) == 2
|
||||
assert data[0]["name"] == hero_1.name
|
||||
assert data[0]["secret_name"] == hero_1.secret_name
|
||||
assert data[0]["age"] == hero_1.age
|
||||
assert data[0]["id"] == hero_1.id
|
||||
assert data[1]["name"] == hero_2.name
|
||||
assert data[1]["secret_name"] == hero_2.secret_name
|
||||
assert data[1]["age"] == hero_2.age
|
||||
assert data[1]["id"] == hero_2.id
|
||||
|
||||
|
||||
def test_api_read_hero(session: Session, client: TestClient):
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||
session.add(hero_1)
|
||||
session.commit()
|
||||
|
||||
response = client.get(f"/hero/999")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_api_read_hero_404(session: Session, client: TestClient):
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||
session.add(hero_1)
|
||||
session.commit()
|
||||
|
||||
response = client.get(f"/hero/{hero_1.id}")
|
||||
data = response.json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data["name"] == hero_1.name
|
||||
assert data["secret_name"] == hero_1.secret_name
|
||||
assert data["age"] == hero_1.age
|
||||
assert data["id"] == hero_1.id
|
||||
|
||||
|
||||
def test_api_update_hero(session: Session, client: TestClient):
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||
session.add(hero_1)
|
||||
session.commit()
|
||||
|
||||
response = client.patch(
|
||||
f"/hero/", json={"hero": {"name": "Deadpuddle", "id": hero_1.id}}
|
||||
)
|
||||
data = response.json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert data["name"] == "Deadpuddle"
|
||||
assert data["secret_name"] == "Dive Wilson"
|
||||
assert data["age"] is None
|
||||
assert data["id"] == hero_1.id
|
||||
|
||||
|
||||
def test_api_update_hero_404(session: Session, client: TestClient):
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||
session.add(hero_1)
|
||||
session.commit()
|
||||
|
||||
response = client.patch(f"/hero/", json={"hero": {"name": "Deadpuddle", "id": 999}})
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_delete_hero(session: Session, client: TestClient):
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||
session.add(hero_1)
|
||||
session.commit()
|
||||
|
||||
response = client.delete(f"/hero/{hero_1.id}")
|
||||
|
||||
hero_in_db = session.get(Hero, hero_1.id)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
assert hero_in_db is None
|
||||
|
||||
|
||||
def test_delete_hero_404(session: Session, client: TestClient):
|
||||
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson")
|
||||
session.add(hero_1)
|
||||
session.commit()
|
||||
|
||||
response = client.delete(f"/hero/999")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_config_memory(mocker):
|
||||
mocker.patch(
|
||||
"learn_sql_model.config.Database.engine",
|
||||
new_callable=lambda: create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
),
|
||||
)
|
||||
config = get_config()
|
||||
SQLModel.metadata.create_all(config.database.engine)
|
||||
hero = HeroFactory().build(name="Steelman", age=25)
|
||||
with config.database.session as session:
|
||||
session.add(hero)
|
||||
session.commit()
|
||||
hero = session.get(Hero, hero.id)
|
||||
heroes = session.exec(select(Hero)).all()
|
||||
assert hero.name == "Steelman"
|
||||
assert hero.age == 25
|
||||
assert len(heroes) == 1
|
||||
|
||||
|
||||
def test_cli_get(mocker):
|
||||
mocker.patch(
|
||||
"learn_sql_model.config.Database.engine",
|
||||
new_callable=lambda: create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
),
|
||||
)
|
||||
|
||||
config = get_config()
|
||||
SQLModel.metadata.create_all(config.database.engine)
|
||||
|
||||
hero = HeroFactory().build(name="Steelman", age=25)
|
||||
with config.database.session as session:
|
||||
session.add(hero)
|
||||
session.commit()
|
||||
hero = session.get(Hero, hero.id)
|
||||
result = runner.invoke(hero_app, ["get", "--hero-id", "1"])
|
||||
assert result.exit_code == 0
|
||||
assert f"name='{hero.name}'" in result.stdout
|
||||
assert f"secret_name='{hero.secret_name}'" in result.stdout
|
||||
|
||||
|
||||
def test_cli_get_404(mocker):
|
||||
mocker.patch(
|
||||
"learn_sql_model.config.Database.engine",
|
||||
new_callable=lambda: create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
),
|
||||
)
|
||||
|
||||
config = get_config()
|
||||
SQLModel.metadata.create_all(config.database.engine)
|
||||
|
||||
hero = HeroFactory().build(name="Steelman", age=25)
|
||||
with config.database.session as session:
|
||||
session.add(hero)
|
||||
session.commit()
|
||||
hero = session.get(Hero, hero.id)
|
||||
result = runner.invoke(hero_app, ["get", "--hero-id", "999"])
|
||||
assert result.exception.status_code == 404
|
||||
assert result.exception.detail == "Hero not found"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue