diff --git a/learn_sql_model/cli/hero.py b/learn_sql_model/cli/hero.py index 00047f8..fcb67c5 100644 --- a/learn_sql_model/cli/hero.py +++ b/learn_sql_model/cli/hero.py @@ -55,8 +55,7 @@ def clear() -> Union[Hero, List[Hero]]: "list many heros" heros = Heros.list() for hero in heros.heros: - HeroDelete(id=hero.id).delete() - + HeroDelete.delete(id=hero.id) return hero @@ -81,10 +80,12 @@ def update( @hero_app.command() @engorgio(typer=True) def delete( - hero: HeroDelete, + hero_id: Optional[int] = typer.Argument(default=None), ) -> Hero: "delete a hero by id" - hero.delete() + hero = HeroDelete.delete(id=hero_id) + Console().print(hero) + return hero @hero_app.command() diff --git a/learn_sql_model/models/hero.py b/learn_sql_model/models/hero.py index f96ca58..14e17b1 100644 --- a/learn_sql_model/models/hero.py +++ b/learn_sql_model/models/hero.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional import httpx from pydantic import BaseModel @@ -93,9 +93,10 @@ class HeroUpdate(SQLModel): class HeroDelete(BaseModel): id: int - def delete(self) -> Hero: + @classmethod + def delete(self, id: int) -> Dict[str, bool]: r = httpx.delete( - f"{config.api_client.url}/hero/{self.id}", + f"{config.api_client.url}/hero/{id}", ) if r.status_code != 200: raise RuntimeError(f"{r.status_code}:\n {r.text}") diff --git a/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja index 6fed597..47a310f 100644 --- a/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja +++ b/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja @@ -1,46 +1,62 @@ from typing import Optional -from fastapi import Depends, HTTPException +from fastapi import HTTPException import httpx from pydantic import BaseModel from sqlmodel import Field, Relationship, SQLModel, Session, select -from learn_sql_model.config import config, get_config +from learn_sql_model.config import config from learn_sql_model.models.pet import Pet -class {{modelname}}Base(SQLModel, table=False): +class {{ model }}Base(SQLModel, table=False): + name: str + secret_name: str + x: int + y: int + size: int + age: Optional[int] = None + shoe_size: Optional[int] = None + + pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") + pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}") -class {{modelname}}({{modelname}}Base, table=True): +class {{ model }}({{ model }}Base, table=True): id: Optional[int] = Field(default=None, primary_key=True) -class {{modelname}}Create({{modelname}}Base): +class {{ model }}Create({{ model }}Base): ... - def post(self) -> {{modelname}}: + def post(self) -> {{ model }}: r = httpx.post( - f"{config.api_client.url}/{{modelname.lower()}}/", + f"{config.api_client.url}/{{ model.lower() }}/", json=self.dict(), ) if r.status_code != 200: raise RuntimeError(f"{r.status_code}:\n {r.text}") + return {{ model }}.parse_obj(r.json()) -class {{modelname}}Read({{modelname}}Base): + +class {{ model }}Read({{ model }}Base): id: int @classmethod def get( cls, id: int, - ) -> {{modelname}}: + ) -> {{ model }}: with config.database.session as session: - {{modelname.lower()}} = session.get({{modelname}}, id) - if not {{modelname.lower()}}: - raise HTTPException(status_code=404, detail="{{modelname}} not found") - return {{modelname.lower()}} + {{ model.lower() }} = session.get({{ model }}, id) + if not {{ model.lower() }}: + raise HTTPException(status_code=404, detail="{{ model }} not found") + return {{ model.lower() }} + + +class {{ model }}s(BaseModel): + {{ model.lower() }}s: list[{{ model }}] @classmethod def list( @@ -49,45 +65,70 @@ class {{modelname}}Read({{modelname}}Base): offset=0, limit=None, session: Session = None, - ) -> {{modelname}}: + ) -> {{ model }}: + # with config.database.session as session: + + def get_{{ model.lower() }}s(session, where, offset, limit): + statement = select({{ model }}) + if where != "None" and where is not None: + from sqlmodel import text + + statement = statement.where(text(where)) + statement = statement.offset(offset).limit(limit) + {{ model.lower() }}s = session.exec(statement).all() + return {{ model }}s({{ model.lower() }}s={{ model.lower() }}s) if session is None: - session = get_config().database.session - statement = select({{modelname}}) - if where != "None" and where is not None: - from sqlmodel import text + r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}s/") + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + return {{ model }}s.parse_obj(r.json()) - statement = statement.where(text(where)) - statement = statement.offset(offset).limit(limit) - {{modelname.lower()}}es = session.exec(statement).all() - return {{modelname.lower()}}es + return get_{{ model.lower() }}s(session, where, offset, limit) -class {{modelname}}Update(SQLModel): - # id is required to update the {{modelname.lower()}} +class {{ model }}Update(SQLModel): + # id is required to update the {{ model.lower() }} id: int # all other fields, must match the model, but with Optional default None + name: Optional[str] = None + secret_name: Optional[str] = None + age: Optional[int] = None + shoe_size: Optional[int] = None + x: int + y: int pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") - pet: Optional[Pet] = Relationship(back_populates="{{modelname.lower()}}") + pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}") + + def update(self, session: Session = None) -> {{ model }}: + if session is not None: + db_{{ model.lower() }} = session.get({{ model }}, self.id) + if not db_{{ model.lower() }}: + raise HTTPException(status_code=404, detail="{{ model }} not found") + for key, value in self.dict(exclude_unset=True).items(): + setattr(db_{{ model.lower() }}, key, value) + session.add(db_{{ model.lower() }}) + session.commit() + session.refresh(db_{{ model.lower() }}) + return db_{{ model.lower() }} - def update(self) -> {{modelname}}: r = httpx.patch( - f"{config.api_client.url}/{{modelname.lower()}}/", + f"{config.api_client.url}/{{ model.lower() }}/", json=self.dict(), ) if r.status_code != 200: raise RuntimeError(f"{r.status_code}:\n {r.text}") -class {{modelname}}Delete(BaseModel): +class {{ model }}Delete(BaseModel): id: int - def delete(self) -> {{modelname}}: + def delete(self) -> {{ model }}: r = httpx.delete( - f"{config.api_client.url}/{{modelname.lower()}}/{self.id}", + f"{config.api_client.url}/{{ model.lower() }}/{self.id}", ) if r.status_code != 200: raise RuntimeError(f"{r.status_code}:\n {r.text}") diff --git a/templates/model/tests/{{modelname.lower()}}.py.jinja b/templates/model/tests/{{modelname.lower()}}.py.jinja index 39da70f..7aa9ba5 100644 --- a/templates/model/tests/{{modelname.lower()}}.py.jinja +++ b/templates/model/tests/{{modelname.lower()}}.py.jinja @@ -1,234 +1,103 @@ -from fastapi.testclient import TestClient -import pytest -from sqlalchemy import create_engine -from sqlmodel import SQLModel, Session, select -from sqlmodel.pool import StaticPool -from typer.testing import CliRunner +from typing import Optional -from learn_sql_model.api.app import app -from learn_sql_model.config import get_config, get_session -from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory -from learn_sql_model.models.{{modelname.lower()}} import {{modelname}} +import httpx +from pydantic import BaseModel +from sqlmodel import Field, Relationship, SQLModel -runner = CliRunner() -client = TestClient(app) +from learn_sql_model.config import config +from learn_sql_model.models.pet import Pet -@pytest.fixture(name="session") -def session_fixture(): - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) - SQLModel.metadata.create_all(engine) - with Session(engine) as session: - yield session +class {{ model.lower }}Base(SQLModel, table=False): + name: str + secret_name: str + x: int + y: int + size: int + age: Optional[int] = None + shoe_size: Optional[int] = None + + pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") + pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}") -@pytest.fixture(name="client") -def client_fixture(session: Session): - def get_session_override(): - return session - - app.dependency_overrides[get_session] = get_session_override - - client = TestClient(app) - yield client - app.dependency_overrides.clear() +class {{ model.lower }}({{ model.lower }}Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) -def test_api_post(client: TestClient): - {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) - {{modelname.lower()}}_dict = {{modelname.lower()}}.dict() - response = client.post("/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {{modelname.lower()}}_dict}) - response_{{modelname.lower()}} = {{modelname}}.parse_obj(response.json()) +class {{ model.lower }}Create({{ model.lower }}Base): + ... - assert response.status_code == 200 - assert response_{{modelname.lower()}}.name == "Steelman" - assert response_{{modelname.lower()}}.age == 25 + def post(self) -> {{ model.lower }}: + r = httpx.post( + f"{config.api_client.url}/{{ model.lower() }}/", + json=self.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + + return {{ model.lower }}.parse_obj(r.json()) -def test_api_read_{{modelname.lower()}}es(session: Session, client: TestClient): - {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") - {{modelname.lower()}}_2 = {{modelname}}(name="Rusty-Man", secret_name="Tommy Sharp", age=48) - session.add({{modelname.lower()}}_1) - session.add({{modelname.lower()}}_2) - session.commit() +class {{ model.lower }}Read({{ model.lower }}Base): + id: int - response = client.get("/{{modelname.lower()}}s/") - data = response.json() - - assert response.status_code == 200 - - assert len(data) == 2 - assert data[0]["name"] == {{modelname.lower()}}_1.name - assert data[0]["secret_name"] == {{modelname.lower()}}_1.secret_name - assert data[0]["age"] == {{modelname.lower()}}_1.age - assert data[0]["id"] == {{modelname.lower()}}_1.id - assert data[1]["name"] == {{modelname.lower()}}_2.name - assert data[1]["secret_name"] == {{modelname.lower()}}_2.secret_name - assert data[1]["age"] == {{modelname.lower()}}_2.age - assert data[1]["id"] == {{modelname.lower()}}_2.id + @classmethod + def get( + cls, + id: int, + ) -> {{ model.lower }}: + r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}/{id}") + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + return {{ model.lower() }} -def test_api_read_{{modelname.lower()}}(session: Session, client: TestClient): - {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") - session.add({{modelname.lower()}}_1) - session.commit() +class {{ model.lower }}s(BaseModel): + {{ model.lower() }}s: list[{{ model.lower }}] - response = client.get(f"/{{modelname.lower()}}/999") - assert response.status_code == 404 + @classmethod + def list( + self, + ) -> {{ model.lower }}: + r = httpx.get(f"{config.api_client.url}/{{ model.lower() }}s/") + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + return {{ model.lower }}s.parse_obj(r.json()) -def test_api_read_{{modelname.lower()}}_404(session: Session, client: TestClient): - {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") - session.add({{modelname.lower()}}_1) - session.commit() +class {{ model.lower }}Update(SQLModel): + # id is required to update the {{ model.lower() }} + id: int - response = client.get(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}") - data = response.json() + # all other fields, must match the model, but with Optional default None + name: Optional[str] = None + secret_name: Optional[str] = None + age: Optional[int] = None + shoe_size: Optional[int] = None + x: int + y: int - assert response.status_code == 200 - assert data["name"] == {{modelname.lower()}}_1.name - assert data["secret_name"] == {{modelname.lower()}}_1.secret_name - assert data["age"] == {{modelname.lower()}}_1.age - assert data["id"] == {{modelname.lower()}}_1.id + pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") + pet: Optional[Pet] = Relationship(back_populates="{{ model.lower() }}") + + def update(self) -> {{ model.lower }}: + r = httpx.patch( + f"{config.api_client.url}/{{ model.lower() }}/", + json=self.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") -def test_api_update_{{modelname.lower()}}(session: Session, client: TestClient): - {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") - session.add({{modelname.lower()}}_1) - session.commit() +class {{ model.lower }}Delete(BaseModel): + id: int - response = client.patch( - f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": {{modelname.lower()}}_1.id}} - ) - data = response.json() + def delete(self) -> {{ model.lower }}: + r = httpx.delete( + f"{config.api_client.url}/{{ model.lower() }}/{self.id}", + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + return {"ok": True} - assert response.status_code == 200 - assert data["name"] == "Deadpuddle" - assert data["secret_name"] == "Dive Wilson" - assert data["age"] is None - assert data["id"] == {{modelname.lower()}}_1.id - - -def test_api_update_{{modelname.lower()}}_404(session: Session, client: TestClient): - {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") - session.add({{modelname.lower()}}_1) - session.commit() - - response = client.patch(f"/{{modelname.lower()}}/", json={"{{modelname.lower()}}": {"name": "Deadpuddle", "id": 999}}) - assert response.status_code == 404 - - -def test_delete_{{modelname.lower()}}(session: Session, client: TestClient): - {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") - session.add({{modelname.lower()}}_1) - session.commit() - - response = client.delete(f"/{{modelname.lower()}}/{{{modelname.lower()}}_1.id}") - - {{modelname.lower()}}_in_db = session.get({{modelname}}, {{modelname.lower()}}_1.id) - - assert response.status_code == 200 - - assert {{modelname.lower()}}_in_db is None - - -def test_delete_{{modelname.lower()}}_404(session: Session, client: TestClient): - {{modelname.lower()}}_1 = {{modelname}}(name="Deadpond", secret_name="Dive Wilson") - session.add({{modelname.lower()}}_1) - session.commit() - - response = client.delete(f"/{{modelname.lower()}}/999") - assert response.status_code == 404 - - -def test_config_memory(mocker): - mocker.patch( - "learn_sql_model.config.Database.engine", - new_callable=lambda: create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ), - ) - config = get_config() - SQLModel.metadata.create_all(config.database.engine) - {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) - with config.database.session as session: - session.add({{modelname.lower()}}) - session.commit() - {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id) - {{modelname.lower()}}es = session.exec(select({{modelname}})).all() - assert {{modelname.lower()}}.name == "Steelman" - assert {{modelname.lower()}}.age == 25 - assert len({{modelname.lower()}}es) == 1 - - -def test_cli_get(mocker): - mocker.patch( - "learn_sql_model.config.Database.engine", - new_callable=lambda: create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ), - ) - - config = get_config() - SQLModel.metadata.create_all(config.database.engine) - - {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) - with config.database.session as session: - session.add({{modelname.lower()}}) - session.commit() - {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id) - result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "1"]) - assert result.exit_code == 0 - assert f"name='{{{modelname.lower()}}.name}'" in result.stdout - assert f"secret_name='{{{modelname.lower()}}.secret_name}'" in result.stdout - - -def test_cli_get_404(mocker): - mocker.patch( - "learn_sql_model.config.Database.engine", - new_callable=lambda: create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ), - ) - - config = get_config() - SQLModel.metadata.create_all(config.database.engine) - - {{modelname.lower()}} = {{modelname}}Factory().build(name="Steelman", age=25) - with config.database.session as session: - session.add({{modelname.lower()}}) - session.commit() - {{modelname.lower()}} = session.get({{modelname}}, {{modelname.lower()}}.id) - result = runner.invoke({{modelname.lower()}}_app, ["get", "--{{modelname.lower()}}-id", "999"]) - assert result.exception.status_code == 404 - assert result.exception.detail == "{{modelname}} not found" - - -def test_cli_list(mocker): - mocker.patch( - "learn_sql_model.config.Database.engine", - new_callable=lambda: create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ), - ) - - config = get_config() - SQLModel.metadata.create_all(config.database.engine) - - {{modelname.lower()}}_1 = {{modelname}}Factory().build(name="Steelman", age=25) - {{modelname.lower()}}_2 = {{modelname}}Factory().build(name="Hunk", age=52) - - with config.database.session as session: - session.add({{modelname.lower()}}_1) - session.add({{modelname.lower()}}_2) - session.commit() - session.refresh({{modelname.lower()}}_1) - session.refresh({{modelname.lower()}}_2) - result = runner.invoke({{modelname.lower()}}_app, ["list"]) - assert result.exit_code == 0 - assert f"name='{{{modelname.lower()}}_1.name}'" in result.stdout - assert f"secret_name='{{{modelname.lower()}}_1.secret_name}'" in result.stdout - assert f"name='{{{modelname.lower()}}_2.name}'" in result.stdout - assert f"secret_name='{{{modelname.lower()}}_2.secret_name}'" in result.stdout diff --git a/tests/test_hero.py b/tests/test_hero.py index 9f74a2b..d76308e 100644 --- a/tests/test_hero.py +++ b/tests/test_hero.py @@ -10,7 +10,7 @@ from learn_sql_model.cli.hero import hero_app from learn_sql_model.config import get_config, get_session from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.models import hero as hero_models -from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, Heros +from learn_sql_model.models.hero import Hero, HeroCreate, HeroDelete, HeroRead, Heros runner = CliRunner() client = TestClient(app) @@ -176,6 +176,8 @@ def test_cli_get(mocker): assert f"name='{hero.name}'" in result.stdout assert f"secret_name='{hero.secret_name}'" in result.stdout assert httpx.get.call_count == 1 + assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 0 def test_cli_get_404(mocker): @@ -191,6 +193,8 @@ def test_cli_get_404(mocker): assert result.exit_code == 1 assert " ".join(result.exception.args[0].split()) == "404: Hero not found" assert httpx.get.call_count == 1 + assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 0 def test_cli_list(mocker): @@ -226,6 +230,7 @@ def test_model_post(mocker): assert result == hero assert httpx.get.call_count == 0 assert httpx.post.call_count == 1 + assert httpx.delete.call_count == 0 def test_model_post_500(mocker): @@ -240,9 +245,10 @@ def test_model_post_500(mocker): hero_create.post() assert httpx.get.call_count == 0 assert httpx.post.call_count == 1 + assert httpx.delete.call_count == 0 -def test_model_read_hero(mocker, session: Session, client: TestClient): +def test_model_read_hero(mocker): hero = HeroFactory().build(name="Steelman", age=25, id=1) httpx = mocker.patch.object(hero_models, "httpx") @@ -255,9 +261,10 @@ def test_model_read_hero(mocker, session: Session, client: TestClient): assert hero_read.secret_name == hero.secret_name assert httpx.get.call_count == 1 assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 0 -def test_model_read_hero_404(mocker, session: Session, client: TestClient): +def test_model_read_hero_404(mocker): hero = HeroFactory().build(name="Steelman", age=25, id=1) httpx = mocker.patch.object(hero_models, "httpx") httpx.get.return_value = mocker.Mock() @@ -269,3 +276,68 @@ def test_model_read_hero_404(mocker, session: Session, client: TestClient): assert e.value.args[0] == "404: Hero not found" assert httpx.get.call_count == 1 assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 0 + + +def test_model_delete_hero(mocker): + hero = HeroFactory().build(name="Steelman", age=25, id=1) + + httpx = mocker.patch.object(hero_models, "httpx") + httpx.delete.return_value = mocker.Mock() + httpx.delete.return_value.status_code = 200 + httpx.delete.return_value.json.return_value = hero.dict() + + hero_delete = HeroDelete.delete(id=hero.id) + assert hero_delete == {"ok": True} + assert httpx.get.call_count == 0 + assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 1 + + +def test_model_delete_hero_404(mocker): + hero = HeroFactory().build(name="Steelman", age=25, id=1) + + httpx = mocker.patch.object(hero_models, "httpx") + httpx.delete.return_value = mocker.Mock() + httpx.delete.return_value.status_code = 404 + httpx.get.return_value.text = "Hero not found" + + with pytest.raises(RuntimeError) as e: + HeroDelete.delete(id=hero.id) + assert e.value.args[0] == "404: Hero not found" + assert httpx.get.call_count == 0 + assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 1 + + +def test_cli_delete_hero(mocker): + hero = HeroFactory().build(name="Steelman", age=25, id=1) + + httpx = mocker.patch.object(hero_models, "httpx") + httpx.delete.return_value = mocker.Mock() + httpx.delete.return_value.status_code = 200 + httpx.delete.return_value.json.return_value = hero.dict() + + result = runner.invoke(hero_app, ["delete", "--hero-id", "1"]) + assert result.exit_code == 0 + assert "{'ok': True}" in result.stdout + assert httpx.get.call_count == 0 + assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 1 + + +def test_cli_delete_hero_404(mocker): + hero = HeroFactory().build(name="Steelman", age=25, id=1) + + httpx = mocker.patch.object(hero_models, "httpx") + httpx.delete.return_value = mocker.Mock() + httpx.delete.return_value.status_code = 404 + httpx.delete.return_value.text = "Hero not found" + httpx.delete.return_value.json.return_value = hero.dict() + + result = runner.invoke(hero_app, ["delete", "--hero-id", "999"]) + assert result.exit_code == 1 + assert " ".join(result.exception.args[0].split()) == "404: Hero not found" + assert httpx.get.call_count == 0 + assert httpx.post.call_count == 0 + assert httpx.delete.call_count == 1 diff --git a/tmp.py b/tmp.py new file mode 100644 index 0000000..a4947c0 --- /dev/null +++ b/tmp.py @@ -0,0 +1,144 @@ +import sqlite3 + +from graphviz import Digraph + + +def generate_er_diagram(database_path, output_path): + # Connect to the SQLite database + conn = sqlite3.connect(database_path) + cursor = conn.cursor() + + # Get the table names from the database + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + + # Create a new Digraph + dot = Digraph(format="png") + dot.attr(rankdir="TD") + + # Iterate over the tables + for table in tables: + table_name = table[0] + dot.node(table_name, shape="box") + cursor.execute(f"PRAGMA table_info({table_name});") + columns = cursor.fetchall() + + # Add the columns to the table node + for column in columns: + column_name = column[1] + dot.node(f"{table_name}.{column_name}", label=column_name, shape="oval") + dot.edge(table_name, f"{table_name}.{column_name}") + + # Check for foreign key relationships + cursor.execute(f"PRAGMA foreign_key_list({table_name});") + foreign_keys = cursor.fetchall() + + # Add dotted lines for foreign key relationships + for foreign_key in foreign_keys: + from_column = foreign_key[3] + to_table = foreign_key[2] + to_column = foreign_key[4] + dot.node(f"{to_table}.{to_column}", shape="oval") + dot.edge( + f"{table_name}.{from_column}", f"{to_table}.{to_column}", style="dotted" + ) + + # Render and save the diagram + dot.render(output_path.replace(".png", ""), cleanup=True) + + # Close the database connection + cursor.close() + conn.close() + + +def generate_markdown(database_path, output_path, er_diagram_path): + # Connect to the SQLite database + conn = sqlite3.connect(database_path) + cursor = conn.cursor() + + # Get the table names from the database + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + + with open(output_path, "w") as f: + # Write the ER Diagram image + f.write(f"![ER Diagram]({er_diagram_path})\n\n---\n\n") + + # Iterate over the tables + for table in tables: + table_name = table[0] + + f.write(f"## Table: {table_name}\n\n") + + # Get the table columns + cursor.execute(f"PRAGMA table_info({table_name});") + columns = cursor.fetchall() + + f.write("### First 5 rows\n\n") + cursor.execute(f"SELECT * FROM {table_name} LIMIT 5;") + rows = cursor.fetchall() + f.write(f'| {" | ".join([c[1] for c in columns])} |\n') + f.write("|") + for column in columns: + # --- + f.write(f'{"-"*(len(column[1]) + 2)}|') + f.write("\n") + for row in rows: + f.write(f'| {" | ".join([str(r) for r in row])} |\n') + f.write("\n") + + cursor.execute(f"PRAGMA foreign_key_list({table_name});") + foreign_keys = cursor.fetchall() + + # Add dotted lines for foreign key relationships + fkeys = {} + for foreign_key in foreign_keys: + from_column = foreign_key[3] + to_table = foreign_key[2] + to_column = foreign_key[4] + fkeys[from_column] = f"{to_table}.{to_column}" + + # Replace 'description' with the actual column name in the table that contains the description, if applicable + try: + cursor.execute(f"SELECT description FROM {table_name} LIMIT 1;") + description = cursor.fetchone() + if description: + f.write(f"### Description\n\n{description[0]}\n\n") + except: + ... + + # Write the table columns + f.write("### Columns\n\n") + f.write("| Column Name | Type | Foreign Key | Example Value |\n") + f.write("|-------------|------|-------------|---------------|\n") + + for column in columns: + + column_name = column[1] + column_type = column[2] + fkey = "" + if column_name in fkeys: + fkey = fkeys[column_name] + f.write(f"| {column_name} | {column_type} | {fkey} | | |\n") + + f.write("\n") + + # Get the count of records + cursor.execute(f"SELECT COUNT(*) FROM {table_name};") + records_count = cursor.fetchone()[0] + f.write( + f"### Records Count\n\nThe table {table_name} contains {records_count} records.\n\n---\n\n" + ) + + # Close the database connection + cursor.close() + conn.close() + + +# Usage example +database_path = "database.db" +md_output_path = "database.md" +er_output_path = "er_diagram.png" + +generate_er_diagram(database_path, er_output_path) +generate_markdown(database_path, md_output_path, er_output_path)