Skip to content

feat(aggregation): Add MoCo#676

Open
rkhosrowshahi wants to merge 1 commit intoSimplexLab:mainfrom
rkhosrowshahi:feat/aggregation-moco
Open

feat(aggregation): Add MoCo#676
rkhosrowshahi wants to merge 1 commit intoSimplexLab:mainfrom
rkhosrowshahi:feat/aggregation-moco

Conversation

@rkhosrowshahi
Copy link
Copy Markdown
Contributor

@rkhosrowshahi rkhosrowshahi commented May 9, 2026

Adds MoCo (ICLR 2023) as MoCo / MoCoWeighting with MatrixWeighting and Stateful, plus docs and tests. Keeps a moving Jacobian estimate $$Y$$ and task weights $$\lambda$$; returns $$\lambda^\top Y$$. Call reset between runs.

Tests:

  • tests/unit/aggregation/test_moco.py (behavior, reset, getters/setters, edge cases, alignment with MoCoWeighting and y)
  • tests/unit/aggregation/test_values.py (fixed baselines for MoCo / MoCoWeighting)

@rkhosrowshahi
Copy link
Copy Markdown
Contributor Author

Hi @ValerianRey,
Could you please review the code and assign labels?
Thanks!

@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels May 9, 2026
@ValerianRey ValerianRey mentioned this pull request May 9, 2026
Comment on lines +114 to +118
if matrix.shape[0] == 0:
self.reset()
self._y = matrix.detach().clone()
self._state_key = (matrix.shape[0], matrix.shape[1], matrix.device, matrix.dtype)
return matrix.new_empty((0,))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen (I think we can safely assume that matrix.shape[0] > 0 in the forward call to an aggregator. Maybe we want to specify that somewhere @ValerianRey

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey May 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, let's get rid of this. And obviously matrices should have at least 1 row, it's not even like autojac can output jacobians with 0 rows. So it would be when a user would really want to aggregate their own empty matrix with this specific aggregator. This will not happen.

beta_step = self.beta / (self.step**self.beta_sigma)
gamma_step = self.gamma / (self.step**self.gamma_sigma)

with torch.no_grad():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird, why not having a no_grad around the whole function instead? Also what is the current philosphy on grads of outputs of aggregators @ValerianRey ? Should we unify this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are just two cases:

  1. The aggregator/weighting is differentiable: and we do not want to use any torch.no_grad(), so that the graph is correctly built and stored

  2. The aggregator/weighting is non-differentiable (could be for example because some operations are made on numpy arrays, like for UPGrad, DualProj, and a few others, or for some other niche reasons, like PCGrad, and I think similarly GradVac): we want to raise an error when we try to backward through it (we do that already) and I even think we would like to prevent graph construction by wrapping the forward in a torch.no_grad() (which we don't do currently). And now that I think of it, I don't even think we need to raise an error when calling backward on a non-differentiable module if its forward is wrapped in a torch.no_grad(). No graph will ever be created to begin with, so autograd will never try to backward through the module.

So I think that we should improve on that:

  • Differentiable aggregators/weighting should inherit from a Differentiable mixin. It will not do anything, and it wont be public, but it will serve as internal documentation.
  • Non-differentiable aggregators/weighting should inherit from a NonDifferentiable mixin (similarly protected). It could maybe wrap the forward pass in a torch.no_grad, (and maybe still make it raise an error if we try to differentiate through it).

About MoCo, idk if it's differentiable or not. I guess that changing the value of y inplace will lead to runtime error when we try to differentiate through it (the same reason why we made PCGrad non-differentiable). Need to verify that. If it's not, I think we should make it inherit NonDifferentiable. If it is, let's remove the torch.no_grad and matrix.detach. In any case, let's remove those two things because NonDifferentiable will handle it itself if we add it.

I'll ask claude to work on that.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #677

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, upon merge of #677 this should be removed, and then the aggregator/weighting pair should implement the Non differentiable mixin. @rkhosrowshahi Take a look at #677 as it specifically needs to be inherited first.

@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan please review this and compare this implementation to that of LibMTL (https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/MoCo.py)

@opencode-agent
Copy link
Copy Markdown
Contributor

Based on my analysis of the TorchJD MoCo implementation and comparison with the LibMTL reference, here is my comprehensive review:

Comparison Summary

Key Differences

Aspect TorchJD Implementation LibMTL Reference
Gradient Normalization MISSING - Uses raw matrix directly ✅ Normalizes by grad norm and scales by loss: grads[tn]/(grads[tn].norm()+1e-8)*losses[tn]
Architecture MoCoWeighting + MoCo aggregator classes Single MoCo class with backward()
State Management Stateful mixin with _state_key validation init_param() method
Parameters Constructor with property validation Passed via kwargs to backward()
API forward(matrix) returns aggregated vector backward(losses) handles grad computation + aggregation
Empty Matrix Explicit handling (lines 114-118) Not handled

Critical Issue: Missing Gradient Normalization

The TorchJD implementation is missing the gradient normalization step from LibMTL. In LibMTL (line 43-44):

for tn in range(self.task_num):
    grads[tn] = grads[tn]/(grads[tn].norm()+1e-8)*losses[tn]

This normalization is critical to the MoCo algorithm as described in the paper. Without it:

  1. The moving estimate y may not properly track normalized gradient directions
  2. The task weight updates may be skewed by gradient magnitude differences
  3. The algorithm may not converge as expected

Recommendation: Add the normalization step or document why it's omitted (if the expectation is that the input matrix is already normalized).

Review Comment Addressing

Per the existing PR reviews by @PierreQuinton and @ValerianRey:

  1. Line 114-118 (empty matrix check): ✅ Agree with removal - matrices should always have ≥1 row in practice

  2. Line 129-130 (torch.no_grad and detach):

    • Current: torch.no_grad() context + explicit detach() calls
    • LibMTL: Only torch.no_grad() wrapper
    • Recommendation: Follow ValerianRey's suggestion to create a NonDifferentiable mixin that wraps forward in torch.no_grad(), consistent with other non-differentiable aggregators (PCGrad, CAGrad, etc.)

Code Structure Assessment

Strengths:

  • Clean separation between MoCoWeighting and MoCo aggregator
  • Proper parameter validation with setters
  • Good test coverage for state management and edge cases
  • Proper state tracking with _state_key for device/dtype validation

Areas for Improvement:

  1. Add gradient normalization to match LibMTL/paper
  2. Remove empty matrix handling (lines 114-118)
  3. Consider NonDifferentiable mixin instead of explicit no_grad + detach
  4. Documentation should clarify if input matrix is expected to be pre-normalized

Verification Recommendation

Before merging, verify numerical equivalence with LibMTL on a simple test case to ensure the implementation produces identical outputs given the same inputs (accounting for the normalization difference).

New%20session%20-%202026-05-10T00%3A41%3A43.839Z
opencode session  |  github run

@ValerianRey
Copy link
Copy Markdown
Contributor

Add gradient normalization to match LibMTL/paper

Let's not do that. We will add normalization as wrappers.

Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few things to change/clarify after review of this and the paper. Tests need a more thorough review.

if self.rho != 0.0:
eye = torch.eye(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
yy_t = yy_t + self.rho * eye
lambd = torch.softmax(lambd - gamma_step * (yy_t @ lambd), dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paper, the softmax is supposed to be a projection to the probability simplex. Why is this a softmax (it is on the probability simplex, but not a projection)?

gamma_step = self.gamma / (self.step**self.gamma_sigma)

with torch.no_grad():
y = y - beta_step * (y - matrix.detach())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of matrix/y are the critical part. This is slightly costly because we have another time the full Jacobian in memory. I would not use y and use self._y instead to garbage collect the previous value and save one full Jacobian. Actualy the most memory efficient implementation would be:

Suggested change
y = y - beta_step * (y - matrix.detach())
self._y = (1-\beta_step) * self._y
self._y += beta_step * matrix.detach()

With this you have exactly two Jacobians stored.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will be complicated to adapt with the normalization wrapper as we normalize the y itself, not matrix. So I think the normalization step of equation 6 needs to be implemented here. Note that what opencode suggested above from LibMTL doesn't match what is said in the paper, so I'm not sure which one to use.

self._y = y
self._lambd = lambd

return lambd
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is incorrect. In equation 10 of the paper, they use lambda @ y as an update, not lambda @ matrix. I think this would make moco not weighted (the row spans of matrix and y can be very different in general).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I just saw below that while this is a weighting, the moco aggregator is not weighted. I assumed wrongly that it was.

raise RuntimeError("The moving gradient estimate is not initialized yet.")
return self._y

def _ensure_state(self, matrix: Matrix) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the role of this function to initialize the state if none? we call it differently if so, maybe something in the direction of _conditionally_initialize_state (but probably improvable).

Comment on lines +201 to +202
if matrix.shape[0] == 0:
return matrix.sum(dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove

Suggested change
if matrix.shape[0] == 0:
return matrix.sum(dim=0)

Comment on lines +205 to +206
if matrix.requires_grad:
vector = vector + 0.0 * matrix.sum(dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not do that. No reason to create a wrong differentiation graph for vector. Maybe we want to make this aggregator/weighting pair non-differentiable (because we don't want to keep all the Jacobians for all times). This can be done after #677 is merged.

Suggested change
if matrix.requires_grad:
vector = vector + 0.0 * matrix.sum(dim=0)

@PierreQuinton
Copy link
Copy Markdown
Contributor

/opencode:Plan Can you review the tests? compare to the tests of the aggregation codebase and try to match the quality and reason about all concepts that needs to be tested.

@opencode-agent
Copy link
Copy Markdown
Contributor

Based on my thorough review of the MoCo tests compared to the existing aggregation test patterns, here's my comprehensive assessment:

Test Quality Review: MoCo Implementation

Overall Assessment: Good Quality with Minor Improvements Needed

The MoCo tests are well-structured and follow most of the established patterns in the codebase. They provide good coverage of the key functionality.


What's Well Covered (Strengths)

Aspect Coverage Notes
Representation tests test_representations() follows the exact pattern used in GradVac, DualProj, etc.
State management test_reset_restores_first_step_behavior, test_weighting_reset_clears_state - excellent coverage of stateful behavior
Aggregator/Weighting alignment test_weighting_matches_aggregator_state_update verifies the relationship between MoCo and MoCoWeighting
Parameter validation Getters, setters, and rejection of negative values all tested
Non-differentiability test_non_differentiable uses the standard assertion
Structure tests test_expected_structure parametrized across typical and scaled matrices
Fixed value baselines Added to test_values.py with expected outputs

What's Missing (Gaps)

Pattern Status Recommendation
test_zero_rows_returns_zero_vector ⚠️ Should be removed Per PR review comments by @PierreQuinton and @ValerianRey, empty matrix handling (lines 114-118 in _moco.py) should be removed. This test should be removed accordingly.
typical_matrices_2_plus_rows ❌ Missing GradVac uses this for stateful tests since 1-row matrices don't make sense for multi-step stateful aggregators. MoCo should similarly filter.
Permutation invariance ❌ Not tested Most aggregators test this (Mean, DualProj, UPGrad). Is MoCo permutation invariant? If yes, add test.
Reproducibility with manual seed ❌ Not tested GradVac has test_reproducible_with_manual_seed - should MoCo have this too?

Test Structure Comparison

MoCo vs. GradVac (most similar - both stateful):

Test Category GradVac MoCo Gap
Representation None
Zero rows/columns None
Reset behavior None
Reproducibility (seed) MoCo missing
Structure (parametrized) None
Non-differentiable None
Property tests (permutation, etc.) Both missing (acceptable for stateful)
Weighting/aggregator alignment None

MoCo vs. DualProj/UPGrad (property-based aggregators):

Property Test DualProj UPGrad MoCo Should MoCo have?
non_conflicting Test if applicable
permutation_invariant Test if applicable
linear_under_scaling Probably not (stateful)
strongly_stationary Probably not (stateful)

Specific Issues

  1. Line 29-31 (test_zero_rows_returns_zero_vector):

    • This tests behavior that should be removed per PR review
    • Should be deleted when empty matrix handling is removed from _moco.py
  2. Line 40-45 (test_reset_restores_first_step_behavior):

    • Uses typical_matrices which includes 1-row matrices
    • For stateful aggregators like GradVac, they use typical_matrices_2_plus_rows
    • MoCo should do the same (stateful aggregators need ≥2 rows to be meaningful across steps)
  3. Missing edge case: No test for y getter raising RuntimeError when uninitialized (similar to weighting test on line 74-76)


Recommendations

High Priority:

  1. Remove test_zero_rows_returns_zero_vector (matches PR review to remove lines 114-118 from implementation)
  2. Change test_reset_restores_first_step_behavior to use typical_matrices_2_plus_rows instead of typical_matrices

Medium Priority:
3. Add test for aggregator's y property raising RuntimeError when uninitialized (mirror of line 74-76)
4. Consider adding test_reproducible_with_manual_seed if MoCo has any stochastic behavior
5. Investigate and test mathematical properties:

  • Is MoCo permutation invariant? (Most aggregators are)
  • Is MoCo non-conflicting? (Important MTL property)

Low Priority:
6. Consider consolidating parameter setter tests using parametrization (currently lines 128-177 are verbose but clear)


Code Quality Notes

  • Good: Clear test names, consistent with codebase patterns
  • Good: Proper use of fixtures and parametrization
  • Good: Tests both MoCo and MoCoWeighting thoroughly
  • Good: Uses assert_close from torch.testing appropriately
  • ⚠️ Minor: Some tests (like setters) are verbose but readable

Summary

The MoCo tests are production-ready with only minor adjustments needed:

  1. Remove the zero-rows test to align with the PR review
  2. Filter out 1-row matrices from reset tests (use typical_matrices_2_plus_rows)
  3. Consider adding permutation invariance test if applicable

The test quality matches or exceeds most existing aggregation tests, particularly in the area of state management testing.

New%20session%20-%202026-05-10T07%3A50%3A52.524Z
opencode session  |  github run

@PierreQuinton
Copy link
Copy Markdown
Contributor

From the opencode review please do:

  • 1
  • 3
  • 4: It is not stochastic, but still it would be good to have reproducibility, so if we run it twice with the same sequence of matrices, and a reset in between, it should return the same results.

For 5, no need, the properties are not satisfied (permutation invariance could be changed to permuting the sequence of all inputs to the aggregator with the same permutation, in which case it would become permutation invariant @ValerianRey )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants