Source code for ewokscore.graph.graph_io
from typing import Dict, Iterator, List, Mapping, Optional, Union
import networkx
from .analysis import start_nodes
from .analysis import end_nodes
from ..node import NodeIdType
from ..node import get_node_label
from .. import missing_data
from ..task import Task
[docs]
def update_default_inputs(
graph: networkx.DiGraph, inputs: Optional[List[dict]] = None
) -> None:
"""Input items have the following keys:
- name: input variable name
- value: input variable value
- id (optional): node id
- label (optional): used when `id` is missing
- task_identifier (optional): used when `id` is missing
- all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: start nodes)
"""
inputs = parse_inputs(graph, inputs)
keys_to_update = "name", "value"
for input_item in inputs:
node_id = input_item.get("id")
if node_id is None:
continue
node_attrs = graph.nodes[node_id]
default_inputs = node_attrs.get("default_inputs")
input_item = {k: input_item[k] for k in keys_to_update}
if default_inputs:
for existing_input_item in default_inputs:
if existing_input_item["name"] == input_item["name"]:
existing_input_item.update(input_item)
break
else:
default_inputs.append(input_item)
else:
node_attrs["default_inputs"] = [input_item]
[docs]
def parse_inputs(
graph: networkx.DiGraph, inputs: Optional[List[dict]] = None
) -> List[dict]:
"""Input items have the following keys:
- name: input variable name
- value: input variable value
- id (optional): node id
- label (optional): used when `id` is missing
- task_identifier (optional): used when `id` is missing
- all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: start nodes)
"""
if not inputs:
return list()
required = {"name", "value"}
returned = {"id", "name", "value"}
parsed = list()
for input_item in list(inputs):
missing = required - input_item.keys()
if missing:
raise ValueError(f"missing keys in one of the graph inputs: {missing}")
if "id" in input_item:
parsed.append({k: v for k, v in input_item.items() if k in returned})
continue
node_filters = dict()
for k in ("label", "task_identifier"):
if k in input_item:
node_filters[k] = input_item[k]
if node_filters:
node_ids = iter_node_ids(graph, **node_filters)
elif input_item.get("all"):
node_ids = graph.nodes
else:
node_ids = start_nodes(graph)
for node_id in node_ids:
input_item = {k: v for k, v in input_item.items() if k in returned}
input_item["id"] = node_id
parsed.append(input_item)
return parsed
[docs]
def parse_outputs(
graph: networkx.DiGraph, outputs: Optional[List[dict]] = None
) -> List[dict]:
"""Output items have the following keys:
- name (optional): output variable name (all outputs when missing)
- new_name (optional): optional renaming when `name` is defined
- id (optional): node id
- label (optional): used when `id` is missing
- task_identifier (optional): used when `id` is missing
- all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: end nodes)
"""
if outputs is None:
outputs = [{"all": False}]
parsed = list()
returned = {"id", "name", "new_name"}
for output_item in outputs:
if "id" in output_item:
parsed.append({k: v for k, v in output_item.items() if k in returned})
continue
node_filters = dict()
for k in ("label", "task_identifier"):
if k in output_item:
node_filters[k] = output_item[k]
if node_filters:
node_ids = iter_node_ids(graph, **node_filters)
elif output_item.get("all"):
node_ids = graph.nodes
else:
node_ids = end_nodes(graph)
for node_id in node_ids:
output_item = {k: v for k, v in output_item.items() if k in returned}
output_item["id"] = node_id
parsed.append(output_item)
return parsed
[docs]
def iter_node_ids(
graph: networkx.DiGraph,
label: Optional[str] = None,
task_identifier: Optional[str] = None,
) -> Iterator[NodeIdType]:
"""Yield nodes with matching `label` AND `task_identifier`"""
for node_id, node_attrs in graph.nodes.items():
return_id = False
if label is not None:
node_label = get_node_label(node_id, node_attrs)
if label != node_label:
continue
return_id = True
if task_identifier is not None:
s = node_attrs.get("task_identifier")
if not s or not s.endswith(task_identifier):
continue
return_id = True
if return_id:
yield node_id
[docs]
def extract_output_values(
node_id: NodeIdType, task_or_outputs: Union[Task, Mapping], outputs: List[dict]
) -> Optional[dict]:
"""Output items have the following keys:
- id: node id
- label (optional): used when `id` is missing
- name (optional): output variable name (all outputs when missing)
- new_name (optional): optional renaming when name is defined
"""
output_values = None
if isinstance(task_or_outputs, Task):
task_output_values = None
else:
task_output_values = task_or_outputs
for output_item in outputs:
if output_item.get("id") != node_id:
continue
if task_output_values is None:
task_output_values = task_or_outputs.get_output_values()
if output_values is None:
output_values = dict()
name = output_item.get("name")
if name:
new_name = output_item.get("new_name", name)
output_values[new_name] = task_output_values.get(
name, missing_data.MISSING_DATA
)
else:
output_values.update(task_output_values)
return output_values
[docs]
def add_output_values(
output_values: dict,
node_id: NodeIdType,
task_or_outputs: Union[Task, Dict],
outputs: List[dict],
merge_outputs: Optional[bool] = True,
) -> None:
"""Output items have the following keys:
- id: node id
- label (optional): used when `id` is missing
- name (optional): output variable name (all outputs when missing)
- new_name (optional): optional renaming when name is defined
"""
task_output_values = extract_output_values(node_id, task_or_outputs, outputs)
if task_output_values is not None:
if merge_outputs:
output_values.update(task_output_values)
else:
output_values[node_id] = task_output_values