diff --git a/CHANGELOG.md b/CHANGELOG.md
index 00a906c6..9a8aedb5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,13 @@ changelog does not include internal changes that do not affect the user.
## [Unreleased]
+### Changed
+
+- Non-differentiable aggregators and weightings (UPGrad, DualProj, PCGrad, GradVac, IMTLG,
+ GradDrop, ConFIG, CAGrad, NashMTL) no longer build a computation graph when called on tensors
+ that require gradients. Their forward pass is now wrapped in `torch.no_grad()`, so attempting to
+ differentiate through them is not possible anymore (while before, it raised a `NonDifferentiableError`).
+
### Added
- Made `WeightedAggregator` and `GramianWeightedAggregator` public. These abstract base classes are
diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py
index 82eb84ca..05818399 100644
--- a/src/torchjd/aggregation/_cagrad.py
+++ b/src/torchjd/aggregation/_cagrad.py
@@ -2,6 +2,7 @@
from torchjd.linalg import PSDMatrix
+from ._mixins import _NonDifferentiable
from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import _GramianWeighting
@@ -15,10 +16,10 @@
from torchjd._linalg import normalize
from ._aggregator_bases import GramianWeightedAggregator
-from ._utils.non_differentiable import raise_non_differentiable_error
-class CAGradWeighting(_GramianWeighting):
+# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
+class CAGradWeighting(_NonDifferentiable, _GramianWeighting):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.CAGrad`.
@@ -92,7 +93,7 @@ def norm_eps(self, value: float) -> None:
self._norm_eps = value
-class CAGrad(GramianWeightedAggregator):
+class CAGrad(_NonDifferentiable, GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
@@ -113,9 +114,6 @@ class CAGrad(GramianWeightedAggregator):
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
- # This prevents considering the computed weights as constant w.r.t. the matrix.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
-
@property
def c(self) -> float:
return self.gramian_weighting.c
diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py
index 98ac0857..54e2c3dc 100644
--- a/src/torchjd/aggregation/_config.py
+++ b/src/torchjd/aggregation/_config.py
@@ -8,12 +8,13 @@
from torchjd.linalg import Matrix
from ._aggregator_bases import Aggregator
+from ._mixins import _NonDifferentiable
from ._sum import SumWeighting
-from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
-class ConFIG(Aggregator):
+# Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context.
+class ConFIG(_NonDifferentiable, Aggregator):
"""
:class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG:
Towards Conflict-free Training of Physics Informed Neural Networks
@@ -31,9 +32,6 @@ def __init__(self, pref_vector: Tensor | None = None) -> None:
super().__init__()
self.pref_vector = pref_vector
- # This prevents computing gradients that can be very wrong.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
-
def forward(self, matrix: Matrix, /) -> Tensor:
weights = self.weighting(matrix)
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py
index 5ba3645c..e379f127 100644
--- a/src/torchjd/aggregation/_dualproj.py
+++ b/src/torchjd/aggregation/_dualproj.py
@@ -5,13 +5,14 @@
from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
+from ._mixins import _NonDifferentiable
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
-from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting
-class DualProjWeighting(_GramianWeighting):
+# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
+class DualProjWeighting(_NonDifferentiable, _GramianWeighting):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.DualProj`.
@@ -77,7 +78,7 @@ def reg_eps(self, value: float) -> None:
self._reg_eps = value
-class DualProj(GramianWeightedAggregator):
+class DualProj(_NonDifferentiable, GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input
matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds
@@ -109,9 +110,6 @@ def __init__(
DualProjWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
)
- # This prevents considering the computed weights as constant w.r.t. the matrix.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
-
@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector
diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py
index c3354f57..31590ebf 100644
--- a/src/torchjd/aggregation/_graddrop.py
+++ b/src/torchjd/aggregation/_graddrop.py
@@ -6,14 +6,15 @@
from torchjd.linalg import Matrix
from ._aggregator_bases import Aggregator
-from ._utils.non_differentiable import raise_non_differentiable_error
+from ._mixins import _NonDifferentiable
def _identity(P: Tensor) -> Tensor:
return P
-class GradDrop(Aggregator):
+# Non-differentiable: the sign-based random masking is not differentiable.
+class GradDrop(_NonDifferentiable, Aggregator):
"""
:class:`~torchjd.aggregation.Aggregator` that applies the gradient combination
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
@@ -31,9 +32,6 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
self.f = f
self.leak = leak
- # This prevents computing gradients that can be very wrong.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
-
def forward(self, matrix: Matrix, /) -> Tensor:
self._check_matrix_has_enough_rows(matrix)
diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py
index e5fcdd4b..61469407 100644
--- a/src/torchjd/aggregation/_gradvac.py
+++ b/src/torchjd/aggregation/_gradvac.py
@@ -5,15 +5,15 @@
import torch
from torch import Tensor
-from torchjd.aggregation._mixins import Stateful
+from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from torchjd.linalg import PSDMatrix
from ._aggregator_bases import GramianWeightedAggregator
-from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import _GramianWeighting
-class GradVacWeighting(_GramianWeighting, Stateful):
+# Non-differentiable: weights are modified in-place during the gradient correction loop.
+class GradVacWeighting(_NonDifferentiable, Stateful, _GramianWeighting):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
@@ -128,7 +128,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
self._state_key = key
-class GradVac(GramianWeightedAggregator, Stateful):
+class GradVac(_NonDifferentiable, Stateful, GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of
@@ -167,7 +167,6 @@ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps)
super().__init__(weighting)
self._gradvac_weighting = weighting
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
@property
def beta(self) -> float:
diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py
index 21a7975f..47504e83 100644
--- a/src/torchjd/aggregation/_imtl_g.py
+++ b/src/torchjd/aggregation/_imtl_g.py
@@ -4,11 +4,12 @@
from torchjd.linalg import PSDMatrix
from ._aggregator_bases import GramianWeightedAggregator
-from ._utils.non_differentiable import raise_non_differentiable_error
+from ._mixins import _NonDifferentiable
from ._weighting_bases import _GramianWeighting
-class IMTLGWeighting(_GramianWeighting):
+# Non-differentiable: differentiating through pinv(gramian) would give incorrect gradients.
+class IMTLGWeighting(_NonDifferentiable, _GramianWeighting):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.IMTLG`.
@@ -24,7 +25,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights
-class IMTLG(GramianWeightedAggregator):
+class IMTLG(_NonDifferentiable, GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in
`Towards Impartial Multi-task Learning `_.
@@ -36,6 +37,3 @@ class IMTLG(GramianWeightedAggregator):
def __init__(self) -> None:
super().__init__(IMTLGWeighting())
-
- # This prevents computing gradients that can be very wrong.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py
index 8481feab..29bf5592 100644
--- a/src/torchjd/aggregation/_mixins.py
+++ b/src/torchjd/aggregation/_mixins.py
@@ -1,4 +1,8 @@
from abc import ABC, abstractmethod
+from typing import Any
+
+import torch
+from torch import nn
class Stateful(ABC):
@@ -7,3 +11,19 @@ class Stateful(ABC):
@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""
+
+
+class _NonDifferentiable(nn.Module):
+ """
+ Mixin making a nn.Module non-differentiable, preventing autograd graph construction by wrapping
+ the call in :func:`torch.no_grad`.
+
+ .. warning::
+ This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance
+ list. Placing it after will silently have no effect, because :meth:`__call__` would be
+ resolved to :class:`torch.nn.Module` before reaching this mixin.
+ """
+
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ with torch.no_grad():
+ return super().__call__(*args, **kwargs)
diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py
index 63271d63..99356fc9 100644
--- a/src/torchjd/aggregation/_nash_mtl.py
+++ b/src/torchjd/aggregation/_nash_mtl.py
@@ -1,7 +1,7 @@
# Partly adapted from https://github.com/AvivNavon/nash-mtl — MIT License, Copyright (c) 2022 Aviv Navon.
# See NOTICES for the full license text.
-from torchjd.aggregation._mixins import Stateful
+from torchjd.aggregation._mixins import Stateful, _NonDifferentiable
from ._utils.check_dependencies import check_dependencies_are_installed
from ._weighting_bases import _MatrixWeighting
@@ -15,10 +15,10 @@
from torch import Tensor
from ._aggregator_bases import WeightedAggregator
-from ._utils.non_differentiable import raise_non_differentiable_error
-class _NashMTLWeighting(_MatrixWeighting, Stateful):
+# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
+class _NashMTLWeighting(_NonDifferentiable, Stateful, _MatrixWeighting):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that
@@ -199,7 +199,7 @@ def reset(self) -> None:
self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)
-class NashMTL(WeightedAggregator, Stateful):
+class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of
@@ -253,9 +253,6 @@ def __init__(
),
)
- # This prevents considering the computed weights as constant w.r.t. the matrix.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
-
@property
def n_tasks(self) -> int:
return self.weighting.n_tasks
diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py
index a796179b..ffce10d3 100644
--- a/src/torchjd/aggregation/_pcgrad.py
+++ b/src/torchjd/aggregation/_pcgrad.py
@@ -6,11 +6,12 @@
from torchjd.linalg import PSDMatrix
from ._aggregator_bases import GramianWeightedAggregator
-from ._utils.non_differentiable import raise_non_differentiable_error
+from ._mixins import _NonDifferentiable
from ._weighting_bases import _GramianWeighting
-class PCGradWeighting(_GramianWeighting):
+# Non-differentiable: weights are modified in-place during the gradient projection loop.
+class PCGradWeighting(_NonDifferentiable, _GramianWeighting):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.PCGrad`.
@@ -46,7 +47,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights.to(device)
-class PCGrad(GramianWeightedAggregator):
+class PCGrad(_NonDifferentiable, GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Gradient Surgery for Multi-Task Learning `_.
@@ -56,6 +57,3 @@ class PCGrad(GramianWeightedAggregator):
def __init__(self) -> None:
super().__init__(PCGradWeighting())
-
- # This prevents running into a RuntimeError due to modifying stored tensors in place.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py
index e039a4ea..c1e4807e 100644
--- a/src/torchjd/aggregation/_upgrad.py
+++ b/src/torchjd/aggregation/_upgrad.py
@@ -6,13 +6,14 @@
from ._aggregator_bases import GramianWeightedAggregator
from ._mean import MeanWeighting
+from ._mixins import _NonDifferentiable
from ._utils.dual_cone import SUPPORTED_SOLVER, project_weights
-from ._utils.non_differentiable import raise_non_differentiable_error
from ._utils.pref_vector import pref_vector_to_str_suffix, pref_vector_to_weighting
from ._weighting_bases import _GramianWeighting
-class UPGradWeighting(_GramianWeighting):
+# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
+class UPGradWeighting(_NonDifferentiable, _GramianWeighting):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.UPGrad`.
@@ -80,7 +81,7 @@ def reg_eps(self, value: float) -> None:
self._reg_eps = value
-class UPGrad(GramianWeightedAggregator):
+class UPGrad(_NonDifferentiable, GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input
matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed
@@ -112,9 +113,6 @@ def __init__(
UPGradWeighting(pref_vector, norm_eps=norm_eps, reg_eps=reg_eps, solver=solver),
)
- # This prevents considering the computed weights as constant w.r.t. the matrix.
- self.register_full_backward_pre_hook(raise_non_differentiable_error)
-
@property
def pref_vector(self) -> Tensor | None:
return self.gramian_weighting.pref_vector
diff --git a/src/torchjd/aggregation/_utils/non_differentiable.py b/src/torchjd/aggregation/_utils/non_differentiable.py
deleted file mode 100644
index c5fb1ffc..00000000
--- a/src/torchjd/aggregation/_utils/non_differentiable.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from torch import Tensor, nn
-
-
-class NonDifferentiableError(RuntimeError):
- def __init__(self, module: nn.Module) -> None:
- super().__init__(f"Trying to differentiate through {module}, which is not differentiable.")
-
-
-def raise_non_differentiable_error(module: nn.Module, _: tuple[Tensor, ...] | Tensor) -> None:
- raise NonDifferentiableError(module)
diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py
index 4b85bf09..a7e418b6 100644
--- a/tests/unit/aggregation/_asserts.py
+++ b/tests/unit/aggregation/_asserts.py
@@ -1,11 +1,9 @@
import torch
-from pytest import raises
from torch import Tensor
from torch.testing import assert_close
from utils.tensors import rand_, randperm_
from torchjd.aggregation import Aggregator
-from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError
def assert_expected_structure(aggregator: Aggregator, matrix: Tensor) -> None:
@@ -103,10 +101,10 @@ def assert_strongly_stationary(
def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None:
"""
- Tests empirically that a given non-differentiable `Aggregator` correctly raises a
- NonDifferentiableError whenever we try to backward through it.
+ Tests that a non-differentiable `Aggregator` does not build a computation graph, even when the
+ input requires gradients.
"""
- vector = aggregator(matrix)
- with raises(NonDifferentiableError):
- vector.backward(torch.ones_like(vector))
+ matrix_with_grad = matrix.clone().requires_grad_(True)
+ vector = aggregator(matrix_with_grad)
+ assert not vector.requires_grad