96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
from typing import List, Optional
|
|
import textwrap
|
|
import inspect
|
|
|
|
from pydantic import BaseModel
|
|
|
|
def _optional_import_(
|
|
module: str,
|
|
name: str = None,
|
|
group: str = None,
|
|
package="learn_sql_model",
|
|
):
|
|
"""
|
|
lazily throws import errors only then the optional import is used, and
|
|
includes a group install command for the user to install all dependencies
|
|
for the requested feature.
|
|
"""
|
|
import importlib
|
|
|
|
try:
|
|
module = importlib.import_module(module)
|
|
return module if name is None else getattr(module, name)
|
|
except ImportError as e:
|
|
msg = textwrap.dedent(
|
|
f"""
|
|
"pip install '{package}[{group}]'" package to make use of this feature
|
|
Alternatively "pip install '{package}[all]'" package to install all optional dependencies
|
|
"""
|
|
)
|
|
import_error = e
|
|
|
|
class _failed_import:
|
|
"""
|
|
Lazily throw an import error. Errors should be thrown whether the
|
|
user tries to call the module, get an attubute from the module, or
|
|
getitem from the module.
|
|
|
|
"""
|
|
|
|
def _failed_import(self, *args):
|
|
raise ImportError(msg) from import_error
|
|
|
|
def __call__(self, *args):
|
|
"""
|
|
Throw error if the user tries to call the module i.e
|
|
_optional_import_('dummy')()
|
|
"""
|
|
self._failed_import(*args)
|
|
|
|
def __getattr__(self, name):
|
|
"""
|
|
Throw error if the user tries to get an attribute from the
|
|
module i.e _optional_import_('dummy').dummy.
|
|
"""
|
|
if name == "_failed_import":
|
|
return object.__getattribute__(self, name)
|
|
self._failed_import()
|
|
|
|
def __getitem__(self, name):
|
|
"""
|
|
Throw error if the user tries to get an item from the module
|
|
i.e _optional_import_('dummy')['dummy']
|
|
"""
|
|
self._failed_import()
|
|
|
|
return _failed_import()
|
|
|
|
|
|
# def optional(fields: Optional[List[str]]=None, required: Optional[List[str]]=None):
|
|
# def decorator(cls):
|
|
# def wrapper(*args, **kwargs):
|
|
# if fields is None:
|
|
# fields = cls.__fields__
|
|
# if required is None:
|
|
# required = []
|
|
#
|
|
# for field in fields:
|
|
# if field not in required:
|
|
# cls.__fields__[field].required = False
|
|
# return _cls
|
|
# return wrapper
|
|
# return decorator
|
|
#
|
|
#
|
|
def optional(*fields):
|
|
def dec(_cls):
|
|
for field in fields:
|
|
_cls.__fields__[field].required = False
|
|
return _cls
|
|
|
|
if fields and inspect.isclass(fields[0]) and issubclass(fields[0], BaseModel):
|
|
cls = fields[0]
|
|
fields = cls.__fields__
|
|
return dec(cls)
|
|
return dec
|
|
|