8000 feat!: change `Optional` to `IsOptional` by ZanSara · Pull Request #142 · deepset-ai/canals · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Nov 28, 2023. It is now read-only.

feat!: change Optional to IsOptional #142

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions canals/component/sockets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import get_origin, get_args, List, Type, Union
from typing import get_args, List, Type
import logging
from dataclasses import dataclass, field

from canals.component.types import CANALS_VARIADIC_ANNOTATION
from canals.component.types import CANALS_VARIADIC_ANNOTATION, CANALS_OPTIONAL_ANNOTATION


logger = logging.getLogger(__name__)
Expand All @@ -20,7 +20,19 @@ class InputSocket:
sender: List[str] = field(default_factory=list)

def __post_init__(self):
self.is_optional = get_origin(self.type) is Union and type(None) in get_args(self.type)
try:
# __metadata__ is a tuple
self.is_optional = self.type.__metadata__[0] == CANALS_OPTIONAL_ANNOTATION
except AttributeError:
self.is_optional = False
if self.is_optional:
# We need to "unpack" the type inside the IsOptional annotation,
# otherwise the pipeline connection api will try to match
# `Annotated[type, CANALS_OPTIONAL_ANNOTATION]`.
# Note IsOptional is expressed as an annotation of one single type,
# so the return value of get_args will always be a one-item tuple.
self.type = get_args(self.type)[0]

try:
# __metadata__ is a tuple
self.is_variadic = self.type.__metadata__[0] == CANALS_VARIADIC_ANNOTATION
Expand Down
12 changes: 10 additions & 2 deletions canals/component/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import TypeVar
from typing import TypeVar, Optional
from typing_extensions import TypeAlias, Annotated # Python 3.8 compatibility

CANALS_VARIADIC_ANNOTATION = "__canals__variadic_t"
CANALS_OPTIONAL_ANNOTATION = "__canals__optional_t"

# # Generic type variable used in the Variadic container

# # Generic type variable used in the Variadic and IsOptional containers
T = TypeVar("T")


Expand All @@ -12,3 +14,9 @@
# type so it can be used in the `InputSocket` creation where we
# check that its annotation equals to CANALS_VARIADIC_ANNOTATION
Variadic: TypeAlias = Annotated[T, CANALS_VARIADIC_ANNOTATION]

# IsOptional is a custom annotation type we use to mark input types.
# This type doesn't do anything else than "marking" the contained
# type so it can be used in the `InputSocket` creation where we
# check that its annotation equals to CANALS_OPTIONAL_ANNOTATION.
IsOptional: TypeAlias = Annotated[Optional[T], CANALS_OPTIONAL_ANNOTATION]
5 changes: 2 additions & 3 deletions sample_components/add_value.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

from canals import component
from canals.component.types import IsOptional

10000
@component
Expand All @@ -16,7 +15,7 @@ def __init__(self, add: int = 1):
self.add = add

@component.output_types(result=int)
def run(self, value: int, add: Optional[int] = None):
def run(self, value: int, add: IsOptional[int] = None):
"""
Adds two values together.
"""
Expand Down
4 changes: 2 additions & 2 deletions sample_components/greet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import logging

from canals import component
from canals.component.types import IsOptional


