Skip to content

[QST] How to build a pipeline between epilogue and communication warp groups? #2900

@Gin-Sin

Description

@Gin-Sin

Hi, recently I am working on the examples of CuTe DSL.

In hopper/dense_gemm_persistent.py, the epilogue is like as the following:

                num_prev_epi_tiles = tile_sched.num_tiles_executed * epi_tile_num
                for epi_idx in cutlass.range_constexpr(epi_tile_num):
                    # Copy from accumulators to D registers
                    for epi_v in cutlass.range_constexpr(size_tRS_rD):
                        tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]

                    # Type conversion
                    acc_vec = tRS_rD.load()
                    tRS_rD_out.store(acc_vec.to(self.c_dtype))

                    # Copy from D registers to shared memory
                    epi_buffer = (num_prev_epi_tiles + epi_idx) % cute.size(
                        tRS_sD, mode=[3]
                    )
                    cute.copy(
                        tiled_copy_r2s,
                        tRS_rD_out,
                        tRS_sD[(None, None, None, epi_buffer)],
                    )

                    cute.arch.fence_proxy(
                        cute.arch.ProxyKind.async_shared,
                        space=cute.arch.SharedSpace.shared_cta,
                    )
                    self.epilog_sync_barrier.arrive_and_wait()

                    gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
                    
                    # Copy from shared memory to global memory
                    if warp_idx == self.epi_store_warp_id:
                        cute.copy(
                            tma_atom_c,
                            bSG_sD[(None, epi_buffer)],
                            bSG_gD[(None, gmem_coord)],
                        )
                        tma_store_pipeline.producer_commit()
                        tma_store_pipeline.producer_acquire()

                    self.epilog_sync_barrier.arrive_and_wait()

Since it uses a multi-stage epilogue pipeline, when calling tma_store_pipeline.producer_acquire(), it actually calls cute.arch.cp_async_bulk_wait_group(num_stages-1) under the hood. However, this becomes complex when I need to use another specialized warp to consume the results finished by MMA warp groups, because I cannot know whether the epilogue pipeline is initially empty or a previous tile has finished.

In addition, tma_store_pipeline.producer_tail() can indeed ensure all in-flight TMA Store transactions are finished, but it's called after all tiles are done, which means the communication warp cannot work parallel with the MMA warp groups.

Essentially, what I want to implement is something like this:

### Producer(Epilogue) ###
while work_tile.is_valid_tile:
    # Wait until the flag is empty
    producer.acquire()

    ## Epilogue, TMA store to GMEM
    cute.copy(tma_atom_c, bSG_sD, bSG_gD)
   
    # Signal the consumer tile is ready.
    producer.commit()

### Consumer ###
while work_tile.is_valid_tile:
    # Wait until the tile is ready.
    consumer.wait()

    ## Do something with the tile, such as communications

    # Release the tile
    consumer.release()

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions