This commit is contained in:
Waylon Walker 2023-04-28 20:00:40 -05:00
parent fec28df75f
commit 5ae6b86712
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
10 changed files with 312 additions and 39 deletions

6
examples/__init__.py Normal file
View file

@ -0,0 +1,6 @@
"""Example usage of expand_pydantic_args.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""

View file

@ -1,29 +1,94 @@
"""Models defines a set of classes for representing people and their hair.
Classes:
* `Alpha`: A class for representing an alpha value.
* `Color`: A class for representing a color.
* `Hair`: A class for representing hair.
* `Person`: A class for representing a person.
"""
from typing import Optional
from pydantic import BaseModel, Field
class Alpha(BaseModel):
a: int
"""A class for representing an alpha value."""
a: int = Field(
...,
description="The alpha value.",
)
class Color(BaseModel):
r: int
g: int
b: int
alpha: Alpha
"""A class for representing a color."""
r: int = Field(
...,
description="The red component of the color.",
)
g: int = Field(
...,
description="The green component of the color.",
)
b: int = Field(
...,
description="The blue component of the color.",
)
alpha: Alpha = Field(
...,
description="The alpha value of the color.",
)
class Hair(BaseModel):
color: Color
length: int
"""A class for representing hair."""
color: Color = Field(
...,
description="The color of the hair.",
)
length: int = Field(
...,
description="The length of the hair.",
)
class Person(BaseModel):
name: str
other_name: Optional[str] = None
age: int
email: Optional[str]
pet: str = "dog"
address: str = Field("123 Main St", description="Where the person calls home.")
hair: Hair
"""A class for representing a person."""
name: str = Field(
...,
description="The name of the person.",
)
other_name: Optional[str] = Field(
None,
description="An optional other name for the person.",
)
age: int = Field(
...,
description="The age of the person.",
)
email: Optional[str] = Field(
None,
description="An optional email address for the person.",
)
pet: str = Field(
"dog",
description="The person's pet.",
)
address: str = Field(
"123 Main St",
description="Where the person calls home.",
)
hair: Hair = Field(
...,
description="The person's hair.",
)

View file

@ -1,3 +1,9 @@
"""Example usage of expand_pydantic_args with the Person model.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
from pyannotate_runtime import collect_types
from examples.models import Person

View file

@ -1,3 +1,9 @@
"""Example usage of expand_pydantic_args with the Person model as a typer cli.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
import typer
from examples.models import Person
@ -11,6 +17,7 @@ app = typer.Typer(
@app.callback()
def main() -> None:
"""Set up typer."""
return

View file

@ -1,4 +1,9 @@
# SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
#
# SPDX-License-Identifier: MIT
"""About pydantic_typer.
Sets metadata about pydantic_typer.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
__version__ = "0.0.0.dev1"

View file

