8000 Support mixed precision training · Issue #740 · pytorch/opacus · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support mixed precision training #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
shuangwu5 opened this issue Mar 6, 2025 · 1 comment
Closed

Support mixed precision training #740

shuangwu5 opened this issue Mar 6, 2025 · 1 comment

Comments

@shuangwu5
Copy link

🚀 Feature

Support training differential privacy models with mixed precision.

Motivation

Mixed precision training is a commonly used technique for training large-scale deep learning models, for example, fine-tuning LLMs. But when one tries to integrate this kind of training loops with opacus, an error would occur:

RuntimeError: expected scalar type BFloat16 but found Float

This is due to the dtype mismatch when computing per-sample gradients.

Pitch

During mixed precision training, some activations will be in half precision (e.g., bfloat16) but backprops are in full precision.
This could be fixed by upcasting the tensors to float32 to ensure that their dtypes are consistent.
Here is an example of the adjustment for the linear layer:

diff --git a/opacus/grad_sample/linear.py b/opacus/grad_sample/linear.py
index 1b30f94..e22ab0f 100644
--- a/opacus/grad_sample/linear.py
+++ b/opacus/grad_sample/linear.py
@@ -41,10 +41,10 @@ def compute_linear_grad_sample(
     activations = activations[0]
     ret = {}
     if layer.weight.requires_grad:
-        gs = torch.einsum("n...i,n...j->nij", backprops, activations)
+        gs = torch.einsum("n...i,n...j->nij", backprops.float(), activations.float())
         ret[layer.weight] = gs
     if layer.bias is not None and layer.bias.requires_grad:
-        ret[layer.bias] = torch.einsum("n...k->nk", backprops)
+        ret[layer.bias] = torch.einsum("n...k->nk", backprops.float())
     return ret

Alternatives

N/A

Additional context

This was the workaround that our team has applied in our project (here).
But it would be great if the feature can be officially supported by opacus in the near future.

@HuanyuZhang
Copy link
Contributor

Thanks to @shuangwu5 for contributing to Opacus. Yeah, mixed-precision training is on our radar but right now we do not have a clear plan when we will support it. In the official support, we need to investigate and solve the potential stability issue introduced by mix-precision calculation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
0