Compare commits

..

10 commits

Author SHA1 Message Date
9e1c83f438
Update README.md 2023-04-28 21:17:18 -05:00
0d5771fd57
one models for testing 2023-04-28 21:13:46 -05:00
ab0cd3e664
add tests 2023-04-28 21:11:48 -05:00
5ae6b86712
wip 2023-04-28 20:00:40 -05:00
fec28df75f
ruff fix 2023-04-28 14:12:29 -05:00
562cf35587
remove unused console 2023-04-28 14:12:16 -05:00
9be9a1ff36
setup pyannotate 2023-04-28 10:24:21 -05:00
0a68509e41
linting 2023-04-28 10:24:21 -05:00
d5916db01d
ruff fix 2023-04-28 10:24:21 -05:00
22b7eaf593
clean up and refactor 2023-04-28 10:24:20 -05:00
22 changed files with 504 additions and 546 deletions

View file

@ -1,5 +1,7 @@
# pydantic-typer
pydantic-typer is a Python package that provides a decorator that can be used to write functions that accept a Pydantic model, but at runtime will allow them to pass in fields to create the models on the fly.
https://user-images.githubusercontent.com/22648375/235036031-a9dc6589-e350-4a18-9114-6568cb362f74.mp4
## Installation

6
examples/__init__.py Normal file
View file

@ -0,0 +1,6 @@
"""Example usage of expand_pydantic_args.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""

27
examples/person.py Normal file
View file

@ -0,0 +1,27 @@
"""Example usage of expand_pydantic_args with the Person model.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
from pyannotate_runtime import collect_types
from pydantic_typer import expand_pydantic_args
from tests.models import Person
@expand_pydantic_args()
def get_person(person: Person, thing: str = None) -> Person:
"""Mydocstring."""
from rich import print
print(str(thing))
print(person)
if __name__ == "__main__":
collect_types.init_types_collection()
with collect_types.collect():
person = get_person(name="John", age=1, r=1, g=1, b=1, a=1, length=1)
collect_types.dump_stats("type_info.json")

37
examples/person_cli.py Normal file
View file

@ -0,0 +1,37 @@
"""Example usage of expand_pydantic_args with the Person model as a typer cli.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
import typer
from pydantic_typer import expand_pydantic_args
from tests.models import Person
app = typer.Typer(
name="pydantic_typer",
help="a demo app",
)
@app.callback()
def main() -> None:
"""Set up typer."""
return
@app.command()
@expand_pydantic_args(typer=True)
def get_person(person: Person, thing: str, another: str = "this") -> Person:
"""Get a person's information."""
from rich import print
print(thing)
print(another)
print(person)
if __name__ == "__main__":
typer.run(get_person)

View file

