-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Adding fsdp fp16 and bf16 hooks #80557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit b9ae5a8 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! A couple of suggestions
optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) | ||
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1) | ||
|
||
in_data = torch.rand(16, 8).cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: torch.rand(16, 8, device='cuda')
fsdp_with_hook.train() | ||
fsdp_with_mp.train() | ||
optim_mp.zero_grad() | ||
optim_hook.zero_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think zero_grad does anything right, since we haven't accumulated any grads yet?
|
||
dist.barrier() | ||
|
||
for hook_params, mp_params in zip(fsdp_with_hook.parameters(), fsdp_with_mp.parameters()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: hook_param, mp_param
since it is a single param
if state.gradient_postdivide_factor > 1: | ||
grad.div_(state.gradient_postdivide_factor) | ||
|
||
def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can consolidate fp16 and bf16 hooks, such as the following:
def reduced_precision_hook(prec, state, grad):
grad.data = grad.data.to(prec)
allreduce_hook(state, grad)
_decompress(state, grad)
def fp16_compress_hook(state, grad):
fp16_hook = functools.partial(reduced_precision_hook, prec=torch.float16)
return fp16_hook(state, grad)
and same for bfloat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the example!
process_group (ProcessGroup): The process group to be used for all-reduce. | ||
world_size (int): The number of workers in a process group. | ||
Determined based on a ``process_group``. | ||
gradient_predivide_factor (float): A factor for gradients' pre-division. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess these aren't actually args into DefaultState, so we can remove them from docs?
if ( | ||
self._mixed_precision_enabled_for_params() or self._mixed_precision_enabled_for_reduce() | ||
not self._low_precision_hook_enabled() and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about how to refactor this, such as, if reduce_dtype is available, then just register a reduced precision communication hook, but that requires a bit more flushing out. Might be good to file a follow up issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will create a follow-up issue after this PR is landed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check not self._low_precision_hook_enabled
in L2980, so this check should always be true within right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rohan-varma right, thanks for catching this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rohan-varma after looking into this: first check on now L3024 is for pre-casting gradients for all cases. This line is for re-casting gradients back into full-precision for NO_SHARD
. This does not belong to the body of L3023 if, so this check is still needed. Otherwise, comm_hook
will cast grads back and will occupy memory for orig_param_grad_data
and after comm_hook is done, without this check here, it will attempt a second re-casting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good
def __init__( | ||
self, | ||
process_group, | ||
parameter_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just default this to torch.float32, which is the usual case so user doesn't have to specify everytime?
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
6e003f0
to
c208bef
Compare
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
51e01a4
to
386d614
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, one non-minor comment, pls take a look at it
state (DefaultState): State information, configures pre- and post-division factors | ||
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision. | ||
""" | ||
bp16_hook = functools.partial(lower_precision_hook, torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: bf16_hook
if ( | ||
self._mixed_precision_enabled_for_params() or self._mixed_precision_enabled_for_reduce() | ||
not self._low_precision_hook_enabled() and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check not self._low_precision_hook_enabled
in L2980, so this check should always be true within right?
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
386d614
to
65d2675
Compare
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
1b0ad0e
to
b9ae5a8
Compare
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @aovladi. |
@pytorchbot revert -m "broke distributed tests on trunk" |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot revert -m "broke distributed tests on trunk" -c weird |
@pytorchbot successfully started a revert job. Check the current status here |
@aovladi your PR has been successfully reverted. |
This reverts commit f7d6828. Reverted #80557 on behalf of https://github.com/aovladi due to broke distributed tests on trunk
Recently,
register_comm_hook
was introduced toFSDP
, which at the moment supports onlyNO_SHARD
strategy and has a defaultall_reduce
hook implemented. This PR adds two lower precision hooks to an existing default hook.I've also made slight adjustments to existing implementation of an
all_reduce
hook including:AllReduceState
->DefaultState
, motivation:AllReduceState
is not specific toall_reduce
. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook.default_hooks.py
Additionally,
FSDP
supportsMixedPrecision
and, theoretically, it is possible to specifyMixedPrecision
for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks tofully_sharded_data_parallel
, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook andMixedPrecision(reduce_dtype=<precision>)
specified, but I am happy to discuss this and adjust current implementation.As a test, I create two models: one with a lower precision hook and one with a
MixedPrecision(reduce_dtype=<precision>)
specified, perform one forward/backward and optimizer step and compare gradients.