This commit is contained in:
Waylon Walker 2023-06-21 16:50:39 -05:00
parent a7e6f2c4e5
commit eb448597c8
6 changed files with 379 additions and 251 deletions

View file

@ -55,8 +55,7 @@ def clear() -> Union[Hero, List[Hero]]:
"list many heros" "list many heros"
heros = Heros.list() heros = Heros.list()
for hero in heros.heros: for hero in heros.heros:
HeroDelete(id=hero.id).delete() HeroDelete.delete(id=hero.id)
return hero return hero
@ -81,10 +80,12 @@ def update(
@hero_app.command() @hero_app.command()
@engorgio(typer=True) @engorgio(typer=True)
def delete( def delete(
hero: HeroDelete, hero_id: Optional[int] = typer.Argument(default=None),
) -> Hero: ) -> Hero:
"delete a hero by id" "delete a hero by id"
hero.delete() hero = HeroDelete.delete(id=hero_id)
Console().print(hero)
return hero
@hero_app.command() @hero_app.command()

View file

@ -1,4 +1,4 @@
from typing import Optional from typing import Dict, Optional
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -93,9 +93,10 @@ class HeroUpdate(SQLModel):
class HeroDelete(BaseModel): class HeroDelete(BaseModel):
id: int id: int
def delete(self) -> Hero: @classmethod
def delete(self, id: int) -> Dict[str, bool]:
r = httpx.delete( r = httpx.delete(
f"{config.api_client.url}/hero/{self.id}", f"{config.api_client.url}/hero/{id}",
) )
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}") raise RuntimeError(f"{r.status_code}:\n {r.text}")

View file

@ -1,46 +1,62 @@
from typing import Optional from typing import Optional
from fastapi import Depends, HTTPException from fastapi import HTTPException
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel, Session, select from sqlmodel import Field, Relationship, SQLModel, Session, select
from learn_sql_model.config import config, get_config from learn_sql_model.config import config
from learn_sql_model.models.pet import Pet from learn_sql_model.models.pet import Pet
class {{modelname}}Base(SQLModel, table=False): class {{ model }}Base(SQLModel, table=False):
name: str
secret_name: str
x: int
y: int
size: int
age: Optional[int] = None
shoe_size: Optional[int] = None
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
class {{modelname}}({{modelname}}Base, table=True): class {{ model }}({{ model }}Base, table=True):
id: Optional[int] = Field(default=None, primary_key=True) id: Optional[int] = Field(default=None, primary_key=True)
class {{modelname}}Create({{modelname}}Base): class {{ model }}Create({{ model }}Base):
... ...
def post(self) -> {{modelname}}: def post(self) -> {{ model }}:
r = httpx.post( r = httpx.post(
f"{config.api_client.url}/{{modelname.lower()}}/", f"{config.api_client.url}/{{ model.lower() }}/",
json=self.dict(), json=self.dict(),
) )
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}") raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model }}.parse_obj(r.json())
class {{modelname}}Read({{modelname}}Base):
class {{ model }}Read({{ model }}Base):
id: int id: int
@classmethod @classmethod
def get( def get(
cls, cls,
id: int, id: int,
) -> {{modelname}}: ) -> {{ model }}:
with config.database.session as session: with config.database.session as session:
{{modelname.lower()}} = session.get({{modelname}}, id) {{ model.lower() }} = session.get({{ model }}, id)
if not {{modelname.lower()}}: if not {{ model.lower() }}:
raise HTTPException(status_code=404, detail="{{modelname}} not found") raise HTTPException(status_code=404, detail="{{ model }} not found")
return {{modelname.lower()}} return {{ model.lower() }}
class {{ model }}s(BaseModel):
{{ model.lower() }}s: list[{{ model }}]
@classmethod @classmethod
def list( def list(
@ -49,45 +65,70 @@ class {{modelname}}Read({{modelname}}Base):
offset=0, offset=0,
limit=None, limit=None,
session: Session = None, session: Session = None,
) -> {{modelname}}: ) -> {{ model }}:
# with config.database.session as session:
def get_{{ model.lower() }}s(session, where, offset, limit):
statement = select({{ model }})
if where != "None" and where is not None:
from sqlmodel import text
statement = statement.where(text(where))
statement = statement.offset(offset).limit(limit)
{{ model.lower() }}s = session.exec(statement).all()
return {{ model }}s({{ model.lower() }}s={{ model.lower() }}s)
if session is None: if session is None:
session = get_config().database.session
statement = select({{modelname}}) r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}s/")
if where != "None" and where is not None: if r.status_code != 200:
from sqlmodel import text raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model }}s.parse_obj(r.json())
statement = statement.where(text(where)) return get_{{ model.lower() }}s(session, where, offset, limit)
statement = statement.offset(offset).limit(limit)
{{modelname.lower()}}es = session.exec(statement).all()
return {{modelname.lower()}}es
class {{modelname}}Update(SQLModel): class {{ model }}Update(SQLModel):
# id is required to update the {{modelname.lower()}} # id is required to update the {{ model.lower() }}
id: int id: int
# all other fields, must match the model, but with Optional default None # all other fields, must match the model, but with Optional default None
name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
shoe_size: Optional[int] = None
x: int
y: int
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{modelname.lower()}}") pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
def update(self, session: Session = None) -> {{ model }}:
if session is not None:
db_{{ model.lower() }} = session.get({{ model }}, self.id)
if not db_{{ model.lower() }}:
raise HTTPException(status_code=404, detail="{{ model }} not found")
for key, value in self.dict(exclude_unset=True).items():
setattr(db_{{ model.lower() }}, key, value)
session.add(db_{{ model.lower() }})
session.commit()
session.refresh(db_{{ model.lower() }})
return db_{{ model.lower() }}
def update(self) -> {{modelname}}:
r = httpx.patch( r = httpx.patch(
f"{config.api_client.url}/{{modelname.lower()}}/", f"{config.api_client.url}/{{ model.lower() }}/",
json=self.dict(), json=self.dict(),
) )
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}") raise RuntimeError(f"{r.status_code}:\n {r.text}")
class {{modelname}}Delete(BaseModel): class {{ model }}Delete(BaseModel):
id: int id: int
def delete(self) -> {{modelname}}: def delete(self) -> {{ model }}:
r = httpx.delete( r = httpx.delete(
f"{config.api_client.url}/{{modelname.lower()}}/{self.id}", f"{config.api_client.url}/{{ model.lower() }}/{self.id}",
) )
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}") raise RuntimeError(f"{r.status_code}:\n {r.text}")

