diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index c86f28c6..8e23b9b3 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -47,12 +47,21 @@ def create_norm_sample( """ if param.requires_grad: - param._norm_sample = torch.zeros( - torch.Size([max_batch_len, 1]), - device=grad_sample.device, - dtype=grad_sample.dtype, - ) - param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1) + if ( + max_batch_len == 0 + ): # To handle the case of empty batch that may arise from Poisson sampling + param._norm_sample = torch.tensor( + [], device=grad_sample.device, dtype=grad_sample.dtype + ) + else: + param._norm_sample = torch.zeros( + torch.Size([max_batch_len, 1]), + device=grad_sample.device, + dtype=grad_sample.dtype, + ) + param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm( + 2, dim=-1 + ) class GradSampleModuleFastGradientClipping(GradSampleModule): @@ -110,7 +119,7 @@ def __init__( self.max_grad_norm = max_grad_norm self.use_ghost_clipping = use_ghost_clipping - def get_coeff(self) -> torch.Tensor: + def get_clipping_coef(self) -> torch.Tensor: """Get per-example gradient scaling factor for clipping.""" norm_sample = self.get_norm_sample() return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0) @@ -175,6 +184,7 @@ def capture_backprops_hook( return backprops = forward_output[0].detach() + activations, backprops = self.rearrange_grad_samples( module=module, backprops=backprops, @@ -216,7 +226,6 @@ def capture_backprops_hook( max_batch_len=module.max_batch_len, ) del p.grad_sample - if len(module.activations) == 0: if hasattr(module, "max_batch_len"): del module.max_batch_len diff --git a/opacus/optimizers/__init__.py b/opacus/optimizers/__init__.py index 5867e127..88f79a8d 100644 --- a/opacus/optimizers/__init__.py +++ b/opacus/optimizers/__init__.py @@ -39,12 +39,17 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None): - if clipping == "flat" and distributed is False: + if grad_sample_mode == "ghost": + if clipping == "flat" and distributed is False: + return DPOptimizerFastGradientClipping + elif clipping == "flat" and distributed is True: + return DistributedDPOptimizerFastGradientClipping + else: + raise ValueError( + f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}" + ) + elif clipping == "flat" and distributed is False: return DPOptimizer - elif clipping == "ghost" and distributed is False: - return DPOptimizerFastGradientClipping - elif clipping == "ghost" and distributed is True: - return DistributedDPOptimizerFastGradientClipping elif clipping == "flat" and distributed is True: return DistributedDPOptimizer elif clipping == "per_layer" and distributed is False: diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index e0789a0e..1af891c4 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -30,6 +30,7 @@ ) from opacus.optimizers import DPOptimizer, get_optimizer_class from opacus.schedulers import _GradClipScheduler, _NoiseScheduler +from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping from opacus.validators.module_validator import ModuleValidator from torch import nn, optim from torch.nn.parallel import DistributedDataParallel as DDP @@ -277,6 +278,7 @@ def make_private( *, module: nn.Module, optimizer: optim.Optimizer, + criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility data_loader: DataLoader, noise_multiplier: float, max_grad_norm: Union[float, List[float]], @@ -400,6 +402,11 @@ def make_private( optimizer.attach_step_hook( self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate) ) + if grad_sample_mode == "ghost": + criterion = DPLossFastGradientClipping( + module, optimizer, criterion, loss_reduction + ) + return module, optimizer, criterion, data_loader return module, optimizer, data_loader @@ -408,6 +415,7 @@ def make_private_with_epsilon( *, module: nn.Module, optimizer: optim.Optimizer, + criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility data_loader: DataLoader, target_epsilon: float, target_delta: float, @@ -487,6 +495,7 @@ def make_private_with_epsilon( module=module, optimizer=optimizer, data_loader=data_loader, + criterion=criterion, noise_multiplier=get_noise_multiplier( target_epsilon=target_epsilon, target_delta=target_delta, diff --git a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py index 394db742..8029c8f4 100644 --- a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py +++ b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py @@ -22,7 +22,7 @@ from hypothesis import given, settings from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping -from opacus.utils.fast_gradient_clipping_utils import double_backward +from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping from opacus.utils.per_sample_gradients_utils import clone_module from torch.utils.data import DataLoader, Dataset @@ -146,7 +146,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): (input_data, target_data) = list(self.dl)[0] optimizer_normal.zero_grad() output_normal = self.model_normal(input_data) - loss_normal = torch.mean(self.criterion(output_normal, target_data)) + loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0) loss_normal.backward() all_norms_normal = torch.stack( [ @@ -165,7 +165,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): first_loss.backward(retain_graph=True) optimizer_gc.zero_grad() - coeff = self.grad_sample_module.get_coeff() + coeff = self.grad_sample_module.get_clipping_coef() second_loss_per_sample = coeff * first_loss_per_sample second_loss = torch.sum(second_loss_per_sample) self.grad_sample_module.disable_hooks() @@ -190,7 +190,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): @settings(deadline=1000000) def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): """ - Tests if gradients are the same between standard (opacus) and fast gradient clipping, using double_backward function" + Tests if gradients are the same between standard (opacus) and fast gradient clipping" """ noise_multiplier = 0.0 @@ -200,7 +200,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): self.dim = dim self.setUp_data_sequantial(self.size, self.length, self.dim) max_grad_norm = 1.0 - self.criterion = torch.nn.CrossEntropyLoss(reduction="none") + self.criterion = torch.nn.CrossEntropyLoss() sample_module = SampleModule() self.model_normal = GradSampleModule(clone_module(sample_module)) @@ -226,10 +226,14 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): expected_batch_size=batch_size, ) + criterion_gc = DPLossFastGradientClipping( + self.grad_sample_module, optimizer_gc, self.criterion + ) + (input_data, target_data) = list(self.dl)[0] optimizer_normal.zero_grad() output_normal = self.model_normal(input_data) - loss_normal = torch.mean(self.criterion(output_normal, target_data)) + loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0) loss_normal.backward() optimizer_normal.step() @@ -240,8 +244,9 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): output_gc = self.grad_sample_module(input_data) - first_loss_per_sample = self.criterion(output_gc, target_data) - double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample) + loss_gc = criterion_gc(output_gc, target_data) + loss_gc.backward() + # double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample) all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()] flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc]) diff --git a/opacus/tests/multigpu_gradcheck.py b/opacus/tests/multigpu_gradcheck.py index f63b15de..6242d8e1 100644 --- a/opacus/tests/multigpu_gradcheck.py +++ b/opacus/tests/multigpu_gradcheck.py @@ -101,7 +101,7 @@ def run_ghost_clipping_test( loss_per_sample = loss_fn(outputs, y) torch.mean(loss_per_sample).backward(retain_graph=True) optimizer.zero_grad() - rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample + rescaled_loss_per_sample = ddp_model.get_clipping_coef() * loss_per_sample rescaled_loss = torch.sum(rescaled_loss_per_sample) ddp_model.disable_hooks() rescaled_loss.backward() diff --git a/opacus/utils/fast_gradient_clipping_utils.py b/opacus/utils/fast_gradient_clipping_utils.py index 367f2495..290341e1 100644 --- a/opacus/utils/fast_gradient_clipping_utils.py +++ b/opacus/utils/fast_gradient_clipping_utils.py @@ -20,28 +20,99 @@ from opacus.optimizers import DPOptimizerFastGradientClipping -def double_backward( - module: GradSampleModuleFastGradientClipping, - optimizer: DPOptimizerFastGradientClipping, - loss_per_sample: torch.Tensor, -) -> None: +class DPTensorFastGradientClipping: """ - Packages the training loop for Fast Gradient and Ghost Clipping. It does the two backward passes, as well as the loss rescaling and hook operations in between. - This function also works with DistributedDPOptimizer. + Packages the training loop for Fast Gradient and Ghost Clipping into loss.backward(). + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + loss_per_sample: torch.Tensor, + loss_reduction: str = "mean", + ): + """ + + Args: + module: the module to train + optimizer: the optimizer used to train the module + loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] + + """ + + self.module = module + self.optimizer = optimizer + self.loss_per_sample = loss_per_sample + self.loss_reduction = loss_reduction + + def item(self): + if self.loss_reduction == "mean": + return torch.mean(self.loss_per_sample).detach().item() + elif self.loss_reduction == "sum": + return torch.sum(self.loss_per_sample).detach().item() - Args: - module: The DP gradient sample module to train - optimizer: The DP optimizer used to train the module - loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] + def backward(self): + """ + Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between + """ - Returns: - None + if self.loss_reduction == "mean": + reduced_loss = torch.mean(self.loss_per_sample, dim=0) + elif self.loss_reduction == "sum": + reduced_loss = torch.sum(self.loss_per_sample, dim=0) + else: + raise ValueError( + f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" + ) + reduced_loss.backward(retain_graph=True) + self.optimizer.zero_grad() + coeff = self.module.get_clipping_coef() + second_loss_per_sample = coeff * self.loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + self.module.disable_hooks() + second_loss.backward() + self.module.enable_hooks() + + +class DPLossFastGradientClipping: + """ + Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping. """ - torch.mean(loss_per_sample).backward(retain_graph=True) - optimizer.zero_grad() - rescaled_loss_per_sample = module.get_coeff() * loss_per_sample - rescaled_loss = torch.sum(rescaled_loss_per_sample) - module.disable_hooks() - rescaled_loss.backward() - module.enable_hooks() + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + criterion, + loss_reduction: str = "mean", + ): + assert loss_reduction in [ + "mean", + "sum", + ], "loss_reduction should be either 'mean' or 'sum'" + assert ( + loss_reduction + == criterion.reduction + == module.loss_reduction + == optimizer.loss_reduction + ), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction" + + self.optimizer = optimizer + self.module = module + self.criterion = criterion + self.loss_reduction = loss_reduction + self.criterion.reduction = "none" + + def __call__(self, input, target) -> DPTensorFastGradientClipping: + """ + Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping + """ + + loss_per_sample = self.criterion( + input, + target, + ) + return DPTensorFastGradientClipping( + self.module, self.optimizer, loss_per_sample, self.loss_reduction + )