From 7393ae47fdf824ad65d5035461dc391c0f4cc932 Mon Sep 17 00:00:00 2001 From: Igor Shilov Date: Fri, 28 Oct 2022 07:52:26 -0700 Subject: [PATCH] Functorch gradients: investigation and fix (#510) Summary: *The investigation part for this PR was done by alexandresablayrolles, thanks for figuring out the reason the tests were failing* ## Background Current implementation of functorch-based per sample gradients fails on modules which have both trainable non-recursive parameters and standard submodules, e.g. below ``` class LinearWithExtraParam(nn.Module): def __init__(self, in_features: int, out_features: int, hidden_dim: int = 8): super().__init__() self.fc = nn.Linear(in_features, hidden_dim) self.extra_param = nn.Parameter(torch.randn(hidden_dim, out_features)) def forward(self, x): x = self.fc(x) x = x.matmul(self.extra_param) return x ``` The reason is - functorch hook actually computes gradients for recursive submodules too. The problem is, normal hooks are also attached to these submodules. GradSampleModule then sees two grad_sample tensors, thinks it needs to accumulate and adds them up together ## Solution(s) There are essentially two ways we can fix this: either make functorch compute per sample gradients for non-recursive parameters only or don't attach normal hooks to submodules where the parent module is handled by functorch. This diff implements the latter option (reasoning below), for demo purposes the former option can be seen in https://github.com/pytorch/opacus/issues/531 For the pure code perspective the former option (let's call it "non-recursive functorch") is more appealing to me. It better fits the existing paradigm and matches normal hooks behaviour - all of the existing code only deals with the immediate non-recursive parameters. However, it doesn't make much sense from the efficiency perspective. "non-recursive functorch" would do all the work to compute per-sample gradients for its submodules, only for them to be filtered out at the very last stage. Alternative option (a.k.a. "functorch for subtrees") does involve a bit more convoluted This has a noticeable effect on performance. Below is the results of MNIST benchmarks with different configurations. I've tested this with different configurations, because at the end of the day, the impact on performance depends on how deep are subtrees * Standard model- our model from MNIST example, standard layers only (2 conv + 2 linear). No overhead expected, functorch doesn't kick in * Mid-level model - leaf nodes (two linear layers) have one extra param and are computed with functorch. Overhead: 2x Linear hook * Extreme model - root model have one extra param and needs to be handled by functorch. Overhead: 2x linear hook + 2x conv hook | Mode | non-recursive functorch | functorch for subtrees | |:-----------------------:|:------------------------:|:-----------------------:| | Standard model (CPU) | 138s | 136s | | Standard model (GPU) | 149s | 150s | | Mid-level model (CPU) | 157s | 150s | | Mid-level model (GPU) | 100s | 97s | | Extreme model (CPU) | 207s | 172s | | Extreme model (GPU) | 101s | 94s | Pull Request resolved: https://github.com/pytorch/opacus/pull/510 Reviewed By: alexandresablayrolles Differential Revision: D39579487 Pulled By: ffuuugor fbshipit-source-id: 1b089bd04ab110174a1f2ebb371380eb2ce76054 --- opacus/grad_sample/functorch.py | 2 +- opacus/grad_sample/grad_sample_module.py | 20 ++++- opacus/tests/privacy_engine_test.py | 85 ++++++++++++++++--- .../tests/privacy_engine_validation_test.py | 54 ++---------- opacus/tests/utils.py | 50 +++++++++++ opacus/utils/module_utils.py | 6 +- 6 files changed, 151 insertions(+), 66 deletions(-) create mode 100644 opacus/tests/utils.py diff --git a/opacus/grad_sample/functorch.py b/opacus/grad_sample/functorch.py index 97779506..ade37a1c 100644 --- a/opacus/grad_sample/functorch.py +++ b/opacus/grad_sample/functorch.py @@ -48,7 +48,7 @@ def ft_compute_per_sample_gradient(layer, activations, backprops): activations: the input to the layer backprops: the gradient of the loss w.r.t. outputs of the layer """ - parameters = list(layer.parameters()) + parameters = list(layer.parameters(recurse=True)) if not hasattr(layer, "ft_compute_sample_grad"): prepare_layer(layer) diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index d2fb0987..3b2a226e 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -18,7 +18,7 @@ import logging import warnings from functools import partial -from typing import List, Tuple +from typing import Iterable, List, Tuple import torch import torch.nn as nn @@ -26,6 +26,7 @@ from opacus.grad_sample.gsm_base import AbstractGradSampleModule from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear from opacus.utils.module_utils import ( + has_trainable_params, requires_grad, trainable_modules, trainable_parameters, @@ -146,6 +147,21 @@ def __init__( def forward(self, *args, **kwargs): return self._module(*args, **kwargs) + def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]: + if has_trainable_params(module): + yield module + + # Don't recurse if module is handled by functorch + if ( + has_trainable_params(module) + and type(module) not in self.GRAD_SAMPLERS + and type(module) not in [DPRNN, DPLSTM, DPGRU] + ): + return + + for m in module.children(): + yield from self.iterate_submodules(m) + def add_hooks( self, *, @@ -177,7 +193,7 @@ def add_hooks( self._module.autograd_grad_sample_hooks = [] self.autograd_grad_sample_hooks = self._module.autograd_grad_sample_hooks - for _module_name, module in trainable_modules(self._module): + for module in self.iterate_submodules(self._module): # Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear` if type(module) in [DPRNN, DPLSTM, DPGRU]: continue diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index 90af717a..ad2eb6f3 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -40,6 +40,18 @@ from torchvision import models, transforms from torchvision.datasets import FakeData +from .utils import CustomLinearModule, LinearWithExtraParam + + +def _is_functorch_available(): + try: + # flake8: noqa F401 + import functorch + + return True + except ImportError: + return False + def get_grad_sample_aggregated(tensor: torch.Tensor, loss_type: str = "mean"): if tensor.grad_sample is None: @@ -246,7 +258,7 @@ def _compare_to_vanilla( # vanilla gradient is nearly zero: will match even with clipping continue - atol = 1e-7 if max_steps == 1 else 1e-5 + atol = 1e-7 if max_steps == 1 else 1e-4 self.assertEqual( torch.allclose(vp, pp, atol=atol, rtol=1e-3), expected_match, @@ -265,10 +277,6 @@ def _compare_to_vanilla( do_noise=st.booleans(), use_closure=st.booleans(), max_steps=st.sampled_from([1, 4]), - # do_clip=st.just(False), - # do_noise=st.just(False), - # use_closure=st.just(False), - # max_steps=st.sampled_from([4]), ) @settings(deadline=None) def test_compare_to_vanilla( @@ -799,9 +807,7 @@ def _init_data(self): ) return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) - def _init_model( - self, private=False, state_dict=None, model=None, **privacy_engine_kwargs - ): + def _init_model(self): return SampleConvNet() @@ -817,9 +823,7 @@ def _init_data(self): ) return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) - def _init_model( - self, private=False, state_dict=None, model=None, **privacy_engine_kwargs - ): + def _init_model(self): m = SampleConvNet() for p in itertools.chain(m.conv1.parameters(), m.gnorm1.parameters()): p.requires_grad = False @@ -827,6 +831,13 @@ def _init_model( return m +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineConvNetFrozenTestFunctorch(PrivacyEngineConvNetFrozenTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + @unittest.skipIf( torch.__version__ < API_CUTOFF_VERSION, "not supported in this torch version" ) @@ -840,6 +851,13 @@ def test_sample_grad_aggregation(self): pass +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineConvNetTestFunctorch(PrivacyEngineConvNetTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + class SampleAttnNet(nn.Module): def __init__(self): super().__init__() @@ -919,6 +937,13 @@ def _init_model( return SampleAttnNet() +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineTextTestFunctorch(PrivacyEngineTextTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + class SampleTiedWeights(nn.Module): def __init__(self, tie=True): super().__init__() @@ -958,7 +983,39 @@ def _init_data(self): ) return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) - def _init_model( - self, private=False, state_dict=None, model=None, **privacy_engine_kwargs - ): + def _init_model(self): return SampleTiedWeights(tie=True) + + +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineTiedWeightsTestFunctorch(PrivacyEngineTiedWeightsTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + +class ModelWithCustomLinear(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = CustomLinearModule(4, 8) + self.fc2 = LinearWithExtraParam(8, 4) + self.extra_param = nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = x.matmul(self.extra_param) + return x + + +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineCustomLayerTest(BasePrivacyEngineTest, unittest.TestCase): + def _init_data(self): + ds = TensorDataset( + torch.randn(self.DATA_SIZE, 4), + torch.randint(low=0, high=3, size=(self.DATA_SIZE,)), + ) + return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) + + def _init_model(self): + return ModelWithCustomLinear() diff --git a/opacus/tests/privacy_engine_validation_test.py b/opacus/tests/privacy_engine_validation_test.py index 8548f73f..0ba061d8 100644 --- a/opacus/tests/privacy_engine_validation_test.py +++ b/opacus/tests/privacy_engine_validation_test.py @@ -1,58 +1,16 @@ import unittest import torch -import torch.nn as nn -import torch.nn.functional as F from opacus import PrivacyEngine from opacus.grad_sample.gsm_exp_weights import API_CUTOFF_VERSION from torch.utils.data import DataLoader - -class BasicSupportedModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2) - self.gn = nn.GroupNorm(num_groups=2, num_channels=8) - self.fc = nn.Linear(in_features=4, out_features=8) - self.ln = nn.LayerNorm([8, 8]) - - def forward(self, x): - x = self.conv(x) - x = self.gn(x) - x = self.fc(x) - x = self.ln(x) - return x - - -class CustomLinearModule(nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self._weight = nn.Parameter(torch.randn(out_features, in_features)) - self._bias = nn.Parameter(torch.randn(out_features)) - - def forward(self, x): - return F.linear(x, self._weight, self._bias) - - -class MatmulModule(nn.Module): - def __init__(self, input_features, output_features): - super().__init__() - self.weight = nn.Parameter(torch.randn(input_features, output_features)) - - def forward(self, x): - return torch.matmul(x, self.weight) - - -class LinearWithExtraParam(nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.fc = nn.Linear(in_features, out_features) - self.extra_param = nn.Parameter(torch.randn(out_features, 2)) - - def forward(self, x): - x = self.fc(x) - x = x.matmul(self.extra_param) - return x +from .utils import ( + BasicSupportedModule, + CustomLinearModule, + LinearWithExtraParam, + MatmulModule, +) class PrivacyEngineValidationTest(unittest.TestCase): diff --git a/opacus/tests/utils.py b/opacus/tests/utils.py new file mode 100644 index 00000000..36833977 --- /dev/null +++ b/opacus/tests/utils.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicSupportedModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2) + self.gn = nn.GroupNorm(num_groups=2, num_channels=8) + self.fc = nn.Linear(in_features=4, out_features=8) + self.ln = nn.LayerNorm([8, 8]) + + def forward(self, x): + x = self.conv(x) + x = self.gn(x) + x = self.fc(x) + x = self.ln(x) + return x + + +class CustomLinearModule(nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self._weight = nn.Parameter(torch.randn(out_features, in_features)) + self._bias = nn.Parameter(torch.randn(out_features)) + + def forward(self, x): + return F.linear(x, self._weight, self._bias) + + +class MatmulModule(nn.Module): + def __init__(self, input_features: int, output_features: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(input_features, output_features)) + + def forward(self, x): + return torch.matmul(x, self.weight) + + +class LinearWithExtraParam(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int = 8): + super().__init__() + self.fc = nn.Linear(in_features, hidden_dim) + self.extra_param = nn.Parameter(torch.randn(hidden_dim, out_features)) + + def forward(self, x): + x = self.fc(x) + x = x.matmul(self.extra_param) + return x diff --git a/opacus/utils/module_utils.py b/opacus/utils/module_utils.py index da2f6c9a..28146cef 100644 --- a/opacus/utils/module_utils.py +++ b/opacus/utils/module_utils.py @@ -31,7 +31,11 @@ logger.setLevel(level=logging.INFO) -def parametrized_modules(module: nn.Module) -> Iterable[nn.Module]: +def has_trainable_params(module: nn.Module) -> bool: + return any(p.requires_grad for p in module.parameters(recurse=False)) + + +def parametrized_modules(module: nn.Module) -> Iterable[Tuple[str, nn.Module]]: """ Recursively iterates over all submodules, returning those that have parameters (as opposed to "wrapper modules" that just organize modules).