You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Details: I’ve implemented a temporary fix in 3D sharding for triangle updates #5890, but this is not a long-term solution. Sharding propagation needs a fundamental rework to establish better defaults and cleaner logic.
Details: The commit updates getCommunicationInfo to support multi-dimensional sharding. It reuses haveDifferentShardings to identify inconsistencies between input and output TensorView objects. The commit needs cleanup and further test verification to be merged.
Technical Debt: Per Extend IdModel to map DIDs for certain patterns. #3987, haveDifferentShardings is currently bottlenecked by the expensive ExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.
Details: A generated transpose kernel is hitting misaligned memory access errors. This occurs during the transposition between the local Einsum and the downstream ReduceScatter. For context, this transposition was introduced by ReorderShardedAxisPass to ensure the scattered axis of the ReduceScatter is allocated outermost.
Details: The current naive AllGather preceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.
Proposed Fix: We need to implement stream-parallelization to enable:
Ring-based AllGather (with Swizzle), or
Broadcast-based communication (without Swizzle). AFAICT, fast broadcast requires multicasting and therefore symmetric memory.
The changes add stricter error checking with NVF_ERROR_EQ and NVF_THROW, but the new error handling in getCommunicationInfo could be too aggressive. The code now throws errors when "Not sharded on this parallel type" which might break legitimate use cases where some parallel types aren't used. This needs validation against existing test suites.
NVF_THROW("Not sharded on this parallel type: ", pt);
The new logic using haveDifferentShardings() to filter parallel types before processing could introduce subtle bugs. The previous code continued processing all parallel types, while the new code skips non-different shardings entirely. This change in behavior should be thoroughly tested with various sharding configurations.
if (!haveDifferentShardings(producer, consumer, {pt})) {
continue;
}
While the AlphaFold3 test is comprehensive, it only tests successful execution without validating correctness against a reference implementation. The test should include torch.testing.assert_close() comparisons to ensure the 3D sharding produces mathematically correct results, especially given the complexity of triangle updates and the mentioned transpose kernel issues.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I got the "triangle updates incoming" test passing in this PR. Below are the key issues identified, workarounds and their current status:
1. Sharding Propagation Rework
2. Multi-Dimensional Sharding &
getCommunicationInfogetCommunicationInfoto support multi-dimensional sharding. It reuseshaveDifferentShardingsto identify inconsistencies between input and outputTensorViewobjects. The commit needs cleanup and further test verification to be merged.haveDifferentShardingsis currently bottlenecked by the expensiveExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.3. Misaligned Memory Access in Transpose Kernels
ReorderShardedAxisPassto ensure the scattered axis of theReduceScatteris allocated outermost.4. High memory usage
AllGatherpreceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.cc @DejunL