diff --git a/CHANGELOG.md b/CHANGELOG.md index 00a906c6..9a8aedb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Changed + +- Non-differentiable aggregators and weightings (UPGrad, DualProj, PCGrad, GradVac, IMTLG, + GradDrop, ConFIG, CAGrad, NashMTL) no longer build a computation graph when called on tensors + that require gradients. Their forward pass is now wrapped in `torch.no_grad()`, so attempting to + differentiate through them is not possible anymore (while before, it raised a `NonDifferentiableError`). + ### Added - Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 82eb84ca..05818399 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -2,6 +2,7 @@ from torchjd.linalg import PSDMatrix +from ._mixins import _NonDifferentiable from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import _GramianWeighting @@ -15,10 +16,10 @@ from torchjd._linalg import normalize from ._aggregator_bases import GramianWeightedAggregator -from ._utils.non_differentiable import raise_non_differentiable_error -class CAGradWeighting(_GramianWeighting): +# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. +class CAGradWeighting(_NonDifferentiable, _GramianWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.CAGrad`. @@ -92,7 +93,7 @@ def norm_eps(self, value: float) -> None: self._norm_eps = value -class CAGrad(GramianWeightedAggregator): +class CAGrad(_NonDifferentiable, GramianWeightedAggregator): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task Learning @@ -113,9 +114,6 @@ class CAGrad(GramianWeightedAggregator): def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - @property def c(self) -> float: return self.gramian_weighting.c diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 98ac0857..54e2c3dc 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -8,12 +8,13 @@ from torchjd.linalg import Matrix from ._aggregator_bases import Aggregator +from ._mixins import _NonDifferentiable from ._sum import SumWeighting -from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting -class ConFIG(Aggregator): +# Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context. +class ConFIG(_NonDifferentiable, Aggregator): """ :class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks @@ -31,9 +32,6 @@ def __init__(self, pref_vector: Tensor | None = None) -> None: super().__init__() self.pref_vector = pref_vector - # This prevents computing gradients that can be very wrong. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - def forward(self, matrix: Matrix, /) -> Tensor: weights = self.weighting(matrix) units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0) diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index 5ba3645c..e379f127 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -5,13 +5,14 @@ from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting +from ._mixins import _NonDifferentiable from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights -from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting -class DualProjWeighting(_GramianWeighting): +# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph. +class DualProjWeighting(_NonDifferentiable, _GramianWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.DualProj`. @@ -77,7 +78,7 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value -class DualProj(GramianWeightedAggregator): +class DualProj(_NonDifferentiable, GramianWeightedAggregator): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds @@ -109,9 +110,6 @@ def __init__( DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), ) - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - @property def pref_vector(self) -> Tensor | None: return self.gramian_weighting.pref_vector diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index c3354f57..31590ebf 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -6,14 +6,15 @@ from torchjd.linalg import Matrix from ._aggregator_bases import Aggregator -from ._utils.non_differentiable import raise_non_differentiable_error +from ._mixins import _NonDifferentiable def _identity(P: Tensor) -> Tensor: return P -class GradDrop(Aggregator): +# Non-differentiable: the sign-based random masking is not differentiable. +class GradDrop(_NonDifferentiable, Aggregator): """ :class:`~torchjd.aggregation.Aggregator` that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: @@ -31,9 +32,6 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None: self.f = f self.leak = leak - # This prevents computing gradients that can be very wrong. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - def forward(self, matrix: Matrix, /) -> Tensor: self._check_matrix_has_enough_rows(matrix) diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index e5fcdd4b..61469407 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -5,15 +5,15 @@ import torch from torch import Tensor -from torchjd.aggregation._mixins import Stateful +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import _GramianWeighting -class GradVacWeighting(_GramianWeighting, Stateful): +# Non-differentiable: weights are modified in-place during the gradient correction loop. +class GradVacWeighting(_NonDifferentiable, Stateful, _GramianWeighting): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] @@ -128,7 +128,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: self._state_key = key -class GradVac(GramianWeightedAggregator, Stateful): +class GradVac(_NonDifferentiable, Stateful, GramianWeightedAggregator): r""" :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of @@ -167,7 +167,6 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: weighting = GradVacWeighting(beta=beta, eps=eps) super().__init__(weighting) self._gradvac_weighting = weighting - self.register_full_backward_pre_hook(raise_non_differentiable_error) @property def beta(self) -> float: diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 21a7975f..47504e83 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -4,11 +4,12 @@ from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._utils.non_differentiable import raise_non_differentiable_error +from ._mixins import _NonDifferentiable from ._weighting_bases import _GramianWeighting -class IMTLGWeighting(_GramianWeighting): +# Non-differentiable: differentiating through pinv(gramian) would give incorrect gradients. +class IMTLGWeighting(_NonDifferentiable, _GramianWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.IMTLG`. @@ -24,7 +25,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights -class IMTLG(GramianWeightedAggregator): +class IMTLG(_NonDifferentiable, GramianWeightedAggregator): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in `Towards Impartial Multi-task Learning `_. @@ -36,6 +37,3 @@ class IMTLG(GramianWeightedAggregator): def __init__(self) -> None: super().__init__(IMTLGWeighting()) - - # This prevents computing gradients that can be very wrong. - self.register_full_backward_pre_hook(raise_non_differentiable_error) diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 8481feab..29bf5592 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -1,4 +1,8 @@ from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch import nn class Stateful(ABC): @@ -7,3 +11,19 @@ class Stateful(ABC): @abstractmethod def reset(self) -> None: """Resets the internal state.""" + + +class _NonDifferentiable(nn.Module): + """ + Mixin making a nn.Module non-differentiable, preventing autograd graph construction by wrapping + the call in :func:`torch.no_grad`. + + .. warning:: + This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance + list. Placing it after will silently have no effect, because :meth:`__call__` would be + resolved to :class:`torch.nn.Module` before reaching this mixin. + """ + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + with torch.no_grad(): + return super().__call__(*args, **kwargs) diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 63271d63..99356fc9 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -1,7 +1,7 @@ # Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon. # See NOTICES for the full license text. -from torchjd.aggregation._mixins import Stateful +from torchjd.aggregation._mixins import Stateful, _NonDifferentiable from ._utils.check_dependencies import check_dependencies_are_installed from ._weighting_bases import _MatrixWeighting @@ -15,10 +15,10 @@ from torch import Tensor from ._aggregator_bases import WeightedAggregator -from ._utils.non_differentiable import raise_non_differentiable_error -class _NashMTLWeighting(_MatrixWeighting, Stateful): +# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph. +class _NashMTLWeighting(_NonDifferentiable, Stateful, _MatrixWeighting): """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that @@ -199,7 +199,7 @@ def reset(self) -> None: self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32) -class NashMTL(WeightedAggregator, Stateful): +class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator): """ :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of @@ -253,9 +253,6 @@ def __init__( ), ) - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - @property def n_tasks(self) -> int: return self.weighting.n_tasks diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index a796179b..ffce10d3 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -6,11 +6,12 @@ from torchjd.linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator -from ._utils.non_differentiable import raise_non_differentiable_error +from ._mixins import _NonDifferentiable from ._weighting_bases import _GramianWeighting -class PCGradWeighting(_GramianWeighting): +# Non-differentiable: weights are modified in-place during the gradient projection loop. +class PCGradWeighting(_NonDifferentiable, _GramianWeighting): """ :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.PCGrad`. @@ -46,7 +47,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights.to(device) -class PCGrad(GramianWeightedAggregator): +class PCGrad(_NonDifferentiable, GramianWeightedAggregator): """ :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. @@ -56,6 +57,3 @@ class PCGrad(GramianWeightedAggregator): def __init__(self) -> None: super().__init__(PCGradWeighting()) - - # This prevents running into a RuntimeError due to modifying stored tensors in place. - self.register_full_backward_pre_hook(raise_non_differentiable_error) diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index e039a4ea..c1e4807e 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -6,13 +6,14 @@ from ._aggregator_bases import GramianWeightedAggregator from ._mean import MeanWeighting +from ._mixins import _NonDifferentiable from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights -from ._utils.non_differentiable import raise_non_differentiable_error from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting from ._weighting_bases import _GramianWeighting -class UPGradWeighting(_GramianWeighting): +# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph. +class UPGradWeighting(_NonDifferentiable, _GramianWeighting): r""" :class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the weights of :class:`~torchjd.aggregation.UPGrad`. @@ -80,7 +81,7 @@ def reg_eps(self, value: float) -> None: self._reg_eps = value -class UPGrad(GramianWeightedAggregator): +class UPGrad(_NonDifferentiable, GramianWeightedAggregator): r""" :class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed @@ -112,9 +113,6 @@ def __init__( UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver), ) - # This prevents considering the computed weights as constant w.r.t. the matrix. - self.register_full_backward_pre_hook(raise_non_differentiable_error) - @property def pref_vector(self) -> Tensor | None: return self.gramian_weighting.pref_vector diff --git a/src/torchjd/aggregation/_utils/non_differentiable.py b/src/torchjd/aggregation/_utils/non_differentiable.py deleted file mode 100644 index c5fb1ffc..00000000 --- a/src/torchjd/aggregation/_utils/non_differentiable.py +++ /dev/null @@ -1,10 +0,0 @@ -from torch import Tensor, nn - - -class NonDifferentiableError(RuntimeError): - def __init__(self, module: nn.Module) -> None: - super().__init__(f"Trying to differentiate through {module}, which is not differentiable.") - - -def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...] | Tensor) -> None: - raise NonDifferentiableError(module) diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 4b85bf09..a7e418b6 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -1,11 +1,9 @@ import torch -from pytest import raises from torch import Tensor from torch.testing import assert_close from utils.tensors import rand_, randperm_ from torchjd.aggregation import Aggregator -from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError def assert_expected_structure(aggregator: Aggregator, matrix: Tensor) -> None: @@ -103,10 +101,10 @@ def assert_strongly_stationary( def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None: """ - Tests empirically that a given non-differentiable `Aggregator` correctly raises a - NonDifferentiableError whenever we try to backward through it. + Tests that a non-differentiable `Aggregator` does not build a computation graph, even when the + input requires gradients. """ - vector = aggregator(matrix) - with raises(NonDifferentiableError): - vector.backward(torch.ones_like(vector)) + matrix_with_grad = matrix.clone().requires_grad_(True) + vector = aggregator(matrix_with_grad) + assert not vector.requires_grad