fix all the tests

This commit is contained in:
Waylon Walker 2023-06-21 10:29:50 -05:00
parent 7db07c7d35
commit a7e6f2c4e5
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
9 changed files with 128 additions and 137 deletions

View file

@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import SQLModel, Session from sqlmodel import Session, select
from learn_sql_model.api.websocket_connection_manager import manager from learn_sql_model.api.websocket_connection_manager import manager
from learn_sql_model.config import get_config, get_session from learn_sql_model.config import get_session
from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, HeroUpdate, Heros from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, HeroUpdate, Heros
hero_router = APIRouter() hero_router = APIRouter()
@ -10,7 +10,8 @@ hero_router = APIRouter()
@hero_router.on_event("startup") @hero_router.on_event("startup")
def on_startup() -> None: def on_startup() -> None:
SQLModel.metadata.create_all(get_config().database.engine) # SQLModel.metadata.create_all(get_config().database.engine)
...
@hero_router.get("/hero/{hero_id}") @hero_router.get("/hero/{hero_id}")
@ -32,7 +33,7 @@ async def post_hero(
session: Session = Depends(get_session), session: Session = Depends(get_session),
hero: HeroCreate, hero: HeroCreate,
) -> HeroRead: ) -> HeroRead:
"read all the heros" "create a hero"
db_hero = Hero.from_orm(hero) db_hero = Hero.from_orm(hero)
session.add(db_hero) session.add(db_hero)
session.commit() session.commit()
@ -47,7 +48,7 @@ async def patch_hero(
session: Session = Depends(get_session), session: Session = Depends(get_session),
hero: HeroUpdate, hero: HeroUpdate,
) -> HeroRead: ) -> HeroRead:
"read all the heros" "update a hero"
db_hero = session.get(Hero, hero.id) db_hero = session.get(Hero, hero.id)
if not db_hero: if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found") raise HTTPException(status_code=404, detail="Hero not found")
@ -66,7 +67,7 @@ async def delete_hero(
session: Session = Depends(get_session), session: Session = Depends(get_session),
hero_id: int, hero_id: int,
): ):
"read all the heros" "delete a hero"
hero = session.get(Hero, hero_id) hero = session.get(Hero, hero_id)
if not hero: if not hero:
raise HTTPException(status_code=404, detail="Hero not found") raise HTTPException(status_code=404, detail="Hero not found")
@ -84,4 +85,4 @@ async def get_heros(
"get all heros" "get all heros"
statement = select(Hero) statement = select(Hero)
heros = session.exec(statement).all() heros = session.exec(statement).all()
return Heros(heros=heros) return Heros(__root__=heros)

View file

@ -33,7 +33,6 @@ def hero():
@hero_app.command() @hero_app.command()
@engorgio(typer=True)
def get( def get(
hero_id: Optional[int] = typer.Argument(default=None), hero_id: Optional[int] = typer.Argument(default=None),
) -> Union[Hero, List[Hero]]: ) -> Union[Hero, List[Hero]]:
@ -44,16 +43,11 @@ def get(
@hero_app.command() @hero_app.command()
@engorgio(typer=True) def list() -> Union[Hero, List[Hero]]:
def list(
where: Optional[str] = None,
offset: int = 0,
limit: Optional[int] = None,
) -> Union[Hero, List[Hero]]:
"list many heros" "list many heros"
heros = Heros.list(where=where, offset=offset, limit=limit) heros = Heros.list()
Console().print(hero) Console().print(heros)
return hero return heros
@hero_app.command() @hero_app.command()
@ -94,7 +88,6 @@ def delete(
@hero_app.command() @hero_app.command()
@engorgio(typer=True)
def populate( def populate(
n: int = 10, n: int = 10,
) -> Hero: ) -> Hero:

View file

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
from fastapi import Depends from fastapi import Depends
from pydantic import BaseModel, BaseSettings, validator from pydantic import BaseModel, BaseSettings, validator
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlmodel import SQLModel, Session from sqlmodel import Session
from learn_sql_model.standard_config import load from learn_sql_model.standard_config import load
@ -71,7 +71,8 @@ class Config(BaseSettings):
return get_database(config=self) return get_database(config=self)
def init(self) -> None: def init(self) -> None:
SQLModel.metadata.create_all(self.database.engine) # SQLModel.metadata.create_all(self.database.engine)
...
def get_database(config: Config = None) -> Database: def get_database(config: Config = None) -> Database:
@ -88,7 +89,8 @@ def get_config(overrides: dict = {}) -> Config:
def get_session() -> "Session": def get_session() -> "Session":
config = get_config() config = get_config()
with Session(config.database.engine) as session: engine = create_engine(config.database_url)
with Session(engine) as session:
yield session yield session

View file

@ -1,23 +1,25 @@
from learn_sql_model.console import console
from learn_sql_model.optional import _optional_import_ from learn_sql_model.optional import _optional_import_
pygame = _optional_import_("pygame", group="game") pygame = _optional_import_("pygame", group="game")
class Light: class Light:
def __init__(self, game): def __init__(self, game):
self.game = game self.game = game
def render(self): def render(self):
mx, my = pygame.mouse.get_pos() mx, my = pygame.mouse.get_pos()
v = pygame.math.Vector2(mx - self.game.player.hero.x, my - self.game.player.hero.y) v = pygame.math.Vector2(
mx - self.game.player.hero.x, my - self.game.player.hero.y
)
v.scale_to_length(1000) v.scale_to_length(1000)
for r in range(0, 360): for r in range(0, 360):
_v = v.rotate(r) _v = v.rotate(r)
pygame.draw.line( pygame.draw.line(
self.game.screen, self.game.screen,
(255,250,205), (255, 250, 205),
(self.game.player.hero.x, self.game.player.hero.y), (self.game.player.hero.x, self.game.player.hero.y),
(self.game.player.hero.x + _v.x, self.game.player.hero.y + _v.y), (self.game.player.hero.x + _v.x, self.game.player.hero.y + _v.y),
50 50,
) )

View file

@ -2,7 +2,8 @@ from typing import Callable, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from learn_sql_model.optional import _optional_import_ from learn_sql_model.optional import _optional_import_
pygame = _optional_import_('pygame', group='game')
pygame = _optional_import_("pygame", group="game")
screen_sizes = [ screen_sizes = [
@ -128,7 +129,6 @@ class Menu:
class Hamburger: class Hamburger:
def __init__(self, game): def __init__(self, game):
self.game = game self.game = game
self.hamburger_width = 50 self.hamburger_width = 50
self.bar_height = self.hamburger_width / 4 self.bar_height = self.hamburger_width / 4

View file

@ -49,7 +49,7 @@ class Player:
def quit(self): def quit(self):
try: try:
HeroDelete(id=self.hero.id).delete() HeroDelete(id=self.hero.id).delete()
except: except RuntimeError:
pass pass
def handle_events(self): def handle_events(self):
@ -153,13 +153,12 @@ class Player:
self.pos = pygame.math.Vector2(self.hero.x, self.hero.y) self.pos = pygame.math.Vector2(self.hero.x, self.hero.y)
if self.game.map.point_check_collision(self.pos.x, self.pos.y): if self.game.map.point_check_collision(self.pos.x, self.pos.y):
start_pos = pygame.math.Vector2(self.x_last, self.y_last) start_pos = pygame.math.Vector2(self.x_last, self.y_last)
end_pos = pygame.math.Vector2(self.hero.x, self.hero.y) end_pos = pygame.math.Vector2(self.hero.x, self.hero.y)
movement_vector = end_pos - start_pos movement_vector = end_pos - start_pos
try: try:
movement_direction = movement_vector.normalize() movement_direction = movement_vector.normalize()
except: except ZeroDivisionError:
end_pos = pygame.math.Vector2(self.hero.x + 128, self.hero.y + 128) end_pos = pygame.math.Vector2(self.hero.x + 128, self.hero.y + 128)
movement_vector = end_pos - start_pos movement_vector = end_pos - start_pos
movement_direction = movement_vector.normalize() movement_direction = movement_vector.normalize()

View file

@ -1,6 +1,5 @@
from typing import Optional from typing import Optional
from fastapi import HTTPException
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Field, Relationship, SQLModel from sqlmodel import Field, Relationship, SQLModel
@ -51,11 +50,11 @@ class HeroRead(HeroBase):
r = httpx.get(f"{config.api_client.url}/hero/{id}") r = httpx.get(f"{config.api_client.url}/hero/{id}")
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}") raise RuntimeError(f"{r.status_code}:\n {r.text}")
return hero return HeroRead.parse_obj(r.json())
class Heros(BaseModel): class Heros(BaseModel):
heros: list[Hero] __root__: list[Hero]
@classmethod @classmethod
def list( def list(
@ -64,7 +63,7 @@ class Heros(BaseModel):
r = httpx.get(f"{config.api_client.url}/heros/") r = httpx.get(f"{config.api_client.url}/heros/")
if r.status_code != 200: if r.status_code != 200:
raise RuntimeError(f"{r.status_code}:\n {r.text}") raise RuntimeError(f"{r.status_code}:\n {r.text}")
return Heros.parse_obj(r.json()) return Heros.parse_obj({"__root__": r.json()})
class HeroUpdate(SQLModel): class HeroUpdate(SQLModel):
@ -76,8 +75,8 @@ class HeroUpdate(SQLModel):
secret_name: Optional[str] = None secret_name: Optional[str] = None
age: Optional[int] = None age: Optional[int] = None
shoe_size: Optional[int] = None shoe_size: Optional[int] = None
x: int x: Optional[int]
y: int y: Optional[int]
pet_id: Optional[int] = Field(default=None, foreign_key="pet.id") pet_id: Optional[int] = Field(default=None, foreign_key="pet.id")
pet: Optional[Pet] = Relationship(back_populates="hero") pet: Optional[Pet] = Relationship(back_populates="hero")

View file

@ -1,14 +1,16 @@
from learn_sql_model.console import console from learn_sql_model.console import console
def test_default_console_not_quiet(capsys):
console.print("hello")
captured = capsys.readouterr()
assert captured.out == "hello\n"
def test_default_console_is_quiet(capsys): def test_default_console_is_quiet(capsys):
console.quiet = True
console.print("hello") console.print("hello")
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "" assert captured.out == ""
def test_default_console_not_quiet(capsys):
console.quiet = False
console.print("hello")
captured = capsys.readouterr()
assert captured.out == "hello\n"

View file

@ -9,7 +9,8 @@ from learn_sql_model.api.app import app
from learn_sql_model.cli.hero import hero_app from learn_sql_model.cli.hero import hero_app
from learn_sql_model.config import get_config, get_session from learn_sql_model.config import get_config, get_session
from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.hero import HeroFactory
from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead from learn_sql_model.models import hero as hero_models
from learn_sql_model.models.hero import Hero, HeroCreate, HeroRead, Heros
runner = CliRunner() runner = CliRunner()
client = TestClient(app) client = TestClient(app)
@ -40,7 +41,7 @@ def client_fixture(session: Session):
def test_api_post(client: TestClient): def test_api_post(client: TestClient):
hero = HeroFactory().build(name="Steelman", age=25) hero = HeroFactory().build(name="Steelman", age=25)
hero_dict = hero.dict() hero_dict = hero.dict()
response = client.post("/hero/", json={"hero": hero_dict}) response = client.post("/hero/", json=hero_dict)
response_hero = Hero.parse_obj(response.json()) response_hero = Hero.parse_obj(response.json())
assert response.status_code == 200 assert response.status_code == 200
@ -49,8 +50,8 @@ def test_api_post(client: TestClient):
def test_api_read_heroes(session: Session, client: TestClient): def test_api_read_heroes(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_2 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48) hero_2 = HeroFactory().build(name="Rusty-Man", age=48)
session.add(hero_1) session.add(hero_1)
session.add(hero_2) session.add(hero_2)
session.commit() session.commit()
@ -72,7 +73,7 @@ def test_api_read_heroes(session: Session, client: TestClient):
def test_api_read_hero(session: Session, client: TestClient): def test_api_read_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") hero_1 = HeroFactory().build(name="Steelman", age=25)
session.add(hero_1) session.add(hero_1)
session.commit() session.commit()
@ -87,7 +88,7 @@ def test_api_read_hero(session: Session, client: TestClient):
def test_api_read_hero_404(session: Session, client: TestClient): def test_api_read_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") hero_1 = HeroFactory().build(name="Steelman", age=25)
session.add(hero_1) session.add(hero_1)
session.commit() session.commit()
@ -96,33 +97,31 @@ def test_api_read_hero_404(session: Session, client: TestClient):
def test_api_update_hero(session: Session, client: TestClient): def test_api_update_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") hero_1 = HeroFactory().build(name="Steelman", age=25)
session.add(hero_1) session.add(hero_1)
session.commit() session.commit()
response = client.patch( response = client.patch(f"/hero/", json={"name": "Deadpuddle", "id": hero_1.id})
f"/hero/", json={"hero": {"name": "Deadpuddle", "id": hero_1.id}}
)
data = response.json() data = response.json()
assert response.status_code == 200 assert response.status_code == 200
assert data["name"] == "Deadpuddle" assert data["name"] == "Deadpuddle"
assert data["secret_name"] == "Dive Wilson" assert data["secret_name"] == hero_1.secret_name
assert data["age"] is None assert data["age"] is hero_1.age
assert data["id"] == hero_1.id assert data["id"] == hero_1.id
def test_api_update_hero_404(session: Session, client: TestClient): def test_api_update_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") hero_1 = HeroFactory().build(name="Steelman", age=25)
session.add(hero_1) session.add(hero_1)
session.commit() session.commit()
response = client.patch(f"/hero/", json={"hero": {"name": "Deadpuddle", "id": 999}}) response = client.patch(f"/hero/", json={"name": "Deadpuddle", "id": 999})
assert response.status_code == 404 assert response.status_code == 404
def test_delete_hero(session: Session, client: TestClient): def test_delete_hero(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") hero_1 = HeroFactory().build(name="Steelman", age=25)
session.add(hero_1) session.add(hero_1)
session.commit() session.commit()
@ -136,7 +135,7 @@ def test_delete_hero(session: Session, client: TestClient):
def test_delete_hero_404(session: Session, client: TestClient): def test_delete_hero_404(session: Session, client: TestClient):
hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") hero_1 = HeroFactory().build(name="Steelman", age=25)
session.add(hero_1) session.add(hero_1)
session.commit() session.commit()
@ -165,68 +164,48 @@ def test_config_memory(mocker):
def test_cli_get(mocker): def test_cli_get(mocker):
mocker.patch( hero = HeroFactory().build(name="Steelman", age=25, id=1)
"learn_sql_model.config.Database.engine", hero = HeroRead(**hero.dict(exclude_none=True))
new_callable=lambda: create_engine( httpx = mocker.patch.object(hero_models, "httpx")
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool httpx.get.return_value = mocker.Mock()
), httpx.get.return_value.status_code = 200
) httpx.get.return_value.json.return_value = hero.dict()
config = get_config() result = runner.invoke(hero_app, ["get", "1"])
SQLModel.metadata.create_all(config.database.engine)
hero = HeroFactory().build(name="Steelman", age=25)
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "1"])
assert result.exit_code == 0 assert result.exit_code == 0
assert f"name='{hero.name}'" in result.stdout assert f"name='{hero.name}'" in result.stdout
assert f"secret_name='{hero.secret_name}'" in result.stdout assert f"secret_name='{hero.secret_name}'" in result.stdout
assert httpx.get.call_count == 1
def test_cli_get_404(mocker): def test_cli_get_404(mocker):
mocker.patch( hero = HeroFactory().build(name="Steelman", age=25, id=1)
"learn_sql_model.config.Database.engine", hero = HeroRead(**hero.dict(exclude_none=True))
new_callable=lambda: create_engine( httpx = mocker.patch.object(hero_models, "httpx")
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool httpx.get.return_value = mocker.Mock()
), httpx.get.return_value.status_code = 404
) httpx.get.return_value.text = "Hero not found"
httpx.get.return_value.json.return_value = hero.dict()
config = get_config() result = runner.invoke(hero_app, ["get", "999"])
SQLModel.metadata.create_all(config.database.engine) assert result.exit_code == 1
assert " ".join(result.exception.args[0].split()) == "404: Hero not found"
hero = HeroFactory().build(name="Steelman", age=25) assert httpx.get.call_count == 1
with config.database.session as session:
session.add(hero)
session.commit()
hero = session.get(Hero, hero.id)
result = runner.invoke(hero_app, ["get", "--hero-id", "999"])
assert result.exception.status_code == 404
assert result.exception.detail == "Hero not found"
def test_cli_list(mocker): def test_cli_list(mocker):
mocker.patch( hero_1 = HeroRead(
"learn_sql_model.config.Database.engine", **HeroFactory().build(name="Steelman", age=25, id=1).dict(exclude_none=True)
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
) )
hero_2 = HeroRead(
**HeroFactory().build(name="Hunk", age=52, id=2).dict(exclude_none=True)
)
heros = Heros(__root__=[hero_1, hero_2])
httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock()
httpx.get.return_value.status_code = 200
httpx.get.return_value.json.return_value = heros.dict()["__root__"]
config = get_config()
SQLModel.metadata.create_all(config.database.engine)
hero_1 = HeroFactory().build(name="Steelman", age=25)
hero_2 = HeroFactory().build(name="Hunk", age=52)
with config.database.session as session:
session.add(hero_1)
session.add(hero_2)
session.commit()
session.refresh(hero_1)
session.refresh(hero_2)
result = runner.invoke(hero_app, ["list"]) result = runner.invoke(hero_app, ["list"])
assert result.exit_code == 0 assert result.exit_code == 0
assert f"name='{hero_1.name}'" in result.stdout assert f"name='{hero_1.name}'" in result.stdout
@ -236,43 +215,57 @@ def test_cli_list(mocker):
def test_model_post(mocker): def test_model_post(mocker):
patch_httpx_post = mocker.patch( hero = HeroFactory().build(name="Steelman", age=25, id=1)
"httpx.post", return_value=mocker.Mock(status_code=200)
)
hero = HeroFactory().build(name="Steelman", age=25)
hero_create = HeroCreate(**hero.dict()) hero_create = HeroCreate(**hero.dict())
hero_create.post()
assert patch_httpx_post.call_count == 1 httpx = mocker.patch.object(hero_models, "httpx")
httpx.post.return_value = mocker.Mock()
httpx.post.return_value.status_code = 200
httpx.post.return_value.json.return_value = hero.dict()
result = hero_create.post()
assert result == hero
assert httpx.get.call_count == 0
assert httpx.post.call_count == 1
def test_model_post_500(mocker): def test_model_post_500(mocker):
patch_httpx_post = mocker.patch( hero = HeroFactory().build(name="Steelman", age=25, id=1)
"httpx.post", return_value=mocker.Mock(status_code=500)
)
hero = HeroFactory().build(name="Steelman", age=25)
hero_create = HeroCreate(**hero.dict()) hero_create = HeroCreate(**hero.dict())
httpx = mocker.patch.object(hero_models, "httpx")
httpx.post.return_value = mocker.Mock()
httpx.post.return_value.status_code = 500
httpx.post.return_value.json.return_value = hero.dict()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
hero_create.post() hero_create.post()
assert patch_httpx_post.call_count == 1 assert httpx.get.call_count == 0
assert httpx.post.call_count == 1
def test_model_read_hero(mocker, session: Session, client: TestClient): def test_model_read_hero(mocker, session: Session, client: TestClient):
mocker.patch( hero = HeroFactory().build(name="Steelman", age=25, id=1)
"learn_sql_model.config.Database.engine",
new_callable=lambda: create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
),
)
config = get_config() httpx = mocker.patch.object(hero_models, "httpx")
SQLModel.metadata.create_all(config.database.engine) httpx.get.return_value = mocker.Mock()
httpx.get.return_value.status_code = 200
hero = Hero(name="Deadpond", secret_name="Dive Wilson") httpx.get.return_value.json.return_value = hero.dict()
session = config.database.session
session.add(hero)
session.commit()
session.refresh(hero)
hero_read = HeroRead.get(id=hero.id) hero_read = HeroRead.get(id=hero.id)
assert hero_read.name == "Deadpond" assert hero_read.name == hero.name
assert hero_read.secret_name == "Dive Wilson" assert hero_read.secret_name == hero.secret_name
assert httpx.get.call_count == 1
assert httpx.post.call_count == 0
def test_model_read_hero_404(mocker, session: Session, client: TestClient):
hero = HeroFactory().build(name="Steelman", age=25, id=1)
httpx = mocker.patch.object(hero_models, "httpx")
httpx.get.return_value = mocker.Mock()
httpx.get.return_value.status_code = 404
httpx.get.return_value.text = "Hero not found"
with pytest.raises(RuntimeError) as e:
HeroRead.get(id=hero.id)
assert e.value.args[0] == "404: Hero not found"
assert httpx.get.call_count == 1
assert httpx.post.call_count == 0