-
Notifications
You must be signed in to change notification settings - Fork 7
Revamp Conjunction
#379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revamp Conjunction
#379
Conversation
* Remove _least_common_ancestor * Disallow empty conjunctions
Codecov ReportAll modified and coverable lines are covered by tests ✅
🚀 New features to boost your workflow:
|
I think that a
|
Yes, my bad. This makes this PR not good enough.
I thought about this too, but we still need their check method for our tests (arguably we could implement these checks differently) and more importantly, we can only have one "parent" with About TensorDicts, I think #380 is good. About Conjunction, I think we could go back to having some kind of EDIT: this would work, but it would still be slower than #382 with practically no benefits. |
closing in favor of #382 |
A fix that I came up with is to just return a TensorDict at the output of Conjunction and use Any as the return type hint. Still not sure whether this is better than #382 or not. It's a bit of a hack (which I really don't like), but #382 makes transforms a bit harder to use IMO (gotta read the docstrings and really understand them rather than relying on the type checker and check_keys to ensure we're not doing something stupid). |
In the end we do all of this for safety, so that it's easy to make combined transforms (because we're sure we cannot break anything). This would not ensure safety whenever we use a Conjunction, so this feeling of safety (which is what makes the creation of transforms easy) would disappear. So I think it's useless, and #382 is better. #382 is also more open to adding another solution one day that will be more satisfying. |
This PR completely reimplements
Conjunction
to make it much more efficient.First, we can notice that all transforms that the
Conjunction
holds map to the same output type (_B
). We thus don't have to do several costly calls to_least_common_ancestor
! We can simply calltype(self.transforms[0](tensor_dict))
to get the class of TensorDict that we want the final output to be.This (the
[0]
indexing), however, would require having at least one transform. Until now, we also allowed empty conjunctions. We have two choices:output_type
beEmptyTensorDict
when theConjunction
is empty. This is a bit confusing for the type checker, because when theConjunction
is empty,_B
is not very well-defined (in this case, it is, in fact,EmtpyTensorDict
, but mypy doesn't seem to infer this).Conjunctions
. We do not allowmtl_backward
with no loss anyway, so there's no way for a user to ever need an emptyConjunction
at the moment.I selected the second choice, because it's much simpler to implement. If we ever really need empty
Conjunctions
(which I doubt we will, because we can always replace them with a trivial transform returning theEmptyTensorDict
), we can always go back on this choice.Another implementation would have been:
This is shorter, but I'm scared it could use a bit more memory (in fact, since only references should be stored, this is probably not significant at all, so we could arguably use this implementation instead).
Lastly, this fixes another issue that we had in the previous implementation:
TensorDict
s are supposed to be immutable, but we called|=
(the__ior__
method) onEmptyTensorDict
. Now, we only call|=
onunion
, which is not aTensorDict
(but rather a simpledict[Tensor, Tensor]
). The instantiation of theTensorDict
is done only at the end, withreturn output_type(union)
.This allows us to assign
_raise_immutable_error
toTensorDict.__ior__
, as we should already have done.