diff --git a/.{{package_name}}-copier-answers.yml b/.pydantic-typer-copier-answers.yml similarity index 100% rename from .{{package_name}}-copier-answers.yml rename to .pydantic-typer-copier-answers.yml diff --git a/README.md b/README.md index 4bf52eb..0fc1973 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # pydantic-typer +pydantic-typer is a Python package that provides a decorator that can be used to write functions that accept a Pydantic model, but at runtime will allow them to pass in fields to create the models on the fly. + https://user-images.githubusercontent.com/22648375/235036031-a9dc6589-e350-4a18-9114-6568cb362f74.mp4 ## Installation diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..2208ec0 --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,6 @@ +"""Example usage of expand_pydantic_args. + +SPDX-FileCopyrightText: 2023-present Waylon S. Walker + +SPDX-License-Identifier: MIT +""" diff --git a/examples/person.py b/examples/person.py new file mode 100644 index 0000000..62a60b4 --- /dev/null +++ b/examples/person.py @@ -0,0 +1,27 @@ +"""Example usage of expand_pydantic_args with the Person model. + +SPDX-FileCopyrightText: 2023-present Waylon S. Walker + +SPDX-License-Identifier: MIT +""" +from pyannotate_runtime import collect_types + +from pydantic_typer import expand_pydantic_args +from tests.models import Person + + +@expand_pydantic_args() +def get_person(person: Person, thing: str = None) -> Person: + """Mydocstring.""" + from rich import print + + print(str(thing)) + print(person) + + +if __name__ == "__main__": + collect_types.init_types_collection() + with collect_types.collect(): + person = get_person(name="John", age=1, r=1, g=1, b=1, a=1, length=1) + + collect_types.dump_stats("type_info.json") diff --git a/examples/person_cli.py b/examples/person_cli.py new file mode 100644 index 0000000..af3faad --- /dev/null +++ b/examples/person_cli.py @@ -0,0 +1,37 @@ +"""Example usage of expand_pydantic_args with the Person model as a typer cli. + +SPDX-FileCopyrightText: 2023-present Waylon S. Walker + +SPDX-License-Identifier: MIT +""" +import typer + +from pydantic_typer import expand_pydantic_args +from tests.models import Person + +app = typer.Typer( + name="pydantic_typer", + help="a demo app", +) + + +@app.callback() +def main() -> None: + """Set up typer.""" + return + + +@app.command() +@expand_pydantic_args(typer=True) +def get_person(person: Person, thing: str, another: str = "this") -> Person: + """Get a person's information.""" + from rich import print + + print(thing) + print(another) + + print(person) + + +if __name__ == "__main__": + typer.run(get_person) diff --git a/pydantic_typer/__about__.py b/pydantic_typer/__about__.py index 90da1db..af0123b 100644 --- a/pydantic_typer/__about__.py +++ b/pydantic_typer/__about__.py @@ -1,4 +1,9 @@ -# SPDX-FileCopyrightText: 2023-present Waylon S. Walker -# -# SPDX-License-Identifier: MIT +"""About pydantic_typer. + +Sets metadata about pydantic_typer. + +SPDX-FileCopyrightText: 2023-present Waylon S. Walker + +SPDX-License-Identifier: MIT +""" __version__ = "0.0.0.dev1" diff --git a/pydantic_typer/__init__.py b/pydantic_typer/__init__.py index 8e4f310..dd40910 100644 --- a/pydantic_typer/__init__.py +++ b/pydantic_typer/__init__.py @@ -1,44 +1,26 @@ -# SPDX-FileCopyrightText: 2023-present Waylon S. Walker -## -# SPDX-License-Identifier: MIT +"""pydantic_typer. -from functools import wraps +SPDX-FileCopyrightText: 2023-present Waylon S. Walker + +SPDX-License-Identifier: MIT +""" import inspect -from typing import Callable, Optional +from functools import wraps +from typing import Any, Callable, Dict, Optional -from pydantic import BaseModel, Field import typer +from pydantic.fields import ModelField __all__ = ["typer"] -class Alpha(BaseModel): - a: int - - -class Color(BaseModel): - r: int - g: int - b: int - alpha: Alpha - - -class Hair(BaseModel): - color: Color - length: int - - -class Person(BaseModel): - name: str - other_name: Optional[str] = None - age: int - email: Optional[str] - pet: str = "dog" - address: str = Field("123 Main St", description="Where the person calls home.") - hair: Hair - - -def make_annotation(name, field, names, typer=False): +def _make_annotation( + name: str, + field: ModelField, + names: Dict[str, str], + *, + typer: bool = False, +) -> str: panel_name = names.get(name) next_name = panel_name while next_name is not None: @@ -66,17 +48,24 @@ def make_annotation(name, field, names, typer=False): default = f' = typer.Option("{field.default}", help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")' else: default = f'="{field.default}"' + elif typer: + default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)' else: - if typer: - default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)' - else: - default = "" + default = "" if typer: return f"{name}: {annotation}{default}" return f"{name}: {annotation}{default}" -def make_signature(func, wrapper, typer=False, more_args={}): +def _make_signature( + func: Callable, + wrapper: Callable, + *, + typer: bool = False, + more_args: Optional[Dict] = None, +): + if more_args is None: + more_args = {} sig = inspect.signature(func) names = {} for name, param in sig.parameters.items(): @@ -88,7 +77,7 @@ def make_signature(func, wrapper, typer=False, more_args={}): more_args[name] = param while any( - [hasattr(param.annotation, "__fields__") for name, param in more_args.items()] + hasattr(param.annotation, "__fields__") for name, param in more_args.items() ): keys_to_remove = [] for name, param in more_args.items(): @@ -99,7 +88,6 @@ def make_signature(func, wrapper, typer=False, more_args={}): if name not in param.annotation.__fields__.keys(): keys_to_remove.append(name) more_args = {**more_args, **param.annotation.__fields__} - # names[name] = param.annotation.__name__ for field in param.annotation.__fields__: names[field] = param.annotation.__name__ @@ -109,9 +97,8 @@ def make_signature(func, wrapper, typer=False, more_args={}): wrapper.__doc__ = ( func.__doc__ or "" ) + f"\nalso accepts {more_args.keys()} in place of person model" - # fields = Person.__fields__ raw_args = [ - make_annotation( + _make_annotation( name, field, names=names, @@ -130,34 +117,27 @@ def {func.__name__}({aargs}{', ' if aargs else ''}{kwargs}): '''{func.__doc__}''' return wrapper({call_args}) """ - # new_func_sig = f"""{func.__name__}({args}{', ' if args else ''}{kwargs})""" - # import typing - - # from makefun import create_function - - # __all__ = ["typing"] - - # new_func = create_function(new_func_sig, func, inject_as_first_arg=True) - - # signature = inspect.Signature() - # signature.add("a", inspect.Parameter(default=1)) - # signature.add("b", inspect.Parameter(default=2)) - # signature.return_annotation = int - # func.signature = signature - # signature = inspect.Signature( - # a=Parameter(default=1), b=Parameter(default=2), return_annotation=int - # ) - exec(new_func_str, locals(), globals()) + exec(new_func_str, locals(), globals()) # noqa: S102 new_func = globals()[func.__name__] sig = inspect.signature(new_func) - for name, param in sig.parameters.items(): + for param in sig.parameters.values(): if hasattr(param.annotation, "__fields__"): - return make_signature(new_func, wrapper, typer=typer, more_args=more_args) + return _make_signature(new_func, wrapper, typer=typer, more_args=more_args) return new_func -def _expand_param(param, kwargs, models=None): +def _expand_param( + param: inspect.Parameter, + kwargs: Dict[str, Any], + models: Optional[Dict[str, str]] = None, +) -> Any: + """Further expands params with a Pydantic annotation, given a param. + + Recursively creates an instance of any param.annotation that has __fields__ + using the expanded kwargs.y: + using the expanded kwargs. + """ models = {} for field_name, field in param.annotation.__fields__.items(): if hasattr(field.annotation, "__fields__"): @@ -165,7 +145,12 @@ def _expand_param(param, kwargs, models=None): return param.annotation(**kwargs, **models) -def _expand_kwargs(func, kwargs): +def _expand_kwargs(func: Callable, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Expand kwargs with Pydantic annotations given a function. + + Inspects the arguments of the func and expands any of the kwargs with a + Pydantic annotation, to add its fields to the kwargs. + """ sig = inspect.signature(func) updated_kwargs = {} for name, value in kwargs.items(): @@ -182,34 +167,21 @@ def _expand_kwargs(func, kwargs): elif hasattr(param.annotation, "__fields__"): updated_kwargs[name] = _expand_param(param, kwargs) # its something else so pass it - # else: - # updated_kwargs[name] = kwargs[name] return updated_kwargs -def expand_pydantic_args(typer: bool = False) -> Callable: - def decorator(func: Callable) -> Callable[..., any]: +def expand_pydantic_args(*, typer: bool = False) -> Callable: + """Expand Pydantic keyword arguments. + + Decorator function to expand arguments of pydantic models to accept the + individual fields of Models. + """ + + def decorator(func: Callable) -> Callable[..., Any]: @wraps(func) def wrapper(*args, **kwargs): - return func(**_expand_kwargs(func, kwargs)) + return func(*args, **_expand_kwargs(func, kwargs)) - return make_signature(func, wrapper, typer=typer) + return _make_signature(func, wrapper, typer=typer) return decorator - - -def get_person_vanilla(person: Person) -> Person: - from rich import print - - print(person) - return person - - -@expand_pydantic_args() -def get_person(person: Person, thing: str = None) -> Person: - """mydocstring""" - from rich import print - - print(str(thing)) - - print(person) diff --git a/pydantic_typer/__main__.py b/pydantic_typer/__main__.py deleted file mode 100644 index 04b645e..0000000 --- a/pydantic_typer/__main__.py +++ /dev/null @@ -1,9 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present Waylon S. Walker -# -# SPDX-License-Identifier: MIT -import sys - -if __name__ == '__main__': - from .cli import {{python_package}} - - sys.exit({{python_package}}()) diff --git a/pydantic_typer/cli/__init__.py b/pydantic_typer/cli/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pydantic_typer/cli/app.py b/pydantic_typer/cli/app.py deleted file mode 100644 index c629c96..0000000 --- a/pydantic_typer/cli/app.py +++ /dev/null @@ -1,62 +0,0 @@ -import typer - -from pydantic_typer import Person, expand_pydantic_args -from pydantic_typer.cli.common import verbose_callback -from pydantic_typer.cli.config import config_app -from pydantic_typer.cli.tui import tui_app - -app = typer.Typer( - name="pydantic_typer", - help="A rich terminal report for coveragepy.", -) -app.add_typer(config_app) -app.add_typer(tui_app) - - -def version_callback(value: bool) -> None: - """Callback function to print the version of the pydantic-typer package. - - Args: - value (bool): Boolean value to determine if the version should be printed. - - Raises: - typer.Exit: If the value is True, the version will be printed and the program will exit. - - Example: - version_callback(True) - """ - if value: - from pydantic_typer.__about__ import __version__ - - typer.echo(f"{__version__}") - raise typer.Exit() - - -@app.callback() -def main( - version: bool = typer.Option( - None, - "--version", - callback=version_callback, - is_eager=True, - ), - verbose: bool = typer.Option( - False, - callback=verbose_callback, - help="show the log messages", - ), -) -> None: - return - - -@app.command() -@expand_pydantic_args -def get_person(person: Person) -> Person: - """mydocstring""" - from rich import print - - print(person) - - -if __name__ == "__main__": - typer.run(main) diff --git a/pydantic_typer/cli/common.py b/pydantic_typer/cli/common.py deleted file mode 100644 index 2957684..0000000 --- a/pydantic_typer/cli/common.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic_typer.console import console - - -def verbose_callback(value: bool) -> None: - if value: - console.quiet = False diff --git a/pydantic_typer/cli/config.py b/pydantic_typer/cli/config.py deleted file mode 100644 index 6699fe7..0000000 --- a/pydantic_typer/cli/config.py +++ /dev/null @@ -1,29 +0,0 @@ -from rich.console import Console -import typer - -from pydantic_typer.cli.common import verbose_callback -from pydantic_typer.config import config as configuration - -config_app = typer.Typer() - - -@config_app.callback() -def config( - verbose: bool = typer.Option( - False, - callback=verbose_callback, - help="show the log messages", - ), -): - "configuration cli" - - -@config_app.command() -def show( - verbose: bool = typer.Option( - False, - callback=verbose_callback, - help="show the log messages", - ), -): - Console().print(configuration) diff --git a/pydantic_typer/cli/tui.py b/pydantic_typer/cli/tui.py deleted file mode 100644 index 0608e8a..0000000 --- a/pydantic_typer/cli/tui.py +++ /dev/null @@ -1,18 +0,0 @@ -import typer - -from pydantic_typer.cli.common import verbose_callback -from pydantic_typer.tui.app import run_app - -tui_app = typer.Typer() - - -@tui_app.callback(invoke_without_command=True) -def i( - verbose: bool = typer.Option( - False, - callback=verbose_callback, - help="show the log messages", - ), -): - "interactive tui" - run_app() diff --git a/pydantic_typer/config.py b/pydantic_typer/config.py deleted file mode 100644 index fbec538..0000000 --- a/pydantic_typer/config.py +++ /dev/null @@ -1,3 +0,0 @@ -from pydantic_typer.standard_config import load - -config = load("pydantic_typer") diff --git a/pydantic_typer/console.py b/pydantic_typer/console.py deleted file mode 100644 index d160979..0000000 --- a/pydantic_typer/console.py +++ /dev/null @@ -1,4 +0,0 @@ -from rich.console import Console - -console = Console() -console.quiet = True diff --git a/pydantic_typer/standard_config.py b/pydantic_typer/standard_config.py deleted file mode 100644 index 0f99499..0000000 --- a/pydantic_typer/standard_config.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Standard Config. -A module to load tooling config from a users project space. - -Inspired from frustrations that some tools have a tool.ini, .tool.ini, -setup.cfg, or pyproject.toml. Some allow for global configs, some don't. Some -properly follow the users home directory, others end up in a weird temp -directory. Windows home directory is only more confusing. Some will even -respect the users `$XDG_HOME` directory. - - -This file is for any project that can be configured in plain text such as `ini` -or `toml` and not requiring a .py file. Just name your tool and let users put -config where it makes sense to them, no need to figure out resolution order. - -## Usage: - -``` python -from standard_config import load - -# Retrieve any overrides from the user -overrides = {'setting': True} -config = load('my_tool', overrides) -``` - -## Resolution Order - -* First global file with a tool key -* First local file with a tool key -* Environment variables prefixed with `TOOL` -* Overrides - -### Tool Specific Ini files - -Ini file formats must include a `` key. - -``` ini -[my_tool] -setting = True -``` - -### pyproject.toml - -Toml files must include a `tool.` key - -``` toml -[tool.my_tool] -setting = True -``` - -### setup.cfg - -setup.cfg files must include a `tool:` key - -``` ini -[tool:my_tool] -setting = True -``` - - -### global files to consider - -* /tool.ini -* /.tool -* /.tool.ini -* /.config/tool.ini -* /.config/.tool -* /.config/.tool.ini - -### local files to consider - -* /tool.ini -* /.tool -* /.tool.ini -* /pyproject.toml -* /setup.cfg - -""" - -import os -from pathlib import Path -from typing import Dict, List, Union - -import anyconfig - -# path_spec_type = List[Dict[str, Union[Path, str, List[str\}\}\}\} -path_spec_type = List - - -def _get_global_path_specs(tool: str) -> path_spec_type: - """ - Generate a list of standard pathspecs for global config files. - - Args: - tool (str): name of the tool to configure - """ - try: - home = Path(os.environ["XDG_HOME"]) - except KeyError: - home = Path.home() - - return [ - {"path_specs": home / f"{tool}.ini", "ac_parser": "ini", "keys": [tool]}, - {"path_specs": home / f".{tool}", "ac_parser": "ini", "keys": [tool]}, - {"path_specs": home / f".{tool}.ini", "ac_parser": "ini", "keys": [tool]}, - { - "path_specs": home / ".config" / f"{tool}.ini", - "ac_parser": "ini", - "keys": [tool], - }, - { - "path_specs": home / ".config" / f".{tool}", - "ac_parser": "ini", - "keys": [tool], - }, - { - "path_specs": home / ".config" / f".{tool}.ini", - "ac_parser": "ini", - "keys": [tool], - }, - ] - - -def _get_local_path_specs(tool: str, project_home: Union[str, Path]) -> path_spec_type: - """ - Generate a list of standard pathspecs for local, project directory config files. - - Args: - tool (str): name of the tool to configure - """ - return [ - { - "path_specs": Path(project_home) / f"{tool}.ini", - "ac_parser": "ini", - "keys": [tool], - }, - { - "path_specs": Path(project_home) / f".{tool}", - "ac_parser": "ini", - "keys": [tool], - }, - { - "path_specs": Path(project_home) / f".{tool}.ini", - "ac_parser": "ini", - "keys": [tool], - }, - { - "path_specs": Path(project_home) / f"{tool}.yml", - "ac_parser": "yaml", - "keys": [tool], - }, - { - "path_specs": Path(project_home) / f".{tool}.yml", - "ac_parser": "yaml", - "keys": [tool], - }, - { - "path_specs": Path(project_home) / f"{tool}.toml", - "ac_parser": "toml", - "keys": [tool], - }, - { - "path_specs": Path(project_home) / f".{tool}.toml", - "ac_parser": "toml", - "keys": [tool], - }, - { - "path_specs": Path(project_home) / "pyproject.toml", - "ac_parser": "toml", - "keys": ["tool", tool], - }, - { - "path_specs": Path(project_home) / "setup.cfg", - "ac_parser": "ini", - "keys": [f"tool.{tool}"], - }, - ] - - -def _get_attrs(attrs: list, config: Dict) -> Dict: - """Get nested config data from a list of keys. - - specifically written for pyproject.toml which needs to get `tool` then `` - """ - for attr in attrs: - config = config[attr] - return config - - -def _load_files(config_path_specs: path_spec_type) -> Dict: - """Use anyconfig to load config files stopping at the first one that exists. - - config_path_specs (list): a list of pathspecs and keys to load - """ - for file in config_path_specs: - if file["path_specs"].exists(): - config = anyconfig.load(**file) - else: - # ignore missing files - continue - - try: - return _get_attrs(file["keys"], config) - except KeyError: - # ignore incorrect keys - continue - - return {} - - -def _load_env(tool: str) -> Dict: - """Load config from environment variables. - - Args: - tool (str): name of the tool to configure - """ - vars = [var for var in os.environ.keys() if var.startswith(tool.upper())] - return { - var.lower().strip(tool.lower()).strip("_").strip("-"): os.environ[var] - for var in vars - } - - -def load(tool: str, project_home: Union[Path, str] = ".", overrides: Dict = {}) -> Dict: - """Load tool config from standard config files. - - Resolution Order - - * First global file with a tool key - * First local file with a tool key - * Environment variables prefixed with `TOOL` - * Overrides - - Args: - tool (str): name of the tool to configure - """ - global_config = _load_files(_get_global_path_specs(tool)) - local_config = _load_files(_get_local_path_specs(tool, project_home)) - env_config = _load_env(tool) - return {**global_config, **local_config, **env_config, **overrides} diff --git a/pydantic_typer/tui/app.css b/pydantic_typer/tui/app.css deleted file mode 100644 index 7ed9fce..0000000 --- a/pydantic_typer/tui/app.css +++ /dev/null @@ -1,18 +0,0 @@ -Screen { - align: center middle; - layers: main footer; -} - -Sidebar { - height: 100vh; - width: auto; - min-width: 20; - background: $secondary-background-darken-2; - dock: left; - margin-right: 1; - layer: main; -} - -Footer { - layer: footer; -} diff --git a/pydantic_typer/tui/app.py b/pydantic_typer/tui/app.py deleted file mode 100644 index a213b27..0000000 --- a/pydantic_typer/tui/app.py +++ /dev/null @@ -1,62 +0,0 @@ -from pathlib import Path - -from textual.app import App, ComposeResult -from textual.containers import Container -from textual.css.query import NoMatches -from textual.widgets import Footer, Static - -from pydantic_typer.config import config - -config["tui"] = {} -config["tui"]["bindings"] = {} - - -class Sidebar(Static): - def compose(self) -> ComposeResult: - yield Container( - Static("sidebar"), - id="sidebar", - ) - - -class Tui(App): - """A Textual app to manage requests.""" - - CSS_PATH = Path("__file__").parent / "app.css" - BINDINGS = [tuple(b.values()) for b in config["tui"]["bindings"]] - - def compose(self) -> ComposeResult: - """Create child widgets for the app.""" - yield Container(Static("hello world")) - yield Footer() - - def action_toggle_dark(self) -> None: - """An action to toggle dark mode.""" - self.dark = not self.dark - - def action_toggle_sidebar(self): - try: - self.query_one("PromptSidebar").remove() - except NoMatches: - self.mount(Sidebar()) - - -def run_app(): - import os - import sys - - from textual.features import parse_features - - dev = "--dev" in sys.argv - features = set(parse_features(os.environ.get("TEXTUAL", ""))) - if dev: - features.add("debug") - features.add("devtools") - - os.environ["TEXTUAL"] = ",".join(sorted(features)) - app = Tui() - app.run() - - -if __name__ == "__main__": - run_app() diff --git a/pyproject.toml b/pyproject.toml index 1090428..64d7c90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,8 +52,9 @@ dependencies = [ "pytest", "pytest-cov", "pytest-mock", - "pytest-rich", + 'polyfactory', "ruff", + 'pyannotate', "black", ] [tool.hatch.envs.default.scripts] @@ -89,9 +90,77 @@ exclude_lines = [ ] [tool.pytest.ini_options] -addopts = "-ra -q --rich" -asyncio_mode = "auto" +addopts = "-ra -q" testpaths = ["tests"] [tool.coverage_rich] fail-under=80 + +[tool.ruff] +ignore = ["E501", "D211", "D212", "D213"] +target-version = "py37" + + + +select = [ +"F", # Pyflakes +"E", # Error +"W", # Warning +"C90", # mccabe +"I", # isort +"N", # pep8-naming +"D", # pydocstyle +"UP", # pyupgrade +"YTT", # flake8-2020 +# "ANN", # flake8-annotations +"S", # flake8-bandit +"BLE", # flake8-blind-except +"FBT", # flake8-boolean-trap +"B", # flake8-bugbear +"A", # flake8-builtins +"COM", # flake8-commas +"C4", # flake8-comprehensions +"DTZ", # flake8-datetimez +"T10", # flake8-debugger +"DJ", # flake8-django +"EM", # flake8-errmsg +"EXE", # flake8-executable +"ISC", # flake8-implicit-str-concat +"ICN", # flake8-import-conventions +"G", # flake8-logging-format +"INP", # flake8-no-pep420 +"PIE", # flake8-pie +"T20", # flake8-print +"PYI", # flake8-pyi +"PT", # flake8-pytest-style +"Q", # flake8-quotes +"RSE", # flake8-raise +"RET", # flake8-return +"SLF", # flake8-self +"SIM", # flake8-simplify +"TID", # flake8-tidy-imports +"TCH", # flake8-type-checking +"INT", # flake8-gettext +"ARG", # flake8-unused-arguments +"PTH", # flake8-use-pathlib +"ERA", # eradicate +"PD", # pandas-vet +"PGH", # pygrep-hooks +"PL", # Pylint +"PLC", # Convention +"PLE", # Error +"PLR", # Refactor +"PLW", # Warning +"TRY", # tryceratops +"NPY", # NumPy-specific rules +"RUF", # Ruff-specific rules +] +[tool.ruff.mccabe] +# Flag errors (`C901`) whenever the complexity level exceeds 5. +max-complexity = 13 + +[tool.ruff.pylint] +max-branches = 13 + +[tool.ruff.per-file-ignores] +'tests/**' = ["D100", "D101", "D102", "D103", "D104", "D105", "S101"] diff --git a/tests/__init__.py b/tests/__init__.py index 2a9f0e4..ac8dd57 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,6 @@ -# SPDX-FileCopyrightText: 2023-present Waylon S. Walker -# -# SPDX-License-Identifier: MIT +"""Tests. + +SPDX-FileCopyrightText: 2023-present Waylon S. Walker + +SPDX-License-Identifier: MIT +""" diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 0000000..8c02bb4 --- /dev/null +++ b/tests/models.py @@ -0,0 +1,123 @@ +"""Models defines a set of classes for representing people and their hair. + +Classes: + +* `Alpha`: A class for representing an alpha value. +* `Color`: A class for representing a color. +* `Hair`: A class for representing hair. +* `Person`: A class for representing a person. + +""" + +from typing import Optional + +from polyfactory.factories.pydantic_factory import ModelFactory +from pydantic import BaseModel, Field + + +class Alpha(BaseModel): + + """A class for representing an alpha value.""" + + a: int = Field( + ..., + description="The alpha value.", + ) + + +class Color(BaseModel): + + """A class for representing a color.""" + + r: int = Field( + ..., + description="The red component of the color.", + ) + g: int = Field( + ..., + description="The green component of the color.", + ) + b: int = Field( + ..., + description="The blue component of the color.", + ) + alpha: Alpha = Field( + ..., + description="The alpha value of the color.", + ) + + +class Hair(BaseModel): + + """A class for representing hair.""" + + color: Color = Field( + ..., + description="The color of the hair.", + ) + length: int = Field( + ..., + description="The length of the hair.", + ) + + +class Person(BaseModel): + + """A class for representing a person.""" + + name: str = Field( + ..., + description="The name of the person.", + ) + alias: Optional[str] = Field( + None, + description="An optional other name for the person.", + ) + age: int = Field( + ..., + description="The age of the person.", + ) + email: Optional[str] = Field( + None, + description="An optional email address for the person.", + ) + pet: str = Field( + "dog", + description="The person's pet.", + ) + address: str = Field( + "123 Main St", + description="Where the person calls home.", + ) + hair: Hair = Field( + ..., + description="The person's hair.", + ) + + +class AlphaFactory(ModelFactory[Alpha]): + + """A class for generating an alpha value.""" + + __model__ = Alpha + + +class ColorFactory(ModelFactory[Color]): + + """A class for generating a color.""" + + __model__ = Color + + +class HairFactory(ModelFactory[Hair]): + + """A class for generating hair.""" + + __model__ = Hair + + +class PersonFactory(ModelFactory[Person]): + + """A class for generating a person.""" + + __model__ = Person diff --git a/tests/test_person.py b/tests/test_person.py new file mode 100644 index 0000000..e0ce69e --- /dev/null +++ b/tests/test_person.py @@ -0,0 +1,164 @@ +"""Example usage of expand_pydantic_args with the Person model. + +SPDX-FileCopyrightText: 2023-present Waylon S. Walker + +SPDX-License-Identifier: MIT +""" + +import inspect + +import pytest + +from pydantic_typer import expand_pydantic_args + +from . import models + +# this one is broken +# def test_no_pydantic() -> None: +# @expand_pydantic_args() +# def get_person(alpha) -> None: +# """Mydocstring.""" + + +def test_single_signature() -> None: + @expand_pydantic_args() + def get_person(alpha: models.Alpha) -> None: + """Mydocstring.""" + return alpha + + sig = inspect.signature(get_person) + params = sig.parameters + assert "a" in params + + assert "alpha" not in params + + +@pytest.mark.parametrize( + "alpha", + models.AlphaFactory().batch(size=5), +) +def test_single_instance(alpha: models.Alpha) -> None: + @expand_pydantic_args() + def get_person(alpha: models.Alpha) -> None: + """Mydocstring.""" + return alpha + + assert get_person(**alpha.dict()) == alpha + # this should maybe work + # assert get_person(models.Alpha(a=1)) == models.Alpha(a=1) + + +def test_one_nest_signature() -> None: + @expand_pydantic_args() + def get_person(color: models.Color) -> None: + """Mydocstring.""" + return color + + sig = inspect.signature(get_person) + params = sig.parameters + assert "r" in params + assert "g" in params + assert "b" in params + assert "a" in params + + assert "color" not in params + assert "alpha" not in params + + +@pytest.mark.parametrize( + "color", + models.ColorFactory().batch(size=5), +) +def test_one_nest_instance(color: models.Color) -> None: + @expand_pydantic_args() + def get_person(color: models.Color) -> None: + """Mydocstring.""" + return color + + assert get_person(**color.dict(exclude={"alpha"}), **color.alpha.dict()) == color + + +def test_two_nest_signature() -> None: + @expand_pydantic_args() + def get_person(hair: models.Hair) -> None: + """Mydocstring.""" + return hair + + sig = inspect.signature(get_person) + params = sig.parameters + assert "length" in params + assert "r" in params + assert "g" in params + assert "b" in params + assert "a" in params + + assert "hair" not in params + assert "color" not in params + assert "alpha" not in params + + +@pytest.mark.parametrize( + "hair", + models.HairFactory().batch(size=5), +) +def test_two_nest_instance(hair: models.Hair) -> None: + @expand_pydantic_args() + def get_person(hair: models.Hair) -> None: + """Mydocstring.""" + return hair + + assert ( + get_person( + **hair.dict(exclude={"color"}), + **hair.color.dict(exclude={"alpha"}), + **hair.color.alpha.dict() + ) == + hair + ) + + +def test_three_nest_signature() -> None: + @expand_pydantic_args() + def get_person(person: models.Person) -> None: + """Mydocstring.""" + return person + + sig = inspect.signature(get_person) + params = sig.parameters + assert "name" in params + assert "alias" in params + assert "age" in params + assert "email" in params + assert "pet" in params + assert "address" in params + assert "length" in params + assert "r" in params + assert "g" in params + assert "b" in params + assert "a" in params + + assert "person" not in params + assert "hair" not in params + assert "color" not in params + assert "alpha" not in params + + +@pytest.mark.parametrize( + "person", + models.PersonFactory().batch(size=5), +) +def test_three_nest_instance(person: models.Person) -> None: + @expand_pydantic_args() + def get_person(person: models.Person) -> None: + """Mydocstring.""" + return person + + assert ( + get_person( + **person.dict(exclude={"hair"}), + **person.hair.dict(exclude={"color"}), + **person.hair.color.dict(exclude={"alpha"}), + **person.hair.color.alpha.dict() + ) == + person + )