-
Notifications
You must be signed in to change notification settings - Fork 78
Broadcast-based allgather in host for-loop #5925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit be4f66b Description
|
| Relevant files | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Performance Concern
|
|
!test |
Greptile OverviewGreptile SummaryImplements broadcast-based allgather for DIDx→Stream resharding in host for-loops by introducing a new
Per the PR description, performance comparison is deferred until multicast integration is complete. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Lowering as lowerSegment
participant Convert as convertSingleOpToCommunication
participant Info as getCommunicationInfo
participant Lower as lowerToStreamBroadcast
participant Comm as Communication IR
Lowering->>Convert: Call with expr, device_id, loop->index()
Convert->>Info: getCommunicationInfo(expr)
Info->>Info: Detect DIDx→Stream resharding
Info->>Info: Check if c_stream_logical_id == p2c_map[p_logical_id]
Info-->>Convert: Return StreamBroadcast type
Convert->>Convert: Check root != nullptr for StreamBroadcast
Convert->>Lower: lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root)
Lower->>Comm: Create Communication(StreamBroadcast, team, root)
Comm-->>Convert: Return communication expr
Convert-->>Lowering: Return vector of comms
Lowering->>Lowering: Skip allocation sharding check for StreamBroadcast
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 2 comments
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 3 comments
Additional Comments (1)
Consider extending validation to require |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 2 comments
| std::vector<Expr*> convertSingleOpToCommunication( | ||
| Expr* c, | ||
| DeviceIdxType my_device_idx, | ||
| Val* host_loop_index = nullptr, | ||
| const CommunicatorBackend backend = CommunicatorBackend::kNccl); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StreamBroadcast not handled
convertSingleOpToCommunication now defaults host_loop_index to nullptr, but it can return CommunicationType::StreamBroadcast from getCommunicationInfo and then hard-errors if host_loop_index == nullptr (lower_to_communication.cpp:657-663). This makes the existing ConvertOpToCommunication pass (which calls it at top-level with nullptr) crash for any fusion that produces this DIDx→Stream resharding outside the host-loop lowering path.
Additional Comments (1)
|
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's great to see this work functionally!
| type == CommunicationType::SendRecv || | ||
| type == CommunicationType::AllToAll) { | ||
| type == CommunicationType::AllToAll || | ||
| type == CommunicationType::StreamBroadcast) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understood the motivation but can this be consolidated into the same Broadcast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept it separate so I don't need to check for the StreamParallel Type in lowerToBroadcast when deciding the root. Posting the communication uses a common function.
I also wanted to first integrate SendRecv based decomposition and then reconsider the design based on what is needed for both these comms.
| "Destination allocation should be sharded on stream after " | ||
| "shardAllocationAsLoop: ", | ||
| destination); | ||
| destination->domain()->toString(0, /*loop_only=*/false)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess destination is still worth printing in addition to the domain?
| TensorView* in = communication->in(); | ||
| TensorView* out = communication->out(); | ||
| if (haveDifferentShardings( | ||
| if (communication->type() != CommunicationType::StreamBroadcast && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner.
Is it possible to frame this as an optimization? For example, if in can be sharded on Stream in the same way as communication, insert a shardByStream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I do think this should be merged in the shardByStream or some other logic.
For now, I kept it simple since I am not sure how it will look like with Collective Permute representation (a composite Communication, P2P comms corresponding to SendRecv etc.), so I took the verbose approach as an interim step.
Let me see what I can do in this PR itself.
|
|
||
| // This ignores device dimensions on reduction axis. | ||
| auto producer_pt_to_did = | ||
| auto producer_pt_to_id = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto producer_pt_to_id = | |
| const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id = |
| auto producer_pt_to_id = | ||
| mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain()); | ||
| auto consumer_pt_to_did = | ||
| auto consumer_pt_to_id = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| std::vector<Expr*> convertSingleOpToCommunication( | ||
| Expr* e, | ||
| DeviceIdxType my_device_idx, | ||
| Val* host_loop_index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Val* host_loop_index, | |
| Val* root, |
Some communications (e.g. broadcast, reduce, gather, and scatter) are rooted. So far, we've been deciding the root according to device meshes. However, this use makes a case for passing in the root from the lowering process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
The broadcast version is very slow so I am not comparing timings until we integrate this with multicast