8000 Fast Gradient Clipping and Ghost Clipping by EnayatUllah · Pull Request #656 · pytorch/opacus · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fast Gradient Clipping and Ghost Clipping #656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion opacus/grad_sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,32 @@
from .dp_rnn import compute_rnn_linear_grad_sample # noqa
from .embedding import compute_embedding_grad_sample # noqa
from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample
from .grad_sample_module_fast_gradient_clipping import ( # noqa
GradSampleModuleFastGradientClipping,
)
from .group_norm import compute_group_norm_grad_sample # noqa
from .gsm_base import AbstractGradSampleModule
from .gsm_exp_weights import GradSampleModuleExpandedWeights
from .gsm_no_op import GradSampleModuleNoOp
from .instance_norm import compute_instance_norm_grad_sample # noqa
from .layer_norm import compute_layer_norm_grad_sample # noqa
from .linear import compute_linear_grad_sample # noqa
from .utils import get_gsm_class, register_grad_sampler, wrap_model
from .utils import (
get_gsm_class,
register_grad_sampler,
register_norm_sampler,
wrap_model,
)


__all__ = [
"GradSampleModule",
"GradSampleModuleFastGradientClipping",
"GradSampleModuleExpandedWeights",
"GradSampleModuleNoOp",
"AbstractGradSampleModule",
"register_grad_sampler",
"register_norm_sampler",
"create_or_accumulate_grad_sample",
"wrap_model",
"get_gsm_class",
Expand Down
5 changes: 2 additions & 3 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@


logger = logging.getLogger(__name__)
logger.disabled = True


def create_or_accumulate_grad_sample(
Expand Down Expand Up @@ -465,10 +466,8 @@ def validate(
errors.extend(
[
NotImplementedError(
f"Model contains a trainable layer "
f"Model contains a trainable layer with buffers"
f"that Opacus doesn't currently support({m_name}:{m}). "
f"Please implement and register grad sampler for this layer. "
f"(See opacus.grad_sample.utils.register_grad_sampler)"
)
for m_name, m in trainable_modules(module)
# With functorch, all modules are trainable
Expand Down
222 changes: 222 additions & 0 deletions opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#!/usr/bin/en 8000 v python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from typing import List

import torch
import torch.nn as nn
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
from opacus.grad_sample.grad_sample_module import (
GradSampleModule,
create_or_accumulate_grad_sample,
promote_current_grad_sample,
)
from opacus.utils.module_utils import requires_grad, trainable_parameters


logger = logging.getLogger(__name__)
logger.disabled = True


def create_norm_sample(
*, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int
) -> None:
"""
Creates a ``_norm_sample`` attribute in the given parameter


Args:
param: Parameter to which ``_norm_sample`` will be added
grad_sample: Per-sample gradients tensor. Must be of the same
shape as ``param`` with extra batch dimension
"""

if param.requires_grad:
param._norm_sample = torch.zeros(
torch.Size([max_batch_len, 1]),
device=grad_sample.device,
dtype=grad_sample.dtype,
)
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1)


class GradSampleModuleFastGradientClipping(GradSampleModule):
"""
Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping

Computes norms of gradients without gradient instantiation
"""

NORM_SAMPLERS = {}

def __init__(
self,
m: nn.Module,
*,
batch_first=True,
loss_reduction="mean",
strict: bool = True,
force_functorch=False,
max_grad_norm=1,
use_ghost_clipping=True,
):
"""

Args:
m: nn.Module to be wrapped
batch_first: Flag to indicate if the input tensor to the corresponding module
has the first dimension representing the batch. If set to True, dimensions on
input tensor are expected be ``[batch_size, ...]``, otherwise
``[K, batch_size, ...]``
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values "sum" or "mean"
max_grad_norm: The value at which gradients are to be clipped.
strict: If set to True, the input module will be validated to make sure that
it does not have buffers in all its submodules.
force_functorch: If set to ``True``, will use functorch to compute
all per sample gradients. Otherwise, functorch will be used only
for layers without registered grad sampler methods.
use_ghost_clipping: If set to ``True``, Ghost Clipping
will be used for clipping gradients of supported layers. If ``False``, Fast
Gradient Clipping will be used for all layers.

Raises:
NotImplementedError
If ``strict`` is set to ``True`` and module ``m`` (or any of its
submodules) doesn't have a registered grad sampler function.
"""

super().__init__(
m,
batch_first=batch_first,
loss_reduction=loss_reduction,
)
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping

def get_coeff(self) -> torch.Tensor:
"""Get per-example gradient scaling factor for clipping."""
norm_sample = self.get_norm_sample()
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)

def get_norm_sample(self) -> torch.Tensor:
"""Get per-example gradient norms."""
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
).norm(2, dim=0)
return norm_sample

def capture_activations_hook(
self,
module: nn.Module,
forward_input: List[torch.Tensor],
_forward_output: torch.Tensor,
):
if (
not requires_grad(module)
or not module.training
or not torch.is_grad_enabled()
or not self.hooks_enabled
):
return

if not hasattr(module, "activations"):
module.activations = []
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore

for _, p in trainable_parameters(module):
p._forward_counter += 1
if (
self.use_ghost_clipping
and p._forward_counter > 1
and type(module) in self.NORM_SAMPLERS
):
raise NotImplementedError(
"Parameter tying is not supported with Ghost Clipping"
)

def capture_backprops_hook(
self,
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
):
"""
Computes norms of per sample gradient given the current backprops and activations
stored by the associated forward hook. Computed per sample gradient norms are
stored in ``norm_sample`` field in each parameter.

Args:
module: nn.Module,
_forward_input: torch.Tensor,
forward_output: torch.Tensor,
loss_reduction: str,
batch_first: bool,
"""
if not self.hooks_enabled:
return

backprops = forward_output[0].detach()
activations, backprops = self.rearrange_grad_samples(
module=module,
backprops=backprops,
loss_reduction=loss_reduction,
batch_first=batch_first,
)

if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS:
norm_sampler_fn = self.NORM_SAMPLERS[type(module)]
norm_samples = norm_sampler_fn(module, activations, backprops)

for param, ns in norm_samples.items():
if param.requires_grad:
param._norm_sample = ns
param._forward_counter -= 1

else:
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
else:
grad_sampler_fn = ft_compute_per_sample_gradient

grad_samples = grad_sampler_fn(module, activations, backprops)
for param, gs in grad_samples.items():
create_or_accumulate_grad_sample(
param=param, grad_sample=gs, max_batch_len=module.max_batch_len
)
del grad_samples
# Detect end of current batch processing and switch accumulation
# mode from sum to stacking. Used for RNNs and tied parameters
# (See #417 for details)
for _, p in trainable_parameters(module):
p._forward_counter -= 1
if p._forward_counter == 0:
promote_current_grad_sample(p)
create_norm_sample(
param=p,
grad_sample=p.grad_sample,
max_batch_len=module.max_batch_len,
)
del p.grad_sample

if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len
46 changes: 45 additions & 1 deletion opacus/grad_sample/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, List

import torch
import torch.nn as nn
from opt_einsum import contract

from .utils import register_grad_sampler
from .utils import register_grad_sampler, register_norm_sampler


logger = logging.getLogger(__name__)
logging.disabled = False


@register_grad_sampler(nn.Linear)
Expand All @@ -42,3 +47,42 @@ def compute_linear_grad_sample(
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = contract("n...k->nk", backprops)
return ret


@register_norm_sampler(nn.Linear)
def compute_linear_norm_sample(
layer: nn.Linear, activations: List[torch.Tensor], backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
"""
Computes per sample gradient norms for ``nn.Linear`` layer

Args:
layer: Layer
activations: Activations
backprops: Backpropagations
"""
activations = activations[0]
ret = {}

if backprops.dim() == 2:
if layer.weight.requires_grad:
g = contract("n...i,n...i->n", backprops, backprops)
a = contract("n...j,n...j->n", activations, activations)
ret[layer.weight] = torch.sqrt((g * a).flatten())
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = torch.sqrt(
contract("n...i,n...i->n", backprops, backprops).flatten()
)
elif backprops.dim() == 3:
if layer.weight.requires_grad:

ggT = contract("nik,njk->nij", backprops, backprops) # batchwise g g^T
aaT = contract("nik,njk->nij", activations, activations) # batchwise a a^T
ga = contract("n...i,n...i->n", ggT, aaT).clamp(min=0)

ret[layer.weight] = torch.sqrt(ga)
if layer.bias is not None and layer.bias.requires_grad:
ggT = contract("nik,njk->nij", backprops, backprops)
gg = contract("n...i,n...i->n", ggT, ggT).clamp(min=0)
ret[layer.bias] = torch.sqrt(gg)
return ret
34 changes: 33 additions & 1 deletion opacus/grad_sample/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
# !/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,6 +18,9 @@
import torch.nn as nn

from .grad_sample_module import GradSampleModule
from .grad_sample_module_fast_gradient_clipping import (
GradSampleModuleFastGradientClipping,
)
from .gsm_base import AbstractGradSampleModule
from .gsm_exp_weights import GradSampleModuleExpandedWeights
from .gsm_no_op import GradSampleModuleNoOp
Expand Down Expand Up @@ -46,6 +49,33 @@ def decorator(f):
)
for target_class in target_classes:
GradSampleModule.GRAD_SAMPLERS[target_class] = f
GradSampleModuleFastGradientClipping.GRAD_SAMPLERS[target_class] = f
return f

return decorator


def register_norm_sampler(
target_class_or_classes: Union[Type[nn.Module], Sequence[Type[nn.Module]]]
):
"""
Registers the decorated function as the ``norm_sampler`` of ``target_class_or_classes``, which is
the function that will be invoked every time you want to compute a per-sample gradient norm
of ``target_class_or_classes``. The signature of every norm_sampler is always the same:

>>> @register_norm_sampler(MyCustomModel)
... def compute_grad_norm_sample(module, activations, backprops):
... pass
"""

def decorator(f):
target_classes = (
target_class_or_classes
if isinstance(target_class_or_classes, Sequence)
else [target_class_or_classes]
)
for target_class in target_classes:
GradSampleModuleFastGradientClipping.NORM_SAMPLERS[target_class] = f
return f

return decorator
Expand All @@ -70,6 +100,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
return GradSampleModule
elif grad_sample_mode == "ew":
return GradSampleModuleExpandedWeights
elif grad_sample_mode == "ghost":
return GradSampleModuleFastGradientClipping
elif 5A3A grad_sample_mode == "no_op":
return GradSampleModuleNoOp
else:
Expand Down
Loading
Loading
0