Skip to content

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 6, 2026

Screenshot 2026-02-09 at 1 24 11 PM

The broadcast version is very slow so I am not comparing timings until we integrate this with multicast

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Review updated until commit be4f66b

Description

  • Implement StreamBroadcast communication type for broadcast-based allgather in host for-loop

  • Add support for DIDx to Stream parallel type conversion using ring allgather pattern

  • Update communication lowering to accept root parameter for stream broadcasts

  • Add test coverage for column-parallel linear forward with StreamBroadcast

Changes walkthrough

Relevant files
Enhancement
lower_to_communication.cpp
Implement StreamBroadcast communication lowering                 

csrc/host_ir/lower_to_communication.cpp

  • Added lowerToStreamBroadcast function to create StreamBroadcast
    communications
  • Enhanced getCommunicationInfo to detect DIDx->Stream transitions for
    ring allgather
  • Updated getCommunicationLayout and convertSingleOpToCommunication to
    handle StreamBroadcast
  • Added root parameter to convertSingleOpToCommunication for stream
    broadcast decomposition
  • +60/-5   
    lowering.cpp
    Update lowering to support stream broadcast root parameter

    csrc/host_ir/lowering.cpp

  • Pass innermost loop index as root parameter to
    convertSingleOpToCommunication
  • Skip sharding validation for StreamBroadcast communications
  • +4/-2     
    ops.cpp
    Improve error messaging in shardByStream                                 

    csrc/host_ir/ops.cpp

    • Enhanced error message in shardByStream for better debugging
    +4/-2     
    convert_op_to_communication.cpp
    Update communication conversion pass for root parameter   

    csrc/host_ir/pass/convert_op_to_communication.cpp

  • Updated convertSingleOpToCommunication call to include root parameter
  • +4/-1     
    communication.cpp
    Add StreamBroadcast communication type support                     

    csrc/multidevice/communication.cpp

  • Added StreamBroadcast to CommunicationType enum and output operator
  • Updated hasRoot and isReduction functions to include StreamBroadcast
  • Modified postSingleCommunication to handle StreamBroadcast using
    broadcast logic
  • +6/-0     
    lower_to_communication.h
    Update function signature for root parameter support         

    csrc/host_ir/lower_to_communication.h

  • Updated convertSingleOpToCommunication signature to include optional
    root parameter
  • Added documentation explaining root parameter usage
  • +6/-0     
    communication.h
    Add StreamBroadcast to communication type enum                     

    csrc/multidevice/communication.h

  • Added StreamBroadcast to CommunicationType enum
  • Updated documentation to explain StreamBroadcast differences from
    Broadcast
  • +6/-1     
    Tests
    test_overlap.py
    Add tests for StreamBroadcast in column-parallel linear   

    tests/python/multidevice/test_overlap.py

  • Added column_parallel_linear_forward function demonstrating
    StreamBroadcast usage
  • Added test_column_parallel_linear_forward to verify StreamBroadcast
    functionality
  • Added benchmark test for performance evaluation
  • +114/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Performance Concern

    The PR description mentions "The broadcast version is very slow so I am not comparing timings until we integrate this with multicast". This suggests the current StreamBroadcast implementation may have performance issues. The reviewer should validate that this implementation provides expected performance benefits or at least doesn't introduce significant regressions compared to existing approaches.

    void lowerToStreamBroadcast(
        TensorView* input_tv,
        TensorView* output_tv,
        const CommunicatorBackend backend,
        std::vector<Expr*>& comms,
        Val* root) {
      const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
      const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
      NVF_ERROR_EQ(
          sender_mesh,
          receiver_mesh,
          "StreamBroadcast sender and receiver meshes must be the same. Given ",
          sender_mesh,
          " and ",
          receiver_mesh);
      Team team = receiver_mesh.vector();
      comms.push_back(IrBuilder::create<Communication>(
          CommunicationType::StreamBroadcast,
          output_tv,
          input_tv,
          team,
          root,
          c10d::ReduceOp::RedOpType::UNUSED,
          backend));
    }
    Incomplete Implementation

    There's a TODO comment at line 425: "TODO: Lower to SendRecv if swizzle is present." This suggests the StreamBroadcast implementation is incomplete and may not handle all intended cases. The reviewer should verify if this limitation affects the correctness or completeness of the implementation.

    if (p_loop_did && !c_loop_did) {
      IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
      // Check if we are going from DIDx -> Stream, which is a ring allgather.
      // This can be executed as a broadcast or send recvs, which is decided
      // by the presence of a swizzle in the stream id definition.
      // TODO: Lower to SendRecv if swizzle is present.
      if (c_stream_id != nullptr) {
        IterDomain* c_stream_logical_id =
            getLogicalFromLoopId(consumer, c_stream_id);
        if (c_stream_logical_id == p2c_map.at(p_logical_id)) {
          NVF_CHECK(
              same_mesh,
              "Broadcast based allgather in stream parallel requires same "
              "mesh.");
          fill_communication_info(
              CommunicationType::StreamBroadcast,
              p_logical_id,
              c_stream_logical_id);
          continue;
        }
      }
    Test Coverage

    The new test_column_parallel_linear_forward test validates the StreamBroadcast functionality but only checks for the presence of broadcast events, not the actual correctness of the computation. The reviewer should ensure the test provides sufficient validation of the StreamBroadcast behavior and consider adding more comprehensive correctness tests.

    @pytest.mark.mpi
    def test_column_parallel_linear_forward(multidevice_test):
        # This is a port of CollectiveBasedOverlapTest.ColumnAndSequenceParallelLinear_Forward.
        # The difference is we are using broadcast based overlapping instead of send/recv.
        h, t = 2, 24
        d = multidevice_test.size
        if (h * 4) % d != 0:
            pytest.skip(
                f"Row-parallel linear requires {h * 4} to be divisible by world size {d}."
            )
        if t % d != 0:
            pytest.skip(
                f"Column-parallel linear requires {t} to be divisible by world size {d}."
            )
    
        fd = column_parallel_linear_forward(h, d)
    
        inp_ref = torch.testing.make_tensor(t, h, dtype=torch.int32, device="cpu").to(
            torch.bfloat16
        )
        weight_ref = torch.testing.make_tensor(
            4 * h, h, dtype=torch.int32, device="cpu"
        ).to(torch.bfloat16)
    
        inp = multidevice_test.shard_tensor(inp_ref, fd.fusion.inputs()[0])
        weight = multidevice_test.shard_tensor(weight_ref, fd.fusion.inputs()[1])
    
        out_ref = torch.nn.functional.linear(inp_ref.cuda(), weight)
    
        with torch.profiler.profile(record_shapes=True) as prof:
            (out,) = fd.execute([inp, weight], _enable_options=["host_ir_lowering"])
        torch.testing.assert_close(out, out_ref)
        broadcast_events = [
            event for event in prof.events() if "ncclDevKernel_Broadcast" in event.name
        ]
        assert len(broadcast_events) == d

    @Priya2698 Priya2698 marked this pull request as ready for review February 9, 2026 21:10
    @Priya2698 Priya2698 requested a review from wujingyue February 9, 2026 21:11
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Greptile Overview

    Greptile Summary

    Implements broadcast-based allgather for DIDx→Stream resharding in host for-loops by introducing a new StreamBroadcast communication type. The implementation adds detection logic in getCommunicationInfo to identify when a producer is sharded on DIDx and consumer is sharded on Stream with matching logical IDs, then lowers this to a StreamBroadcast communication using the host loop index as the broadcast root.

    • Adds lowerToStreamBroadcast function to create StreamBroadcast communication IR with the for-loop index as root
    • Updates getCommunicationInfo to detect DIDx→Stream pattern and return StreamBroadcast communication type
    • Modifies convertSingleOpToCommunication signature to accept optional root parameter (required for StreamBroadcast)
    • Skips allocation sharding check for StreamBroadcast in lowerSegment since it reuses existing allocation pattern
    • Routes StreamBroadcast through existing postBroadcast handler in runtime communication code
    • Adds comprehensive tests including functional validation and benchmark for column-parallel linear forward pass

    Per the PR description, performance comparison is deferred until multicast integration is complete.

    Confidence Score: 4/5

    • This PR is safe to merge with proper testing of the new StreamBroadcast code path
    • The implementation is well-structured and follows existing patterns in the codebase. Previous compilation and pytest collection issues appear resolved. The main risk is that StreamBroadcast only works within host for-loops (will error otherwise), which is intentional per developer comments but requires careful usage. Performance characteristics are acknowledged as suboptimal until multicast integration.
    • No files require special attention - the implementation follows established patterns and includes appropriate error handling

    Important Files Changed

    Filename Overview
    csrc/host_ir/lower_to_communication.cpp Adds lowerToStreamBroadcast function and detection logic in getCommunicationInfo to handle DIDx→Stream resharding as broadcast-based allgather
    csrc/host_ir/lower_to_communication.h Adds optional root parameter to convertSingleOpToCommunication with default nullptr, documented for StreamBroadcast use case
    csrc/host_ir/lowering.cpp Passes loop index as root to convertSingleOpToCommunication and skips allocation sharding check for StreamBroadcast
    tests/python/multidevice/test_overlap.py Adds helper function and two tests (functional + benchmark) for column-parallel linear forward with broadcast-based allgather

    Sequence Diagram

    sequenceDiagram
        participant Lowering as lowerSegment
        participant Convert as convertSingleOpToCommunication
        participant Info as getCommunicationInfo
        participant Lower as lowerToStreamBroadcast
        participant Comm as Communication IR
        
        Lowering->>Convert: Call with expr, device_id, loop->index()
        Convert->>Info: getCommunicationInfo(expr)
        Info->>Info: Detect DIDx→Stream resharding
        Info->>Info: Check if c_stream_logical_id == p2c_map[p_logical_id]
        Info-->>Convert: Return StreamBroadcast type
        Convert->>Convert: Check root != nullptr for StreamBroadcast
        Convert->>Lower: lowerToStreamBroadcast(input_tv, output_tv, backend, comms, root)
        Lower->>Comm: Create Communication(StreamBroadcast, team, root)
        Comm-->>Convert: Return communication expr
        Convert-->>Lowering: Return vector of comms
        Lowering->>Lowering: Skip allocation sharding check for StreamBroadcast
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Additional Comments (1)

    csrc/multidevice/communication.cpp
    Root validation rejects non-const

    Communication::validate only enforces the root/type contract when root() is a const integral scalar. For StreamBroadcast, root is the host loop index (non-const), so hasRoot(type()) is never validated and invalid roots (e.g., non-integral or negative-at-runtime) can slip through. This can lead to runtime failures when postBroadcast interprets the root.

    Consider extending validation to require root() be Index dtype for StreamBroadcast/rooted collectives even when not constant, and/or add runtime checks where the root is consumed.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines 54 to 58
    std::vector<Expr*> convertSingleOpToCommunication(
    Expr* c,
    DeviceIdxType my_device_idx,
    Val* host_loop_index = nullptr,
    const CommunicatorBackend backend = CommunicatorBackend::kNccl);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    StreamBroadcast not handled

    convertSingleOpToCommunication now defaults host_loop_index to nullptr, but it can return CommunicationType::StreamBroadcast from getCommunicationInfo and then hard-errors if host_loop_index == nullptr (lower_to_communication.cpp:657-663). This makes the existing ConvertOpToCommunication pass (which calls it at top-level with nullptr) crash for any fusion that produces this DIDx→Stream resharding outside the host-loop lowering path.

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 9, 2026

    Additional Comments (1)

    csrc/multidevice/communication.cpp
    Non-constant root accepted

    Communication::validate only checks root/type consistency when root() is a const integral scalar (communication.cpp:238-246). For StreamBroadcast, the root is intentionally a non-const Val* (host loop index), so this validation becomes a no-op: invalid roots (e.g., negative at runtime, wrong dtype) won’t be rejected here but later code assumes a valid rank/root. If StreamBroadcast relies on runtime root, it still needs a type/dtype/range validation path for non-const roots (at least DataType::Index and non-negative).

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    It's great to see this work functionally!

    type == CommunicationType::SendRecv ||
    type == CommunicationType::AllToAll) {
    type == CommunicationType::AllToAll ||
    type == CommunicationType::StreamBroadcast) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I understood the motivation but can this be consolidated into the same Broadcast?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I kept it separate so I don't need to check for the StreamParallel Type in lowerToBroadcast when deciding the root. Posting the communication uses a common function.
    I also wanted to first integrate SendRecv based decomposition and then reconsider the design based on what is needed for both these comms.

    "Destination allocation should be sharded on stream after "
    "shardAllocationAsLoop: ",
    destination);
    destination->domain()->toString(0, /*loop_only=*/false));
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I guess destination is still worth printing in addition to the domain?

    TensorView* in = communication->in();
    TensorView* out = communication->out();
    if (haveDifferentShardings(
    if (communication->type() != CommunicationType::StreamBroadcast &&
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    While I understood the motivation and that the tests pass, I'm thinking how to make this cleaner.

    Is it possible to frame this as an optimization? For example, if in can be sharded on Stream in the same way as communication, insert a shardByStream.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yeah, I do think this should be merged in the shardByStream or some other logic.
    For now, I kept it simple since I am not sure how it will look like with Collective Permute representation (a composite Communication, P2P comms corresponding to SendRecv etc.), so I took the verbose approach as an interim step.

    Let me see what I can do in this PR itself.


    // This ignores device dimensions on reduction axis.
    auto producer_pt_to_did =
    auto producer_pt_to_id =
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    auto producer_pt_to_id =
    const std::unordered_map<ParallelType, IterDomain*>& producer_pt_to_id =

    auto producer_pt_to_id =
    mapDeviceAndStreamParallelTypeToId(producer->getLoopDomain());
    auto consumer_pt_to_did =
    auto consumer_pt_to_id =
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    ditto

    std::vector<Expr*> convertSingleOpToCommunication(
    Expr* e,
    DeviceIdxType my_device_idx,
    Val* host_loop_index,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    Val* host_loop_index,
    Val* root,

    Some communications (e.g. broadcast, reduce, gather, and scatter) are rooted. So far, we've been deciding the root according to device meshes. However, this use makes a case for passing in the root from the lowering process.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants