8000 fix: change `is_optional` with `has_default` by ZanSara · Pull Request #155 · deepset-ai/canals · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Nov 28, 2023. It is now read-only.

fix: change is_optional with has_default #155

Merged
merged 4 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions canals/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __call__(cls, *args, **kwargs):
param: InputSocket(
name=param,
type=run_signature.parameters[param].annotation,
has_default=run_signature.parameters[param].default != inspect.Parameter.empty,
)
for param in list(run_signature.parameters)[1:] # First is 'self' and it doesn't matter.
}
Expand Down
5 changes: 2 additions & 3 deletions canals/component/sockets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import get_origin, get_args, List, Type, Union
from typing import get_args, List, Type
import logging
from dataclasses import dataclass, field

Expand All @@ -15,12 +15,11 @@
class InputSocket:
name: str
type: Type
is_optional: bool = field(init=False)
has_default: bool = False
is_variadic: bool = field(init=False)
sender: List[str] = field(default_factory=list)

def __post_init__(self):
self.is_optional = get_origin(self.type) is Union and type(None) in get_args(self.type)
try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == CANALS_VARIADIC_ANNOTATION
Expand Down
2 changes: 1 addition & 1 deletion canals/pipeline/descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def describe_pipeline_inputs(graph: networkx.MultiDiGraph):
Returns a dictionary with the input names and types that this pipeline accepts.
"""
inputs = {
comp: {socket.name: {"type": socket.type, "is_optional": socket.is_optional} for socket in data}
comp: {socket.name: {"type": socket.type, "has_default": socket.has_default} for socket in data}
for comp, data in find_pipeline_inputs(graph).items()
if data
}
Expand Down
2 changes: 1 addition & 1 deletion canals/pipeline/draw/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]
graph.add_node("input")
for node, in_sockets in find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
if not in_socket.sender and not in_socket.is_optional:
if not in_socket.sender and not in_socket.has_default:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
Expand Down
6 changes: 3 additions & 3 deletions canals/pipeline/draw/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation
with `mermaid` codeblocks and it will be automatically rendered.
"""
sockets = {
opt_sockets = {
comp: "".join(
[
f"<li>{name} ({_type_name(socket.type)})</li>"
for name, socket in data.get("input_sockets", {}).items()
if socket.is_optional and not socket.sender
if (socket.has_default and not socket.sender) or socket.is_variadic
]
)
for comp, data in graph.nodes(data=True)
}
optional_inputs = {
comp: f"<br><br>Optional inputs:<ul style='text-align:left;'>{sockets}</ul>" if sockets else ""
for comp, sockets in sockets.items()
for comp, sockets in opt_sockets.items()
}

states = {
Expand Down
6 changes: 3 additions & 3 deletions canals/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def inputs(self):
Returns a dictionary with the input names and types that this pipeline accepts.
"""
inputs = {
comp: {socket.name: {"type": socket.type, "is_optional": socket.is_optional} for socket in data}
comp: {socket.name: {"type": socket.type, "has_default": socket.has_default} for socket in data}
for comp, data in find_pipeline_inputs(self.graph).items()
if data
}
Expand Down Expand Up @@ -586,10 +586,10 @@ def _calculate_action( # pylint: disable=too-many-locals, too-many-return-state
# All components inputs, whether they're connected, default or pipeline inputs
input_sockets: Dict[str, InputSocket] = self.graph.nodes[name]["input_sockets"].keys()
optional_input_sockets = {
socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if socket.is_optional
socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if not socket.has_default
}
mandatory_input_sockets = {
socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if not socket.is_optional
socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if socket.has_default
}

# Components that are in the inputs buffer and have no inputs assigned are considered skipped
Expand Down
2 changes: 1 addition & 1 deletion canals/pipeline/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _validate_input_sockets_are_connected(graph: networkx.MultiDiGraph, input_va
or not socket.name in inputs_for_node.keys()
or inputs_for_node.get(socket.name, None) is None
)
if missing_input_value and not socket.is_optional:
if missing_input_value and not socket.has_default:
all_inputs = describe_pipeline_inputs_as_string(graph)
raise ValueError(f"Missing input: {node}.{socket.name}\n\n{all_inputs}")

Expand Down
13 changes: 0 additions & 13 deletions test/component/test_sockets.py

This file was deleted.

1 change: 0 additions & 1 deletion test/pipeline/integration/test_complex_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from pprint import pprint
import logging

from canals.pipeline import Pipeline
Expand Down
14 changes: 7 additions & 7 deletions test/pipeline/unit/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_from_dict():
assert add_two["instance"].add == 2
assert add_two["input_sockets"] == {
"value": InputSocket(name="value", type=int),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=Optional[int], has_default=True),
}
assert add_two["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand All @@ -122,7 +122,7 @@ def test_from_dict():
assert add_default["instance"].add == 1
assert add_default["input_sockets"] == {
"value": InputSocket(name="value", type=int, sender=["double"]),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=Op C3E2 tional[int], has_default=True),
}
assert add_default["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand Down Expand Up @@ -202,7 +202,7 @@ def test_from_dict_with_components_instances():
assert add_two_data["instance"].add == 2
assert add_two_data["input_sockets"] == {
"value": InputSocket(name="value", type=int),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=Optional[int], has_default=True),
}
assert add_two_data["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand All @@ -215,7 +215,7 @@ def test_from_dict_with_components_instances():
assert add_default_data["instance"].add == 1
assert add_default_data["input_sockets"] == {
"value": InputSocket(name="value", type=int, sender=["double"]),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=Optional[int], has_default=True),
}
assert add_default_data["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_describe_input_some_components_with_no_inputs():
p.add_component("c", C())
p.connect("a.x", "c.x")
p.connect("b.y", "c.y")
assert p.inputs() == {"b": {"y": {"type": int, "is_optional": False}}}
assert p.inputs() == {"b": {"y": {"type": int, "has_default": False}}}


def test_describe_input_all_components_have_inputs():
Expand All @@ -369,6 +369,6 @@ def test_describe_input_all_components_have_inputs():
p.connect("a.x", "c.x")
p.connect("b.y", "c.y")
assert p.inputs() == {
"a": {"x": {"type": Optional[int], "is_optional": True}},
"b": {"y": {"type": int, "is_optional": False}},
"a": {"x": {"type": Optional[int], "has_default": False}},
"b": {"y": {"type": int, "has_default": False}},
}
0