diff --git a/research/disk_optimizer/KFprivacy_engine.py b/research/disk_optimizer/KFprivacy_engine.py index c890c224..d70e8267 100644 --- a/research/disk_optimizer/KFprivacy_engine.py +++ b/research/disk_optimizer/KFprivacy_engine.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import List, Union from opacus.optimizers import DPOptimizer diff --git a/research/disk_optimizer/ReadMe.md b/research/disk_optimizer/ReadMe.md index 7c25bda1..8cf2bfcf 100644 --- a/research/disk_optimizer/ReadMe.md +++ b/research/disk_optimizer/ReadMe.md @@ -3,7 +3,9 @@ ## Introduction This part of the code introduces a new component to the optimizer named DiSK. The code uses a simplifed Kalman to improve the privatized gradient estimate. Speficially, the privatized minibatch gradient is replaced with: -$$\mathbb{g_{t+\frac{1}{2}}} = \frac{1}{B}\sum_{\xi \in \mathcal{B}_t} \mathrm{clip}_C\left(\frac{1-\kappa}{\kappa\gamma}\nabla f(x_t + \gamma(x_t-x_{t-1});\xi) + \Big(1- \frac{1-\kappa}{\kappa\gamma}\Big)\nabla f(x_t;\xi)\right) + w_t$$ + +$$\mathbb{g}_{t+\frac{1}{2}}(\xi) = \frac{1-\kappa}{\kappa\gamma}\nabla f(x_t + \gamma(x_t-x_{t-1});\xi) + \Big(1- \frac{1-\kappa}{\kappa\gamma}\Big)\nabla f(x_t;\xi)$$ +$$\mathbb{g_{t+\frac{1}{2}}} = \frac{1}{B}\sum_{\xi \in \mathcal{B}_t} \mathrm{clip}_C\left(\mathbb{g}_{t+\frac{1}{2}}(\xi)\right) + w_t$$ $$g_{t}= (1-\kappa)g_{t-1} + \kappa g_{t+\frac{1}{2}}$$ A detailed description of the algorithm can be found at [Here](https://arxiv.org/abs/2410.03883). diff --git a/research/disk_optimizer/optimizers/KFadaclipoptimizer.py b/research/disk_optimizer/optimizers/KFadaclipoptimizer.py index b8721903..bc30eb07 100644 --- a/research/disk_optimizer/optimizers/KFadaclipoptimizer.py +++ b/research/disk_optimizer/optimizers/KFadaclipoptimizer.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import logging @@ -33,6 +47,7 @@ def __init__( secure_mode: bool = False, kappa: float = 0.7, gamma: float = 0.5, + **kwargs, ): if gamma == 0 or abs(gamma - (1 - kappa) / kappa) < 1e-3: gamma = (1 - kappa) / kappa @@ -58,6 +73,7 @@ def __init__( max_clipbound=max_clipbound, min_clipbound=min_clipbound, unclipped_num_std=unclipped_num_std, + **kwargs, ) self.kappa = kappa self.gamma = gamma @@ -79,7 +95,7 @@ def step(self, closure=required) -> Optional[float]: first_step = True state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) state["kf_m_t"] = grad.clone().to(p.data) - state["kf_m_t"].lerp_(grad, weight=self.kappa) + state["kf_m_t"].lerp_(grad, weight=1 - self.kappa) p.grad = state["kf_m_t"].clone().to(p.data) state["kf_d_t"] = -p.data.clone().to(p.data) if first_step: diff --git a/research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py b/research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py index 662e3612..fe52b5f7 100644 --- a/research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py +++ b/research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from functools import partial @@ -29,6 +43,7 @@ def __init__( secure_mode: bool = False, kappa: float = 0.7, gamma: float = 0.5, + **kwargs, ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -43,6 +58,7 @@ def __init__( secure_mode=secure_mode, kappa=kappa, gamma=gamma, + **kwargs, ) @@ -64,6 +80,7 @@ def __init__( secure_mode: bool = False, kappa: float = 0.7, gamma: float = 0.5, + **kwargs, ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -79,6 +96,7 @@ def __init__( secure_mode=secure_mode, kappa=kappa, gamma=gamma, + **kwargs, ) self._register_hooks() diff --git a/research/disk_optimizer/optimizers/KFddpoptimizer.py b/research/disk_optimizer/optimizers/KFddpoptimizer.py index 00c3bbcd..4ff58c2d 100644 --- a/research/disk_optimizer/optimizers/KFddpoptimizer.py +++ b/research/disk_optimizer/optimizers/KFddpoptimizer.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import logging @@ -32,6 +46,7 @@ def __init__( secure_mode: bool = False, kappa=0.7, gamma=0.5, + **kwargs, ): super().__init__( optimizer, @@ -43,6 +58,7 @@ def __init__( secure_mode=secure_mode, kappa=kappa, gamma=gamma, + **kwargs, ) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -80,7 +96,7 @@ def step(self, closure=required) -> Optional[float]: first_step = True state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) state["kf_m_t"] = grad.clone().to(p.data) - state["kf_m_t"].lerp_(grad, weight=self.kappa) + state["kf_m_t"].lerp_(grad, weight=1 - self.kappa) p.grad = state["kf_m_t"].clone().to(p.data) state["kf_d_t"] = -p.data.clone().to(p.data) if first_step: diff --git a/research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py b/research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py index f3d61b06..771d4ec0 100644 --- a/research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py +++ b/research/disk_optimizer/optimizers/KFddpoptimizer_fast_gradient_clipping.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import logging @@ -32,6 +46,7 @@ def __init__( secure_mode: bool = False, kappa=0.7, gamma=0.5, + **kwargs, ): super().__init__( optimizer, @@ -43,6 +58,7 @@ def __init__( secure_mode=secure_mode, kappa=kappa, gamma=gamma, + **kwargs, ) self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() @@ -80,7 +96,7 @@ def step(self, closure=required) -> Optional[float]: first_step = True state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) state["kf_m_t"] = grad.clone().to(p.data) - state["kf_m_t"].lerp_(grad, weight=self.kappa) + state["kf_m_t"].lerp_(grad, weight=1 - self.kappa) p.grad = state["kf_m_t"].clone().to(p.data) state["kf_d_t"] = -p.data.clone().to(p.data) if first_step: diff --git a/research/disk_optimizer/optimizers/KFoptimizer.py b/research/disk_optimizer/optimizers/KFoptimizer.py index 30597653..3cd2622d 100644 --- a/research/disk_optimizer/optimizers/KFoptimizer.py +++ b/research/disk_optimizer/optimizers/KFoptimizer.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import logging @@ -27,6 +41,7 @@ def __init__( secure_mode: bool = False, kappa=0.7, gamma=0.5, + **kwargs, ): if gamma == 0 or abs(gamma - (1 - kappa) / kappa) < 1e-3: gamma = (1 - kappa) / kappa @@ -47,6 +62,7 @@ def __init__( loss_reduction=loss_reduction, generator=generator, secure_mode=secure_mode, + **kwargs, ) self.kappa = kappa self.gamma = gamma @@ -131,7 +147,7 @@ def step(self, closure=required) -> Optional[float]: first_step = True state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) state["kf_m_t"] = grad.clone().to(p.data) - state["kf_m_t"].lerp_(grad, weight=self.kappa) + state["kf_m_t"].lerp_(grad, weight=1 - self.kappa) p.grad = state["kf_m_t"].clone().to(p.data) state["kf_d_t"] = -p.data.clone().to(p.data) if first_step: diff --git a/research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py b/research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py index 3b105418..a7bcdb3b 100644 --- a/research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py +++ b/research/disk_optimizer/optimizers/KFoptimizer_fast_gradient_clipping.py @@ -1,9 +1,26 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import logging from typing import Optional import torch +from opacus.optimizers.optimizer_fast_gradient_clipping import ( + DPOptimizerFastGradientClipping, +) from torch.optim import Optimizer from torch.optim.optimizer import required @@ -13,8 +30,7 @@ logger = logging.getLogger(__name__) logger.disabled = True - -class KF_DPOptimizerFastGradientClipping(KF_DPOptimizer): +class KF_DPOptimizerFastGradientClipping(DPOptimizerFastGradientClipping, KF_DPOptimizer): def __init__( self, optimizer: Optimizer, @@ -27,8 +43,9 @@ def __init__( secure_mode: bool = False, kappa=0.7, gamma=0.5, + **kwargs, ): - super().__init__( + super(KF_DPOptimizer).__init__( optimizer, noise_multiplier=noise_multiplier, max_grad_norm=max_grad_norm, @@ -38,6 +55,7 @@ def __init__( secure_mode=secure_mode, kappa=kappa, gamma=gamma, + **kwargs, ) def _compute_one_closure(self, closure=required): diff --git a/research/disk_optimizer/optimizers/KFperlayeroptimizer.py b/research/disk_optimizer/optimizers/KFperlayeroptimizer.py index fa48a246..5f41c8f7 100644 --- a/research/disk_optimizer/optimizers/KFperlayeroptimizer.py +++ b/research/disk_optimizer/optimizers/KFperlayeroptimizer.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import logging @@ -28,6 +42,7 @@ def __init__( secure_mode: bool = False, kappa=0.7, gamma=0.5, + **kwargs, ): assert len(max_grad_norm) == len(params(optimizer)) self.max_grad_norms = max_grad_norm @@ -42,6 +57,7 @@ def __init__( secure_mode=secure_mode, kappa=kappa, gamma=gamma, + **kwargs, ) def clip_and_accumulate(self): diff --git a/research/disk_optimizer/optimizers/__init__.py b/research/disk_optimizer/optimizers/__init__.py index cbf4f04f..4af052d8 100644 --- a/research/disk_optimizer/optimizers/__init__.py +++ b/research/disk_optimizer/optimizers/__init__.py @@ -1,3 +1,17 @@ +# Copyright (c) Xinwei Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from opacus.optimizers import ( AdaClipDPOptimizer, DistributedDPOptimizer,