-
Notifications
You must be signed in to change notification settings - Fork 16
feat(aggregation): Add CRMOGMWeighting #669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0aa1c8b
53f3eb3
23c0f62
daf59f9
f3c21a0
65df561
e846a9c
e16cf48
1b74974
1cb9953
7e02ef8
d186b7c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| :hide-toc: | ||
|
|
||
| CR-MOGM | ||
| ======= | ||
|
|
||
| .. autoclass:: torchjd.aggregation.CRMOGMWeighting | ||
| :members: __call__, reset | ||
|
|
||
| .. note:: | ||
| The usage example in the docstring above imports | ||
| ``WeightedAggregator`` / ``GramianWeightedAggregator`` from | ||
| ``torchjd.aggregation._aggregator_bases``, which is a private module. These two | ||
| aggregator base classes are not currently part of the public ``torchjd.aggregation`` | ||
| namespace, so this private-module import is the only path that works today. Promoting | ||
| them to the public namespace is a separate decision left to the maintainers. | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,117 @@ | ||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from typing import TypeVar | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| from torch import Tensor | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from torchjd.aggregation._mixins import Stateful | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from ._weighting_bases import Weighting | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| _T = TypeVar("_T", contravariant=True, bound=Tensor) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class CRMOGMWeighting(Weighting[_T], Stateful): | ||||||||||||||||||||||||||
|
ValerianRey marked this conversation as resolved.
|
||||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||||
| :class:`~torchjd.aggregation._mixins.Stateful` | ||||||||||||||||||||||||||
| :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another | ||||||||||||||||||||||||||
| :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it | ||||||||||||||||||||||||||
| produces with an exponential moving average (EMA) across calls. This is the weight-smoothing | ||||||||||||||||||||||||||
| modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and | ||||||||||||||||||||||||||
| Beyond <https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf>`_ | ||||||||||||||||||||||||||
| (NeurIPS 2022). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step | ||||||||||||||||||||||||||
| :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| .. math:: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top | ||||||||||||||||||||||||||
| \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first | ||||||||||||||||||||||||||
| forward call once :math:`m` is known and is reset automatically when ``m`` changes. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a | ||||||||||||||||||||||||||
| ``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding :class:`~torchjd.aggregation.Aggregator` can be done by composing it with the appropriate | ||||||||||||||||||||||||||
| aggregator base: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| .. code-block:: python | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from torchjd.aggregation import MeanWeighting, UPGradWeighting | ||||||||||||||||||||||||||
| from torchjd.aggregation._aggregator_bases import ( | ||||||||||||||||||||||||||
| GramianWeightedAggregator, WeightedAggregator, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| from torchjd.aggregation._cr_mogm import CRMOGMWeighting | ||||||||||||||||||||||||||
|
Comment on lines
+42
to
+46
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) | ||||||||||||||||||||||||||
| gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` | ||||||||||||||||||||||||||
| when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| :param weighting: The wrapped weighting whose output is smoothed. | ||||||||||||||||||||||||||
| :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing | ||||||||||||||||||||||||||
| (``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes | ||||||||||||||||||||||||||
| the weights at their initial uniform value. The default of ``0.9`` follows the usual | ||||||||||||||||||||||||||
| EMA convention (analogous to Adam's :math:`\beta_1`). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| .. note:: | ||||||||||||||||||||||||||
| ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a | ||||||||||||||||||||||||||
| schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning | ||||||||||||||||||||||||||
| rate decays. Update ``alpha`` between forward calls via the public attribute on the | ||||||||||||||||||||||||||
| wrapping aggregator: | ||||||||||||||||||||||||||
|
Comment on lines
+64
to
+65
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| .. code-block:: python | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # With WeightedAggregator | ||||||||||||||||||||||||||
| aggregator.weighting.alpha = 1 - current_lr / initial_lr | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # With GramianWeightedAggregator | ||||||||||||||||||||||||||
| aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the initialization of lambda is debattable. For now, we have 1/m all the time. Maybe sometimes a user wants to provide their own starting weights (btw they don't need to be in the simplex, even though it's stated like that in the paper I think it's a mistake). So we could have a initial_weights parameter, of type Tensor | None, so that the user can provide their weights or we use 1/m if they don't. The alternative would be to have still type Tensor | None, but if the user gives None, we use lambda_0 = lambda_1_hat. This means that the first weights output by the CRMOGMWeighting will be lambda_1 = lambda_1_hat * alpha + (1 - alpha) * lambda_1_hat = lambda_1_hat. I don't know which option we should go for. @PierreQuinton maybe need your insight on this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex. I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the paper, there is no suggested value for lambda_0. If we set the default to be 1/m, then there is no way to say that we want the default to be lambda_0 = lambda_hat_1. So I would go for:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a default value of maybe 0.9 by default would be more appropriate. If the user doesn't ever update this value, it needs to be rather high. |
||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||
| self.weighting = weighting | ||||||||||||||||||||||||||
| self.alpha = alpha | ||||||||||||||||||||||||||
| self._lambda: Tensor | None = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||
| def alpha(self) -> float: | ||||||||||||||||||||||||||
| return self._alpha | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @alpha.setter | ||||||||||||||||||||||||||
| def alpha(self, value: float) -> None: | ||||||||||||||||||||||||||
| if not (0.0 <= value <= 1.0): | ||||||||||||||||||||||||||
| raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.") | ||||||||||||||||||||||||||
| self._alpha = value | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def reset(self) -> None: | ||||||||||||||||||||||||||
| """Clears the EMA state so the next forward starts from uniform weights.""" | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if isinstance(self.weighting, Stateful): | ||||||||||||||||||||||||||
| self.weighting.reset() | ||||||||||||||||||||||||||
| self._lambda = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def forward(self, stat: _T, /) -> Tensor: | ||||||||||||||||||||||||||
| lambda_hat = self.weighting(stat) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| lambda_prev = self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self._lambda = lambda_k.detach() | ||||||||||||||||||||||||||
| return lambda_k | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: | ||||||||||||||||||||||||||
| if self._lambda is None: | ||||||||||||||||||||||||||
| self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) | ||||||||||||||||||||||||||
| elif self._lambda.shape[0] != m: | ||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||
| f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call " | ||||||||||||||||||||||||||
| f"`reset()` before changing the number of objectives." | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return self._lambda | ||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| from pytest import mark, raises | ||
| from torch import Tensor | ||
| from torch.testing import assert_close | ||
| from utils.tensors import randn_, tensor_ | ||
|
|
||
| from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting | ||
| from torchjd.aggregation._aggregator_bases import ( | ||
| GramianWeightedAggregator, | ||
| WeightedAggregator, | ||
| ) | ||
| from torchjd.aggregation._cr_mogm import CRMOGMWeighting | ||
|
|
||
| from ._asserts import assert_expected_structure | ||
| from ._inputs import scaled_matrices, typical_matrices | ||
|
|
||
| # UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in | ||
| # scaled_matrices, so the gramian-path structural test only uses typical_matrices. | ||
| matrix_pairs = [ | ||
| (WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m) | ||
| for m in typical_matrices + scaled_matrices | ||
| ] | ||
| gramian_pairs = [ | ||
| (GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices | ||
| ] | ||
|
|
||
|
|
||
| @mark.parametrize(["aggregator", "matrix"], matrix_pairs) | ||
| def test_expected_structure_matrix_weighting( | ||
| aggregator: WeightedAggregator, matrix: Tensor | ||
| ) -> None: | ||
| assert_expected_structure(aggregator, matrix) | ||
|
|
||
|
|
||
| @mark.parametrize(["aggregator", "matrix"], gramian_pairs) | ||
| def test_expected_structure_gramian_weighting( | ||
| aggregator: GramianWeightedAggregator, matrix: Tensor | ||
| ) -> None: | ||
| assert_expected_structure(aggregator, matrix) | ||
|
|
||
|
|
||
| def test_reset_restores_first_step_behavior() -> None: | ||
|
ValerianRey marked this conversation as resolved.
|
||
| """ | ||
| Use ``UPGradWeighting`` so the weights actually depend on the input — with | ||
| ``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would | ||
| be trivial. | ||
| """ | ||
|
|
||
| J = randn_((3, 8)) | ||
| G = J @ J.T | ||
| W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5) | ||
| first = W(G) | ||
| W(G) | ||
| W.reset() | ||
| assert_close(first, W(G)) | ||
|
|
||
|
|
||
| def test_reset_propagates_to_stateful_weighting() -> None: | ||
| """ | ||
| Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is | ||
| :class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal | ||
| state is cleared after ``reset()``. | ||
| """ | ||
|
|
||
| inner = GradVacWeighting() | ||
| W = CRMOGMWeighting(inner, alpha=0.5) | ||
| J = randn_((3, 8)) | ||
| W(J @ J.T) | ||
| assert inner._phi_t is not None | ||
| W.reset() | ||
| assert inner._phi_t is None | ||
|
|
||
|
|
||
| def test_changing_m_raises() -> None: | ||
| """Verify that changing the number of objectives after the first call raises a ValueError.""" | ||
|
|
||
| W = CRMOGMWeighting(MeanWeighting()) | ||
| W(randn_((3, 8)) @ randn_((3, 8)).T) | ||
| with raises(ValueError, match="number of objectives"): | ||
| W(randn_((2, 8)) @ randn_((2, 8)).T) | ||
|
|
||
|
|
||
| def test_alpha_setter_accepts_valid() -> None: | ||
| W = CRMOGMWeighting(MeanWeighting()) | ||
| W.alpha = 0.0 | ||
| assert W.alpha == 0.0 | ||
| W.alpha = 0.5 | ||
| assert W.alpha == 0.5 | ||
| W.alpha = 1.0 | ||
| assert W.alpha == 1.0 | ||
|
|
||
|
|
||
| def test_alpha_setter_rejects_out_of_range() -> None: | ||
| W = CRMOGMWeighting(MeanWeighting()) | ||
| with raises(ValueError, match="alpha"): | ||
| W.alpha = -0.1 | ||
| with raises(ValueError, match="alpha"): | ||
| W.alpha = 1.1 | ||
|
|
||
|
|
||
| def test_alpha_zero_reduces_to_bare_weighting() -> None: | ||
| """ | ||
| With ``alpha=0`` the previous state is always multiplied by zero, so the smoothed weights | ||
| equal the bare weighting's output on every call — not just the first. | ||
| """ | ||
|
|
||
| J = randn_((3, 8)) | ||
| G = J @ J.T | ||
| bare = UPGradWeighting() | ||
| smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0) | ||
|
|
||
| expected = bare(G) | ||
| assert_close(smoothed(G), expected) | ||
| assert_close(smoothed(G), expected) | ||
|
|
||
|
|
||
| def test_alpha_one_freezes_weights() -> None: | ||
| """ | ||
| With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at | ||
| their initial uniform value forever. Note: the equality with uniform weights is a | ||
| consequence of the uniform initialisation, not a general property of CR-MOGM. | ||
| """ | ||
|
|
||
| J = randn_((3, 8)) | ||
| m = J.shape[0] | ||
| W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0) | ||
| uniform = tensor_([1.0 / m] * m) | ||
|
|
||
| assert_close(W(J @ J.T), uniform) | ||
| assert_close(W(J @ J.T), uniform) | ||
|
|
||
|
|
||
| def test_ema_is_applied() -> None: | ||
| """Run two steps with ``alpha=0.9`` and check the EMA recurrence by hand.""" | ||
|
|
||
| alpha = 0.9 | ||
| J1 = randn_((3, 8)) | ||
| J2 = randn_((3, 8)) | ||
| G1 = J1 @ J1.T | ||
| G2 = J2 @ J2.T | ||
| m = J1.shape[0] | ||
|
|
||
| bare = UPGradWeighting() | ||
| smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha) | ||
|
|
||
| lambda_hat_1 = bare(G1) | ||
| lambda_hat_2 = bare(G2) | ||
| uniform = tensor_([1.0 / m] * m) | ||
|
|
||
| expected_1 = alpha * uniform + (1.0 - alpha) * lambda_hat_1 | ||
| expected_2 = alpha * expected_1 + (1.0 - alpha) * lambda_hat_2 | ||
|
|
||
| assert_close(smoothed(G1), expected_1) | ||
| assert_close(smoothed(G2), expected_2) | ||
|
|
||
|
|
||
| def test_zero_columns() -> None: | ||
| """ | ||
| A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row | ||
| inputs are intentionally not tested: ``MeanWeighting`` does ``1/m`` in Python and would | ||
| raise ``ZeroDivisionError`` at ``m=0``, which is the wrapped weighting's responsibility. | ||
| """ | ||
|
|
||
| aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) | ||
| out = aggregator(tensor_([]).reshape(2, 0)) | ||
| assert out.shape == (0,) | ||
Uh oh!
There was an error while loading. Please reload this page.