Skip to content

feat(aggregation): Add CRMOGMWeighting#669

Open
KhusPatel4450 wants to merge 12 commits intoSimplexLab:mainfrom
KhusPatel4450:feat/cr-mogm-weighting
Open

feat(aggregation): Add CRMOGMWeighting#669
KhusPatel4450 wants to merge 12 commits intoSimplexLab:mainfrom
KhusPatel4450:feat/cr-mogm-weighting

Conversation

@KhusPatel4450
Copy link
Copy Markdown

  • Adds CRMOGMWeighting, a stateful Weighting modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022)
  • Wraps any existing Weighting and smooths its output with an EMA: λk = α·λ{k-1} + (1−α)·λ̂_k
  • Generic over the input type so it composes correctly with both WeightedAggregator and GramianWeightedAggregator
  • Stateful via the Stateful mixin; reset() restores uniform initial weights

Tests:

  • uv run pytest tests/unit/aggregation/test_cr_mogm.py -v, 92 tests covering EMA recurrence, alpha boundaries, reset, structural checks on both aggregator paths
  • uv run pytest tests/unit -q, full regression (2889 passed)
  • uv run ty check src/torchjd/aggregation/_cr_mogm.py, passes
  • Sphinx doctest, 94 tests, 0 failures

@KhusPatel4450 KhusPatel4450 changed the title Add CRMOGMWeighting from NeurIPS 2022 (Aggregation Feature) feat(aggregation): Add CRMOGMWeighting from NeurIPS 2022 May 7, 2026
@ValerianRey ValerianRey changed the title feat(aggregation): Add CRMOGMWeighting from NeurIPS 2022 feat(aggregation): Add CRMOGMWeighting May 7, 2026
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels May 7, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks a lot for the PR! I'm gonna review soon! In the meantime, you can try to get the CI the pass

@ValerianRey ValerianRey mentioned this pull request May 7, 2026
@KhusPatel4450
Copy link
Copy Markdown
Author

Hello,

Happy to say, all checks have been passed! Glad to have got my first PR as well. Looking forward to feedback

Comment thread docs/source/docs/aggregation/cr_mogm.rst
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread tests/unit/aggregation/test_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread tests/unit/aggregation/test_cr_mogm.py
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr
"""

def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the initialization of lambda is debattable.

For now, we have 1/m all the time.

Maybe sometimes a user wants to provide their own starting weights (btw they don't need to be in the simplex, even though it's stated like that in the paper I think it's a mistake).

So we could have a initial_weights parameter, of type Tensor | None, so that the user can provide their weights or we use 1/m if they don't.

The alternative would be to have still type Tensor | None, but if the user gives None, we use lambda_0 = lambda_1_hat.

This means that the first weights output by the CRMOGMWeighting will be lambda_1 = lambda_1_hat * alpha + (1 - alpha) * lambda_1_hat = lambda_1_hat.

I don't know which option we should go for. @PierreQuinton maybe need your insight on this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.

I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paper, there is no suggested value for lambda_0.

If we set the default to be 1/m, then there is no way to say that we want the default to be lambda_0 = lambda_hat_1.

So I would go for:

  • Add a initial_weights param, that will be Tensor | None.
  • If none initialize it at lambda_hat_1 at the first iteration.
  • at init, save the initial_weights param in a variable so that we can reset lambda to it when calling reset.
  • if people want to use 1/m, they can still do it by manually providing this tensor of weights.

Comment thread CHANGELOG.md Outdated
@KhusPatel4450
Copy link
Copy Markdown
Author

All the things addressed:

  • Reset propagation: reset() now calls self.weighting.reset() if the wrapped weighting is Stateful.

  • device/dtype/m from weighting output: forward() now calls self.weighting(stat) first and reads everything from lambda_hat, not from stat.

  • Removed repr and the corresponding test_representations test.

  • Removed _state_key: _ensure_state now checks shape/dtype/device directly off _lambda.

  • Added test for reset() propagation using GradVacWeighting as the inner stateful weighting.

Still open:

  • Initial weight strategy (uniform 1/m vs first weighting output)

  • Type checking failure

Comment thread src/torchjd/aggregation/_cr_mogm.py
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr
"""

def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.

I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread CHANGELOG.md Outdated
@KhusPatel4450
Copy link
Copy Markdown
Author

Hello I updated the code with the changes that were requested with these two commits, its just that 2nd commit has the similified version and the raise on shape change in CRMOGMWeighting._ensure_state

@ValerianRey

This comment was marked as resolved.

ValerianRey added a commit that referenced this pull request May 9, 2026
- Introduces a new public `torchjd.linalg` package exposing `Matrix` and
`PSDMatrix` (the rest of `_linalg` stays protected)
- Makes `MatrixWeighting` and `GramianWeighting` protected. These
classes are still used to specify the docstring of the `__call__`
methods of the aggregators, but the user only sees those aggregators as
`Weighting[Matrix]` and `Weighting[PSDMatrix]`, respectively. The
`MatrixWeighting` and `GramianWeighting` classes really just bring
updated docstrings, that's all.
- Makes the public type of the gramian_weighting of
GramianWeightedAggregator be Weighting[PSDMatrix] instead of
GramianWeighting, so that #669 can work. Similar with weighting of
WeightedAggregator being Weighting[Matrix].
- Expands docstrings on `Matrix` and `PSDMatrix` with Jacobian and
Gramian examples; adds Sphinx documentation under a new **linalg**
section in the API Reference
gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))

This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset`
when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also
to restart the smoothing from the starting weights. Note that calling :meth:`reset` will also

self._alpha = value

def reset(self) -> None:
"""Clears the EMA state so the next forward starts from uniform weights."""
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Clears the EMA state so the next forward starts from uniform weights."""
"""
Clears the EMA state to the originally provided initial weights.
Also resets the wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`.
"""

Comment on lines +42 to +46
from torchjd.aggregation import MeanWeighting, UPGradWeighting
from torchjd.aggregation._aggregator_bases import (
GramianWeightedAggregator, WeightedAggregator,
)
from torchjd.aggregation._cr_mogm import CRMOGMWeighting
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from torchjd.aggregation import MeanWeighting, UPGradWeighting
from torchjd.aggregation._aggregator_bases import (
GramianWeightedAggregator, WeightedAggregator,
)
from torchjd.aggregation._cr_mogm import CRMOGMWeighting
from torchjd.aggregation import (
CRMOGMWeighting
GramianWeightedAggregator
MeanWeighting,
UPGradWeighting,
WeightedAggregator
)


matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting()))
gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note that in ``MeanWeighting`` is used just for the sake of the example: the exponential moving
average of constant weights will always be equal to the weights themselves, so wrapping by
``CRMOGMWeighting`` will have no effect.

Comment on lines +64 to +65
rate decays. Update ``alpha`` between forward calls via the public attribute on the
wrapping aggregator:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rate decays. Update ``alpha`` between forward calls via the public attribute on the
wrapping aggregator:
rate decays. Update ``alpha`` between forward calls via the setter.

aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr
"""

def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a default value of maybe 0.9 by default would be more appropriate. If the user doesn't ever update this value, it needs to be rather high.

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan Please review this with a lot of precision.

@opencode-agent
Copy link
Copy Markdown
Contributor

Comprehensive Code Review: CRMOGMWeighting PR

Based on my thorough analysis of the code, review comments, and project conventions, here are my findings:

🔴 Critical Issues (Must Fix)

1. Docstring/Default Value Mismatch (Line 56-59, 76)

The docstring states the default alpha is 0.9 following "the usual EMA convention (analogous to Adam's β₁)":

# Line 56-59 in docstring:
"The default of ``0.9`` follows the usual EMA convention (analogous to Adam's :math:`\\beta_1`)."

# Line 76 in __init__:
def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None:  # WRONG DEFAULT

Fix: Change default from 0.1 to 0.9 as per ValerianRey's review comment and the documented intent.

2. Docstring Misleading Description of m Change Behavior (Lines 32-34)

The docstring says the state "is reset automatically when m changes":

# Lines 32-34:
"The state :math:`\\lambda_{k-1}` is initialised lazily on the first
forward call once :math:`m` is known and is reset automatically when ``m`` changes."

But the actual implementation in _ensure_state (lines 109-117) raises a ValueError when m changes, not an automatic reset:

elif self._lambda.shape[0] != m:
    raise ValueError(
        f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call "
        f"`reset()` before changing the number of objectives."
    )

Fix: Update the docstring to accurately describe the error-raising behavior.


🟡 Minor Issues (Should Fix)

3. Example Code Indentation in Docstring (Lines 42-49)

The example code in the docstring uses inconsistent indentation. The imports are on separate lines but could be cleaner:

# Current:
from torchjd.aggregation import MeanWeighting, UPGradWeighting
from torchjd.aggregation._aggregator_bases import (
    GramianWeightedAggregator, WeightedAggregator,
)
from torchjd.aggregation._cr_mogm import CRMOGMWeighting

4. Missing Import in Example (Lines 42-49)

The docstring example should show users how to properly import CRMOGMWeighting from the public API:

from torchjd.aggregation import CRMOGMWeighting  # It's now in __all__

5. CHANGELOG Formatting

The CHANGELOG entry (line 13-16) has a URL that might make the line too long. According to PierreQuinton's review, "This might have become a too long line."


What's Done Correctly

  1. Reset Propagation: Correctly implemented - reset() calls self.weighting.reset() when the wrapped weighting is Stateful (lines 95-97).

  2. Type Safety: Properly uses generic TypeVar("_T") to allow wrapping both MatrixWeighting and GramianWeighting.

  3. Device/Dtype Handling: Correctly reads device, dtype, and m from lambda_hat rather than from stat (line 100-102).

  4. Alpha Validation: Proper getter/setter with bounds checking [0, 1] (lines 82-90).

  5. State Management: Uses self._lambda: Tensor | None = None pattern consistent with GradVacWeighting.

  6. Test Coverage: Comprehensive tests covering:

    • EMA recurrence
    • Alpha boundaries
    • Reset functionality
    • Reset propagation to wrapped stateful weightings
    • Structure checks on both aggregator paths
    • Zero columns handling

📋 Specific Line-by-Line Issues

Line Issue Severity
76 Default alpha=0.1 should be 0.9 Critical
32-34 Docstring says "reset automatically" but code raises ValueError Critical
42-46 Import example should use public API import Minor
56-59 Docstring says default is 0.9 but code has 0.1 Critical

🧪 Test Review

The test file (test_cr_mogm.py) is well-written with good coverage:

  • test_reset_restores_first_step_behavior: ✓ Verifies EMA reset works
  • test_reset_propagates_to_stateful_weighting: ✓ Verifies reset propagation using GradVacWeighting
  • test_changing_m_raises: ✓ Verifies ValueError on m change
  • test_alpha_setter_accepts_valid: ✓ Boundary testing
  • test_alpha_setter_rejects_out_of_range: ✓ Invalid input testing
  • test_alpha_zero_reduces_to_bare_weighting: ✓ Edge case
  • test_alpha_one_freezes_weights: ✓ Edge case
  • test_ema_is_applied: ✓ Mathematical correctness
  • test_zero_columns: ✓ Edge case handling

All tests follow the project's patterns from _asserts.py and _inputs.py.


🔧 Recommended Fixes

  1. Fix default alpha value:
def __init__(self, weighting: Weighting[_T], alpha: float = 0.9) -> None:  # Changed from 0.1
  1. Fix docstring description of m change:
# Change from:
"The state :math:`\\lambda_{k-1}` is initialised lazily on the first
forward call once :math:`m` is known and is reset automatically when ``m`` changes."

# To:
"The state :math:`\\lambda_{k-1}` is initialised lazily on the first
forward call once :math:`m` is known. If ``m`` changes between calls, a
:exc:`ValueError` is raised; call :meth:`reset` to handle task count changes."
  1. Update example import in docstring:
from torchjd.aggregation import (
    CRMOGMWeighting,
    GramianWeightedAggregator,
    MeanWeighting,
    UPGradWeighting,
    WeightedAggregator,
)
  1. Verify CHANGELOG line length - split long line if needed.

Approval Status: CONDITIONAL

This PR is ready to merge with minor fixes:

  1. Change default alpha from 0.1 to 0.9
  2. Fix the docstring description of m change behavior
  3. Optionally update the example import for clarity

The implementation is correct, well-tested, and follows project conventions. The typing is properly handled after #673 was merged.

New%20session%20-%202026-05-10T00%3A47%3A55.664Z
opencode session  |  github run

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants