allow other non pydantic arguments

This commit is contained in:
Waylon Walker 2023-04-27 21:38:46 -05:00
parent d120815f5a
commit ff522f24a6
No known key found for this signature in database
GPG key ID: 66E2BF2B4190EFE4

View file

@ -38,8 +38,8 @@ class Person(BaseModel):
hair: Hair
def make_annotation(name, field, names):
panel_name = names[name]
def make_annotation(name, field, names, typer=False):
panel_name = names.get(name)
next_name = panel_name
while next_name is not None:
next_name = names.get(next_name)
@ -52,22 +52,27 @@ def make_annotation(name, field, names):
else str(field.annotation)
)
if field.default is None and not field.required:
default = "None"
default = f' = typer.Option(None, help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")'
if field.default is None and not getattr(field, "required", False):
if typer:
default = f' = typer.Option(None, help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")'
else:
default = "=None"
elif field.default is not None:
default = f'"{field.default}"'
default = f' = typer.Option("{field.default}", help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")'
if typer:
default = f' = typer.Option("{field.default}", help="{field.field_info.description or ""}", rich_help_panel="{panel_name}")'
else:
default = f'="{field.default}"'
else:
default = ""
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
# if not typer
# return f"{name}: {annotation}{default}"
if typer:
default = f' = typer.Option(..., help="{field.field_info.description or ""}", rich_help_panel="{panel_name}", prompt=True)'
else:
default = ""
if typer:
return f"{name}: {annotation}{default}"
return f"{name}: {annotation}{default}"
def make_signature(func, wrapper, more_args={}):
def make_signature(func, wrapper, typer=False, more_args={}):
sig = inspect.signature(func)
names = {}
for name, param in sig.parameters.items():
@ -75,6 +80,8 @@ def make_signature(func, wrapper, more_args={}):
more_args = {**more_args, **param.annotation.__fields__}
for field in param.annotation.__fields__:
names[field] = param.annotation.__name__
else:
more_args[name] = param
while any(
[hasattr(param.annotation, "__fields__") for name, param in more_args.items()]
@ -100,7 +107,13 @@ def make_signature(func, wrapper, more_args={}):
) + f"\nalso accepts {more_args.keys()} in place of person model"
# fields = Person.__fields__
raw_args = [
make_annotation(name, field, names) for name, field in more_args.items()
make_annotation(
name,
field,
names=names,
typer=typer,
)
for name, field in more_args.items()
]
aargs = ", ".join([arg for arg in raw_args if "=" not in arg])
kwargs = ", ".join([arg for arg in raw_args if "=" in arg])
@ -136,7 +149,7 @@ def {func.__name__}({aargs}{', ' if aargs else ''}{kwargs}):
sig = inspect.signature(new_func)
for name, param in sig.parameters.items():
if hasattr(param.annotation, "__fields__"):
return make_signature(new_func, wrapper, more_args=more_args)
return make_signature(new_func, wrapper, typer=typer, more_args=more_args)
return new_func
@ -165,26 +178,20 @@ def _expand_kwargs(func, kwargs):
elif hasattr(param.annotation, "__fields__"):
updated_kwargs[name] = _expand_param(param, kwargs)
# its something else so pass it
else:
updated_kwargs[name] = kwargs[name]
# else:
# updated_kwargs[name] = kwargs[name]
return updated_kwargs
def expand_pydantic_args(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
def expand_pydantic_args(typer: bool = False) -> Callable:
def decorator(func: Callable) -> Callable[..., any]:
@wraps(func)
def wrapper(*args, **kwargs):
return func(**_expand_kwargs(func, kwargs))
return func(**_expand_kwargs(func, kwargs))
return make_signature(func, wrapper, typer=typer)
return make_signature(func, wrapper)
# @expand_pydantic_args
# def get_person(person: Person) -> Person:
# """mydocstring"""
# from rich import print
# print(person)
return decorator
def get_person_vanilla(person: Person) -> Person:
@ -194,13 +201,11 @@ def get_person_vanilla(person: Person) -> Person:
return person
@expand_pydantic_args
@expand_pydantic_args()
def get_person(person: Person, thing: str = None) -> Person:
"""mydocstring"""
from rich import print
print(str(thing))
print(person)
# return person
# get_person(name="me", age=1)