diff --git a/learn_sql_model/api/hero.py b/learn_sql_model/api/hero.py index ec1ed24..deca520 100644 --- a/learn_sql_model/api/hero.py +++ b/learn_sql_model/api/hero.py @@ -4,15 +4,15 @@ from fastapi import APIRouter, Depends from sqlmodel import SQLModel from learn_sql_model.api.user import oauth2_scheme -from learn_sql_model.config import get_config +from learn_sql_model.config import Config, get_config from learn_sql_model.models.hero import Hero hero_router = APIRouter() @hero_router.on_event("startup") -def on_startup() -> None: - SQLModel.metadata.create_all(get_config().database.engine) +def on_startup(config: Config = Depends(get_config)) -> None: + SQLModel.metadata.create_all(config.database.engine) @hero_router.get("/items/") @@ -21,21 +21,22 @@ async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): @hero_router.get("/hero/{id}") -def get_hero(id: int) -> Hero: +def get_hero(id: int, config: Config = Depends(get_config)) -> Hero: "get one hero" - return Hero().get(id=id) + return Hero().get(id=id, config=config) @hero_router.post("/hero/") -def post_hero(hero: Hero) -> Hero: +def post_hero(hero: Hero, config: Config = Depends(get_config)) -> Hero: "read all the heros" - return hero.post() + hero.post(config=config) + return hero @hero_router.get("/heros/") -def get_heros() -> list[Hero]: +def get_heros(config: Config = Depends(get_config)) -> list[Hero]: "get all heros" - return Hero().get() + return Hero().get(config=config) # Alternatively # with get_config().database.session as session: # statement = select(Hero) diff --git a/learn_sql_model/cli/app.py b/learn_sql_model/cli/app.py index a367ab5..4224ba8 100644 --- a/learn_sql_model/cli/app.py +++ b/learn_sql_model/cli/app.py @@ -10,12 +10,12 @@ from learn_sql_model.cli.tui import tui_app app = typer.Typer( name="learn_sql_model", - help="A rich terminal report for coveragepy.", + help="learn-sql-model cli for managing the project", ) -app.add_typer(config_app) -app.add_typer(tui_app) -app.add_typer(model_app) -app.add_typer(api_app) +app.add_typer(config_app, name="config") +app.add_typer(tui_app, name="tui") +app.add_typer(model_app, name="model") +app.add_typer(api_app, name="api") app.add_typer(hero_app, name="hero") @@ -38,6 +38,17 @@ def version_callback(value: bool) -> None: raise typer.Exit() +@app.callback() +def main( + version: bool = typer.Option( + False, + callback=version_callback, + help="show the version of the learn-sql-model package.", + ), +): + "configuration cli" + + @app.command() def tui(ctx: typer.Context) -> None: Trogon(get_group(app), click_context=ctx).run() diff --git a/learn_sql_model/cli/hero.py b/learn_sql_model/cli/hero.py index b019494..f1fdb91 100644 --- a/learn_sql_model/cli/hero.py +++ b/learn_sql_model/cli/hero.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Union from pydantic_typer import expand_pydantic_args from rich.console import Console @@ -9,6 +9,7 @@ from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.pet import PetFactory from learn_sql_model.models.hero import Hero from learn_sql_model.models.pet import Pet +import sys hero_app = typer.Typer() @@ -21,7 +22,7 @@ def hero(): @hero_app.command() @expand_pydantic_args(typer=True) def get( - id: int = None, + id: Optional[int] = None, config: Config = None, ) -> Union[Hero, List[Hero]]: "get one hero" @@ -52,12 +53,11 @@ def populate( config: Config = None, ) -> Hero: "read all the heros" - config.init() if config is None: config = Config() if config.env == "prod": Console().print("populate is not supported in production") - return + sys.exit(1) for hero in HeroFactory().batch(n): pet = PetFactory().build() diff --git a/learn_sql_model/config.py b/learn_sql_model/config.py index 7b5c47b..46d6eaa 100644 --- a/learn_sql_model/config.py +++ b/learn_sql_model/config.py @@ -1,5 +1,7 @@ +from contextvars import ContextVar from typing import TYPE_CHECKING +from fastapi import Depends from pydantic import BaseModel, BaseSettings from sqlalchemy import create_engine from sqlmodel import SQLModel, Session @@ -24,6 +26,13 @@ class Database: self.config = get_config() else: self.config = config + self.db_state_default = { + "closed": None, + "conn": None, + "ctx": None, + "transactions": None, + } + self.db_state = ContextVar("db_state", default=self.db_state_default.copy()) @property def engine(self) -> "Engine": @@ -60,6 +69,24 @@ def get_database(config: Config = None) -> Database: return Database(config) +async def reset_db_state(config: Config = None) -> None: + if config is None: + config = get_config() + config.database.db._state._state.set(db_state_default.copy()) + config.database.db._state.reset() + + +def get_db(config: Config = None, reset_db_state=Depends(reset_db_state)): + if config is None: + config = get_config() + try: + config.database.db.connect() + yield + finally: + if not config.database.db.is_closed(): + config.database.db.close() + + def get_config(overrides: dict = {}) -> Config: raw_config = load("learn_sql_model") config = Config(**raw_config, **overrides) diff --git a/learn_sql_model/models.py b/learn_sql_model/models.py deleted file mode 100644 index a919bb7..0000000 --- a/learn_sql_model/models.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from learn_sql_model.models.fast_model import FastModel -from sqlmodel import Field - - -class Hero(FastModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str - secret_name: str - age: Optional[int] = None - # new_attribute: Optional[str] = None - # pets: List["Pet"] = Relationship(back_populates="hero") diff --git a/learn_sql_model/models/fast_model.py b/learn_sql_model/models/fast_model.py index 380f8fc..eab3e30 100644 --- a/learn_sql_model/models/fast_model.py +++ b/learn_sql_model/models/fast_model.py @@ -29,6 +29,8 @@ class FastModel(SQLModel): with config.database.session as session: session.add(self) session.commit() + session.refresh(self) + return def get( self, id: int = None, config: "Config" = None, where=None @@ -52,6 +54,16 @@ class FastModel(SQLModel): results = session.exec(statement).one() return results + def flags(self, config: "Config" = None) -> None: + if config is None: + config = get_config() + flags = [] + for k, v in self.dict().items(): + if v: + flags.append(f"--{k.replace('_', '-').lower()}") + flags.append(v) + return flags + # TODO # update # delete diff --git a/markata.toml b/markata.toml new file mode 100644 index 0000000..6a96bf9 --- /dev/null +++ b/markata.toml @@ -0,0 +1,275 @@ +# +# __ __ _ _ _ _ +# | \/ | __ _ _ __| | ____ _| |_ __ _ | |_ ___ _ __ ___ | | +# | |\/| |/ _` | '__| |/ / _` | __/ _` || __/ _ \| '_ ` _ \| | +# | | | | (_| | | | < (_| | || (_| || || (_) | | | | | | | +# |_| |_|\__,_|_| |_|\_\__,_|\__\__,_(_)__\___/|_| |_| |_|_| +# +# learn-sql-model.dev + +[markata.nav] +'learn-sql-model'='https://learn-sql-model.dev/' +'GitHub'='https://github.com/WaylonWalker/learn-sql-model' + +[markata] +# bump site version to bust GitHub actions cache +site_version = 13 + +## choose your markdown backend +# markdown_backend='markdown' +# markdown_backend='markdown2' +markdown_backend='markdown-it-py' + +# 2 weeks in seconds +default_cache_expire = 1209600 +# subroute = "docs" + +## Markata Setup +output_dir = "markout" +assets_dir = "static" +hooks = [ + "markata.plugins.publish_source", + "markata.plugins.docs", + "default", + ] +disabled_hooks = [ +'markata.plugins.heading_link', +'markata.plugins.manifest', +'markata.plugins.rss' +] + +## Site Config +url = "https://learn-sql-model.dev" +title = "Learn SQLModel's Docs" +description = "Documentation for using the Learn SQLModel" +rss_description = "Learn SQLModel docs" +author_name = "Waylon Walker" +author_email = "waylon@waylonwalaker.com" +icon = "favicon.ico" +lang = "en" +# post_template = "pages/templates/post_template.html" +repo_url = "https://github.com/waylonwalker/learn-sql-model" +repo_branch = "main" +theme_color = "#322D39" +background_color = "#B73CF6" +start_url = "/" +site_name = "Learn SQLModel's Docs" +short_name = "ww" +display = "minimal-ui" +twitter_card = "summary_large_image" +twitter_creator = "@_waylonwalker" +twitter_site = "@_waylonwalker" + +# markdown_it flavor +# [markata.markdown_it_py] +# config='gfm-like' +# # markdown_it built-in plugins +# enable = [ "table" ] +# disable = [ "image" ] + +# # markdown_it built-in plugin options +# [markata.markdown_it_py.options_update] +# linkify = true +# html = true +# typographer = true +# highlight = 'markata.plugins.md_it_highlight_code:highlight_code' + +# # add custom markdown_it plugins +# [[markata.markdown_it_py.plugins]] +# plugin = "mdit_py_plugins.admon:admon_plugin" + +# [[markata.markdown_it_py.plugins]] +# plugin = "mdit_py_plugins.admon:admon_plugin" + +# [[markata.markdown_it_py.plugins]] +# plugin = "mdit_py_plugins.attrs:attrs_plugin" +# config = {spans = true} + +# [[markata.markdown_it_py.plugins]] +# plugin = "mdit_py_plugins.attrs:attrs_block_plugin" + +# [[markata.markdown_it_py.plugins]] +# plugin = "markata.plugins.mdit_details:details_plugin" + +# [[markata.markdown_it_py.plugins]] +# plugin = "mdit_py_plugins.anchors:anchors_plugin" + +# [markata.markdown_it_py.plugins.config] +# permalink = true +# permalinkSymbol = '' + +# [[markata.markdown_it_py.plugins]] +# plugin = "markata.plugins.md_it_wikilinks:wikilinks_plugin" +# config = {markata = "markata"} + +# markata feeds +# creating pages of posts +# [markata.feeds_config] + +## feed template +# [markata.feeds.] +# title="Project Gallery" +## python eval to True adds post to the feed +# filter="'project-gallery' in path" +## the key to sort on +# sort='title' +## the template for each post to use when added to the page +# card_template=""" +# """ + +[[markata.feeds]] +slug='project-gallery' +title="Project Gallery" +filter="'project-gallery' in str(path)" +sort='title' +card_template=""" +
  • +

    {{ title }}

    + +{{ article_html }} +
  • +""" + +[[markata.feeds]] +slug='docs' +title="Documentation" +filter='"markata" not in slug and "tests" not in slug and "404" not in slug' +sort='slug' +card_template="
  • {{ title }}

    {{ description }}

  • " + +[[markata.feeds]] +slug='all' +title="All Learn SQLModel Modules" +filter="True" +card_template=""" +
  • + + {{ title }} +

    + {{ article_html[:article_html.find('

    ')] }} +

    + +
  • +""" + +[[markata.feeds]] +slug='core-modules' +title="Learn SQLModel Core Modules" +filter="'plugin' not in slug and 'test' not in slug and title.endswith('.py')" +card_template=""" +
  • + + {{ title }} +

    + {{ article_html[:article_html.find('

    ')] }} +

    + +
  • +""" + + +[markata.jinja_md] +ignore=[ +'jinja_md.md', +'post_template.md', +'publish_html.md', +] + +[[markata.head.meta]] +name = "og:author_email" +content = "waylon@waylonwalker.com" + +[markata.tui] +new_cmd=['tmux', 'popup', 'markata', 'new', 'post'] + +[[markata.tui.keymap]] +name='new' +key='n' + +[markata.summary] +grid_attr = ['tags', 'series'] + +[[markata.summary.filter_count]] +name='drafts' +filter="not published" +color='red' + +[[markata.summary.filter_count]] +name='articles' +color='dark_orange' + +[[markata.summary.filter_count]] +name='py_modules' +filter='"plugin" not in slug and "docs" not in str(path)' +color="yellow1" + +[[markata.summary.filter_count]] +name='published' +filter="published" +color='green1' + +[[markata.summary.filter_count]] +name='plugins' +filter='"plugin" in slug and "docs" not in str(path)' +color="blue" + +[[markata.summary.filter_count]] +name='docs' +filter="'docs' in str(path)" +color='purple' + +[markata.post_model] +include = ['date', 'description', 'published', 'slug', 'title', 'content', 'html'] +repr_include = ['date', 'description', 'published', 'slug', 'title', 'output_html'] + +[markata.render_markdown] +backend='markdown-it-py' + +# [markata.markdown_it_py] +# config='gfm-like' +# # markdown_it built-in plugins +# enable = [ "table" ] +# disable = [ "image" ] + +# # markdown_it built-in plugin options +# [markata.markdown_it_py.options_update] +# linkify = true +# html = true +# typographer = true +# highlight = 'markata.plugins.md_it_highlight_code:highlight_code' + +# add custom markdown_it plugins +[[markata.render_markdown.md_it_extensions]] +plugin = "mdit_py_plugins.admon:admon_plugin" + +[[markata.render_markdown.md_it_extensions]] +plugin = "mdit_py_plugins.admon:admon_plugin" + +[[markata.render_markdown.md_it_extensions]] +plugin = "mdit_py_plugins.attrs:attrs_plugin" +config = {spans = true} + +[[markata.render_markdown.md_it_extensions]] +plugin = "mdit_py_plugins.attrs:attrs_block_plugin" + +[[markata.render_markdown.md_it_extensions]] +plugin = "markata.plugins.mdit_details:details_plugin" + +[[markata.render_markdown.md_it_extensions]] +plugin = "mdit_py_plugins.anchors:anchors_plugin" + +[markata.render_markdown.md_it_extensions.config] +permalink = true +permalinkSymbol = '' + +[[markata.render_markdown.md_it_extensions]] +plugin = "markata.plugins.md_it_wikilinks:wikilinks_plugin" +config = {markata = "markata"} + +[markata.glob] +glob_patterns = "docs/**/*.md,CHANGELOG.md" +use_gitignore = true diff --git a/pyproject.toml b/pyproject.toml index d5c9a52..26de1d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,17 +57,20 @@ path = "learn_sql_model/__about__.py" dependencies = [ "black", "ipython", + "coverage[toml]", + "coverage-rich", + "markata", "mypy", "pyflyby", "pytest", - "pytest-cov", "pytest-mock", "ruff", "alembic", ] [tool.hatch.envs.default.scripts] test = "coverage run -m pytest" -cov = "coverage-rich" +cov = "coverage-rich report" +test-cov = ['test', 'cov'] lint = "ruff learn_sql_model" format = "black learn_sql_model" format-check = "black --check learn_sql_model" @@ -84,6 +87,7 @@ test-lint = "lint-test" python = ["37", "38", "39", "310", "311"] [tool.coverage.run] +source=["learn_sql_model"] branch = true parallel = true omit = [ diff --git a/tests/test_cli_app.py b/tests/test_cli_app.py new file mode 100644 index 0000000..a89a373 --- /dev/null +++ b/tests/test_cli_app.py @@ -0,0 +1,15 @@ +from typer.testing import CliRunner + +from learn_sql_model.cli.app import app + +runner = CliRunner() + + +def test_cli_app_version(): + result = runner.invoke(app, ["--version"]) + assert result.exit_code == 0 + + +def test_cli_help(): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 diff --git a/tests/test_console.py b/tests/test_console.py new file mode 100644 index 0000000..14274bd --- /dev/null +++ b/tests/test_console.py @@ -0,0 +1,14 @@ +from learn_sql_model.console import console + + +def test_default_console_not_quiet(capsys): + console.print("hello") + captured = capsys.readouterr() + assert captured.out == "hello\n" + + +def test_default_console_is_quiet(capsys): + console.quiet = True + console.print("hello") + captured = capsys.readouterr() + assert captured.out == "" diff --git a/tests/test_hero.py b/tests/test_hero.py index 976e8be..7eb42af 100644 --- a/tests/test_hero.py +++ b/tests/test_hero.py @@ -2,6 +2,9 @@ import tempfile from fastapi.testclient import TestClient import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import SQLModel from typer.testing import CliRunner from learn_sql_model.api.app import app @@ -18,11 +21,31 @@ client = TestClient(app) def config() -> Config: tmp_db = tempfile.NamedTemporaryFile(suffix=".db") config = get_config({"database_url": f"sqlite:///{tmp_db.name}"}) - return config + + engine = create_engine( + config.database_url, connect_args={"check_same_thread": False} + ) + TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + # breakpoint() + SQLModel.metadata.create_all(config.database.engine) + + # def override_get_db(): + # try: + # db = TestingSessionLocal() + # yield db + # finally: + # db.close() + def override_get_config(): + return config + + app.dependency_overrides[get_config] = override_get_config + + yield config + # tmp_db automatically deletes here def test_post_hero(config: Config) -> None: - config.init() # required for python api, and no existing db hero = HeroFactory().build(name="Batman", age=50, id=1) hero = hero.post(config=config) db_hero = Hero().get(id=1, config=config) @@ -31,48 +54,105 @@ def test_post_hero(config: Config) -> None: def test_update_hero(config: Config) -> None: - config.init() # required for python api, and no existing db hero = HeroFactory().build(name="Batman", age=50, id=1) hero = hero.post(config=config) db_hero = Hero().get(id=1, config=config) - db_hero.name = "Superman" + db_hero.name = "Superbman" hero = db_hero.post(config=config) db_hero = Hero().get(id=1, config=config) assert db_hero.age == 50 - assert db_hero.name == "Superman" + assert db_hero.name == "Superbman" + + +def test_cli_get(config): + hero = HeroFactory().build(name="Steelman", age=25, id=99) + hero.post(config=config) + result = runner.invoke( + hero_app, + ["get", "--id", 99, "--database-url", config.database_url], + ) + assert result.exit_code == 0 + db_hero = Hero().get(id=99, config=config) + assert db_hero.age == 25 + assert db_hero.name == "Steelman" def test_cli_create(config): + hero = HeroFactory().build(name="Steelman", age=25, id=99) result = runner.invoke( hero_app, [ "create", - "--name", - "Darth Vader", - "--secret-name", - "Anakin", - "--id", - "2", - "--age", - "100", + *hero.flags(config=config), "--database-url", config.database_url, ], ) assert result.exit_code == 0 - db_hero = Hero().get(id=2, config=config) - assert db_hero.age == 100 - assert db_hero.name == "Darth Vader" + db_hero = Hero().get(id=99, config=config) + assert db_hero.age == 25 + assert db_hero.name == "Steelman" -def test_read_main(config): - config.init() - hero = HeroFactory().build(name="Ironman", age=25, id=99) +def test_cli_populate(config): + result = runner.invoke( + hero_app, + [ + "populate", + "--n", + 10, + "--database-url", + config.database_url, + ], + ) + assert result.exit_code == 0 + db_hero = Hero().get(config=config) + assert len(db_hero) == 10 + + +def test_cli_populate_fails_prod(config): + result = runner.invoke( + hero_app, + ["populate", "--n", 10, "--database-url", config.database_url, "--env", "prod"], + ) + assert result.exit_code == 1 + assert result.output.strip() == "populate is not supported in production" + + +def test_api_read(config): + hero = HeroFactory().build(name="Steelman", age=25, id=99) hero_id = hero.id hero = hero.post(config=config) response = client.get(f"/hero/{hero_id}") assert response.status_code == 200 reponse_hero = Hero.parse_obj(response.json()) assert reponse_hero.id == hero_id - assert reponse_hero.name == "Ironman" + assert reponse_hero.name == "Steelman" assert reponse_hero.age == 25 + + +def test_api_post(config): + hero = HeroFactory().build(name="Steelman", age=25) + hero_dict = hero.dict() + response = client.post("/hero/", json={"hero": hero_dict}) + assert response.status_code == 200 + + response_hero = Hero.parse_obj(response.json()) + db_hero = Hero().get(id=response_hero.id, config=config) + + assert db_hero.name == "Steelman" + assert db_hero.age == 25 + + +def test_api_read_all(config): + hero = HeroFactory().build(name="Mothman", age=25, id=99) + hero_id = hero.id + hero = hero.post(config=config) + response = client.get("/heros/") + assert response.status_code == 200 + heros = response.json() + response_hero_json = [hero for hero in heros if hero["id"] == hero_id][0] + response_hero = Hero.parse_obj(response_hero_json) + assert response_hero.id == hero_id + assert response_hero.name == "Mothman" + assert response_hero.age == 25