8000 Adding fsdp fp16 and bf16 hooks by aovladi · Pull Request #80557 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Adding fsdp fp16 and bf16 hooks #80557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed

Conversation

aovladi
Copy link
@aovladi aovladi commented Jun 29, 2022

Recently, register_comm_hook was introduced to FSDP, which at the moment supports only NO_SHARD strategy and has a default all_reduce hook implemented. This PR adds two lower precision hooks to an existing default hook.

I've also made slight adjustments to existing implementation of an all_reduce hook including:

  • AllReduceState -> DefaultState , motivation: AllReduceState is not specific to all_reduce. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook.
  • I've put all 3 hooks into default_hooks.py

Additionally, FSDP supports MixedPrecision and, theoretically, it is possible to specify MixedPrecision for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to fully_sharded_data_parallel, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and MixedPrecision(reduce_dtype=<precision>) specified, but I am happy to discuss this and adjust current implementation.

As a test, I create two models: one with a lower precision hook and one with a MixedPrecision(reduce_dtype=<precision>) specified, perform one forward/backward and optimizer step and compare gradients.

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Jun 29, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit b9ae5a8 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 29, 2022
@aovladi aovladi marked this pull request as ready for review June 29, 2022 19:28
Copy link
Member
@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! A couple of suggestions

optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)
optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)

in_data = torch.rand(16, 8).cuda()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: torch.rand(16, 8, device='cuda')

fsdp_with_hook.train()
fsdp_with_mp.train()
optim_mp.zero_grad()
optim_hook.zero_grad()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think zero_grad does anything right, since we haven't accumulated any grads yet?


dist.barrier()

for hook_params, mp_params in zip(fsdp_with_hook.parameters(), fsdp_with_mp.parameters()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: hook_param, mp_param since it is a single param

if state.gradient_postdivide_factor > 1:
grad.div_(state.gradient_postdivide_factor)

def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can consolidate fp16 and bf16 hooks, such as the following:

def reduced_precision_hook(prec, state, grad):
    grad.data = grad.data.to(prec)
    allreduce_hook(state, grad)
    _decompress(state, grad)
def fp16_compress_hook(state, grad):
    fp16_hook = functools.partial(reduced_precision_hook, prec=torch.float16)
    return fp16_hook(state, grad)

and same for bfloat

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example!

process_group (ProcessGroup): The process group to be used for all-reduce.
world_size (int): The number of workers in a process group.
Determined based on a ``process_group``.
gradient_predivide_factor (float): A factor for gradients' pre-division.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess these aren't actually args into DefaultState, so we can remove them from docs?

if (
self._mixed_precision_enabled_for_params() or self._mixed_precision_enabled_for_reduce()
not self._low_precision_hook_enabled() and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about how to refactor this, such as, if reduce_dtype is available, then just register a reduced precision communication hook, but that requires a bit more flushing out. Might be good to file a follow up issue.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create a follow-up issue after this PR is landed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check not self._low_precision_hook_enabled in L2980, so this check should always be true within right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma right, thanks for catching this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma after looking into this: first check on now L3024 is for pre-casting gradients for all cases. This line is for re-casting gradients back into full-precision for NO_SHARD. This does not belong to the body of L3023 if, so this check is still needed. Otherwise, comm_hook will cast grads back and will occupy memory for orig_param_grad_data and after comm_hook is done, without this check here, it will attempt a second re-casting

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

def __init__(
self,
process_group,
parameter_type,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just default this to torch.float32, which is the usual case so user doesn't have to specify everytime?

@aovladi
Copy link
Author
aovladi commented Jul 12, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased addons/fsdp_low_precision_hooks onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout addons/fsdp_low_precision_hooks && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the addons/fsdp_low_precision_hooks branch from 6e003f0 to c208bef Compare July 12, 2022 19:32
@aovladi
Copy link
Author
aovladi commented Jul 14, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased addons/fsdp_low_precision_hooks onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout addons/fsdp_low_precision_hooks && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the addons/fsdp_low_precision_hooks branch from 51e01a4 to 386d614 Compare July 14, 2022 03:44
Copy link
Member
@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one non-minor comment, pls take a look at it

state (DefaultState): State information, configures pre- and post-division factors
grad (torch.Tensor): A gradient for the local batch that needs to be communicated across ranks in a lower precision.
"""
bp16_hook = functools.partial(lower_precision_hook, torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: bf16_hook

if (
self._mixed_precision_enabled_for_params() or self._mixed_precision_enabled_for_reduce()
not self._low_precision_hook_enabled() and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check not self._low_precision_hook_enabled in L2980, so this check should always be true within right?

@aovladi
Copy link
Author
aovladi commented Jul 15, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased addons/fsdp_low_precision_hooks onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout addons/fsdp_low_precision_hooks && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the addons/fsdp_low_precision_hooks branch from 386d614 to 65d2675 Compare July 15, 2022 23:30
@aovladi
Copy link
Author
aovladi commented Jul 18, 2022

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased addons/fsdp_low_precision_hooks onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout addons/fsdp_low_precision_hooks && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the addons/fsdp_low_precision_hooks branch from 1b0ad0e to b9ae5a8 Compare July 18, 2022 17:54
@aovladi
Copy link
Author
aovladi commented Jul 18, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @aovladi.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@aovladi
Copy link
Author
aovladi commented Jul 19, 2022

@pytorchbot revert -m "broke distributed tests on trunk"

@pytorch-bot
Copy link
pytorch-bot bot commented Jul 19, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@aovladi
Copy link
Author
aovladi commented Jul 19, 2022

@pytorchbot revert -m "broke distributed tests on trunk" -c weird

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

@aovladi your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jul 19, 2022
This reverts commit f7d6828.

Reverted #80557 on behalf of https://github.com/aovladi due to broke distributed tests on trunk
@github-actions github-actions bot deleted the addons/fsdp_low_precision_hooks branch February 18, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0