refactor(aggregation): Add _NonDifferentiable mixin#677
Open
ValerianRey wants to merge 4 commits intomainfrom
Open
refactor(aggregation): Add _NonDifferentiable mixin#677ValerianRey wants to merge 4 commits intomainfrom
ValerianRey wants to merge 4 commits intomainfrom
Conversation
…mixin Non-differentiable aggregators and weightings previously registered a full_backward_pre_hook to raise NonDifferentiableError. This only caught the problem after a graph had already been built through the module. The new _NonDifferentiable mixin (in _mixins.py) wraps __call__ in torch.no_grad(), so no graph is ever constructed in the first place, making the hook and NonDifferentiableError entirely redundant. The mixin is applied to both the aggregator and its paired weighting class for each non-differentiable method (UPGrad, DualProj, PCGrad, GradVac, IMTLG, GradDrop, ConFIG, CAGrad, NashMTL). The test helper assert_non_differentiable is updated to assert that the output has no grad_fn rather than catching NonDifferentiableError. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Mixins should be listed before the main base class so they are resolved first in the MRO. Move Stateful before GramianWeightedAggregator / WeightedAggregator / _MatrixWeighting / _GramianWeighting in GradVac, GradVacWeighting, NashMTL, and _NashMTLWeighting. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
ValerianRey
commented
May 10, 2026
…-differentiable Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
PierreQuinton
approved these changes
May 10, 2026
Contributor
PierreQuinton
left a comment
There was a problem hiding this comment.
Very cool. If you are certain this is run then I'm happy with this. BTW, since this needs to be the first in the objects we inherit from, then we cannot have two such classes. If we do then we'll need to think about it.
| .. warning:: | ||
| This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance | ||
| list. Placing it after will silently have no effect, because :meth:`__call__` would be | ||
| resolved to :class:`torch.nn.Module` before reaching this mixin. |
Contributor
There was a problem hiding this comment.
Is this Python specific MRO to solve diamonds? I guess the reason is to prevent diamonds?
Contributor
There was a problem hiding this comment.
Are you certain this is run BTW? did you try to put a raise or something?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
_NonDifferentiablemixin (in_mixins.py) that wraps__call__intorch.no_grad(), preventing autograd graph construction entirely_utils/non_differentiable.py(NonDifferentiableError+raise_non_differentiable_error) and allregister_full_backward_pre_hookcalls, which are now redundantassert_non_differentiablein tests to check that the output has nograd_fn(graph was never built) instead of catchingNonDifferentiableErroron backwardMotivation
The old approach registered a
full_backward_pre_hookto raiseNonDifferentiableError. This only caught the problem after a graph had already been built through the module — calling a non-differentiable aggregator on arequires_grad=Truetensor would silently produce a result withgrad_fn = BackwardHookFunctionBackward. The new approach is both stricter (no graph is ever built) and simpler (no error class, no hook registration).Design notes
_NonDifferentiableinherits fromnn.Module. This avoids acastand makes the inheritance constraint explicit: it only makes sense for modules.nn.Modulebase class to take effect (documented with a warning in the docstring). If placed after,nn.Module.__call__is resolved first and the mixin is silently bypassed._NonDifferentiableis applied to both aggregators and their paired weightings, so the invariant holds whether a class is used via theAggregatorinterface (autojac) or theWeightinginterface (autogram).🤖 Generated with Claude Code