You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support training differential privacy models with mixed precision.
Motivation
Mixed precision training is a commonly used technique for training large-scale deep learning models, for example, fine-tuning LLMs. But when one tries to integrate this kind of training loops with opacus, an error would occur:
RuntimeError: expected scalar type BFloat16 but found Float
This is due to the dtype mismatch when computing per-sample gradients.
Pitch
During mixed precision training, some activations will be in half precision (e.g., bfloat16) but backprops are in full precision.
This could be fixed by upcasting the tensors to float32 to ensure that their dtypes are consistent.
Here is an example of the adjustment for the linear layer:
diff --git a/opacus/grad_sample/linear.py b/opacus/grad_sample/linear.py
index 1b30f94..e22ab0f 100644
--- a/opacus/grad_sample/linear.py+++ b/opacus/grad_sample/linear.py@@ -41,10 +41,10 @@ def compute_linear_grad_sample(
activations = activations[0]
ret = {}
if layer.weight.requires_grad:
- gs = torch.einsum("n...i,n...j->nij", backprops, activations)+ gs = torch.einsum("n...i,n...j->nij", backprops.float(), activations.float())
ret[layer.weight] = gs
if layer.bias is not None and layer.bias.requires_grad:
- ret[layer.bias] = torch.einsum("n...k->nk", backprops)+ ret[layer.bias] = torch.einsum("n...k->nk", backprops.float())
return ret
Alternatives
N/A
Additional context
This was the workaround that our team has applied in our project (here).
But it would be great if the feature can be officially supported by opacus in the near future.
The text was updated successfully, but these errors were encountered:
Thanks to @shuangwu5 for contributing to Opacus. Yeah, mixed-precision training is on our radar but right now we do not have a clear plan when we will support it. In the official support, we need to investigate and solve the potential stability issue introduced by mix-precision calculation.
🚀 Feature
Support training differential privacy models with mixed precision.
Motivation
Mixed precision training is a commonly used technique for training large-scale deep learning models, for example, fine-tuning LLMs. But when one tries to integrate this kind of training loops with opacus, an error would occur:
This is due to the dtype mismatch when computing per-sample gradients.
Pitch
During mixed precision training, some activations will be in half precision (e.g.,
bfloat16
) but backprops are in full precision.This could be fixed by upcasting the tensors to
float32
to ensure that their dtypes are consistent.Here is an example of the adjustment for the linear layer:
Alternatives
N/A
Additional context
This was the workaround that our team has applied in our project (here).
But it would be great if the feature can be officially supported by opacus in the near future.
The text was updated successfully, but these errors were encountered: