Source code for ewoksdask.bindings

# https://docs.dask.org/en/latest/scheduler-overview.html

import json
from typing import Any, Dict, List, Optional, Union

import dask
from dask.distributed import Client
from dask.threaded import get as multithreading_scheduler
from dask.multiprocessing import get as multiprocessing_scheduler
from dask import get as sequential_scheduler

from ewokscore import load_graph
from ewokscore import execute_graph_decorator
from ewokscore.inittask import instantiate_task
from ewokscore.inittask import add_dynamic_inputs
from ewokscore.graph.serialize import ewoks_jsonload_hook
from ewokscore.node import NodeIdType
from ewokscore.node import get_node_label
from ewokscore.graph import analysis
from ewokscore.graph import graph_io
from ewokscore.graph import TaskGraph
from ewokscore import events


def _execute_task(execute_options, *inputs):
    execute_options = json.loads(execute_options, object_pairs_hook=ewoks_jsonload_hook)

    dynamic_inputs = dict()
    target_id = execute_options["node_id"]
    for source_id, source_results, link_attrs in zip(
        execute_options["source_ids"], inputs, execute_options["link_attrs"]
    ):
        add_dynamic_inputs(
            dynamic_inputs,
            link_attrs,
            source_results,
            source_id=source_id,
            target_id=target_id,
        )
    task = instantiate_task(
        execute_options["node_id"],
        execute_options["node_attrs"],
        inputs=dynamic_inputs,
        varinfo=execute_options.get("varinfo"),
        execinfo=execute_options.get("execinfo"),
    )

    task.execute()

    return task.get_output_transfer_data()


def _create_dask_graph(ewoksgraph, **execute_options) -> dict:
    daskgraph = dict()
    for target_id, node_attrs in ewoksgraph.graph.nodes.items():
        source_ids = tuple(analysis.node_predecessors(ewoksgraph.graph, target_id))
        link_attrs = tuple(
            ewoksgraph.graph[source_id][target_id] for source_id in source_ids
        )
        node_label = get_node_label(target_id, node_attrs)
        execute_options["node_id"] = target_id
        execute_options["node_label"] = node_label
        execute_options["node_attrs"] = node_attrs
        execute_options["link_attrs"] = link_attrs
        execute_options["source_ids"] = source_ids
        # Note: the execute_options is serialized to prevent dask
        #       from interpreting node names as task results
        daskgraph[target_id] = (_execute_task, json.dumps(execute_options)) + source_ids
    return daskgraph


def _execute_dask_graph(
    daskgraph,
    node_ids: List[NodeIdType],
    scheduler: Union[dict, str, None, Client] = None,
    scheduler_options: Optional[dict] = None,
) -> Dict[NodeIdType, Any]:
    if scheduler_options is None:
        scheduler_options = dict()
    if scheduler is None:
        results = sequential_scheduler(daskgraph, node_ids, **scheduler_options)
    elif scheduler == "multiprocessing":
        # num_workers: CPU_COUNT by default
        scheduler_options = dict(scheduler_options)
        context = scheduler_options.pop("context", "spawn")
        dask.config.set({"multiprocessing.context": context})
        results = multiprocessing_scheduler(daskgraph, node_ids, **scheduler_options)
    elif scheduler == "multithreading":
        # num_workers: CPU_COUNT by default
        results = multithreading_scheduler(daskgraph, node_ids, **scheduler_options)
    elif scheduler == "cluster":
        # n_workers: n workers with m threads
        with Client(**scheduler_options) as client:
            results = client.get(daskgraph, node_ids)
    elif isinstance(scheduler, str):
        with Client(address=scheduler, **scheduler_options) as client:
            results = client.get(daskgraph, node_ids)
    elif isinstance(scheduler, Client):
        results = client.get(daskgraph, node_ids)
    else:
        raise ValueError("Unknown scheduler")

    return dict(zip(node_ids, results))


def _execute_graph(
    ewoksgraph: TaskGraph,
    outputs: Optional[List[dict]] = None,
    merge_outputs: Optional[bool] = True,
    varinfo: Optional[dict] = None,
    execinfo: Optional[dict] = None,
    scheduler: Union[dict, str, None, Client] = None,
    scheduler_options: Optional[dict] = None,
) -> Dict[NodeIdType, Any]:
    with events.workflow_context(execinfo, workflow=ewoksgraph.graph) as execinfo:
        if ewoksgraph.is_cyclic:
            raise RuntimeError("Dask can only execute DAGs")
        if ewoksgraph.has_conditional_links:
            raise RuntimeError("Dask cannot handle conditional links")

        daskgraph = _create_dask_graph(ewoksgraph, varinfo=varinfo, execinfo=execinfo)
        outputs = graph_io.parse_outputs(ewoksgraph.graph, outputs)
        node_ids = list(analysis.topological_sort(ewoksgraph.graph))

        result = _execute_dask_graph(
            daskgraph,
            node_ids,
            scheduler=scheduler,
            scheduler_options=scheduler_options,
        )
        output_values = dict()
        for node_id, task_outputs in result.items():
            graph_io.add_output_values(
                output_values,
                node_id,
                task_outputs,
                outputs,
                merge_outputs=merge_outputs,
            )
        return output_values


[docs]@execute_graph_decorator(engine="dask") def execute_graph( graph, inputs: Optional[List[dict]] = None, load_options: Optional[dict] = None, **execute_options ): if load_options is None: load_options = dict() ewoksgraph = load_graph(graph, inputs=inputs, **load_options) return _execute_graph(ewoksgraph, **execute_options)