import sys
import pkgutil
import inspect
import logging
from fnmatch import fnmatch
from types import FunctionType, ModuleType
from typing import Generator, Optional, List, Dict, Tuple, Union
if sys.version_info < (3, 9):
from importlib_metadata import entry_points as _entry_points
def iter_entry_points(group: str):
return _entry_points(group=group)
elif sys.version_info < (3, 10):
from importlib.metadata import entry_points as _entry_points
def iter_entry_points(group: str):
return _entry_points().get(group, [])
else:
from importlib.metadata import entry_points as _entry_points
[docs]
def iter_entry_points(group: str):
return _entry_points(group=group)
from ewoksutils.import_utils import qualname
from ewoksutils.import_utils import import_module
from .task import Task
TaskDict = Dict[str, Union[str, List[str]]]
logger = logging.getLogger(__name__)
[docs]
def discover_tasks_from_modules(
*module_names: str,
task_type="class",
reload: bool = False,
raise_import_failure: bool = True,
) -> List[TaskDict]:
return list(
iter_discover_tasks_from_modules(
*module_names,
task_type=task_type,
reload=reload,
raise_import_failure=raise_import_failure,
)
)
[docs]
def iter_discover_tasks_from_modules(
*module_names: str,
task_type="class",
reload: bool = False,
raise_import_failure: bool = True,
) -> Generator[TaskDict, None, None]:
if "" not in sys.path:
# This happens when the python process was launched
# through a python console script
sys.path.append("")
if task_type == "method":
yield from _iter_method_tasks(
*module_names, reload=reload, raise_import_failure=raise_import_failure
)
elif task_type == "ppfmethod":
yield from _iter_ppfmethod_tasks(
*module_names, reload=reload, raise_import_failure=raise_import_failure
)
elif task_type == "class":
for module_name in module_names:
_safe_import_module(
module_name, reload=reload, raise_import_failure=raise_import_failure
)
yield from _iter_registered_tasks(*module_names)
else:
raise ValueError("Class type does not support discovery")
def _iter_registered_tasks(*filter_modules: str) -> Generator[TaskDict, None, None]:
"""Yields all task classes registered in the current process."""
for cls in Task.get_subclasses():
module = cls.__module__
if filter_modules and not any(
module.startswith(prefix) for prefix in filter_modules
):
continue
task_identifier = cls.class_registry_name()
category = task_identifier.split(".")[0]
yield {
"task_type": "class",
"task_identifier": task_identifier,
"required_input_names": list(cls.required_input_names()),
"optional_input_names": list(cls.optional_input_names()),
"output_names": list(cls.output_names()),
"category": category,
"description": cls.__doc__,
}
def _iter_method_tasks(
*module_names: str,
reload: bool = False,
raise_import_failure: bool = False,
) -> Generator[TaskDict, None, None]:
"""Yields all task methods from the provided module_names. The module_names will be will
imported for discovery.
"""
for module_name in module_names:
mod = _safe_import_module(
module_name, reload=reload, raise_import_failure=raise_import_failure
)
if mod is None:
continue
for method_name, method_qn in inspect.getmembers(mod, inspect.isfunction):
if method_name.startswith("_"):
continue
yield {
"task_type": "method",
**_common_method_task_fields(method_name, method_qn, mod),
}
def _iter_ppfmethod_tasks(
*module_names: str,
reload: bool = False,
raise_import_failure: bool = False,
) -> Generator[TaskDict, None, None]:
"""Yields all task ppfmethods from the provided module_names. The module_names will be will
imported for discovery.
The difference with regular methods is that ppfmethods are expected to be called `run`. Other method names will be ignored.
"""
for module_name in module_names:
mod = _safe_import_module(
module_name, reload=reload, raise_import_failure=raise_import_failure
)
if mod is None:
continue
for method_name, method_qn in inspect.getmembers(mod, inspect.isfunction):
if method_name != "run":
continue
yield {
"task_type": "ppfmethod",
**_common_method_task_fields(method_name, method_qn, mod),
}
[docs]
def iter_discover_all_tasks(
reload: bool = False, raise_import_failure: bool = False
) -> Generator[TaskDict, None, None]:
visited = set()
for task_type in ("class", "ppfmethod", "method"):
group = "ewoks.tasks." + task_type
for entrypoint in iter_entry_points(group):
module_pattern = entrypoint.name
if module_pattern is visited:
continue
visited.add(module_pattern)
for module_name in _iter_modules_from_pattern(
module_pattern, reload=reload, raise_import_failure=raise_import_failure
):
yield from iter_discover_tasks_from_modules(
module_name,
task_type=task_type,
reload=reload,
raise_import_failure=raise_import_failure,
)
[docs]
def discover_all_tasks(
reload: bool = False, raise_import_failure: bool = False
) -> List[TaskDict]:
return list(
iter_discover_all_tasks(
reload=reload, raise_import_failure=raise_import_failure
)
)
def _iter_modules_from_pattern(
module_pattern: str, reload: bool = False, raise_import_failure: bool = False
) -> Generator[str, None, None]:
if "*" not in module_pattern:
yield module_pattern
return
ndots = module_pattern.count(".")
parts = module_pattern.split(".")
pkg = _safe_import_module(
parts[0], reload=reload, raise_import_failure=raise_import_failure
)
if pkg is None:
return
if raise_import_failure:
def onerror(module_name):
raise
else:
onerror = _onerror
for pkginfo in pkgutil.walk_packages(
pkg.__path__, pkg.__name__ + ".", onerror=onerror
):
if pkginfo.name.count(".") == ndots and fnmatch(pkginfo.name, module_pattern):
yield pkginfo.name
def _safe_import_module(
module_name: str, reload: bool = False, raise_import_failure: bool = False
) -> Optional[ModuleType]:
try:
return import_module(module_name, reload=reload)
except Exception as e:
if raise_import_failure:
raise
_onerror(module_name, exception=e)
def _onerror(module_name, exception: Optional[Exception] = None):
if exception is None:
exception = sys.exc_info()[1]
logger.error(f"Module '{module_name}' cannot be imported: {exception}")
def _method_arguments(method) -> Tuple[List[str], List[str]]:
sig = inspect.signature(method)
required_input_names = list()
optional_input_names = list()
for name, param in sig.parameters.items():
required = param.default is inspect._empty
if param.kind == param.VAR_POSITIONAL:
continue
if param.kind == param.VAR_KEYWORD:
continue
if required:
required_input_names.append(name)
else:
optional_input_names.append(name)
return required_input_names, optional_input_names
def _common_method_task_fields(
method_name: str, method_qn: FunctionType, mod: ModuleType
) -> Dict[str, Union[str, List[str]]]:
task_identifier = qualname(method_qn)
method = getattr(mod, method_name)
required_input_names, optional_input_names = _method_arguments(method)
return {
"task_identifier": qualname(method_qn),
"required_input_names": required_input_names,
"optional_input_names": optional_input_names,
"output_names": ["return_value"],
"category": task_identifier.split(".")[0],
"description": method.__doc__,
}