from typing import List, Dict
import pytest
from ewokscore.task import Task
from ewokscore.task import TaskInputError
from .utils import execute_task
from ..bindings import owwidgets
from ..bindings import ow_build_opts
[docs]
class TaskForTesting(
Task, input_names=["a", "b", "recorded_calls", "failures"], output_names=["sum"]
):
[docs]
def run(self):
self.inputs.recorded_calls.append("run")
exception = self.inputs.failures.get("run")
if exception:
raise exception
self.outputs.sum = self.inputs.a + self.inputs.b
[docs]
class PatchCalls:
def __init__(
self, *args, recorded_calls: List[str], failures: Dict[str, Exception], **kw
) -> None:
self.__recorded_calls = recorded_calls
self.__failures = failures
super().__init__(*args, **kw)
[docs]
def task_output_changed(self) -> None:
self.__recorded_calls.append("task_output_changed")
super().task_output_changed()
exception = self.__failures.get("task_output_changed")
if exception:
raise exception
[docs]
def trigger_downstream(self) -> None:
self.__recorded_calls.append("trigger_downstream")
super().trigger_downstream()
exception = self.__failures.get("trigger_downstream")
if exception:
raise exception
[docs]
def clear_downstream(self) -> None:
self.__recorded_calls.append("clear_downstream")
super().clear_downstream()
exception = self.__failures.get("clear_downstream")
if exception:
raise exception
_TASK_CLASSES = [TaskForTesting]
_WIDGET_CLASSES = [
NoThreadTestWidget,
OneThreadTestWidget,
OneThreadPerRunTestWidget,
TaskStackTestWidget,
]
class _TestException(Exception):
pass
[docs]
@pytest.mark.parametrize("task_cls", _TASK_CLASSES + _WIDGET_CLASSES)
def test_task_success(task_cls):
recorded_calls = list()
failures = dict()
result = _execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert result == {"sum": 3}
if issubclass(task_cls, Task):
assert recorded_calls == ["run"]
else:
assert recorded_calls == ["run", "trigger_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _TASK_CLASSES + _WIDGET_CLASSES)
def test_task_init_failure(task_cls):
recorded_calls = list()
failures = dict()
with pytest.raises(TaskInputError):
_execute_task(task_cls, recorded_calls, failures, a=1)
if issubclass(task_cls, Task):
assert recorded_calls == []
else:
assert recorded_calls == ["clear_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _TASK_CLASSES + _WIDGET_CLASSES)
def test_task_run_failure(task_cls):
recorded_calls = list()
failures = {"run": _TestException("error in task")}
if issubclass(task_cls, Task):
with pytest.raises(RuntimeError) as exc_info:
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
exception = exc_info.value.__cause__
assert isinstance(exception, _TestException)
assert str(exception) == "error in task"
assert recorded_calls == ["run"]
else:
with pytest.raises(_TestException, match="error in task") as exc_info:
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert recorded_calls == ["run", "clear_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _WIDGET_CLASSES)
def test_success_with_output_changed_failure(task_cls):
recorded_calls = list()
failures = {
"task_output_changed": _TestException("error in widget: output callback")
}
with pytest.raises(_TestException, match="error in widget: output callback"):
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert recorded_calls == ["run", "trigger_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _WIDGET_CLASSES)
def test_failure_with_output_changed_failure(task_cls):
recorded_calls = list()
failures = {
"run": _TestException("error in task"),
"task_output_changed": _TestException("error in widget: output callback"),
}
with pytest.raises(_TestException, match="error in task"):
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert recorded_calls == ["run", "clear_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _WIDGET_CLASSES)
def test_success_with_propagation_failure(task_cls):
recorded_calls = list()
failures = {
"trigger_downstream": _TestException("error in widget: success propagation")
}
with pytest.raises(_TestException, match="error in widget: success propagation"):
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert recorded_calls == ["run", "trigger_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _WIDGET_CLASSES)
def test_failure_with_propagation_failure(task_cls):
recorded_calls = list()
failures = {
"run": _TestException("error in task"),
"clear_downstream": _TestException("error in widget: failure propagation"),
}
with pytest.raises(_TestException, match="error in task"):
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert recorded_calls == ["run", "clear_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _WIDGET_CLASSES)
def test_success_with_propagation_and_output_changed_failure(task_cls):
recorded_calls = list()
failures = {
"trigger_downstream": _TestException("error in widget: success propagation"),
"task_output_changed": _TestException("error in widget: output callback"),
}
with pytest.raises(_TestException, match="error in widget: output callback"):
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert recorded_calls == ["run", "trigger_downstream", "task_output_changed"]
[docs]
@pytest.mark.parametrize("task_cls", _WIDGET_CLASSES)
def test_failure_with_propagation_and_output_changed_failure(task_cls):
recorded_calls = list()
failures = {
"run": _TestException("error in task"),
"clear_downstream": _TestException("error in widget: success propagation"),
"task_output_changed": _TestException("error in widget: output callback"),
}
with pytest.raises(_TestException, match="error in task"):
_execute_task(task_cls, recorded_calls, failures, a=1, b=2)
assert recorded_calls == ["run", "clear_downstream", "task_output_changed"]
def _execute_task(
task_cls, recorded_calls: List[str], failures: Dict[str, Exception], **params
) -> dict:
return execute_task(
task_cls,
inputs={**params, "recorded_calls": recorded_calls, "failures": failures},
timeout=10,
recorded_calls=recorded_calls,
failures=failures,
)