8000 Make to/from_dict optional by masci · Pull Request #107 · deepset-ai/canals · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Nov 28, 2023. It is now read-only.

Make to/from_dict optional #107

Merged
merged 7 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions canals/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,6 @@ def run(self, **kwargs) -> Dict[str, Any]:
or with `component.set_output_types()` if dynamic.
"""

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
"""

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Component":
"""
Deserializes the component from a dictionary.
"""


class _Component:
"""
Expand Down Expand Up @@ -234,16 +223,6 @@ def _component(self, class_):
raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
run_signature = inspect.signature(class_.run)

if not hasattr(class_, "to_dict"):
raise ComponentError(
f"{class_.__name__} must have a 'to_dict()' method. See the docs for more information."
)

if not hasattr(class_, "from_dict"):
raise ComponentError(
f"{class_.__name__} must have a 'from_dict()' method. See the docs for more information."
)

# Create the input sockets
class_.run.__canals_input__ = {
param: {
Expand Down
4 changes: 4 additions & 0 deletions canals/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ class ComponentDeserializationError(Exception):

class DeserializationError(Exception):
pass


class SerializationError(Exception):
pass
3 changes: 2 additions & 1 deletion canals/pipeline/draw/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from canals.errors import PipelineDrawingError
from canals.utils import _type_name
from canals.serialization import component_to_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,7 +70,7 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
for name, comp in graph.nodes(data="instance"):
if name in ["input", "output"]:
continue
data = comp.to_dict()
data = component_to_dict(comp)
params = [f"{k}={json.dumps(v)}" for k, v in data.get("init_parameters", {}).items()]
init_params[name] = ",<br>".join(params)
states = "\n".join(
Expand Down
9 changes: 7 additions & 2 deletions canals/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from canals.pipeline.validation import _validate_pipeline_input
from canals.pipeline.connections import _parse_connection_name, _find_unambiguous_connection
from canals.utils import _type_name
from canals.serialization import component_to_dict, component_from_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,7 +85,10 @@ def to_dict(self) -> Dict[str, Any]:
Returns this Pipeline instance as a dictionary.
This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
"""
components = {name: instance.to_dict() for name, instance in self.graph.nodes(data="instance")}
components = {}
for name, instance in self.graph.nodes(data="instance"):
components[name] = component_to_dict(instance)

connections = []
for sender, receiver, sockets in self.graph.edges:
(sender_socket, receiver_socket) = sockets.split("/")
Expand Down Expand Up @@ -152,7 +156,8 @@ def from_dict(cls, data: Dict[str, Any], **kwargs) -> "Pipeline":
if component_data["type"] not in component.registry:
raise PipelineError(f"Component '{component_data['type']}' not imported.")
# Create a new one
instance = component.registry[component_data["type"]].from_dict(component_data)
component_class = component.registry[component_data["type"]]
instance = component_from_dict(component_class, component_data)
pipe.add_component(name=name, instance=instance)

for connection in data.get("connections", []):
Expand Down
43 changes: 42 additions & 1 deletion canals/serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,50 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Type, Dict, Any

from canals.errors import DeserializationError
from canals.errors import DeserializationError, SerializationError


def component_to_dict(obj: Any) -> Dict[str, Any]:
"""
The marshaller used by the Pipeline. If a `to_dict` method is present in the
component instance, that will be used instead of the default method.
"""
if hasattr(obj, "to_dict"):
return obj.to_dict()

init_parameters = {}
for name, param in inspect.signature(obj.__init__).parameters.items():
try:
# This only works if the Component constructor assigns the init
# parameter to an instance variable or property with the same name
param_value = getattr(obj, name)
except AttributeError as e:
# If the parameter doesn't have a default value, raise an error
if param.default is param.empty:
raise SerializationError(
f"Cannot determined the value of the init parameter '{name}'. "
f"You can fix this error by assigning 'self.{name} = {name}' or adding a "
f"custom serialization method 'to_dict' to the class {obj.__class__.__name__}"
) from e
# In case the init parameter was not assigned, we use the default value
param_value = param.default
init_parameters[name] = param_value

return default_to_dict(obj, **init_parameters)


def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
"""
The unmarshaller used by the Pipeline. If a `from_dict` method is present in the
component instance, that will be used instead of the default method.
"""
if hasattr(cls, "from_dict"):
return cls.from_dict(data)

return default_from_dict(cls, data)


def default_to_dict(obj: Any, **init_parameters) -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,5 @@ disable = [
"line-too-long",
"missing-class-docstring",
"missing-module-docstring",
"too-few-public-methods",
]
10 changes: 1 addition & 9 deletions sample_components/greet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Dict, Any
from typing import Optional
import logging

from canals import component
from canals.serialization import default_to_dict, default_from_dict


logger = logging.getLogger(__name__)
Expand All @@ -31,13 +30,6 @@ def __init__(
self.message = message
self.log_level = log_level

def to_dict(self) -> Dict[str, Any]: # pylint: disable=missing-function-docstring
return default_to_dict(self, message=self.message, log_level=self.log_level)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Greet": # pylint: disable=missing-function-docstring
return default_from_dict(cls, data)

@component.output_types(value=int)
def run(self, value: int, message: Optional[str] = None, log_level: Optional[str] = None):
"""
Expand Down
9 changes: 5 additions & 4 deletions test/sample_components/test_greet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import logging

from sample_components import Greet
from canals.serialization import component_to_dict, component_from_dict


def test_to_dict():
component = Greet()
res = component.to_dict()
res = component_to_dict(component)
assert res == {
"type": "Greet",
"init_parameters": {
Expand All @@ -20,7 +21,7 @@ def test_to_dict():

def test_to_dict_with_custom_parameters():
component = Greet(message="My message", log_level="ERROR")
res = component.to_dict()
res = component_to_dict(component)
assert res == {
"type": "Greet",
"init_parameters": {
Expand All @@ -35,7 +36,7 @@ def test_from_dict():
"type": "Greet",
"init_parameters": {},
}
component = Greet.from_dict(data)
component = component_from_dict(Greet, data)
assert component.message == "\nGreeting component says: Hi! The value is {value}\n"
assert component.log_level == "INFO"

Expand All @@ -45,7 +46,7 @@ def test_from_with_custom_parameters():
"type": "Greet",
"init_parameters": {"message": "My message", "log_level": "ERROR"},
}
component = Greet.from_dict(data)
component = component_from_dict(Greet, data)
assert component.message == "My message"
assert component.log_level == "ERROR"

Expand Down
0