8000 support `deps_path` and `params_path` in dataclass dependencies by PythonFZ · Pull Request #915 · zincware/ZnTrack · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

support deps_path and params_path in dataclass dependencies #915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tests/files/dvc_config/dataclass_w_deps_params_path.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
stages:
NodeWithModel:
cmd: zntrack run test_dataclass_w_deps_params_path.NodeWithModel --name NodeWithModel
metrics:
- nodes/NodeWithModel/node-meta.json:
cache: true
params:
- NodeWithModel
- config.yaml: null
NodeWithModel_1:
cmd: zntrack run test_dataclass_w_deps_params_path.NodeWithModel --name NodeWithModel_1
metrics:
- nodes/NodeWithModel_1/node-meta.json:
cache: true
params:
- NodeWithModel_1
- config.yaml: null
NodeWithModel_2:
cmd: zntrack run test_dataclass_w_deps_params_path.NodeWithModel --name NodeWithModel_2
deps:
- file.txt
metrics:
- nodes/NodeWithModel_2/node-meta.json:
cache: true
params:
- NodeWithModel_2
NodeWithModel_3:
cmd: zntrack run test_dataclass_w_deps_params_path.NodeWithModel --name NodeWithModel_3
deps:
- file.txt
metrics:
- nodes/NodeWithModel_3/node-meta.json:
cache: true
params:
- NodeWithModel_3
NodeWithModel_4:
cmd: zntrack run test_dataclass_w_deps_params_path.NodeWithModel --name NodeWithModel_4
deps:
- file.txt
- file2.txt
metrics:
- nodes/NodeWithModel_4/node-meta.json:
cache: true
params:
- NodeWithModel_4
- config.yaml: null
- config2.yaml: null
NodeWithModel_5:
cmd: zntrack run test_dataclass_w_deps_params_path.NodeWithModel --name NodeWithModel_5
deps:
- file.txt
- file2.txt
metrics:
- nodes/NodeWithModel_5/node-meta.json:
cache: true
params:
- NodeWithModel_5
- config.yaml: null
- config2.yaml: null
30 changes: 30 additions & 0 deletions tests/files/params_config/dataclass_w_deps_params_path.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
NodeWithModel:
model:
_cls: test_dataclass_w_deps_params_path.ModelWithParamsPath
params:
a: 1
NodeWithModel_1:
model:
_cls: test_dataclass_w_deps_params_path.ModelWithParamsPath
params:
a: 1
NodeWithModel_2:
model:
_cls: test_dataclass_w_deps_params_path.ModelWithDepsPath
params:
a: 1
NodeWithModel_3:
model:
_cls: test_dataclass_w_deps_params_path.ModelWithDepsPath
params:
a: 1
NodeWithModel_4:
model:
_cls: test_dataclass_w_deps_params_path.ModelWithParamsAndDepsPath
params:
a: 1
NodeWithModel_5:
model:
- _cls: test_dataclass_w_deps_params_path.ModelWithParamsAndDepsPath
params:
a: 1
79 changes: 79 additions & 0 deletions tests/files/test_dataclass_w_deps_params_path.py
< 10000 td class="blob-num blob-num-addition empty-cell">
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import dataclasses
import json
import pathlib
from pathlib import Path

import yaml

import zntrack

CWD = pathlib.Path(__file__).parent.resolve()


class Model:
"""Base Model class"""


@dataclasses.dataclass
class ModelWithParamsPath(Model):
"""Model with parameters"""

params: dict
config: str | Path | list[str | Path] = zntrack.params_path()


@dataclasses.dataclass
class ModelWithDepsPath(Model):
"""Model with dependencies"""

params: dict
files: str | Path | list[str | Path] = zntrack.deps_path()


@dataclasses.dataclass
class ModelWithParamsAndDepsPath(Model):
"""Model with parameters and dependencies"""

params: dict
config: str | Path | list[str | Path] = zntrack.params_path()
files: str | Path | list[str | Path] = zntrack.deps_path()


class NodeWithModel(zntrack.Node):
"""Node with model"""

model: Model | list[Model] = zntrack.deps()


def test_node_with_dc_model_params_deps(proj_path):
project = zntrack.Project()

a1 = ModelWithParamsPath(params={"a": 1}, config="config.yaml")
a2 = ModelWithParamsPath(params={"a": 1}, config=Path("config.yaml"))
b1 = ModelWithDepsPath(params={"a": 1}, files="file.txt")
b2 = ModelWithDepsPath(params={"a": 1}, files=Path("file.txt"))
c = ModelWithParamsAndDepsPath(
params={"a": 1},
config=["config.yaml", Path("config2.yaml")],
files=["file.txt", Path("file2.txt")],
)

with project:
NodeWithModel(model=a1)
NodeWithModel(model=a2)
NodeWithModel(model=b1)
NodeWithModel(model=b2)
NodeWithModel(model=c)
NodeWithModel(model=[c])

project.build()

