8000 Revamp `Conjunction` by ValerianRey · Pull Request #379 · TorchJD/torchjd · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Revamp Conjunction #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/torchjd/_autojac/_transform/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Generic
from typing import Any, Generic

from torch import Tensor

from ._tensor_dict import _A, _B, _C, EmptyTensorDict, _least_common_ancestor
from ._tensor_dict import _A, _B, _C, TensorDict


class RequirementError(ValueError):
Expand Down Expand Up @@ -81,7 +81,8 @@ class Conjunction(Transform[_B, _C]):
Transform applying several transforms to the same input, and combining the results (by union)
into a single TensorDict.

:param transforms: The transforms to apply. Their outputs should have disjoint sets of keys.
:param transforms: The non-empty sequence of transforms to apply. Their outputs should have
disjoint sets of keys.
"""

def __init__(self, transforms: Sequence[Transform[_B, _C]]):
Expand All @@ -97,14 +98,15 @@ def __str__(self) -> str:
strings.append(s)
return "(" + " | ".join(strings) + ")"

def __call__(self, tensor_dict: _B) -> _C:
tensor_dicts = [transform(tensor_dict) for transform in self.transforms]
output_type: type[_B] = EmptyTensorDict
output: _B = EmptyTensorDict()
for tensor_dict in tensor_dicts:
output_type = _least_common_ancestor(output_type, type(tensor_dict))
output |= tensor_dict
return output_type(output)
def __call__(self, tensor_dict: _B) -> Any:
# Technically, we should infer what _C is (it might not just be type(self.transforms[0]))
# and instantiate _C(union), but this is too costly so we just return a TensorDict and use
# the `Any` return type hint to indicate to the type checker to be lenient here.

union: dict[Tensor, Tensor] = {}
for transform in self.transforms:
union |= transform(tensor_dict)
return TensorDict(union)

def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]
Expand Down
13 changes: 2 additions & 11 deletions src/torchjd/_autojac/_transform/_tensor_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, cast
from typing import TypeVar

from torch import Tensor

Expand Down Expand Up @@ -37,6 +37,7 @@ def _check_key_value_pair(key: Tensor, value: Tensor) -> None:
def _raise_immutable_error(self, *args, **kwargs) -> None:
raise TypeError(f"{self.__class__.__name__} is immutable.")

__ior__ = _raise_immutable_error # type: ignore[assignment]
__setitem__ = _raise_immutable_error
__delitem__ = _raise_immutable_error
clear = _raise_immutable_error
Expand Down Expand Up @@ -128,16 +129,6 @@ def __init__(self, tensor_dict: dict[Tensor, Tensor] | None = None):
super().__init__({})


def _least_common_ancestor(first: type[TensorDict], second: type[TensorDict]) -> type[TensorDict]:
first_mro = first.mro()[:-1] # removes `object` from `mro`.
output = TensorDict
for candidate_type in first_mro:
if issubclass(second, candidate_type):
output = cast(type[TensorDict], candidate_type)
break
return output


def _check_values_have_unique_first_dim(tensor_dict: dict[Tensor, Tensor]) -> None:
first_dims = [value.shape[0] for value in tensor_dict.values()]
if len(set(first_dims)) > 1:
Expand Down
15 changes: 2 additions & 13 deletions tests/unit/autojac/_transform/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from pytest import raises
from torch import Tensor

from torchjd._autojac._transform._base import Conjunction, RequirementError, Transform
from torchjd._autojac._transform._tensor_dict import _B, _C, TensorDict
from torchjd._autojac._transform._base import RequirementError, Transform
from torchjd._autojac._transform._tensor_dict import _B, _C


class FakeTransform(Transform[_B, _C]):
Expand Down Expand Up @@ -102,17 +102,6 @@ def test_conjunct_check_keys_2():
(t1 | t2 | t3).check_keys(set())


def test_empty_conjunction():
"""
Tests that it is possible to take the conjunction of no transform. This should return an empty
dictionary.
"""

conjunction = Conjunction([])

assert len(conjunction(TensorDict({}))) == 0


def test_str():
"""
Tests that the __str__ method works correctly even for transform involving compositions and
Expand Down
20 changes: 0 additions & 20 deletions tests/unit/autojac/_transform/test_tensor_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Jacobians,
TensorDict,
)
from torchjd._autojac._transform._tensor_dict import _least_common_ancestor

_key_shapes = [[], [1], [2, 3]]

Expand Down Expand Up @@ -80,25 +79,6 @@ def test_jacobian_matrices(value_shapes: list[list[int]], expectation: Exception
_assert_class_checks_properly(JacobianMatrices, value_shapes, expectation)


@mark.parametrize(
["first", "second", "result"],
[
(EmptyTensorDict, EmptyTensorDict, EmptyTensorDict),
(EmptyTensorDict, Jacobians, Jacobians),
(Jacobians, EmptyTensorDict, Jacobians),
(Jacobians, Jacobians, Jacobians),
(EmptyTensorDict, Gradients, Gradients),
(EmptyTensorDict, GradientVectors, GradientVectors),
(EmptyTensorDict, JacobianMatrices, JacobianMatrices),
(GradientVectors, JacobianMatrices, TensorDict),
],
)
def test_least_common_ancestor(
first: type[TensorDict], second: type[TensorDict], result: type[TensorDict]
):
assert _least_common_ancestor(first, second) == result


def _assert_class_checks_properly(
class_: type[TensorDict], value_shapes: list[list[int]], expectation: ExceptionContext
):
Expand Down
Loading
0