logger = logging.getLogger(__name__)
Expand All @@ -31,7 +31,7 @@ def __init__(
self.log_level = log_level

@component.output_types(value=int)
def run(self, value: int, message: Optional[str] = None, log_level: Optional[str] = None):
def run(self, value: int, message: IsOptional[str] = None, log_level: IsOptional[str] = None):
"""
Logs a greeting message without affecting the value passing on the connection.
"""
Expand Down
5 changes: 3 additions & 2 deletions sample_components/merge_loop.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import List, Any, Optional, Dict
from typing import List, Any, Dict
import sys

from canals import component
from canals.errors import DeserializationError
from canals.serialization import default_to_dict
from canals.component.types import IsOptional


@component
class MergeLoop:
def __init__(self, expected_type: Any, inputs: List[str]):
component.set_input_types(self, **{input_name: Optional[expected_type] for input_name in inputs})
component.set_input_types(self, **{input_name: IsOptional[expected_type] for input_name in inputs}) # type: ignore
component.set_output_types(self, value=expected_type)

if expected_type.__module__ == "builtins":
Expand Down
5 changes: 3 additions & 2 deletions sample_components/sum.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, List
from typing import List

from canals import component
from canals.component.types import IsOptional


@component
class Sum:
def __init__(self, inputs: List[str]):
self.inputs = inputs
component.set_input_types(self, **{input_name: Optional[int] for input_name in inputs})
component.set_input_types(self, **{input_name: IsOptional[int] for input_name in inputs}) # type: ignore

@component.output_types(total=int)
def run(self, **kwargs):
Expand Down
5 changes: 2 additions & 3 deletions sample_components/threshold.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

from canals import component
from canals.component.types import IsOptional


@component
Expand All @@ -22,7 +21,7 @@ def __init__(self, threshold: int = 10):
self.threshold = threshold

@component.output_types(above=int, below=int)
def run(self, value: int, threshold: Optional[int] = None):
def run(self, value: int, threshold: IsOptional[int] = None):
"""
Redirects the value, unchanged, along a different connection whether the value is above
or below the given threshold.
Expand Down
8 changes: 7 additions & 1 deletion test/component/test_sockets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from typing import Optional

from canals.component import InputSocket
from canals.component.types import IsOptional


def test_is_not_optional():
s = InputSocket("test_name", int)
assert s.is_optional is False


def test_is_optional():
def test_is_optional_python_type():
s = InputSocket("test_name", Optional[int])
assert not s.is_optional


def test_is_optional_canals_type():
s = InputSocket("test_name", IsOptional[int])
assert s.is_optional
9 changes: 5 additions & 4 deletions test/pipeline/unit/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from canals.errors import PipelineMaxLoops, PipelineError, PipelineRuntimeError
from sample_components import AddFixedValue, Threshold, MergeLoop, Double
from canals.testing.factory import component_class
from canals.component.types import IsOptional

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -110,7 +111,7 @@ def test_from_dict():
assert add_two["instance"].add == 2
assert add_two["input_sockets"] == {
"value": InputSocket(name="value", type=int),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=IsOptional[int]),
}
assert add_two["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand All @@ -122,7 +123,7 @@ def test_from_dict():
assert add_default["instance"].add == 1
assert add_default["input_sockets"] == {
"value": InputSocket(name="value", type=int, sender=["double"]),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=IsOptional[int]),
}
assert add_default["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand Down Expand Up @@ -202,7 +203,7 @@ def test_from_dict_with_components_instances():
assert add_two_data["instance"].add == 2
assert add_two_data["input_sockets"] == {
"value": InputSocket(name="value", type=int),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=IsOptional[int]),
}
assert add_two_data["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand All @@ -215,7 +216,7 @@ def test_from_dict_with_components_instances():
assert add_default_data["instance"].add == 1
assert add_default_data["input_sockets"] == {
"value": InputSocket(name="value", type=int, sender=["double"]),
"add": InputSocket(name="add", type=Optional[int]),
"add": InputSocket(name="add", type=IsOptional[int]),
}
assert add_default_data["output_sockets"] == {
"result": OutputSocket(name="result", type=int),
Expand Down
11 changes: 6 additions & 5 deletions test/pipeline/unit/test_validation_pipeline_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from canals.component.sockets import InputSocket, OutputSocket
from canals.pipeline.validation import _find_pipeline_inputs, _find_pipeline_outputs
from sample_components import Double, AddFixedValue, Sum, Parity
from canals.component.types import IsOptional


def test_find_pipeline_input_no_input():
Expand Down Expand Up @@ -41,7 +42,7 @@ def test_find_pipeline_input_two_inputs_same_component():
assert _find_pipeline_inputs(pipe.graph) == {
"comp1": [
InputSocket(name="value", type=int),
InputSocket(name="add", type=Optional[int]),
InputSocket(name="add", type=IsOptional[int]),
],
"comp2": [],
}
Expand All @@ -58,7 +59,7 @@ def test_find_pipeline_input_some_inputs_different_components():
assert _find_pipeline_inputs(pipe.graph) == {
"comp1": [
InputSocket(name="value", type=int),
InputSocket(name="add", type=Optional[int]),
InputSocket(name="add", type=IsOptional[int]),
],
"comp2": [InputSocket(name="value", type=int)],
"comp3": [],
Expand All @@ -74,12 +75,12 @@ def test_find_pipeline_variable_input_nodes_in_the_pipeline():
assert _find_pipeline_inputs(pipe.graph) == {
"comp1": [
InputSocket(name="value", type=int),
InputSocket(name="add", type=Optional[int]),
InputSocket(name="add", type=IsOptional[int]),
],
"comp2": [InputSocket(name="value", type=int)],
"comp3": [
InputSocket(name="in_1", type=Optional[int]),
InputSocket(name="in_2", type=Optional[int]),
InputSocket(name="in_1", type=IsOptional[int]),
InputSocket(name="in_2", type=IsOptional[int]),
],
}

Expand Down
0