feat(aggregation): Add CRMOGMWeighting#669
feat(aggregation): Add CRMOGMWeighting#669KhusPatel4450 wants to merge 12 commits intoSimplexLab:mainfrom
Conversation
|
Thanks a lot for the PR! I'm gonna review soon! In the meantime, you can try to get the CI the pass |
|
Hello, Happy to say, all checks have been passed! Glad to have got my first PR as well. Looking forward to feedback |
| aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr | ||
| """ | ||
|
|
||
| def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: |
There was a problem hiding this comment.
I think the initialization of lambda is debattable.
For now, we have 1/m all the time.
Maybe sometimes a user wants to provide their own starting weights (btw they don't need to be in the simplex, even though it's stated like that in the paper I think it's a mistake).
So we could have a initial_weights parameter, of type Tensor | None, so that the user can provide their weights or we use 1/m if they don't.
The alternative would be to have still type Tensor | None, but if the user gives None, we use lambda_0 = lambda_1_hat.
This means that the first weights output by the CRMOGMWeighting will be lambda_1 = lambda_1_hat * alpha + (1 - alpha) * lambda_1_hat = lambda_1_hat.
I don't know which option we should go for. @PierreQuinton maybe need your insight on this.
There was a problem hiding this comment.
The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.
I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.
There was a problem hiding this comment.
In the paper, there is no suggested value for lambda_0.
If we set the default to be 1/m, then there is no way to say that we want the default to be lambda_0 = lambda_hat_1.
So I would go for:
- Add a initial_weights param, that will be Tensor | None.
- If none initialize it at lambda_hat_1 at the first iteration.
- at init, save the initial_weights param in a variable so that we can reset lambda to it when calling reset.
- if people want to use 1/m, they can still do it by manually providing this tensor of weights.
|
All the things addressed:
Still open:
|
| aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr | ||
| """ | ||
|
|
||
| def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: |
There was a problem hiding this comment.
The weights definitely can leave the simplex (maybe not for MGDA, but for most other weightings it is not the case) So I would not limit ourselves to the simplex.
I don't know about the second question, but I would use the default value they use int he paper, which seems to be what is currently the implementation.
|
Hello I updated the code with the changes that were requested with these two commits, its just that 2nd commit has the similified version and the raise on shape change in CRMOGMWeighting._ensure_state |
This comment was marked as resolved.
This comment was marked as resolved.
- Introduces a new public `torchjd.linalg` package exposing `Matrix` and `PSDMatrix` (the rest of `_linalg` stays protected) - Makes `MatrixWeighting` and `GramianWeighting` protected. These classes are still used to specify the docstring of the `__call__` methods of the aggregators, but the user only sees those aggregators as `Weighting[Matrix]` and `Weighting[PSDMatrix]`, respectively. The `MatrixWeighting` and `GramianWeighting` classes really just bring updated docstrings, that's all. - Makes the public type of the gramian_weighting of GramianWeightedAggregator be Weighting[PSDMatrix] instead of GramianWeighting, so that #669 can work. Similar with weighting of WeightedAggregator being Weighting[Matrix]. - Expands docstrings on `Matrix` and `PSDMatrix` with Jacobian and Gramian examples; adds Sphinx documentation under a new **linalg** section in the API Reference
| gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) | ||
|
|
||
| This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` | ||
| when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also |
There was a problem hiding this comment.
| when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also | |
| to restart the smoothing from the starting weights. Note that calling :meth:`reset` will also |
| self._alpha = value | ||
|
|
||
| def reset(self) -> None: | ||
| """Clears the EMA state so the next forward starts from uniform weights.""" |
There was a problem hiding this comment.
| """Clears the EMA state so the next forward starts from uniform weights.""" | |
| """ | |
| Clears the EMA state to the originally provided initial weights. | |
| Also resets the wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`. | |
| """ |
| from torchjd.aggregation import MeanWeighting, UPGradWeighting | ||
| from torchjd.aggregation._aggregator_bases import ( | ||
| GramianWeightedAggregator, WeightedAggregator, | ||
| ) | ||
| from torchjd.aggregation._cr_mogm import CRMOGMWeighting |
There was a problem hiding this comment.
| from torchjd.aggregation import MeanWeighting, UPGradWeighting | |
| from torchjd.aggregation._aggregator_bases import ( | |
| GramianWeightedAggregator, WeightedAggregator, | |
| ) | |
| from torchjd.aggregation._cr_mogm import CRMOGMWeighting | |
| from torchjd.aggregation import ( | |
| CRMOGMWeighting | |
| GramianWeightedAggregator | |
| MeanWeighting, | |
| UPGradWeighting, | |
| WeightedAggregator | |
| ) |
|
|
||
| matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) | ||
| gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) | ||
|
|
There was a problem hiding this comment.
| Note that in ``MeanWeighting`` is used just for the sake of the example: the exponential moving | |
| average of constant weights will always be equal to the weights themselves, so wrapping by | |
| ``CRMOGMWeighting`` will have no effect. |
| rate decays. Update ``alpha`` between forward calls via the public attribute on the | ||
| wrapping aggregator: |
There was a problem hiding this comment.
| rate decays. Update ``alpha`` between forward calls via the public attribute on the | |
| wrapping aggregator: | |
| rate decays. Update ``alpha`` between forward calls via the setter. |
| aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr | ||
| """ | ||
|
|
||
| def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: |
There was a problem hiding this comment.
I think a default value of maybe 0.9 by default would be more appropriate. If the user doesn't ever update this value, it needs to be rather high.
|
/opencode:Plan Please review this with a lot of precision. |
Comprehensive Code Review: CRMOGMWeighting PRBased on my thorough analysis of the code, review comments, and project conventions, here are my findings: 🔴 Critical Issues (Must Fix)1. Docstring/Default Value Mismatch (Line 56-59, 76)The docstring states the default # Line 56-59 in docstring:
"The default of ``0.9`` follows the usual EMA convention (analogous to Adam's :math:`\\beta_1`)."
# Line 76 in __init__:
def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: # WRONG DEFAULTFix: Change default from 2. Docstring Misleading Description of m Change Behavior (Lines 32-34)The docstring says the state "is reset automatically when # Lines 32-34:
"The state :math:`\\lambda_{k-1}` is initialised lazily on the first
forward call once :math:`m` is known and is reset automatically when ``m`` changes."But the actual implementation in elif self._lambda.shape[0] != m:
raise ValueError(
f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call "
f"`reset()` before changing the number of objectives."
)Fix: Update the docstring to accurately describe the error-raising behavior. 🟡 Minor Issues (Should Fix)3. Example Code Indentation in Docstring (Lines 42-49)The example code in the docstring uses inconsistent indentation. The imports are on separate lines but could be cleaner: # Current:
from torchjd.aggregation import MeanWeighting, UPGradWeighting
from torchjd.aggregation._aggregator_bases import (
GramianWeightedAggregator, WeightedAggregator,
)
from torchjd.aggregation._cr_mogm import CRMOGMWeighting4. Missing Import in Example (Lines 42-49)The docstring example should show users how to properly import from torchjd.aggregation import CRMOGMWeighting # It's now in __all__5. CHANGELOG FormattingThe CHANGELOG entry (line 13-16) has a URL that might make the line too long. According to PierreQuinton's review, "This might have become a too long line." ✅ What's Done Correctly
📋 Specific Line-by-Line Issues
🧪 Test ReviewThe test file (
All tests follow the project's patterns from 🔧 Recommended Fixes
def __init__(self, weighting: Weighting[_T], alpha: float = 0.9) -> None: # Changed from 0.1
# Change from:
"The state :math:`\\lambda_{k-1}` is initialised lazily on the first
forward call once :math:`m` is known and is reset automatically when ``m`` changes."
# To:
"The state :math:`\\lambda_{k-1}` is initialised lazily on the first
forward call once :math:`m` is known. If ``m`` changes between calls, a
:exc:`ValueError` is raised; call :meth:`reset` to handle task count changes."
from torchjd.aggregation import (
CRMOGMWeighting,
GramianWeightedAggregator,
MeanWeighting,
UPGradWeighting,
WeightedAggregator,
)
✅ Approval Status: CONDITIONALThis PR is ready to merge with minor fixes:
The implementation is correct, well-tested, and follows project conventions. The typing is properly handled after #673 was merged. |

Tests: