Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
*
* **Architecture:**
* - Uses TensorDescriptorUtils for stride-aware descriptor creation
* - Custom RunGemm implementation with descriptor-based tensor views
* - Reuses GemmPipeline and EpiloguePipeline for computation
* - Split-K support via UniversalGemmKernel utilities
*/
Expand Down Expand Up @@ -375,104 +374,6 @@ struct BatchedContractionKernel
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
}

/// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride
/// support
///
/// @details This function performs the core GEMM computation using tensor descriptors to handle
/// arbitrary multi-dimensional stride patterns. It creates tensor views from
/// pre-computed descriptors (stored in kargs), applies padding, creates tile windows,
/// and executes the GemmPipeline and EpiloguePipeline.
///
/// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied)
/// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied)
/// @param ds_ptr Array of pointers to auxiliary D tensor data
/// @param e_ptr Pointer to output tensor E data (after batch offset applied)
/// @param smem_ptr Pointer to shared memory for tile operations
/// @param kargs Kernel arguments containing tensor descriptors and dimension information
/// @param k_size Size of K dimension for this split (for split-K support)
/// @param i_m Starting M index for this block's tile
/// @param i_n Starting N index for this block's tile
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m,
const index_t i_n)
{
// Create tensor views from descriptors (supports arbitrary stride patterns)
auto a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
auto b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
auto e_tensor_view =
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);

// Pad views for boundary handling and optimization (like UniversalGemmKernel)
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});

auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});

auto e_pad_view = pad_tensor_view(
e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});

// Create tile windows from PADDED views
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});

auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});

auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});

// Calculate number of K loops
const index_t num_loop =
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));

// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;

const auto& c_block_tile = GemmPipeline{}(
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);

// Create D windows from descriptors (for each D tensor)
auto ds_block_windows = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);

auto d_tensor_view =
make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);

return make_tile_window(d_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
},
number<NumDTensor>{});

// Run Epilogue Pipeline with descriptor-based D windows
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
}

CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
{
Expand Down Expand Up @@ -671,18 +572,28 @@ struct BatchedContractionKernel
i_splitk);

// Apply K-split offsets and run descriptor-based RunGemm
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];

RunGemm(a_ptr_split,
b_ptr_split,
ds_batch_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset.splitted_k,
i_m,
i_n);
const std::array<const ADataType*, number<1>{}> a_ptr_split = {
a_ptr + splitk_batch_offset.as_k_split_offset[0]};
const std::array<const BDataType*, number<1>{}> b_ptr_split = {
b_ptr + splitk_batch_offset.bs_k_split_offset[0]};

const std::array<typename KernelArgs::AGridDesc_M_K_, number<1>{}> a_grid_desc = {
kargs.a_grid_desc_m_k};
const std::array<typename KernelArgs::BGridDesc_N_K_, number<1>{}> b_grid_desc = {
kargs.b_grid_desc_n_k};

UniversalGemmKernel::RunGemmDesc(a_ptr_split,
b_ptr_split,
ds_batch_ptr,
e_ptr,
smem_ptr,
splitk_batch_offset,
i_m,
i_n,
a_grid_desc,
b_grid_desc,
kargs.ds_grid_desc_m_n,
kargs.e_grid_desc_m_n);
}
};

Expand Down
124 changes: 104 additions & 20 deletions include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,84 @@ struct UniversalGemmKernel
return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
}

// Version of RunGemm using descriptors
// FIXME: Currently Templated to XsList to allow both arrays and tuples for convenience, which
// doesn't enforce same size nor matching types (as with arrays)
template <typename AsList,
typename BsList,
typename DsList,
typename AGridDescs,
typename BGridDescs,
typename DGridDescs,
typename EGridDesc,
bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunGemmDesc(const AsList& as_ptr,
const BsList& bs_ptr,
const DsList& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_0,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n,
const AGridDescs& as_desc,
const BGridDescs& bs_desc,
const DGridDescs& ds_desc,
const EGridDesc& e_desc)
{
// Create tensor views from descriptors (supports arbitrary stride patterns)
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]), as_desc[i]);
},
number<NumATensor>{});

const auto& bs_tensor_view = generate_tuple(
[&](auto i) {
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), bs_desc[i]);
},
number<NumBTensor>{});

const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const DiDataType*>(ds_ptr[i]), ds_desc[i]);
},
number<NumDTensor>{});

auto e_tensor_view =
make_tensor_view<address_space_enum::global>(static_cast<EDataType*>(e_ptr), e_desc);

const auto& gemm_tensors_views_tuple =
make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);

const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensors_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);

const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));

// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(I0);
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);

const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);

if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);

EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}

/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
Expand All @@ -961,32 +1039,38 @@ struct UniversalGemmKernel
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);

const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);

const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));

// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(I0);
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
// FIXME: Refactor to generate descriptors and views separately, then rework signatures
// FIXME: pointers need to be extracted as well
// FIXME: Fails (at least) 1024x1024x256_splitk2 and 1024x1024x256_splitk4 in
// test_gemm_tile_engine_fp16_rcr_quick_coverage_config_compv3_cshuffle_intrawave_False_False_False_False_32x64x16_2x2x1_16x16x16

const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);

if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
auto as_desc = generate_tuple(
[&](auto i) { return gemm_tensor_views_tuple.at(I0)[i].get_tensor_descriptor(); },
number<NumATensor>{});
auto bs_desc = generate_tuple(
[&](auto i) { return gemm_tensor_views_tuple.at(I1)[i].get_tensor_descriptor(); },
number<NumBTensor>{});
auto ds_desc = generate_tuple(
[&](auto i) { return gemm_tensor_views_tuple.at(I2)[i].get_tensor_descriptor(); },
number<NumDTensor>{});
auto e_desc = gemm_tensor_views_tuple.at(I3).get_tensor_descriptor();

EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
RunGemmDesc(as_ptr,
bs_ptr,
ds_ptr,
e_ptr,
smem_ptr_0,
splitk_batch_offset,
block_idx_m,
block_idx_n,
as_desc,
bs_desc,
ds_desc,
e_desc);
}

/**
Expand Down