Skip to content

Conversation

@Priya2698
Copy link
Collaborator

No description provided.

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

Description

  • Replace manual split() calls with outer_split() API in three test functions

  • Remove redundant device mesh and splitting operations on output tensors

  • Simplify sharding logic by using higher-level outer_split method

  • Clean up test code for better maintainability and readability

Changes walkthrough

Relevant files
Enhancement
test_multidevice_sharding.cpp
Refactor manual sharding to use outer_split API                   

tests/cpp/test_multidevice_sharding.cpp

  • Replace split(-1, d, false) with outer_split(-1, d) in
    LoopShardedSplitReshapeIds test
  • Remove manual device mesh and splitting operations on output tensor
    tv1
  • Replace split(-2, d, false) with outer_split(-2, d) in
    LoopShardedMergeReshapeIds test
  • Remove manual device mesh and splitting operations on output tensor
    tv1
  • Replace split(0, d, false) with outer_split(0, d) in
    MultipleTransformReshape test
  • +3/-11   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Correctness Validation

    The PR removes manual output tensor sharding operations for tv1 in two tests (LoopShardedSplitReshapeIds and LoopShardedMergeReshapeIds). This could potentially break correctness if the framework doesn't automatically handle output tensor sharding for reshape operations. Need to verify that the tests still pass and produce correct results without explicit output sharding.

    fusion->addInput(tv0);
    fusion->addOutput(tv1);
    
    auto mesh = DeviceMesh::createForNumDevices(d);
    
    tv0->setDeviceMesh(mesh);
    tv0->outer_split(-1, d);
    tv0->axis(-2)->parallelize(ParallelType::DIDx);
    
    API Behavior Change

    The PR replaces split(-1, d, /*inner_split=*/false) with outer_split(-1, d) and similar changes in other tests. Need to confirm that outer_split provides identical behavior to the previous manual sharding approach, particularly regarding how the split dimension is handled and parallelized.

      tv0->outer_split(-1, d);
      tv0->axis(-2)->parallelize(ParallelType::DIDx);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor inp = at::randn({b, s, d * h * e}, tensor_options_);
      at::Tensor sharded_inp = shardTensor1D(inp, -1, mesh);
    
      at::Tensor nvf_out =
          executor_cache.runFusionWithInputs({sharded_inp})[0].as<at::Tensor>();
      testValidate(
          executor_cache.fusion(),
          {nvf_out},
          {sharded_inp},
          {sharded_inp.view({b, s, h, e})},
          __LINE__,
          __FILE__);
    }
    
    TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int d = communicator_->size();
      const int64_t b = 2, s = 3, h = 8, e = 4;
    
      TensorView* tv0 = makeContigConcreteTensor({b, s, d * h, e});
      TensorView* tv1 = reshape(tv0, {b, s, d * h, e}, {b, s, d * h * e});
    
      fusion->addInput(tv0);
      fusion->addOutput(tv1);
    
      auto mesh = DeviceMesh::createForNumDevices(d);
      tv0->setDeviceMesh(mesh);
      tv0->outer_split(-2, d);
      tv0->axis(-3)->parallelize(ParallelType::DIDx);

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 12, 2026

    Greptile Overview

    Greptile Summary

    Cleaned up three reshape-related multidevice tests by replacing the verbose split(axis, factor, /*inner_split=*/false) calls with the more concise outer_split(axis, factor) helper method and removed redundant manual sharding configuration on output tensors (tv1), which now relies on automatic sharding propagation through reshape operations.

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk
    • The changes are purely refactoring: outer_split() is a direct wrapper around split(..., inner_split=false) confirmed in interface_nodes.h, and removing manual sharding on output tensors simplifies the tests by leveraging automatic sharding propagation, which is the intended behavior for reshape operations
    • No files require special attention

    Important Files Changed

    Filename Overview
    tests/cpp/test_multidevice_sharding.cpp Replaced split(..., inner_split=false) with cleaner outer_split() API and removed redundant manual sharding on reshaped output tensors, relying on automatic sharding propagation

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant