8000 Revamp `Conjunction` by ValerianRey · Pull Request #379 · TorchJD/torchjd · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Revamp Conjunction #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed

Revamp Conjunction #379

wants to merge 6 commits into from

Conversation

ValerianRey
Copy link
Contributor
@ValerianRey ValerianRey commented May 31, 2025

This PR completely reimplements Conjunction to make it much more efficient.

First, we can notice that all transforms that the Conjunction holds map to the same output type (_B). We thus don't have to do several costly calls to _least_common_ancestor! We can simply call type(self.transforms[0](tensor_dict)) to get the class of TensorDict that we want the final output to be.

This (the [0] indexing), however, would require having at least one transform. Until now, we also allowed empty conjunctions. We have two choices:

  • Make the output_type be EmptyTensorDict when the Conjunction is empty. This is a bit confusing for the type checker, because when the Conjunction is empty, _B is not very well-defined (in this case, it is, in fact, EmtpyTensorDict, but mypy doesn't seem to infer this).
  • Stop allowing empty Conjunctions. We do not allow mtl_backward with no loss anyway, so there's no way for a user to ever need an empty Conjunction at the moment.

I selected the second choice, because it's much simpler to implement. If we ever really need empty Conjunctions (which I doubt we will, because we can always replace them with a trivial transform returning the EmptyTensorDict), we can always go back on this choice.

Another implementation would have been:

tensor_dicts = [transform(tensor_dict) for transform in self.transforms]
union: dict[Tensor, Tensor] = {}
for td in tensor_dicts:
    union |= td
return type(tensor_dicts[0])(union)

This is shorter, but I'm scared it could use a bit more memory (in fact, since only references should be stored, this is probably not significant at all, so we could arguably use this implementation instead).

Lastly, this fixes another issue that we had in the previous implementation: TensorDicts are supposed to be immutable, but we called |= (the __ior__ method) on EmptyTensorDict. Now, we only call |= on union, which is not a TensorDict (but rather a simple dict[Tensor, Tensor]). The instantiation of the TensorDict is done only at the end, with return output_type(union).

This allows us to assign _raise_immutable_error to TensorDict.__ior__, as we should already have done.

  • Revamp Conjunction.call
  • Remove _least_common_ancestor
  • Disable ior in TensorDict (for immutability)

* Remove _least_common_ancestor
* Disallow empty conjunctions
Copy link
codecov bot commented May 31, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
src/torchjd/_autojac/_transform/_base.py 100.00% <100.00%> (ø)
src/torchjd/_autojac/_transform/_tensor_dict.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@PierreQuinton
Copy link
Contributor

First, we can notice that all transforms that the Conjunction holds map to the same output type (_B). We thus don't have to do several costly calls to _least_common_ancestor! We can simply call type(self.transforms[0](tensor_dict)) to get the class of TensorDict that we want the final output to be.

I think that a Transform[EmptyTensorDict] is a Transform[_B], so if the first transform is a strict sub-type, then this is wrong. In my opinion, TensorDicts are weird now. It feels like we would like to remove the checks, but then they don't do anything except giving information to the developer if the composition (or stack or conjuncted) of Transforms is valid. So I think that they are basically only useful to verify that:

  1. The transform are correct, i.e. they do output expected TDs
  2. The transforms are correctly composed (or stacked or conjuncted).
    It feels like TDs could be annotated types.

@ValerianRey
Copy link
Contributor Author
ValerianRey commented Jun 2, 2025

I think that a Transform[EmptyTensorDict] is a Transform[_B], so if the first transform is a strict sub-type, then this is wrong.

Yes, my bad. This makes this PR not good enough.

It feels like TDs could be annotated types.

I thought about this too, but we still need their check method for our tests (arguably we could implement these checks differently) and more importantly, we can only have one "parent" with typing.Annotated, which wouldn't work with EmptyTensorDict. This could work when type intersection is added to Python, which could be in a long time.

About TensorDicts, I think #380 is good.

About Conjunction, I think we could go back to having some kind of least_common_ancestor function, but optimized to work directly on a list of objects rather than looking at objects two by two.

EDIT: this would work, but it would still be slower than #382 with practically no benefits.

@ValerianRey ValerianRey mentioned this pull request Jun 2, 2025
3 tasks
@ValerianRey
Copy link
Contributor Author

closing in favor of #382

@ValerianRey ValerianRey closed this Jun 3, 2025
@ValerianRey ValerianRey reopened this Jun 3, 2025
@ValerianRey
Copy link
Contributor Author

A fix that I came up with is to just return a TensorDict at the output of Conjunction and use Any as the return type hint.
It's like saying to mypy: this TensorDict will be compatible with everything, don't worry. So in the end this doesn't protect us from wrongly using the output of a conjunct at all, but at least it preserves the genericity of transforms (which make them clearer IMO).

Still not sure whether this is better than #382 or not. It's a bit of a hack (which I really don't like), but #382 makes transforms a bit harder to use IMO (gotta read the docstrings and really understand them rather than relying on the type checker and check_keys to ensure we're not doing something stupid).

@ValerianRey
Copy link
Contributor Author

In the end we do all of this for safety, so that it's easy to make combined transforms (because we're sure we cannot break anything). This would not ensure safety whenever we use a Conjunction, so this feeling of safety (which is what makes the creation of transforms easy) would disappear. So I think it's useless, and #382 is better. #382 is also more open to adding another solution one day that will be more satisfying.

@ValerianRey ValerianRey closed this Jun 3, 2025
@ValerianRey ValerianRey deleted the revamp-conjunction branch June 3, 2025 22:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0