This commit is contained in:
Waylon Walker 2023-06-21 16:50:39 -05:00
parent a7e6f2c4e5
commit eb448597c8
6 changed files with 379 additions and 251 deletions

View file

@ -1,234 +1,103 @@
from fastapi.testclient import TestClient
import pytest
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from typing import Optional
from learn_sql_model.api.app import app
from learn_sql_model.config import get_config, get_session
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
import httpx
from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel
runner = CliRunner()
client = TestClient(app)
from learn_sql_model.config import config
from learn_sql_model.models.pet import Pet
@pytest.fixture(name="session")
def session_fixture():
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
class {{ model.lower }}Base(SQLModel, table=False):
name: str
secret_name: str
x: int
y: int
size: int
age: Optional[int] = None
shoe_size: Optional[int] = None
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
@pytest.fixture(name="client")
def client_fixture(session: Session):
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
client = TestClient(app)
yield client
app.dependency_overrides.clear()
class {{ model.lower }}({{ model.lower }}Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
def test_api_post(client: TestClient):
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
{{modelname.lower()}}_dict = {{modelname.lower()}}.dict()
response = client.post("/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {{modelname.lower()}}_dict})
response_{{modelname.lower()}} = {{modelname}}.parse_obj(response.json())
class {{ model.lower }}Create({{ model.lower }}Base):
...
assert response.status_code == 200
assert response_{{modelname.lower()}}.name == "Steelman"
assert response_{{modelname.lower()}}.age == 25
def post(self) -> {{ model.lower }}:
r = httpx.post(
f"{config.api_client.url}/{{ model.lower() }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model.lower }}.parse_obj(r.json())
def test_api_read_{{modelname.lower()}}es(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
{{modelname.lower()}}_2 = {{modelname}}(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
session.add({{modelname.lower()}}_1)
session.add({{modelname.lower()}}_2)
session.commit()
class {{ model.lower }}Read({{ model.lower }}Base):
id: int
response = client.get("/{{modelname.lower()}}s/")
data = response.json()
assert response.status_code == 200
assert len(data) == 2
assert data[0]["name"] == {{modelname.lower()}}_1.name
assert data[0]["secret_name"] == {{modelname.lower()}}_1.secret_name
assert data[0]["age"] == {{modelname.lower()}}_1.age
assert data[0]["id"] == {{modelname.lower()}}_1.id
assert data[1]["name"] == {{modelname.lower()}}_2.name
assert data[1]["secret_name"] == {{modelname.lower()}}_2.secret_name
assert data[1]["age"] == {{modelname.lower()}}_2.age
assert data[1]["id"] == {{modelname.lower()}}_2.id
@classmethod
def get(
cls,
id: int,
) -> {{ model.lower }}:
r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}/{id}")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model.lower() }}
def test_api_read_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
class {{ model.lower }}s(BaseModel):
{{ model.lower() }}s: list[{{ model.lower }}]
response = client.get(f"/{{modelname.lower()}}/999")
assert response.status_code == 404
@classmethod
def list(
self,
) -> {{ model.lower }}:
r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}s/")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model.lower }}s.parse_obj(r.json())
def test_api_read_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
class {{ model.lower }}Update(SQLModel):
# id is required to update the {{ model.lower() }}
id: int
response = client.get(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
data = response.json()
# all other fields, must match the model, but with Optional default None
name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
shoe_size: Optional[int] = None
x: int
y: int
assert response.status_code == 200
assert data["name"] == {{modelname.lower()}}_1.name
assert data["secret_name"] == {{modelname.lower()}}_1.secret_name
assert data["age"] == {{modelname.lower()}}_1.age
assert data["id"] == {{modelname.lower()}}_1.id
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
def update(self) -> {{ model.lower }}:
r = httpx.patch(
f"{config.api_client.url}/{{ model.lower() }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
def test_api_update_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
class {{ model.lower }}Delete(BaseModel):
id: int
response = client.patch(
f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": {{modelname.lower()}}_1.id}}
)
data = response.json()
def delete(self) -> {{ model.lower }}:
r = httpx.delete(
f"{config.api_client.url}/{{ model.lower() }}/{self.id}",
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {"ok": True}
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson"
assert data["age"] is None
assert data["id"] == {{modelname.lower()}}_1.id
def test_api_update_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.patch(f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": 999}})
assert response.status_code == 404
def test_delete_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
{{modelname.lower()}}_in_db = session.get({{modelname}}, {{modelname.lower()}}_1.id)
assert response.status_code == 200
assert {{modelname.lower()}}_in_db is None
def test_delete_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/999")
assert response.status_code == 404
def test_config_memory(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
{{modelname.lower()}}es = session.exec(select({{modelname}})).all()
assert {{modelname.lower()}}.name == "Steelman"
assert {{modelname.lower()}}.age == 25
assert len({{modelname.lower()}}es) == 1
def test_cli_get(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "1"])
assert result.exit_code == 0
assert f"name='{{{modelname.lower()}}.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}.secret_name}'" in result.stdout
def test_cli_get_404(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "{{modelname}} not found"
def test_cli_list(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}}_1 = {{modelname}}Factory().build(name="Steelman", age=25)
{{modelname.lower()}}_2 = {{modelname}}Factory().build(name="Hunk", age=52)
with config.database.session as session:
session.add({{modelname.lower()}}_1)
session.add({{modelname.lower()}}_2)
session.commit()
session.refresh({{modelname.lower()}}_1)
session.refresh({{modelname.lower()}}_2)
result = runner.invoke({{modelname.lower()}}_app, ["list"])
assert result.exit_code == 0
assert f"name='{{{modelname.lower()}}_1.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}_1.secret_name}'" in result.stdout
assert f"name='{{{modelname.lower()}}_2.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}_2.secret_name}'" in result.stdout