diff --git a/.pyflyby b/.pyflyby index 784d779..c61c116 100644 --- a/.pyflyby +++ b/.pyflyby @@ -1,19 +1,20 @@ -# from learn_sql_model.config import config + +from learn_sql_model.api.websocket_connection_manager import manager from learn_sql_model.config import Config from learn_sql_model.config import get_config -from learn_sql_model.database import get_database from learn_sql_model.console import console - -# models - -from learn_sql_model.models.fast_model import FastModel - -from learn_sql_model.models.hero import Hero -from learn_sql_model.models.hero import HeroCreate -from learn_sql_model.models.hero import HeroRead -from learn_sql_model.models.hero import HeroUpdate -from learn_sql_model.models.hero import HeroDelete - +from learn_sql_model.database import get_database from learn_sql_model.factories.hero import HeroFactory from learn_sql_model.factories.pet import PetFactory +from learn_sql_model.models.fast_model import FastModel +from learn_sql_model.models.hero import Hero +from learn_sql_model.models.hero import HeroCreate +from learn_sql_model.models.hero import HeroDelete +from learn_sql_model.models.hero import HeroRead +from learn_sql_model.models.hero import HeroUpdate +from learn_sql_model.models.new import new +from learn_sql_model.models.new import newCreate +from learn_sql_model.models.new import newDelete +from learn_sql_model.models.new import newRead +from learn_sql_model.models.new import newUpdate from learn_sql_model.models.pet import Pet diff --git a/client.py b/client.py new file mode 100644 index 0000000..7437da2 --- /dev/null +++ b/client.py @@ -0,0 +1,42 @@ +import time + +from rich.console import Console +from websocket import create_connection + +from learn_sql_model.models.hero import Hero + + +def connect(): + id = 1 + url = f"ws://localhost:5000/ws/{id}" + Console().log(f"connecting to: {url}") + ws = create_connection(url) + Console().log(f"connected to: {url}") + return ws + + +data = [] + + +def watch(ws): + while ws.connected: + try: + data.append(ws.recv()) + if data[-1].startswith("{"): + Console().log(Hero.parse_raw(data[-1])) + else: + Console().log(data[-1]) + except Exception as e: + Console().log("failed to recieve data") + Console().log(e) + + +if __name__ == "__main__": + while True: + try: + ws = connect() + watch(ws) + except Exception as e: + Console().log("failed to connect") + Console().log(e) + time.sleep(1) diff --git a/client_sender.py b/client_sender.py new file mode 100644 index 0000000..7809f4c --- /dev/null +++ b/client_sender.py @@ -0,0 +1,14 @@ +import time + +from rich.console import Console +from websocket import create_connection + +id = 1 +url = f"ws://localhost:5000/ws/{id}" +Console().log(f"connecting to: {url}") +ws = create_connection(url) + +data = [] +while True: + ws.send("hello".encode()) + time.sleep(1) diff --git a/learn_sql_model/api/app.py b/learn_sql_model/api/app.py index 8b64875..5d82280 100644 --- a/learn_sql_model/api/app.py +++ b/learn_sql_model/api/app.py @@ -2,8 +2,24 @@ from fastapi import FastAPI from learn_sql_model.api.hero import hero_router from learn_sql_model.api.user import user_router +from learn_sql_model.api.websocket import web_socket_router + +# from fastapi_socketio import SocketManager + app = FastAPI() +# socket_manager = SocketManager(app=app) app.include_router(hero_router) app.include_router(user_router) +app.include_router(web_socket_router) + + +# @app.sio.on("join") +# def handle_join(sid, *args, **kwargs): +# app.sio.emit("lobby", "User joined") + + +# @app.sio.on("leave") +# def handle_leave(sid, *args, **kwargs): +# sm.emit("lobby", "User left") diff --git a/learn_sql_model/api/hero.py b/learn_sql_model/api/hero.py index c177246..5f115c9 100644 --- a/learn_sql_model/api/hero.py +++ b/learn_sql_model/api/hero.py @@ -4,8 +4,15 @@ from fastapi import APIRouter, Depends from sqlmodel import SQLModel from learn_sql_model.api.user import oauth2_scheme +from learn_sql_model.api.websocket_connection_manager import manager from learn_sql_model.config import Config, get_config -from learn_sql_model.models.hero import Hero +from learn_sql_model.models.hero import ( + Hero, + HeroCreate, + HeroDelete, + HeroRead, + HeroUpdate, +) hero_router = APIRouter() @@ -21,30 +28,46 @@ async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): @hero_router.get("/hero/{id}") -def get_hero(id: int, config: Config = Depends(get_config)) -> Hero: +async def get_hero(id: int, config: Config = Depends(get_config)) -> Hero: "get one hero" return Hero().get(id=id, config=config) @hero_router.get("/h/{id}") -def get_h(id: int, config: Config = Depends(get_config)) -> Hero: +async def get_h(id: int, config: Config = Depends(get_config)) -> Hero: "get one hero" return Hero().get(id=id, config=config) @hero_router.post("/hero/") -def post_hero(hero: Hero, config: Config = Depends(get_config)) -> Hero: +async def post_hero(hero: HeroCreate) -> HeroRead: "read all the heros" - hero.post(config=config) + config = get_config() + hero = hero.post(config=config) + await manager.broadcast({hero.json()}, id=1) + return hero + + +@hero_router.patch("/hero/") +async def patch_hero(hero: HeroUpdate) -> HeroRead: + "read all the heros" + config = get_config() + hero = hero.update(config=config) + await manager.broadcast({hero.json()}, id=1) + return hero + + +@hero_router.delete("/hero/{hero_id}") +async def delete_hero(hero_id: int): + "read all the heros" + hero = HeroDelete(id=hero_id) + config = get_config() + hero = hero.delete(config=config) + await manager.broadcast(f"deleted hero {hero_id}", id=1) return hero @hero_router.get("/heros/") -def get_heros(config: Config = Depends(get_config)) -> list[Hero]: +async def get_heros(config: Config = Depends(get_config)) -> list[Hero]: "get all heros" return Hero().get(config=config) - # Alternatively - # with get_config().database.session as session: - # statement = select(Hero) - # results = session.exec(statement).all() - # return results diff --git a/learn_sql_model/api/new.py b/learn_sql_model/api/new.py new file mode 100644 index 0000000..813ec54 --- /dev/null +++ b/learn_sql_model/api/new.py @@ -0,0 +1,45 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlmodel import SQLModel + +from learn_sql_model.api.user import oauth2_scheme +from learn_sql_model.config import Config, get_config +from learn_sql_model.models.new import new + +new_router = APIRouter() + + +@new_router.on_event("startup") +def on_startup() -> None: + SQLModel.metadata.create_all(get_config().database.engine) + + +@new_router.get("/items/") +async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): + return {"token": token} + + +@new_router.get("/new/{id}") +def get_new(id: int, config: Config = Depends(get_config)) -> new: + "get one new" + return new().get(id=id, config=config) + + +@new_router.get("/h/{id}") +def get_h(id: int, config: Config = Depends(get_config)) -> new: + "get one new" + return new().get(id=id, config=config) + + +@new_router.post("/new/") +def post_new(new: new, config: Config = Depends(get_config)) -> new: + "read all the news" + new.post(config=config) + return new + + +@new_router.get("/news/") +def get_news(config: Config = Depends(get_config)) -> list[new]: + "get all news" + return new().get(config=config) diff --git a/learn_sql_model/api/websocket.py b/learn_sql_model/api/websocket.py new file mode 100644 index 0000000..e0e8b5e --- /dev/null +++ b/learn_sql_model/api/websocket.py @@ -0,0 +1,72 @@ +from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse + +from learn_sql_model.api.websocket_connection_manager import manager + +web_socket_router = APIRouter() + +html = """ + + + + Chat + + +

WebSocket Chat

+
+ + +
+ + + + +""" + + +@web_socket_router.get("/watch") +async def get(): + return HTMLResponse(html) + + +@web_socket_router.websocket("/ws/{id}") +async def websocket_endpoint_connect(websocket: WebSocket, id: int): + await manager.connect(websocket, id) + try: + while True: + data = await websocket.receive_text() + await websocket.send_text(f"[gold]You Said: {data}") + await manager.broadcast(f"[blue]USER: {data}", id) + + except WebSocketDisconnect: + manager.disconnect(websocket, id) + await manager.broadcast(f"Client #{id} left the chat", id) + + +@web_socket_router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await manager.connect(websocket) + try: + while True: + data = await websocket.receive_text() + await manager.broadcast(f"[blue]USER: {data}") + + except WebSocketDisconnect: + manager.disconnect(websocket, id) + await manager.broadcast(f"Client #{client_id} left the chat", id) diff --git a/learn_sql_model/api/websocket_connection_manager.py b/learn_sql_model/api/websocket_connection_manager.py new file mode 100644 index 0000000..d832d53 --- /dev/null +++ b/learn_sql_model/api/websocket_connection_manager.py @@ -0,0 +1,41 @@ +from typing import Dict + +from fastapi import WebSocket + + +class ConnectionManager: + def __init__(self): + self.active_connections: Dict[str, list[WebSocket]] = {} + + async def connect(self, websocket: WebSocket, id: str): + print("connecting...", id) + if id not in self.active_connections: + self.active_connections[id] = [] + await websocket.accept() + self.active_connections[id].append(websocket) + + def disconnect(self, websocket: WebSocket, id: str): + if id not in self.active_connections: + return + self.active_connections[id].remove(websocket) + + async def send_personal_message(self, message: str, websocket: WebSocket): + await websocket.send_text(message) + + async def broadcast(self, message: str, id: str): + if id not in self.active_connections: + return + print(f"i go this message {message}") + print( + f"I am going to send it to {len(self.active_connections[id])} connections" + ) + for connection in self.active_connections[id]: + print("sending it to ", connection) + try: + await connection.send_text(message) + except Exception: + self.disconnect(connection, id) + print("sent it to ", connection) + + +manager = ConnectionManager() diff --git a/learn_sql_model/cli/app.py b/learn_sql_model/cli/app.py index a6f8b94..a5e1c21 100644 --- a/learn_sql_model/cli/app.py +++ b/learn_sql_model/cli/app.py @@ -5,6 +5,7 @@ from typer.main import get_group from learn_sql_model.cli.api import api_app from learn_sql_model.cli.config import config_app from learn_sql_model.cli.hero import hero_app +from learn_sql_model.cli.model import model_app app = typer.Typer( name="learn_sql_model", @@ -12,7 +13,7 @@ app = typer.Typer( ) app.add_typer(config_app, name="config") # app.add_typer(tui_app, name="tui") -# app.add_typer(model_app, name="model") +app.add_typer(model_app, name="model") app.add_typer(api_app, name="api") app.add_typer(hero_app, name="hero") diff --git a/learn_sql_model/cli/hero.py b/learn_sql_model/cli/hero.py index 768d389..267c413 100644 --- a/learn_sql_model/cli/hero.py +++ b/learn_sql_model/cli/hero.py @@ -2,6 +2,7 @@ import sys from typing import List, Optional, Union from engorgio import engorgio +import httpx from rich.console import Console import typer @@ -58,17 +59,17 @@ def create( config: Config = None, ) -> Hero: "read all the heros" - # config.init() - hero = hero.post(config=config) - Console().print(hero) - return hero - # config.init() - # with Session(config.database.engine) as session: - # db_hero = Hero.from_orm(hero) - # session.add(db_hero) - # session.commit() - # session.refresh(db_hero) - # return db_hero + + r = httpx.post( + f"{config.api_client.url}/hero/", + json=hero.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") + + # hero = hero.post(config=config) + # Console().print(hero) + # return hero @hero_app.command() @@ -78,9 +79,12 @@ def update( config: Config = None, ) -> Hero: "read all the heros" - hero = hero.update(config=config) - Console().print(hero) - return hero + r = httpx.patch( + f"{config.api_client.url}/hero/", + json=hero.dict(), + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") @hero_app.command() @@ -90,10 +94,11 @@ def delete( config: Config = None, ) -> Hero: "read all the heros" - # config.init() - hero = hero.delete(config=config) - return hero - # Console().print(hero) + r = httpx.delete( + f"{config.api_client.url}/hero/{hero.id}", + ) + if r.status_code != 200: + raise RuntimeError(f"{r.status_code}:\n {r.text}") @hero_app.command() diff --git a/learn_sql_model/cli/model.py b/learn_sql_model/cli/model.py index 826333d..df6a27d 100644 --- a/learn_sql_model/cli/model.py +++ b/learn_sql_model/cli/model.py @@ -1,7 +1,9 @@ +from pathlib import Path + import alembic -import typer from alembic.config import Config from copier import run_auto +import typer from learn_sql_model.cli.common import verbose_callback @@ -26,18 +28,19 @@ def create( callback=verbose_callback, help="show the log messages", ), - template=Path('templates/model') - run_auto(template, Path('.')) +): + template = Path("templates/model") + run_auto(str(template), ".") -@ model_app.command() +@model_app.command() def create_revision( - verbose: bool=typer.Option( + verbose: bool = typer.Option( False, callback=verbose_callback, help="show the log messages", ), - message: str=typer.Option( + message: str = typer.Option( prompt=True, ), ): @@ -51,23 +54,23 @@ def create_revision( alembic.command.upgrade(config=alembic_cfg, revision="head") -@ model_app.command() +@model_app.command() def checkout( - verbose: bool=typer.Option( + verbose: bool = typer.Option( False, callback=verbose_callback, help="show the log messages", ), - revision: str=typer.Option("head"), + revision: str = typer.Option("head"), ): alembic_cfg = Config("alembic.ini") alembic.command.upgrade(config=alembic_cfg, revision="head") -@ model_app.command() +@model_app.command() def populate( - verbose: bool=typer.Option( + verbose: bool = typer.Option( False, callback=verbose_callback, help="show the log messages", diff --git a/learn_sql_model/cli/new.py b/learn_sql_model/cli/new.py new file mode 100644 index 0000000..20c90cf --- /dev/null +++ b/learn_sql_model/cli/new.py @@ -0,0 +1,107 @@ +import sys +from typing import List, Optional, Union + +from engorgio import engorgio +from rich.console import Console +import typer + +from learn_sql_model.config import Config, get_config +from learn_sql_model.factories.new import newFactory +from learn_sql_model.factories.pet import PetFactory +from learn_sql_model.models.new import ( + new, + newCreate, + newDelete, + newRead, + newUpdate, +) + +new_app = typer.Typer() + + +@new_app.callback() +def new(): + "model cli" + + +@new_app.command() +@engorgio(typer=True) +def get( + id: Optional[int] = typer.Argument(default=None), + config: Config = None, +) -> Union[new, List[new]]: + "get one new" + config.init() + new = newRead.get(id=id, config=config) + Console().print(new) + return new + + +@new_app.command() +@engorgio(typer=True) +def list( + where: Optional[str] = None, + config: Config = None, + offset: int = 0, + limit: Optional[int] = None, +) -> Union[new, List[new]]: + "get one new" + new = newRead.list(config=config, where=where, offset=offset, limit=limit) + Console().print(new) + return new + + +@new_app.command() +@engorgio(typer=True) +def create( + new: newCreate, + config: Config = None, +) -> new: + "read all the news" + # config.init() + new = new.post(config=config) + Console().print(new) + return new + + +@new_app.command() +@engorgio(typer=True) +def update( + new: newUpdate, + config: Config = None, +) -> new: + "read all the news" + new = new.update(config=config) + Console().print(new) + return new + + +@new_app.command() +@engorgio(typer=True) +def delete( + new: newDelete, + config: Config = None, +) -> new: + "read all the news" + # config.init() + new = new.delete(config=config) + return new + + +@new_app.command() +@engorgio(typer=True) +def populate( + new: new, + n: int = 10, +) -> new: + "read all the news" + config = get_config() + if config.env == "prod": + Console().print("populate is not supported in production") + sys.exit(1) + + for new in newFactory().batch(n): + pet = PetFactory().build() + new.pet = pet + Console().print(new) + new.post(config=config) diff --git a/learn_sql_model/config.py b/learn_sql_model/config.py index 6e77623..62d28f6 100644 --- a/learn_sql_model/config.py +++ b/learn_sql_model/config.py @@ -20,6 +20,13 @@ class ApiServer(BaseModel): host: str = "0.0.0.0" +class ApiClient(BaseModel): + host: str = "0.0.0.0" + port: int = 5000 + protocol: str = "http" + url: str = f"{protocol}://{host}:{port}" + + class Database: def __init__(self, config: "Config" = None) -> None: if config is None: @@ -48,6 +55,7 @@ class Config(BaseSettings): env: str = "dev" database_url: str = "sqlite:///database.db" api_server: ApiServer = ApiServer() + api_client: ApiClient = ApiClient() class Config: extra = "ignore" diff --git a/learn_sql_model/factories/new.py b/learn_sql_model/factories/new.py new file mode 100644 index 0000000..8432c9c --- /dev/null +++ b/learn_sql_model/factories/new.py @@ -0,0 +1,14 @@ +from faker import Faker +from polyfactory.factories.pydantic_factory import ModelFactory + +from learn_sql_model.models.new import new + + +class newFactory(ModelFactory[new]): + __model__ = new + __faker__ = Faker(locale="en_US") + __set_as_default_factory_for_type__ = True + id = None + + __random_seed__ = 10 + diff --git a/learn_sql_model/models/new.py b/learn_sql_model/models/new.py new file mode 100644 index 0000000..5acca4f --- /dev/null +++ b/learn_sql_model/models/new.py @@ -0,0 +1,99 @@ +from typing import Optional + +from fastapi import HTTPException +from pydantic import BaseModel +from sqlmodel import Field, Relationship, SQLModel, Session, select + +from learn_sql_model.config import Config +from learn_sql_model.models.pet import Pet + + +class newBase(SQLModel, table=False): + + +class new(newBase, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class newCreate(newBase): + ... + + def post(self, config: Config) -> new: + config.init() + with Session(config.database.engine) as session: + db_new = new.from_orm(self) + session.add(db_new) + session.commit() + session.refresh(db_new) + return db_new + + +class newRead(newBase): + id: int + + @classmethod + def get( + cls, + config: Config, + id: int, + ) -> new: + + with config.database.session as session: + new = session.get(new, id) + if not new: + raise HTTPException(status_code=404, detail="new not found") + return new + + @classmethod + def list( + self, + config: Config, + where=None, + offset=0, + limit=None, + ) -> new: + + with config.database.session as session: + statement = select(new) + if where != "None": + from sqlmodel import text + + statement = statement.where(text(where)) + statement = statement.offset(offset).limit(limit) + newes = session.exec(statement).all() + return newes + + +class newUpdate(SQLModel): + # id is required to get the new + id: int + + # all other fields, must match the model, but with Optional default None + + def update(self, config: Config) -> new: + with Session(config.database.engine) as session: + db_new = session.get(new, self.id) + if not db_new: + raise HTTPException(status_code=404, detail="new not found") + new_data = self.dict(exclude_unset=True) + for key, value in new_data.items(): + if value is not None: + setattr(db_new, key, value) + session.add(db_new) + session.commit() + session.refresh(db_new) + return db_new + + +class newDelete(BaseModel): + id: int + + def delete(self, config: Config) -> new: + config.init() + with Session(config.database.engine) as session: + new = session.get(new, self.id) + if not new: + raise HTTPException(status_code=404, detail="new not found") + session.delete(new) + session.commit() + return {"ok": True} diff --git a/pyproject.toml b/pyproject.toml index 3c9919a..b0c9f41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,10 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ +"python-socketio[client]", +"fastapi-socketio", "anyconfig", + "copier", "engorgio", "fastapi", "httpx", diff --git a/templates/model/.pyflyby-{{modelname.lower()}}.jinja b/templates/model/.pyflyby-{{modelname.lower()}}.jinja new file mode 100644 index 0000000..ee3e946 --- /dev/null +++ b/templates/model/.pyflyby-{{modelname.lower()}}.jinja @@ -0,0 +1,5 @@ +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}} +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}Create +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}Read +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}Update +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}}Delete diff --git a/templates/model/copier.yml b/templates/model/copier.yml index f48b700..242b75b 100644 --- a/templates/model/copier.yml +++ b/templates/model/copier.yml @@ -3,5 +3,10 @@ _exclude: - README.md - .git - copier.yml -name: +modelname: type: str + +_tasks: + - "cat .pyflyby-{{modelname}} >> .pyflyby" + - "sort -u -o .pyflyby .pyflyby" + - "rm .pyflyby-{{modelname}}" diff --git a/templates/model/learn_sql_model/api/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/api/{{modelname.lower()}}.py.jinja new file mode 100644 index 0000000..663dea6 --- /dev/null +++ b/templates/model/learn_sql_model/api/{{modelname.lower()}}.py.jinja @@ -0,0 +1,45 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlmodel import SQLModel + +from learn_sql_model.api.user import oauth2_scheme +from learn_sql_model.config import Config, get_config +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}} + +{{modelname.lower()}}_router = APIRouter() + + +@{{modelname.lower()}}_router.on_event("startup") +def on_startup() -> None: + SQLModel.metadata.create_all(get_config().database.engine) + + +@{{modelname.lower()}}_router.get("/items/") +async def read_items(token: Annotated[str, Depends(oauth2_scheme)]): + return {"token": token} + + +@{{modelname.lower()}}_router.get("/{{modelname.lower()}}/{id}") +def get_{{modelname.lower()}}(id: int, config: Config = Depends(get_config)) -> {{modelname}}: + "get one {{modelname.lower()}}" + return {{modelname}}().get(id=id, config=config) + + +@{{modelname.lower()}}_router.get("/h/{id}") +def get_h(id: int, config: Config = Depends(get_config)) -> {{modelname}}: + "get one {{modelname.lower()}}" + return {{modelname}}().get(id=id, config=config) + + +@{{modelname.lower()}}_router.post("/{{modelname.lower()}}/") +def post_{{modelname.lower()}}({{modelname.lower()}}: {{modelname}}, config: Config = Depends(get_config)) -> {{modelname.lower()}}: + "read all the {{modelname.lower()}}s" + {{modelname.lower()}}.post(config=config) + return {{modelname.lower()}} + + +@{{modelname.lower()}}_router.get("/{{modelname.lower()}}s/") +def get_{{modelname.lower()}}s(config: Config = Depends(get_config)) -> list[{{modelname}}]: + "get all {{modelname.lower()}}s" + return {{modelname}}().get(config=config) diff --git a/templates/model/learn_sql_model/cli/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/cli/{{modelname.lower()}}.py.jinja new file mode 100644 index 0000000..80f2057 --- /dev/null +++ b/templates/model/learn_sql_model/cli/{{modelname.lower()}}.py.jinja @@ -0,0 +1,107 @@ +import sys +from typing import List, Optional, Union + +from engorgio import engorgio +from rich.console import Console +import typer + +from learn_sql_model.config import Config, get_config +from learn_sql_model.factories.{{modelname.lower()}} import {{modelname}}Factory +from learn_sql_model.factories.pet import PetFactory +from learn_sql_model.models.{{modelname.lower()}} import ( + {{modelname}}, + {{modelname}}Create, + {{modelname}}Delete, + {{modelname}}Read, + {{modelname}}Update, +) + +{{modelname.lower()}}_app = typer.Typer() + + +@{{modelname.lower()}}_app.callback() +def {{modelname.lower()}}(): + "model cli" + + +@{{modelname.lower()}}_app.command() +@engorgio(typer=True) +def get( + id: Optional[int] = typer.Argument(default=None), + config: Config = None, +) -> Union[{{modelname}}, List[{{modelname.lower()}}]]: + "get one {{modelname.lower()}}" + config.init() + {{modelname.lower()}} = {{modelname}}Read.get(id=id, config=config) + Console().print({{modelname.lower()}}) + return {{modelname.lower()}} + + +@{{modelname.lower()}}_app.command() +@engorgio(typer=True) +def list( + where: Optional[str] = None, + config: Config = None, + offset: int = 0, + limit: Optional[int] = None, +) -> Union[{{modelname}}, List[{{modelname.lower()}}]]: + "get one {{modelname.lower()}}" + {{modelname.lower()}} = {{modelname}}Read.list(config=config, where=where, offset=offset, limit=limit) + Console().print({{modelname.lower()}}) + return {{modelname.lower()}} + + +@{{modelname.lower()}}_app.command() +@engorgio(typer=True) +def create( + {{modelname.lower()}}: {{modelname}}Create, + config: Config = None, +) -> {{modelname}}: + "read all the {{modelname.lower()}}s" + # config.init() + {{modelname.lower()}} = {{modelname.lower()}}.post(config=config) + Console().print({{modelname.lower()}}) + return {{modelname.lower()}} + + +@{{modelname.lower()}}_app.command() +@engorgio(typer=True) +def update( + {{modelname.lower()}}: {{modelname}}Update, + config: Config = None, +) -> {{modelname}}: + "read all the {{modelname.lower()}}s" + {{modelname.lower()}} = {{modelname.lower()}}.update(config=config) + Console().print({{modelname.lower()}}) + return {{modelname.lower()}} + + +@{{modelname.lower()}}_app.command() +@engorgio(typer=True) +def delete( + {{modelname.lower()}}: {{modelname}}Delete, + config: Config = None, +) -> {{modelname}}: + "read all the {{modelname.lower()}}s" + # config.init() + {{modelname.lower()}} = {{modelname.lower()}}.delete(config=config) + return {{modelname.lower()}} + + +@{{modelname.lower()}}_app.command() +@engorgio(typer=True) +def populate( + {{modelname.lower()}}: {{modelname}}, + n: int = 10, +) -> {{modelname}}: + "read all the {{modelname.lower()}}s" + config = get_config() + if config.env == "prod": + Console().print("populate is not supported in production") + sys.exit(1) + + for {{modelname.lower()}} in {{modelname}}Factory().batch(n): + pet = PetFactory().build() + {{modelname.lower()}}.pet = pet + Console().print({{modelname.lower()}}) + {{modelname.lower()}}.post(config=config) diff --git a/templates/model/learn_sql_model/factories/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/factories/{{modelname.lower()}}.py.jinja new file mode 100644 index 0000000..995f10b --- /dev/null +++ b/templates/model/learn_sql_model/factories/{{modelname.lower()}}.py.jinja @@ -0,0 +1,14 @@ +from faker import Faker +from polyfactory.factories.pydantic_factory import ModelFactory + +from learn_sql_model.models.{{modelname.lower()}} import {{modelname}} + + +class {{modelname}}Factory(ModelFactory[{{modelname.lower()}}]): + __model__ = {{modelname}} + __faker__ = Faker(locale="en_US") + __set_as_default_factory_for_type__ = True + id = None + + __random_seed__ = 10 + diff --git a/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja b/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja new file mode 100644 index 0000000..0aef84c --- /dev/null +++ b/templates/model/learn_sql_model/models/{{modelname.lower()}}.py.jinja @@ -0,0 +1,99 @@ +from typing import Optional + +from fastapi import HTTPException +from pydantic import BaseModel +from sqlmodel import Field, Relationship, SQLModel, Session, select + +from learn_sql_model.config import Config +from learn_sql_model.models.pet import Pet + + +class {{modelname}}Base(SQLModel, table=False): + + +class {{modelname}}({{modelname.lower()}}Base, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class {{modelname}}Create({{modelname.lower()}}Base): + ... + + def post(self, config: Config) -> {{modelname}}: + config.init() + with Session(config.database.engine) as session: + db_{{modelname.lower()}} = {{modelname}}.from_orm(self) + session.add(db_{{modelname.lower()}}) + session.commit() + session.refresh(db_{{modelname.lower()}}) + return db_{{modelname.lower()}} + + +class {{modelname}}Read({{modelname.lower()}}Base): + id: int + + @classmethod + def get( + cls, + config: Config, + id: int, + ) -> {{modelname}}: + + with config.database.session as session: + {{modelname.lower()}} = session.get({{modelname}}, id) + if not {{modelname.lower()}}: + raise HTTPException(status_code=404, detail="{{modelname}} not found") + return {{modelname.lower()}} + + @classmethod + def list( + self, + config: Config, + where=None, + offset=0, + limit=None, + ) -> {{modelname}}: + + with config.database.session as session: + statement = select({{modelname}}) + if where != "None": + from sqlmodel import text + + statement = statement.where(text(where)) + statement = statement.offset(offset).limit(limit) + {{modelname.lower()}}es = session.exec(statement).all() + return {{modelname.lower()}}es + + +class {{modelname}}Update(SQLModel): + # id is required to get the {{modelname.lower()}} + id: int + + # all other fields, must match the model, but with Optional default None + + def update(self, config: Config) -> {{modelname}}: + with Session(config.database.engine) as session: + db_{{modelname.lower()}} = session.get({{modelname}}, self.id) + if not db_{{modelname.lower()}}: + raise HTTPException(status_code=404, detail="{{modelname}} not found") + {{modelname.lower()}}_data = self.dict(exclude_unset=True) + for key, value in {{modelname.lower()}}_data.items(): + if value is not None: + setattr(db_{{modelname.lower()}}, key, value) + session.add(db_{{modelname.lower()}}) + session.commit() + session.refresh(db_{{modelname.lower()}}) + return db_{{modelname.lower()}} + + +class {{modelname}}Delete(BaseModel): + id: int + + def delete(self, config: Config) -> {{modelname}}: + config.init() + with Session(config.database.engine) as session: + {{modelname.lower()}} = session.get({{modelname}}, self.id) + if not {{modelname.lower()}}: + raise HTTPException(status_code=404, detail="{{modelname}} not found") + session.delete({{modelname.lower()}}) + session.commit() + return {"ok": True}