From 282a66c01857d80169ebe083ae2f34a98def4918 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Mon, 23 Oct 2023 17:15:12 +0200 Subject: [PATCH] change Optional to IsOptional --- canals/component/sockets.py | 18 +++++++++++++++--- canals/component/types.py | 12 ++++++++++-- sample_components/add_value.py | 5 ++--- sample_components/greet.py | 4 ++-- sample_components/merge_loop.py | 5 +++-- sample_components/sum.py | 5 +++-- sample_components/threshold.py | 5 ++--- test/component/test_sockets.py | 8 +++++++- test/pipeline/unit/test_pipeline.py | 9 +++++---- .../unit/test_validation_pipeline_io.py | 11 ++++++----- 10 files changed, 55 insertions(+), 27 deletions(-) diff --git a/canals/component/sockets.py b/canals/component/sockets.py index 6fb2854d..89e6324f 100644 --- a/canals/component/sockets.py +++ b/canals/component/sockets.py @@ -1,11 +1,11 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import get_origin, get_args, List, Type, Union +from typing import get_args, List, Type import logging from dataclasses import dataclass, field -from canals.component.types import CANALS_VARIADIC_ANNOTATION +from canals.component.types import CANALS_VARIADIC_ANNOTATION, CANALS_OPTIONAL_ANNOTATION logger = logging.getLogger(__name__) @@ -20,7 +20,19 @@ class InputSocket: sender: List[str] = field(default_factory=list) def __post_init__(self): - self.is_optional = get_origin(self.type) is Union and type(None) in get_args(self.type) + try: + # __metadata__ is a tuple + self.is_optional = self.type.__metadata__[0] == CANALS_OPTIONAL_ANNOTATION + except AttributeError: + self.is_optional = False + if self.is_optional: + # We need to "unpack" the type inside the IsOptional annotation, + # otherwise the pipeline connection api will try to match + # `Annotated[type, CANALS_OPTIONAL_ANNOTATION]`. + # Note IsOptional is expressed as an annotation of one single type, + # so the return value of get_args will always be a one-item tuple. + self.type = get_args(self.type)[0] + try: # __metadata__ is a tuple self.is_variadic = self.type.__metadata__[0] == CANALS_VARIADIC_ANNOTATION diff --git a/canals/component/types.py b/canals/component/types.py index dba8b7f1..95ab3d8e 100644 --- a/canals/component/types.py +++ b/canals/component/types.py @@ -1,9 +1,11 @@ -from typing import TypeVar +from typing import TypeVar, Optional from typing_extensions import TypeAlias, Annotated # Python 3.8 compatibility CANALS_VARIADIC_ANNOTATION = "__canals__variadic_t" +CANALS_OPTIONAL_ANNOTATION = "__canals__optional_t" -# # Generic type variable used in the Variadic container + +# # Generic type variable used in the Variadic and IsOptional containers T = TypeVar("T") @@ -12,3 +14,9 @@ # type so it can be used in the `InputSocket` creation where we # check that its annotation equals to CANALS_VARIADIC_ANNOTATION Variadic: TypeAlias = Annotated[T, CANALS_VARIADIC_ANNOTATION] + +# IsOptional is a custom annotation type we use to mark input types. +# This type doesn't do anything else than "marking" the contained +# type so it can be used in the `InputSocket` creation where we +# check that its annotation equals to CANALS_OPTIONAL_ANNOTATION. +IsOptional: TypeAlias = Annotated[Optional[T], CANALS_OPTIONAL_ANNOTATION] diff --git a/sample_components/add_value.py b/sample_components/add_value.py index d9370841..dd7ef238 100644 --- a/sample_components/add_value.py +++ b/sample_components/add_value.py @@ -1,9 +1,8 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Optional - from canals import component +from canals.component.types import IsOptional @component @@ -16,7 +15,7 @@ def __init__(self, add: int = 1): self.add = add @component.output_types(result=int) - def run(self, value: int, add: Optional[int] = None): + def run(self, value: int, add: IsOptional[int] = None): """ Adds two values together. """ diff --git a/sample_components/greet.py b/sample_components/greet.py index b3c2afeb..1c17967f 100644 --- a/sample_components/greet.py +++ b/sample_components/greet.py @@ -1,10 +1,10 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Optional import logging from canals import component +from canals.component.types import IsOptional logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def __init__( self.log_level = log_level @component.output_types(value=int) - def run(self, value: int, message: Optional[str] = None, log_level: Optional[str] = None): + def run(self, value: int, message: IsOptional[str] = None, log_level: IsOptional[str] = None): """ Logs a greeting message without affecting the value passing on the connection. """ diff --git a/sample_components/merge_loop.py b/sample_components/merge_loop.py index e84a5aac..6acf91cd 100644 --- a/sample_components/merge_loop.py +++ b/sample_components/merge_loop.py @@ -1,18 +1,19 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import List, Any, Optional, Dict +from typing import List, Any, Dict import sys from canals import component from canals.errors import DeserializationError from canals.serialization import default_to_dict +from canals.component.types import IsOptional @component class MergeLoop: def __init__(self, expected_type: Any, inputs: List[str]): - component.set_input_types(self, **{input_name: Optional[expected_type] for input_name in inputs}) + component.set_input_types(self, **{input_name: IsOptional[expected_type] for input_name in inputs}) # type: ignore component.set_output_types(self, value=expected_type) if expected_type.__module__ == "builtins": diff --git a/sample_components/sum.py b/sample_components/sum.py index c6adb0d9..9e6f3ce6 100644 --- a/sample_components/sum.py +++ b/sample_components/sum.py @@ -1,16 +1,17 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, List +from typing import List from canals import component +from canals.component.types import IsOptional @component class Sum: def __init__(self, inputs: List[str]): self.inputs = inputs - component.set_input_types(self, **{input_name: Optional[int] for input_name in inputs}) + component.set_input_types(self, **{input_name: IsOptional[int] for input_name in inputs}) # type: ignore @component.output_types(total=int) def run(self, **kwargs): diff --git a/sample_components/threshold.py b/sample_components/threshold.py index d27f9ee7..daf17f57 100644 --- a/sample_components/threshold.py +++ b/sample_components/threshold.py @@ -1,9 +1,8 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Optional - from canals import component +from canals.component.types import IsOptional @component @@ -22,7 +21,7 @@ def __init__(self, threshold: int = 10): self.threshold = threshold @component.output_types(above=int, below=int) - def run(self, value: int, threshold: Optional[int] = None): + def run(self, value: int, threshold: IsOptional[int] = None): """ Redirects the value, unchanged, along a different connection whether the value is above or below the given threshold. diff --git a/test/component/test_sockets.py b/test/component/test_sockets.py index 72d61114..daf4bb2d 100644 --- a/test/component/test_sockets.py +++ b/test/component/test_sockets.py @@ -1,6 +1,7 @@ from typing import Optional from canals.component import InputSocket +from canals.component.types import IsOptional def test_is_not_optional(): @@ -8,6 +9,11 @@ def test_is_not_optional(): assert s.is_optional is False -def test_is_optional(): +def test_is_optional_python_type(): s = InputSocket("test_name", Optional[int]) + assert not s.is_optional + + +def test_is_optional_canals_type(): + s = InputSocket("test_name", IsOptional[int]) assert s.is_optional diff --git a/test/pipeline/unit/test_pipeline.py b/test/pipeline/unit/test_pipeline.py index debf1101..88b1f4a1 100644 --- a/test/pipeline/unit/test_pipeline.py +++ b/test/pipeline/unit/test_pipeline.py @@ -11,6 +11,7 @@ from canals.errors import PipelineMaxLoops, PipelineError, PipelineRuntimeError from sample_components import AddFixedValue, Threshold, MergeLoop, Double from canals.testing.factory import component_class +from canals.component.types import IsOptional logging.basicConfig(level=logging.DEBUG) @@ -110,7 +111,7 @@ def test_from_dict(): assert add_two["instance"].add == 2 assert add_two["input_sockets"] == { "value": InputSocket(name="value", type=int), - "add": InputSocket(name="add", type=Optional[int]), + "add": InputSocket(name="add", type=IsOptional[int]), } assert add_two["output_sockets"] == { "result": OutputSocket(name="result", type=int), @@ -122,7 +123,7 @@ def test_from_dict(): assert add_default["instance"].add == 1 assert add_default["input_sockets"] == { "value": InputSocket(name="value", type=int, sender=["double"]), - "add": InputSocket(name="add", type=Optional[int]), + "add": InputSocket(name="add", type=IsOptional[int]), } assert add_default["output_sockets"] == { "result": OutputSocket(name="result", type=int), @@ -202,7 +203,7 @@ def test_from_dict_with_components_instances(): assert add_two_data["instance"].add == 2 assert add_two_data["input_sockets"] == { "value": InputSocket(name="value", type=int), - "add": InputSocket(name="add", type=Optional[int]), + "add": InputSocket(name="add", type=IsOptional[int]), } assert add_two_data["output_sockets"] == { "result": OutputSocket(name="result", type=int), @@ -215,7 +216,7 @@ def test_from_dict_with_components_instances(): assert add_default_data["instance"].add == 1 assert add_default_data["input_sockets"] == { "value": InputSocket(name="value", type=int, sender=["double"]), - "add": InputSocket(name="add", type=Optional[int]), + "add": InputSocket(name="add", type=IsOptional[int]), } assert add_default_data["output_sockets"] == { "result": OutputSocket(name="result", type=int), diff --git a/test/pipeline/unit/test_validation_pipeline_io.py b/test/pipeline/unit/test_validation_pipeline_io.py index e525902c..641876e5 100644 --- a/test/pipeline/unit/test_validation_pipeline_io.py +++ b/test/pipeline/unit/test_validation_pipeline_io.py @@ -8,6 +8,7 @@ from canals.component.sockets import InputSocket, OutputSocket from canals.pipeline.validation import _find_pipeline_inputs, _find_pipeline_outputs from sample_components import Double, AddFixedValue, Sum, Parity +from canals.component.types import IsOptional def test_find_pipeline_input_no_input(): @@ -41,7 +42,7 @@ def test_find_pipeline_input_two_inputs_same_component(): assert _find_pipeline_inputs(pipe.graph) == { "comp1": [ InputSocket(name="value", type=int), - InputSocket(name="add", type=Optional[int]), + InputSocket(name="add", type=IsOptional[int]), ], "comp2": [], } @@ -58,7 +59,7 @@ def test_find_pipeline_input_some_inputs_different_components(): assert _find_pipeline_inputs(pipe.graph) == { "comp1": [ InputSocket(name="value", type=int), - InputSocket(name="add", type=Optional[int]), + InputSocket(name="add", type=IsOptional[int]), ], "comp2": [InputSocket(name="value", type=int)], "comp3": [], @@ -74,12 +75,12 @@ def test_find_pipeline_variable_input_nodes_in_the_pipeline(): assert _find_pipeline_inputs(pipe.graph) == { "comp1": [ InputSocket(name="value", type=int), - InputSocket(name="add", type=Optional[int]), + InputSocket(name="add", type=IsOptional[int]), ], "comp2": [InputSocket(name="value", type=int)], "comp3": [ - InputSocket(name="in_1", type=Optional[int]), - InputSocket(name="in_2", type=Optional[int]), + InputSocket(name="in_1", type=IsOptional[int]), + InputSocket(name="in_2", type=IsOptional[int]), ], }