@ -1,4 +1,9 @@
# SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
#
# SPDX-License-Identifier: MIT
"""About pydantic_typer.
Sets metadata about pydantic_typer.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
__version__ = "0.0.0.dev1"

View file

@ -1,44 +1,26 @@
# SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
##
# SPDX-License-Identifier: MIT
"""pydantic_typer.
from functools import wraps
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
import inspect
from typing import Callable, Optional
from functools import wraps
from typing import Any, Callable, Dict, Optional
from pydantic import BaseModel, Field
import typer
from pydantic.fields import ModelField
__all__ = ["typer"]
class Alpha(BaseModel):
a: int
class Color(BaseModel):
r: int
g: int
b: int
alpha: Alpha
class Hair(BaseModel):
color: Color
length: int
class Person(BaseModel):
name: str
other_name: Optional[str] = None
age: int
email: Optional[str]
pet: str = "dog"
address: str = Field("123 Main St", description="Where the person calls home.")
hair: Hair
def make_annotation(name, field, names, typer=False):
def _make_annotation(
name: str,
field: ModelField,
names: Dict[str, str],
*,
typer: bool = False,
) -> str:
panel_name = names.get(name)
next_name = panel_name
while next_name is not None:
@ -66,17 +48,24 @@ def make_annotation(name, field, names, typer=False):
default = f' = typer.Option("{field.default}", help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")'
else:
default = f'="{field.default}"'
elif typer:
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
else:
if typer:
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
else:
default = ""
default = ""
if typer:
return f"{name}: {annotation}{default}"
return f"{name}: {annotation}{default}"
def make_signature(func, wrapper, typer=False, more_args={}):
def _make_signature(
func: Callable,
wrapper: Callable,
*,
typer: bool = False,
more_args: Optional[Dict] = None,
):
if more_args is None:
more_args = {}
sig = inspect.signature(func)
names = {}
for name, param in sig.parameters.items():
@ -88,7 +77,7 @@ def make_signature(func, wrapper, typer=False, more_args={}):
more_args[name] = param
while any(
[hasattr(param.annotation, "__fields__") for name, param in more_args.items()]
hasattr(param.annotation, "__fields__") for name, param in more_args.items()
):
keys_to_remove = []
for name, param in more_args.items():
@ -99,7 +88,6 @@ def make_signature(func, wrapper, typer=False, more_args={}):
if name not in param.annotation.__fields__.keys():
keys_to_remove.append(name)
more_args = {**more_args, **param.annotation.__fields__}
# names[name] = param.annotation.__name__
for field in param.annotation.__fields__:
names[field] = param.annotation.__name__
@ -109,9 +97,8 @@ def make_signature(func, wrapper, typer=False, more_args={}):
wrapper.__doc__ = (
func.__doc__ or ""
) + f"\nalso accepts {more_args.keys()} in place of person model"
# fields = Person.__fields__
raw_args = [
make_annotation(
_make_annotation(
name,
field,
names=names,
@ -130,34 +117,27 @@ def {func.__name__}({aargs}{', ' if aargs else ''}{kwargs}):
'''{func.__doc__}'''
return wrapper({call_args})
"""
# new_func_sig = f"""{func.__name__}({args}{', ' if args else ''}{kwargs})"""
# import typing
# from makefun import create_function
# __all__ = ["typing"]
# new_func = create_function(new_func_sig, func, inject_as_first_arg=True)
# signature = inspect.Signature()
# signature.add("a", inspect.Parameter(default=1))
# signature.add("b", inspect.Parameter(default=2))
# signature.return_annotation = int
# func.signature = signature
# signature = inspect.Signature(
# a=Parameter(default=1), b=Parameter(default=2), return_annotation=int
# )
exec(new_func_str, locals(), globals())
exec(new_func_str, locals(), globals()) # noqa: S102
new_func = globals()[func.__name__]
sig = inspect.signature(new_func)
for name, param in sig.parameters.items():
for param in sig.parameters.values():
if hasattr(param.annotation, "__fields__"):
return make_signature(new_func, wrapper, typer=typer, more_args=more_args)
return _make_signature(new_func, wrapper, typer=typer, more_args=more_args)
return new_func
def _expand_param(param, kwargs, models=None):
def _expand_param(
param: inspect.Parameter,
kwargs: Dict[str, Any],
models: Optional[Dict[str, str]] = None,
) -> Any:
"""Further expands params with a Pydantic annotation, given a param.
Recursively creates an instance of any param.annotation that has __fields__
using the expanded kwargs.y:
using the expanded kwargs.
"""
models = {}
for field_name, field in param.annotation.__fields__.items():
if hasattr(field.annotation, "__fields__"):
@ -165,7 +145,12 @@ def _expand_param(param, kwargs, models=None):
return param.annotation(**kwargs, **models)
def _expand_kwargs(func, kwargs):
def _expand_kwargs(func: Callable, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Expand kwargs with Pydantic annotations given a function.
Inspects the arguments of the func and expands any of the kwargs with a
Pydantic annotation, to add its fields to the kwargs.
"""
sig = inspect.signature(func)
updated_kwargs = {}
for name, value in kwargs.items():
@ -182,34 +167,21 @@ def _expand_kwargs(func, kwargs):
elif hasattr(param.annotation, "__fields__"):
updated_kwargs[name] = _expand_param(param, kwargs)
# its something else so pass it
# else:
# updated_kwargs[name] = kwargs[name]
return updated_kwargs
def expand_pydantic_args(typer: bool = False) -> Callable:
def decorator(func: Callable) -> Callable[..., any]:
def expand_pydantic_args(*, typer: bool = False) -> Callable:
"""Expand Pydantic keyword arguments.
Decorator function to expand arguments of pydantic models to accept the
individual fields of Models.
"""
def decorator(func: Callable) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args, **kwargs):
return func(**_expand_kwargs(func, kwargs))
return func(*args, **_expand_kwargs(func, kwargs))
return make_signature(func, wrapper, typer=typer)
return _make_signature(func, wrapper, typer=typer)
return decorator
def get_person_vanilla(person: Person) -> Person:
from rich import print
print(person)
return person
@expand_pydantic_args()
def get_person(person: Person, thing: str = None) -> Person:
"""mydocstring"""
from rich import print
print(str(thing))
print(person)

View file

@ -1,9 +0,0 @@
# SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
#
# SPDX-License-Identifier: MIT
import sys
if __name__ == '__main__':
from .cli import {{python_package}}
sys.exit({{python_package}}())

View file

@ -1,62 +0,0 @@
import typer
from pydantic_typer import Person, expand_pydantic_args
from pydantic_typer.cli.common import verbose_callback
from pydantic_typer.cli.config import config_app
from pydantic_typer.cli.tui import tui_app
app = typer.Typer(
name="pydantic_typer",
help="A rich terminal report for coveragepy.",
)
app.add_typer(config_app)
app.add_typer(tui_app)
def version_callback(value: bool) -> None:
"""Callback function to print the version of the pydantic-typer package.
Args:
value (bool): Boolean value to determine if the version should be printed.
Raises:
typer.Exit: If the value is True, the version will be printed and the program will exit.
Example:
version_callback(True)
"""
if value:
from pydantic_typer.__about__ import __version__
typer.echo(f"{__version__}")
raise typer.Exit()
@app.callback()
def main(
version: bool = typer.Option(
None,
"--version",
callback=version_callback,
is_eager=True,
),
verbose: bool = typer.Option(
False,
callback=verbose_callback,
help="show the log messages",
),
) -> None:
return
@app.command()
@expand_pydantic_args
def get_person(person: Person) -> Person:
"""mydocstring"""
from rich import print
print(person)
if __name__ == "__main__":
typer.run(main)

View file

@ -1,6 +0,0 @@
from pydantic_typer.console import console
def verbose_callback(value: bool) -> None:
if value:
console.quiet = False

View file

@ -1,29 +0,0 @@
from rich.console import Console
import typer
from pydantic_typer.cli.common import verbose_callback
from pydantic_typer.config import config as configuration
config_app = typer.Typer()
@config_app.callback()
def config(
verbose: bool = typer.Option(
False,
callback=verbose_callback,
help="show the log messages",
),
):
"configuration cli"
@config_app.command()
def show(
verbose: bool = typer.Option(
False,
callback=verbose_callback,
help="show the log messages",
),
):
Console().print(configuration)

View file

@ -1,18 +0,0 @@
import typer
from pydantic_typer.cli.common import verbose_callback
from pydantic_typer.tui.app import run_app
tui_app = typer.Typer()
@tui_app.callback(invoke_without_command=True)
def i(
verbose: bool = typer.Option(
False,
callback=verbose_callback,
help="show the log messages",
),
):
"interactive tui"
run_app()

View file

@ -1,3 +0,0 @@
from pydantic_typer.standard_config import load
config = load("pydantic_typer")

View file

@ -1,4 +0,0 @@
from rich.console import Console
console = Console()
console.quiet = True

View file

@ -1,239 +0,0 @@
"""Standard Config.
A module to load tooling config from a users project space.
Inspired from frustrations that some tools have a tool.ini, .tool.ini,
setup.cfg, or pyproject.toml. Some allow for global configs, some don't. Some
properly follow the users home directory, others end up in a weird temp
directory. Windows home directory is only more confusing. Some will even
respect the users `$XDG_HOME` directory.
This file is for any project that can be configured in plain text such as `ini`
or `toml` and not requiring a .py file. Just name your tool and let users put
config where it makes sense to them, no need to figure out resolution order.
## Usage:
``` python
from standard_config import load
# Retrieve any overrides from the user
overrides = {'setting': True}
config = load('my_tool', overrides)
```
## Resolution Order
* First global file with a tool key
* First local file with a tool key
* Environment variables prefixed with `TOOL`
* Overrides
### Tool Specific Ini files
Ini file formats must include a `<tool>` key.
``` ini
[my_tool]
setting = True
```
### pyproject.toml
Toml files must include a `tool.<tool>` key
``` toml
[tool.my_tool]
setting = True
```
### setup.cfg
setup.cfg files must include a `tool:<tool>` key
``` ini
[tool:my_tool]
setting = True
```
### global files to consider
* <home>/tool.ini
* <home>/.tool
* <home>/.tool.ini
* <home>/.config/tool.ini
* <home>/.config/.tool
* <home>/.config/.tool.ini
### local files to consider
* <project_home>/tool.ini
* <project_home>/.tool
* <project_home>/.tool.ini
* <project_home>/pyproject.toml
* <project_home>/setup.cfg
"""
import os
from pathlib import Path
from typing import Dict, List, Union
import anyconfig
# path_spec_type = List[Dict[str, Union[Path, str, List[str\}\}\}\}
path_spec_type = List
def _get_global_path_specs(tool: str) -> path_spec_type:
"""
Generate a list of standard pathspecs for global config files.
Args:
tool (str): name of the tool to configure
"""
try:
home = Path(os.environ["XDG_HOME"])
except KeyError:
home = Path.home()
return [
{"path_specs": home / f"{tool}.ini", "ac_parser": "ini", "keys": [tool]},
{"path_specs": home / f".{tool}", "ac_parser": "ini", "keys": [tool]},
{"path_specs": home / f".{tool}.ini", "ac_parser": "ini", "keys": [tool]},
{
"path_specs": home / ".config" / f"{tool}.ini",
"ac_parser": "ini",
"keys": [tool],
},
{
"path_specs": home / ".config" / f".{tool}",
"ac_parser": "ini",
"keys": [tool],
},
{
"path_specs": home / ".config" / f".{tool}.ini",
"ac_parser": "ini",
"keys": [tool],
},
]
def _get_local_path_specs(tool: str, project_home: Union[str, Path]) -> path_spec_type:
"""
Generate a list of standard pathspecs for local, project directory config files.
Args:
tool (str): name of the tool to configure
"""
return [
{
"path_specs": Path(project_home) / f"{tool}.ini",
"ac_parser": "ini",
"keys": [tool],
},
{
"path_specs": Path(project_home) / f".{tool}",
"ac_parser": "ini",
"keys": [tool],
},
{
"path_specs": Path(project_home) / f".{tool}.ini",
"ac_parser": "ini",
"keys": [tool],
},
{
"path_specs": Path(project_home) / f"{tool}.yml",
"ac_parser": "yaml",
"keys": [tool],
},
{
"path_specs": Path(project_home) / f".{tool}.yml",
"ac_parser": "yaml",
"keys": [tool],
},
{
"path_specs": Path(project_home) / f"{tool}.toml",
"ac_parser": "toml",
"keys": [tool],
},
{
"path_specs": Path(project_home) / f".{tool}.toml",
"ac_parser": "toml",
"keys": [tool],
},
{
"path_specs": Path(project_home) / "pyproject.toml",
"ac_parser": "toml",
"keys": ["tool", tool],
},
{
"path_specs": Path(project_home) / "setup.cfg",
"ac_parser": "ini",
"keys": [f"tool.{tool}"],
},
]
def _get_attrs(attrs: list, config: Dict) -> Dict:
"""Get nested config data from a list of keys.
specifically written for pyproject.toml which needs to get `tool` then `<tool>`
"""
for attr in attrs:
config = config[attr]
return config
def _load_files(config_path_specs: path_spec_type) -> Dict:
"""Use anyconfig to load config files stopping at the first one that exists.
config_path_specs (list): a list of pathspecs and keys to load
"""
for file in config_path_specs:
if file["path_specs"].exists():
config = anyconfig.load(**file)
else:
# ignore missing files
continue
try:
return _get_attrs(file["keys"], config)
except KeyError:
# ignore incorrect keys
continue
return {}
def _load_env(tool: str) -> Dict:
"""Load config from environment variables.
Args:
tool (str): name of the tool to configure
"""
vars = [var for var in os.environ.keys() if var.startswith(tool.upper())]
return {
var.lower().strip(tool.lower()).strip("_").strip("-"): os.environ[var]
for var in vars
}
def load(tool: str, project_home: Union[Path, str] = ".", overrides: Dict = {}) -> Dict:
"""Load tool config from standard config files.
Resolution Order
* First global file with a tool key
* First local file with a tool key
* Environment variables prefixed with `TOOL`
* Overrides
Args:
tool (str): name of the tool to configure
"""
global_config = _load_files(_get_global_path_specs(tool))
local_config = _load_files(_get_local_path_specs(tool, project_home))
env_config = _load_env(tool)
return {**global_config, **local_config, **env_config, **overrides}

View file

@ -1,18 +0,0 @@
Screen {
align: center middle;
layers: main footer;
}
Sidebar {
height: 100vh;
width: auto;
min-width: 20;
background: $secondary-background-darken-2;
dock: left;
margin-right: 1;
layer: main;
}
Footer {
layer: footer;
}

View file

@ -1,62 +0,0 @@
from pathlib import Path
from textual.app import App, ComposeResult
from textual.containers import Container
from textual.css.query import NoMatches
from textual.widgets import Footer, Static
from pydantic_typer.config import config
config["tui"] = {}
config["tui"]["bindings"] = {}
class Sidebar(Static):
def compose(self) -> ComposeResult:
yield Container(
Static("sidebar"),
id="sidebar",
)
class Tui(App):
"""A Textual app to manage requests."""
CSS_PATH = Path("__file__").parent / "app.css"
BINDINGS = [tuple(b.values()) for b in config["tui"]["bindings"]]
def compose(self) -> ComposeResult:
"""Create child widgets for the app."""
yield Container(Static("hello world"))
yield Footer()
def action_toggle_dark(self) -> None:
"""An action to toggle dark mode."""
self.dark = not self.dark
def action_toggle_sidebar(self):
try:
self.query_one("PromptSidebar").remove()
except NoMatches:
self.mount(Sidebar())
def run_app():
import os
import sys
from textual.features import parse_features
dev = "--dev" in sys.argv
features = set(parse_features(os.environ.get("TEXTUAL", "")))
if dev:
features.add("debug")
features.add("devtools")
os.environ["TEXTUAL"] = ",".join(sorted(features))
app = Tui()
app.run()
if __name__ == "__main__":
run_app()

View file

@ -52,8 +52,9 @@ dependencies = [
"pytest",
"pytest-cov",
"pytest-mock",
"pytest-rich",
'polyfactory',
"ruff",
'pyannotate',
"black",
]
[tool.hatch.envs.default.scripts]
@ -89,9 +90,77 @@ exclude_lines = [
]
[tool.pytest.ini_options]
addopts = "-ra -q --rich"
asyncio_mode = "auto"
addopts = "-ra -q"
testpaths = ["tests"]
[tool.coverage_rich]
fail-under=80
[tool.ruff]
ignore = ["E501", "D211", "D212", "D213"]
target-version = "py37"
select = [
"F", # Pyflakes
"E", # Error
"W", # Warning
"C90", # mccabe
"I", # isort
"N", # pep8-naming
"D", # pydocstyle
"UP", # pyupgrade
"YTT", # flake8-2020
# "ANN", # flake8-annotations
"S", # flake8-bandit
"BLE", # flake8-blind-except
"FBT", # flake8-boolean-trap
"B", # flake8-bugbear
"A", # flake8-builtins
"COM", # flake8-commas
"C4", # flake8-comprehensions
"DTZ", # flake8-datetimez
"T10", # flake8-debugger
"DJ", # flake8-django
"EM", # flake8-errmsg
"EXE", # flake8-executable
"ISC", # flake8-implicit-str-concat
"ICN", # flake8-import-conventions
"G", # flake8-logging-format
"INP", # flake8-no-pep420
"PIE", # flake8-pie
"T20", # flake8-print
"PYI", # flake8-pyi
"PT", # flake8-pytest-style
"Q", # flake8-quotes
"RSE", # flake8-raise
"RET", # flake8-return
"SLF", # flake8-self
"SIM", # flake8-simplify
"TID", # flake8-tidy-imports
"TCH", # flake8-type-checking
"INT", # flake8-gettext
"ARG", # flake8-unused-arguments
"PTH", # flake8-use-pathlib
"ERA", # eradicate
"PD", # pandas-vet
"PGH", # pygrep-hooks
"PL", # Pylint
"PLC", # Convention
"PLE", # Error
"PLR", # Refactor
"PLW", # Warning
"TRY", # tryceratops
"NPY", # NumPy-specific rules
"RUF", # Ruff-specific rules
]
[tool.ruff.mccabe]
# Flag errors (`C901`) whenever the complexity level exceeds 5.
max-complexity = 13
[tool.ruff.pylint]
max-branches = 13
[tool.ruff.per-file-ignores]
'tests/**' = ["D100", "D101", "D102", "D103", "D104", "D105", "S101"]

View file

@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
#
# SPDX-License-Identifier: MIT
"""Tests.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""

123
tests/models.py Normal file
View file

@ -0,0 +1,123 @@
"""Models defines a set of classes for representing people and their hair.
Classes:
* `Alpha`: A class for representing an alpha value.
* `Color`: A class for representing a color.
* `Hair`: A class for representing hair.
* `Person`: A class for representing a person.
"""
from typing import Optional
from polyfactory.factories.pydantic_factory import ModelFactory
from pydantic import BaseModel, Field
class Alpha(BaseModel):
"""A class for representing an alpha value."""
a: int = Field(
...,
description="The alpha value.",
)
class Color(BaseModel):
"""A class for representing a color."""
r: int = Field(
...,
description="The red component of the color.",
)
g: int = Field(
...,
description="The green component of the color.",
)
b: int = Field(
...,
description="The blue component of the color.",
)
alpha: Alpha = Field(
...,
description="The alpha value of the color.",
)
class Hair(BaseModel):
"""A class for representing hair."""
color: Color = Field(
...,
description="The color of the hair.",
)
length: int = Field(
...,
description="The length of the hair.",
)
class Person(BaseModel):
"""A class for representing a person."""
name: str = Field(
...,
description="The name of the person.",
)
alias: Optional[str] = Field(
None,
description="An optional other name for the person.",
)
age: int = Field(
...,
description="The age of the person.",
)
email: Optional[str] = Field(
None,
description="An optional email address for the person.",
)
pet: str = Field(
"dog",
description="The person's pet.",
)
address: str = Field(
"123 Main St",
description="Where the person calls home.",
)
hair: Hair = Field(
...,
description="The person's hair.",
)
class AlphaFactory(ModelFactory[Alpha]):
"""A class for generating an alpha value."""
__model__ = Alpha
class ColorFactory(ModelFactory[Color]):
"""A class for generating a color."""
__model__ = Color
class HairFactory(ModelFactory[Hair]):
"""A class for generating hair."""
__model__ = Hair
class PersonFactory(ModelFactory[Person]):
"""A class for generating a person."""
__model__ = Person

164
tests/test_person.py Normal file
View file

@ -0,0 +1,164 @@
"""Example usage of expand_pydantic_args with the Person model.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
import inspect
import pytest
from pydantic_typer import expand_pydantic_args
from . import models
# this one is broken
# def test_no_pydantic() -> None:
# @expand_pydantic_args()
# def get_person(alpha) -> None:
# """Mydocstring."""
def test_single_signature() -> None:
@expand_pydantic_args()
def get_person(alpha: models.Alpha) -> None:
"""Mydocstring."""
return alpha
sig = inspect.signature(get_person)
params = sig.parameters
assert "a" in params
assert "alpha" not in params
@pytest.mark.parametrize(
"alpha",
models.AlphaFactory().batch(size=5),
)
def test_single_instance(alpha: models.Alpha) -> None:
@expand_pydantic_args()
def get_person(alpha: models.Alpha) -> None:
"""Mydocstring."""
return alpha
assert get_person(**alpha.dict()) == alpha
# this should maybe work
# assert get_person(models.Alpha(a=1)) == models.Alpha(a=1)
def test_one_nest_signature() -> None:
@expand_pydantic_args()
def get_person(color: models.Color) -> None:
"""Mydocstring."""
return color
sig = inspect.signature(get_person)
params = sig.parameters
assert "r" in params
assert "g" in params
assert "b" in params
assert "a" in params
assert "color" not in params
assert "alpha" not in params
@pytest.mark.parametrize(
"color",
models.ColorFactory().batch(size=5),
)
def test_one_nest_instance(color: models.Color) -> None:
@expand_pydantic_args()
def get_person(color: models.Color) -> None:
"""Mydocstring."""
return color
assert get_person(**color.dict(exclude={"alpha"}), **color.alpha.dict()) == color
def test_two_nest_signature() -> None:
@expand_pydantic_args()
def get_person(hair: models.Hair) -> None:
"""Mydocstring."""
return hair
sig = inspect.signature(get_person)
params = sig.parameters
assert "length" in params
assert "r" in params
assert "g" in params
assert "b" in params
assert "a" in params
assert "hair" not in params
assert "color" not in params
assert "alpha" not in params
@pytest.mark.parametrize(
"hair",
models.HairFactory().batch(size=5),
)
def test_two_nest_instance(hair: models.Hair) -> None:
@expand_pydantic_args()
def get_person(hair: models.Hair) -> None:
"""Mydocstring."""
return hair
assert (
get_person(
**hair.dict(exclude={"color"}),
**hair.color.dict(exclude={"alpha"}),
**hair.color.alpha.dict()
) ==
hair
)
def test_three_nest_signature() -> None:
@expand_pydantic_args()
def get_person(person: models.Person) -> None:
"""Mydocstring."""
return person
sig = inspect.signature(get_person)
params = sig.parameters
assert "name" in params
assert "alias" in params
assert "age" in params
assert "email" in params
assert "pet" in params
assert "address" in params
assert "length" in params
assert "r" in params
assert "g" in params
assert "b" in params
assert "a" in params
assert "person" not in params
assert "hair" not in params
assert "color" not in params
assert "alpha" not in params
@pytest.mark.parametrize(
"person",
models.PersonFactory().batch(size=5),
)
def test_three_nest_instance(person: models.Person) -> None:
@expand_pydantic_args()
def get_person(person: models.Person) -> None:
"""Mydocstring."""
return person
assert (
get_person(
**person.dict(exclude={"hair"}),
**person.hair.dict(exclude={"color"}),
**person.hair.color.dict(exclude={"alpha"}),
**person.hair.color.alpha.dict()
) ==
person
)