-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Hi, recently I am working on the examples of CuTe DSL.
In hopper/dense_gemm_persistent.py, the epilogue is like as the following:
num_prev_epi_tiles = tile_sched.num_tiles_executed * epi_tile_num
for epi_idx in cutlass.range_constexpr(epi_tile_num):
# Copy from accumulators to D registers
for epi_v in cutlass.range_constexpr(size_tRS_rD):
tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v]
# Type conversion
acc_vec = tRS_rD.load()
tRS_rD_out.store(acc_vec.to(self.c_dtype))
# Copy from D registers to shared memory
epi_buffer = (num_prev_epi_tiles + epi_idx) % cute.size(
tRS_sD, mode=[3]
)
cute.copy(
tiled_copy_r2s,
tRS_rD_out,
tRS_sD[(None, None, None, epi_buffer)],
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
self.epilog_sync_barrier.arrive_and_wait()
gmem_coord = epi_tile_layout.get_hier_coord(epi_idx)
# Copy from shared memory to global memory
if warp_idx == self.epi_store_warp_id:
cute.copy(
tma_atom_c,
bSG_sD[(None, epi_buffer)],
bSG_gD[(None, gmem_coord)],
)
tma_store_pipeline.producer_commit()
tma_store_pipeline.producer_acquire()
self.epilog_sync_barrier.arrive_and_wait()Since it uses a multi-stage epilogue pipeline, when calling tma_store_pipeline.producer_acquire(), it actually calls cute.arch.cp_async_bulk_wait_group(num_stages-1) under the hood. However, this becomes complex when I need to use another specialized warp to consume the results finished by MMA warp groups, because I cannot know whether the epilogue pipeline is initially empty or a previous tile has finished.
In addition, tma_store_pipeline.producer_tail() can indeed ensure all in-flight TMA Store transactions are finished, but it's called after all tiles are done, which means the communication warp cannot work parallel with the MMA warp groups.
Essentially, what I want to implement is something like this:
### Producer(Epilogue) ###
while work_tile.is_valid_tile:
# Wait until the flag is empty
producer.acquire()
## Epilogue, TMA store to GMEM
cute.copy(tma_atom_c, bSG_sD, bSG_gD)
# Signal the consumer tile is ready.
producer.commit()
### Consumer ###
while work_tile.is_valid_tile:
# Wait until the tile is ready.
consumer.wait()
## Do something with the tile, such as communications
# Release the tile
consumer.release()