@ -1,14 +1,26 @@
"""pydantic_typer."""
"""pydantic_typer.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
import inspect
from functools import wraps
from typing import Any, Callable
from typing import Any, Callable, Dict, Optional
import typer
from pydantic.fields import ModelField
__all__ = ["typer"]
def _make_annotation(name, field, names, typer=False):
def _make_annotation(
name: str,
field: ModelField,
names: Dict[str, str],
*,
typer: bool = False,
) -> str:
panel_name = names.get(name)
next_name = panel_name
while next_name is not None:
@ -36,17 +48,24 @@ def _make_annotation(name, field, names, typer=False):
default = f' = typer.Option("{field.default}", help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")'
else:
default = f'="{field.default}"'
elif typer:
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
else:
if typer:
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
else:
default = ""
default = ""
if typer:
return f"{name}: {annotation}{default}"
return f"{name}: {annotation}{default}"
def _make_signature(func, wrapper, typer=False, more_args={}):
def _make_signature(
func: Callable,
wrapper: Callable,
*,
typer: bool = False,
more_args: Optional[Dict] = None,
):
if more_args is None:
more_args = {}
sig = inspect.signature(func)
names = {}
for name, param in sig.parameters.items():
@ -98,17 +117,27 @@ def {func.__name__}({aargs}{', ' if aargs else ''}{kwargs}):
'''{func.__doc__}'''
return wrapper({call_args})
"""
exec(new_func_str, locals(), globals())
exec(new_func_str, locals(), globals()) # noqa: S102
new_func = globals()[func.__name__]
sig = inspect.signature(new_func)
for name, param in sig.parameters.items():
for param in sig.parameters.values():
if hasattr(param.annotation, "__fields__"):
return _make_signature(new_func, wrapper, typer=typer, more_args=more_args)
return new_func
def _expand_param(param, kwargs, models=None):
def _expand_param(
param: inspect.Parameter,
kwargs: Dict[str, Any],
models: Optional[Dict[str, str]] = None,
) -> Any:
"""Further expands params with a Pydantic annotation, given a param.
Recursively creates an instance of any param.annotation that has __fields__
using the expanded kwargs.y:
using the expanded kwargs.
"""
models = {}
for field_name, field in param.annotation.__fields__.items():
if hasattr(field.annotation, "__fields__"):
@ -116,7 +145,12 @@ def _expand_param(param, kwargs, models=None):
return param.annotation(**kwargs, **models)
def _expand_kwargs(func, kwargs):
def _expand_kwargs(func: Callable, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Expand kwargs with Pydantic annotations given a function.
Inspects the arguments of the func and expands any of the kwargs with a
Pydantic annotation, to add its fields to the kwargs.
"""
sig = inspect.signature(func)
updated_kwargs = {}
for name, value in kwargs.items():
@ -136,11 +170,17 @@ def _expand_kwargs(func, kwargs):
return updated_kwargs
def expand_pydantic_args(typer: bool = False) -> Callable:
def expand_pydantic_args(*, typer: bool = False) -> Callable:
"""Expand Pydantic keyword arguments.
Decorator function to expand arguments of pydantic models to accept the
individual fields of Models.
"""
def decorator(func: Callable) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args, **kwargs):
return func(**_expand_kwargs(func, kwargs))
return func(*args, **_expand_kwargs(func, kwargs))
return _make_signature(func, wrapper, typer=typer)

View file

@ -52,7 +52,6 @@ dependencies = [
"pytest",
"pytest-cov",
"pytest-mock",
"pytest-rich",
"ruff",
'pyannotate',
"black",
@ -90,15 +89,14 @@ exclude_lines = [
]
[tool.pytest.ini_options]
addopts = "-ra -q --rich"
asyncio_mode = "auto"
addopts = "-ra -q"
testpaths = ["tests"]
[tool.coverage_rich]
fail-under=80
[tool.ruff]
ignore = ["E501"]
ignore = ["E501", "D211", "D213"]
target-version = "py37"
@ -112,7 +110,7 @@ select = [
"D", # pydocstyle
"UP", # pyupgrade
"YTT", # flake8-2020
"ANN", # flake8-annotations
# "ANN", # flake8-annotations
"S", # flake8-bandit
"BLE", # flake8-blind-except
"FBT", # flake8-boolean-trap
@ -155,3 +153,9 @@ select = [
"NPY", # NumPy-specific rules
"RUF", # Ruff-specific rules
]
[tool.ruff.mccabe]
# Flag errors (`C901`) whenever the complexity level exceeds 5.
max-complexity = 13
[tool.ruff.pylint]
max-branches = 13

View file

@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
#
# SPDX-License-Identifier: MIT
"""Tests.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""

94
tests/models.py Normal file
View file

@ -0,0 +1,94 @@
"""Models defines a set of classes for representing people and their hair.
Classes:
* `Alpha`: A class for representing an alpha value.
* `Color`: A class for representing a color.
* `Hair`: A class for representing hair.
* `Person`: A class for representing a person.
"""
from typing import Optional
from pydantic import BaseModel, Field
class Alpha(BaseModel):
"""A class for representing an alpha value."""
a: int = Field(
...,
description="The alpha value.",
)
class Color(BaseModel):
"""A class for representing a color."""
r: int = Field(
...,
description="The red component of the color.",
)
g: int = Field(
...,
description="The green component of the color.",
)
b: int = Field(
...,
description="The blue component of the color.",
)
alpha: Alpha = Field(
...,
description="The alpha value of the color.",
)
class Hair(BaseModel):
"""A class for representing hair."""
color: Color = Field(
...,
description="The color of the hair.",
)
length: int = Field(
...,
description="The length of the hair.",
)
class Person(BaseModel):
"""A class for representing a person."""
name: str = Field(
...,
description="The name of the person.",
)
other_name: Optional[str] = Field(
None,
description="An optional other name for the person.",
)
age: int = Field(
...,
description="The age of the person.",
)
email: Optional[str] = Field(
None,
description="An optional email address for the person.",
)
pet: str = Field(
"dog",
description="The person's pet.",
)
address: str = Field(
"123 Main St",
description="Where the person calls home.",
)
hair: Hair = Field(
...,
description="The person's hair.",
)

43
tests/test_person.py Normal file
View file

@ -0,0 +1,43 @@
"""Example usage of expand_pydantic_args with the Person model.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
import inspect
from pydantic_typer import expand_pydantic_args
from . import models
# this one is broken
# def test_no_pydantic() -> None:
# @expand_pydantic_args()
# def get_person(alpha) -> None:
# """Mydocstring."""
# from rich import print
# print(str(thing))
# print(person)
# sig = inspect.signature(get_person)
# params = sig.parameters
# for field in models.Alpha.__fields__.values():
# assert field.name in params
def test_single_signature() -> None:
@expand_pydantic_args()
def get_person(alpha: models.Alpha) -> None:
"""Mydocstring."""
return alpha
sig = inspect.signature(get_person)
params = sig.parameters
for field in models.Alpha.__fields__.values():
assert field.name in params
assert get_person(a=1) == models.Alpha(a=1)