View file

@ -1,234 +1,103 @@
from fastapi.testclient import TestClient from typing import Optional
import pytest
from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session, select
from sqlmodel.pool import StaticPool
from typer.testing import CliRunner
from learn_sql_model.api.app import app import httpx
from learn_sql_model.config import get_config, get_session from pydantic import BaseModel
from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory from sqlmodel import Field, Relationship, SQLModel
from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}
runner = CliRunner() from learn_sql_model.config import config
client = TestClient(app) from learn_sql_model.models.pet import Pet
@pytest.fixture(name="session") class {{ model.lower }}Base(SQLModel, table=False):
def session_fixture(): name: str
engine = create_engine( secret_name: str
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool x: int
) y: int
SQLModel.metadata.create_all(engine) size: int
with Session(engine) as session: age: Optional[int] = None
yield session shoe_size: Optional[int] = None
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
@pytest.fixture(name="client") class {{ model.lower }}({{ model.lower }}Base, table=True):
def client_fixture(session: Session): id: Optional[int] = Field(default=None, primary_key=True)
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_api_post(client: TestClient): class {{ model.lower }}Create({{ model.lower }}Base):
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) ...
{{modelname.lower()}}_dict = {{modelname.lower()}}.dict()
response = client.post("/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {{modelname.lower()}}_dict})
response_{{modelname.lower()}} = {{modelname}}.parse_obj(response.json())
assert response.status_code == 200 def post(self) -> {{ model.lower }}:
assert response_{{modelname.lower()}}.name == "Steelman" r = httpx.post(
assert response_{{modelname.lower()}}.age == 25 f"{config.api_client.url}/{{ model.lower() }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model.lower }}.parse_obj(r.json())
def test_api_read_{{modelname.lower()}}es(session: Session, client: TestClient): class {{ model.lower }}Read({{ model.lower }}Base):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") id: int
{{modelname.lower()}}_2 = {{modelname}}(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
session.add({{modelname.lower()}}_1)
session.add({{modelname.lower()}}_2)
session.commit()
response = client.get("/{{modelname.lower()}}s/") @classmethod
data = response.json() def get(
cls,
assert response.status_code == 200 id: int,
) -> {{ model.lower }}:
assert len(data) == 2 r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}/{id}")
assert data[0]["name"] == {{modelname.lower()}}_1.name if r.status_code != 200:
assert data[0]["secret_name"] == {{modelname.lower()}}_1.secret_name raise RuntimeError(f"{r.status_code}:\n {r.text}")
assert data[0]["age"] == {{modelname.lower()}}_1.age return {{ model.lower() }}
assert data[0]["id"] == {{modelname.lower()}}_1.id
assert data[1]["name"] == {{modelname.lower()}}_2.name
assert data[1]["secret_name"] == {{modelname.lower()}}_2.secret_name
assert data[1]["age"] == {{modelname.lower()}}_2.age
assert data[1]["id"] == {{modelname.lower()}}_2.id
def test_api_read_{{modelname.lower()}}(session: Session, client: TestClient): class {{ model.lower }}s(BaseModel):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") {{ model.lower() }}s: list[{{ model.lower }}]
session.add({{modelname.lower()}}_1)
session.commit()
response = client.get(f"/{{modelname.lower()}}/999") @classmethod
assert response.status_code == 404 def list(
self,
) -> {{ model.lower }}:
r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}s/")
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {{ model.lower }}s.parse_obj(r.json())
def test_api_read_{{modelname.lower()}}_404(session: Session, client: TestClient): class {{ model.lower }}Update(SQLModel):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") # id is required to update the {{ model.lower() }}
session.add({{modelname.lower()}}_1) id: int
session.commit()
response = client.get(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}") # all other fields, must match the model, but with Optional default None
data = response.json() name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
shoe_size: Optional[int] = None
x: int
y: int
assert response.status_code == 200 pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
assert data["name"] == {{modelname.lower()}}_1.name pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}")
assert data["secret_name"] == {{modelname.lower()}}_1.secret_name
assert data["age"] == {{modelname.lower()}}_1.age def update(self) -> {{ model.lower }}:
assert data["id"] == {{modelname.lower()}}_1.id r = httpx.patch(
f"{config.api_client.url}/{{ model.lower() }}/",
json=self.dict(),
)
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
def test_api_update_{{modelname.lower()}}(session: Session, client: TestClient): class {{ model.lower }}Delete(BaseModel):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") id: int
session.add({{modelname.lower()}}_1)
session.commit()
response = client.patch( def delete(self) -> {{ model.lower }}:
f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": {{modelname.lower()}}_1.id}} r = httpx.delete(
) f"{config.api_client.url}/{{ model.lower() }}/{self.id}",
data = response.json() )
if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}")
return {"ok": True}
assert response.status_code == 200
assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson"
assert data["age"] is None
assert data["id"] == {{modelname.lower()}}_1.id
def test_api_update_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.patch(f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": 999}})
assert response.status_code == 404
def test_delete_{{modelname.lower()}}(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}")
{{modelname.lower()}}_in_db = session.get({{modelname}}, {{modelname.lower()}}_1.id)
assert response.status_code == 200
assert {{modelname.lower()}}_in_db is None
def test_delete_{{modelname.lower()}}_404(session: Session, client: TestClient):
{{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson")
session.add({{modelname.lower()}}_1)
session.commit()
response = client.delete(f"/{{modelname.lower()}}/999")
assert response.status_code == 404
def test_config_memory(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
{{modelname.lower()}}es = session.exec(select({{modelname}})).all()
assert {{modelname.lower()}}.name == "Steelman"
assert {{modelname.lower()}}.age == 25
assert len({{modelname.lower()}}es) == 1
def test_cli_get(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "1"])
assert result.exit_code == 0
assert f"name='{{{modelname.lower()}}.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}.secret_name}'" in result.stdout
def test_cli_get_404(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25)
with config.database.session as session:
session.add({{modelname.lower()}})
session.commit()
{{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id)
result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "{{modelname}} not found"
def test_cli_list(mocker):
mocker.patch(
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
{{modelname.lower()}}_1 = {{modelname}}Factory().build(name="Steelman", age=25)
{{modelname.lower()}}_2 = {{modelname}}Factory().build(name="Hunk", age=52)
with config.database.session as session:
session.add({{modelname.lower()}}_1)
session.add({{modelname.lower()}}_2)
session.commit()
session.refresh({{modelname.lower()}}_1)
session.refresh({{modelname.lower()}}_2)
result = runner.invoke({{modelname.lower()}}_app, ["list"])
assert result.exit_code == 0
assert f"name='{{{modelname.lower()}}_1.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}_1.secret_name}'" in result.stdout
assert f"name='{{{modelname.lower()}}_2.name}'" in result.stdout
assert f"secret_name='{{{modelname.lower()}}_2.secret_name}'" in result.stdout

View file

@ -10,7 +10,7 @@ from learn_sql_model.cli.hero import hero_app
from learn_sql_model.config import get_config, get_session from learn_sql_model.config import get_config, get_session
from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.models import hero as hero_models from learn_sql_model.models import hero as hero_models
from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, Heros from learn_sql_model.models.hero import Hero, HeroCreate, HeroDelete, HeroRead, Heros
runner = CliRunner() runner = CliRunner()
client = TestClient(app) client = TestClient(app)
@ -176,6 +176,8 @@ def test_cli_get(mocker):
assert f"name='{hero.name}'" in result.stdout assert f"name='{hero.name}'" in result.stdout
assert f"secret_name='{hero.secret_name}'" in result.stdout assert f"secret_name='{hero.secret_name}'" in result.stdout
assert httpx.get.call_count == 1 assert httpx.get.call_count == 1
assert httpx.post.call_count == 0
assert httpx.delete.call_count == 0
def test_cli_get_404(mocker): def test_cli_get_404(mocker):
@ -191,6 +193,8 @@ def test_cli_get_404(mocker):
assert result.exit_code == 1 assert result.exit_code == 1
assert " ".join(result.exception.args[0].split()) == "404: Hero not found" assert " ".join(result.exception.args[0].split()) == "404: Hero not found"
assert httpx.get.call_count == 1 assert httpx.get.call_count == 1
assert httpx.post.call_count == 0
assert httpx.delete.call_count == 0
def test_cli_list(mocker): def test_cli_list(mocker):
@ -226,6 +230,7 @@ def test_model_post(mocker):
assert result == hero assert result == hero
assert httpx.get.call_count == 0 assert httpx.get.call_count == 0
assert httpx.post.call_count == 1 assert httpx.post.call_count == 1
assert httpx.delete.call_count == 0
def test_model_post_500(mocker): def test_model_post_500(mocker):
@ -240,9 +245,10 @@ def test_model_post_500(mocker):
hero_create.post() hero_create.post()
assert httpx.get.call_count == 0 assert httpx.get.call_count == 0
assert httpx.post.call_count == 1 assert httpx.post.call_count == 1
assert httpx.delete.call_count == 0
def test_model_read_hero(mocker, session: Session, client: TestClient): def test_model_read_hero(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1) hero = HeroFactory().build(name="Steelman", age=25, id=1)
httpx = mocker.patch.object(hero_models, "httpx") httpx = mocker.patch.object(hero_models, "httpx")
@ -255,9 +261,10 @@ def test_model_read_hero(mocker, session: Session, client: TestClient):
assert hero_read.secret_name == hero.secret_name assert hero_read.secret_name == hero.secret_name
assert httpx.get.call_count == 1 assert httpx.get.call_count == 1
assert httpx.post.call_count == 0 assert httpx.post.call_count == 0
assert httpx.delete.call_count == 0
def test_model_read_hero_404(mocker, session: Session, client: TestClient): def test_model_read_hero_404(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1) hero = HeroFactory().build(name="Steelman", age=25, id=1)
httpx = mocker.patch.object(hero_models, "httpx") httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock() httpx.get.return_value = mocker.Mock()
@ -269,3 +276,68 @@ def test_model_read_hero_404(mocker, session: Session, client: TestClient):
assert e.value.args[0] == "404: Hero not found" assert e.value.args[0] == "404: Hero not found"
assert httpx.get.call_count == 1 assert httpx.get.call_count == 1
assert httpx.post.call_count == 0 assert httpx.post.call_count == 0
assert httpx.delete.call_count == 0
def test_model_delete_hero(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()
httpx.delete.return_value.status_code = 200
httpx.delete.return_value.json.return_value = hero.dict()
hero_delete = HeroDelete.delete(id=hero.id)
assert hero_delete == {"ok": True}
assert httpx.get.call_count == 0
assert httpx.post.call_count == 0
assert httpx.delete.call_count == 1
def test_model_delete_hero_404(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()
httpx.delete.return_value.status_code = 404
httpx.get.return_value.text = "Hero not found"
with pytest.raises(RuntimeError) as e:
HeroDelete.delete(id=hero.id)
assert e.value.args[0] == "404: Hero not found"
assert httpx.get.call_count == 0
assert httpx.post.call_count == 0
assert httpx.delete.call_count == 1
def test_cli_delete_hero(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()
httpx.delete.return_value.status_code = 200
httpx.delete.return_value.json.return_value = hero.dict()
result = runner.invoke(hero_app, ["delete", "--hero-id", "1"])
assert result.exit_code == 0
assert "{'ok': True}" in result.stdout
assert httpx.get.call_count == 0
assert httpx.post.call_count == 0
assert httpx.delete.call_count == 1
def test_cli_delete_hero_404(mocker):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
httpx = mocker.patch.object(hero_models, "httpx")
httpx.delete.return_value = mocker.Mock()
httpx.delete.return_value.status_code = 404
httpx.delete.return_value.text = "Hero not found"
httpx.delete.return_value.json.return_value = hero.dict()
result = runner.invoke(hero_app, ["delete", "--hero-id", "999"])
assert result.exit_code == 1
assert " ".join(result.exception.args[0].split()) == "404: Hero not found"
assert httpx.get.call_count == 0
assert httpx.post.call_count == 0
assert httpx.delete.call_count == 1

144
tmp.py Normal file
View file

@ -0,0 +1,144 @@
import sqlite3
from graphviz import Digraph
def generate_er_diagram(database_path, output_path):
# Connect to the SQLite database
conn = sqlite3.connect(database_path)
cursor = conn.cursor()
# Get the table names from the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
# Create a new Digraph
dot = Digraph(format="png")
dot.attr(rankdir="TD")
# Iterate over the tables
for table in tables:
table_name = table[0]
dot.node(table_name, shape="box")
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
# Add the columns to the table node
for column in columns:
column_name = column[1]
dot.node(f"{table_name}.{column_name}", label=column_name, shape="oval")
dot.edge(table_name, f"{table_name}.{column_name}")
# Check for foreign key relationships
cursor.execute(f"PRAGMA foreign_key_list({table_name});")
foreign_keys = cursor.fetchall()
# Add dotted lines for foreign key relationships
for foreign_key in foreign_keys:
from_column = foreign_key[3]
to_table = foreign_key[2]
to_column = foreign_key[4]
dot.node(f"{to_table}.{to_column}", shape="oval")
dot.edge(
f"{table_name}.{from_column}", f"{to_table}.{to_column}", style="dotted"
)
# Render and save the diagram
dot.render(output_path.replace(".png", ""), cleanup=True)
# Close the database connection
cursor.close()
conn.close()
def generate_markdown(database_path, output_path, er_diagram_path):
# Connect to the SQLite database
conn = sqlite3.connect(database_path)
cursor = conn.cursor()
# Get the table names from the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
with open(output_path, "w") as f:
# Write the ER Diagram image
f.write(f"![ER Diagram]({er_diagram_path})\n\n---\n\n")
# Iterate over the tables
for table in tables:
table_name = table[0]
f.write(f"## Table: {table_name}\n\n")
# Get the table columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
f.write("### First 5 rows\n\n")
cursor.execute(f"SELECT * FROM {table_name} LIMIT 5;")
rows = cursor.fetchall()
f.write(f'| {" | ".join([c[1] for c in columns])} |\n')
f.write("|")
for column in columns:
# ---
f.write(f'{"-"*(len(column[1]) + 2)}|')
f.write("\n")
for row in rows:
f.write(f'| {" | ".join([str(r) for r in row])} |\n')
f.write("\n")
cursor.execute(f"PRAGMA foreign_key_list({table_name});")
foreign_keys = cursor.fetchall()
# Add dotted lines for foreign key relationships
fkeys = {}
for foreign_key in foreign_keys:
from_column = foreign_key[3]
to_table = foreign_key[2]
to_column = foreign_key[4]
fkeys[from_column] = f"{to_table}.{to_column}"
# Replace 'description' with the actual column name in the table that contains the description, if applicable
try:
cursor.execute(f"SELECT description FROM {table_name} LIMIT 1;")
description = cursor.fetchone()
if description:
f.write(f"### Description\n\n{description[0]}\n\n")
except:
...
# Write the table columns
f.write("### Columns\n\n")
f.write("| Column Name | Type | Foreign Key | Example Value |\n")
f.write("|-------------|------|-------------|---------------|\n")
for column in columns:
column_name = column[1]
column_type = column[2]
fkey = ""
if column_name in fkeys:
fkey = fkeys[column_name]
f.write(f"| {column_name} | {column_type} | {fkey} | | |\n")
f.write("\n")
# Get the count of records
cursor.execute(f"SELECT COUNT(*) FROM {table_name};")
records_count = cursor.fetchone()[0]
f.write(
f"### Records Count\n\nThe table {table_name} contains {records_count} records.\n\n---\n\n"
)
# Close the database connection
cursor.close()
conn.close()
# Usage example
database_path = "database.db"
md_output_path = "database.md"
er_output_path = "er_diagram.png"
generate_er_diagram(database_path, er_output_path)
generate_markdown(database_path, md_output_path, er_output_path)