assert json.loads(
(CWD / "zntrack_config" / "dataclass_w_deps_params_path.json").read_text()
) == json.loads((proj_path / "zntrack.json").read_text())
assert yaml.safe_load(
(CWD / "dvc_config" / "dataclass_w_deps_params_path.yaml").read_text()
) == yaml.safe_load((proj_path / "dvc.yaml").read_text())
assert (CWD / "params_config" / "dataclass_w_deps_params_path.yaml").read_text() == (
proj_path / "params.yaml"
).read_text()
82 changes: 82 additions & 0 deletions tests/files/zntrack_config/dataclass_w_deps_params_path.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
{
"NodeWithModel": {
"nwd": {
"_type": "pathlib.Path",
"value": "nodes/NodeWithModel"
},
"model": {
"_type": "@dataclasses.dataclass",
"value": {
"module": "test_dataclass_w_deps_params_path",
"cls": "ModelWithParamsPath"
}
}
},
"NodeWithModel_1": {
"nwd": {
"_type": "pathlib.Path",
"value": "nodes/NodeWithModel_1"
},
"model": {
"_type": "@dataclasses.dataclass",
"value": {
"module": "test_dataclass_w_deps_params_path",
"cls": "ModelWithParamsPath"
}
}
},
"NodeWithModel_2": {
"nwd": {
"_type": "pathlib.Path",
"value": "nodes/NodeWithModel_2"
},
"model": {
"_type": "@dataclasses.dataclass",
"value": {
"module": "test_dataclass_w_deps_params_path",
"cls": "ModelWithDepsPath"
}
}
},
"NodeWithModel_3": {
"nwd": {
"_type": "pathlib.Path",
"value": "nodes/NodeWithModel_3"
},
"model": {
"_type": "@dataclasses.dataclass",
"value": {
"module": "test_dataclass_w_deps_params_path",
"cls": "ModelWithDepsPath"
}
}
},
"NodeWithModel_4": {
"nwd": {
"_type": "pathlib.Path",
"value": "nodes/NodeWithModel_4"
},
"model": {
"_type": "@dataclasses.dataclass",
"value": {
"module": "test_dataclass_w_deps_params_path",
"cls": "ModelWithParamsAndDepsPath"
}
}
},
"NodeWithModel_5": {
"nwd": {
"_type": "pathlib.Path",
"value": "nodes/ 6D4E NodeWithModel_5"
},
"model": [
{
"_type": "@dataclasses.dataclass",
"value": {
"module": "test_dataclass_w_deps_params_path",
"cls": "ModelWithParamsAndDepsPath"
}
}
]
}
}
56 changes: 44 additions & 12 deletions zntrack/plugins/dvc_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import typing as t

import znflow
import znflow.handler
import znflow.utils
import znjson

from zntrack import config, converter
Expand All @@ -33,6 +31,21 @@
from zntrack.utils.node_wd import NWDReplaceHandler, nwd


def _dataclass_to_dict(object) -> dict:
"""Convert a dataclass to a dictionary excluding certain keys."""
exclude_fields = [
field.name
for field in dataclasses.fields(object)
if field.metadata.get(FIELD_TYPE)
in [FieldTypes.PARAMS_PATH, FieldTypes.DEPS_PATH]
]
dc_params = dataclasses.asdict(object)
for f in exclude_fields:
dc_params.pop(f)
dc_params["_cls"] = f"{module_handler(object)}.{object.__class__.__name__}"
return dc_params


@dataclasses.dataclass
class DVCPlugin(ZnTrackPlugin):
def getter(self, field: dataclasses.Field) -> t.Any:
Expand Down Expand Up @@ -78,11 +91,7 @@ def convert_to_params_yaml(self) -> dict | object:
# to the params.yaml file to be later used
# by the DataclassContainer to recreate the
# instance with the correct parameters.
dc_params = dataclasses.asdict(val)
dc_params["_cls"] = (
f"{module_handler(val)}.{val.__class__.__name__}"
)
new_content.append(dc_params)
new_content.append(_dataclass_to_dict(val))
elif isinstance(
val, (znflow.Connection, znflow.CombinedConnections)
):
Expand All @@ -97,11 +106,7 @@ def convert_to_params_yaml(self) -> dict | object:
elif dataclasses.is_dataclass(content) and not isinstance(
content, (Node, znflow.Connection, znflow.CombinedConnections)
):
dc_params = dataclasses.asdict(content)
dc_params["_cls"] = (
f"{module_handler(content)}.{content.__class__.__name__}"
)
data[field.name] = dc_params
data[field.name] = _dataclass_to_dict(content)
elif isinstance(content, (znflow.Connection, znflow.CombinedConnections)):
pass
else:
Expand Down Expand Up @@ -253,6 +258,33 @@ def convert_to_dvc_yaml(self) -> dict | object:
)
)
elif dataclasses.is_dataclass(con) and not isinstance(con, Node):
8B50 for field in dataclasses.fields(con):
if field.metadata.get(FIELD_TYPE) == FieldTypes.PARAMS_PATH:
# add the path to the params_path
content = nwd_handler(
get_attr_always_list(con, field.name),
nwd=self.node.nwd,
)
content = [
{pathlib.Path(x).as_posix(): None}
for x in content
if x is not None
]
if len(content) > 0:
stages.setdefault(FieldTypes.PARAMS.value, []).extend(
content
)
if field.metadata.get(FIELD_TYPE) == FieldTypes.DEPS_PATH:
content = [
pathlib.Path(c).as_posix()
for c in get_attr_always_list(con, field.name)
if c is not None
]
if len(content) > 0:
stages.setdefault(FieldTypes.DEPS.value, []).extend(
content
)

# add node name to params.yaml
stages.setdefault(FieldTypes.PARAMS.value, []).append(
self.node.name
Expand Down
Loading
0