This commit is contained in:
Waylon Walker 2023-04-28 20:00:40 -05:00
parent fec28df75f
commit 5ae6b86712
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4
10 changed files with 312 additions and 39 deletions

View file

@ -1,4 +1,9 @@
# SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
#
# SPDX-License-Identifier: MIT
"""About pydantic_typer.
Sets metadata about pydantic_typer.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
__version__ = "0.0.0.dev1"

View file

@ -1,14 +1,26 @@
"""pydantic_typer."""
"""pydantic_typer.
SPDX-FileCopyrightText: 2023-present Waylon S. Walker <waylon@waylonwalker.com>
SPDX-License-Identifier: MIT
"""
import inspect
from functools import wraps
from typing import Any, Callable
from typing import Any, Callable, Dict, Optional
import typer
from pydantic.fields import ModelField
__all__ = ["typer"]
def _make_annotation(name, field, names, typer=False):
def _make_annotation(
name: str,
field: ModelField,
names: Dict[str, str],
*,
typer: bool = False,
) -> str:
panel_name = names.get(name)
next_name = panel_name
while next_name is not None:
@ -36,17 +48,24 @@ def _make_annotation(name, field, names, typer=False):
default = f' = typer.Option("{field.default}", help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")'
else:
default = f'="{field.default}"'
elif typer:
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
else:
if typer:
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
else:
default = ""
default = ""
if typer:
return f"{name}: {annotation}{default}"
return f"{name}: {annotation}{default}"
def _make_signature(func, wrapper, typer=False, more_args={}):
def _make_signature(
func: Callable,
wrapper: Callable,
*,
typer: bool = False,
more_args: Optional[Dict] = None,
):
if more_args is None:
more_args = {}
sig = inspect.signature(func)
names = {}
for name, param in sig.parameters.items():
@ -98,17 +117,27 @@ def {func.__name__}({aargs}{', ' if aargs else ''}{kwargs}):
'''{func.__doc__}'''
return wrapper({call_args})
"""
exec(new_func_str, locals(), globals())
exec(new_func_str, locals(), globals()) # noqa: S102
new_func = globals()[func.__name__]
sig = inspect.signature(new_func)
for name, param in sig.parameters.items():
for param in sig.parameters.values():
if hasattr(param.annotation, "__fields__"):
return _make_signature(new_func, wrapper, typer=typer, more_args=more_args)
return new_func
def _expand_param(param, kwargs, models=None):
def _expand_param(
param: inspect.Parameter,
kwargs: Dict[str, Any],
models: Optional[Dict[str, str]] = None,
) -> Any:
"""Further expands params with a Pydantic annotation, given a param.
Recursively creates an instance of any param.annotation that has __fields__
using the expanded kwargs.y:
using the expanded kwargs.
"""
models = {}
for field_name, field in param.annotation.__fields__.items():
if hasattr(field.annotation, "__fields__"):
@ -116,7 +145,12 @@ def _expand_param(param, kwargs, models=None):
return param.annotation(**kwargs, **models)
def _expand_kwargs(func, kwargs):
def _expand_kwargs(func: Callable, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Expand kwargs with Pydantic annotations given a function.
Inspects the arguments of the func and expands any of the kwargs with a
Pydantic annotation, to add its fields to the kwargs.
"""
sig = inspect.signature(func)
updated_kwargs = {}
for name, value in kwargs.items():
@ -136,11 +170,17 @@ def _expand_kwargs(func, kwargs):
return updated_kwargs
def expand_pydantic_args(typer: bool = False) -> Callable:
def expand_pydantic_args(*, typer: bool = False) -> Callable:
"""Expand Pydantic keyword arguments.
Decorator function to expand arguments of pydantic models to accept the
individual fields of Models.
"""
def decorator(func: Callable) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args, **kwargs):
return func(**_expand_kwargs(func, kwargs))
return func(*args, **_expand_kwargs(func, kwargs))
return _make_signature(func, wrapper, typer=typer)