Skip to content
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Added
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added gptoss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.

### Changed

Expand Down
5 changes: 4 additions & 1 deletion example/ck_tile/01_fmha/example_fmha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
"Comma-separated list of length 'b'. If empty, no override.")
.insert("init_sink", "0", "value to init the output tensor sink value for validation");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
Expand Down Expand Up @@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
int init_sink_value = arg_parser.get_int("init_sink");

ck_tile::stream_config stream_config{nullptr,
true,
Expand Down Expand Up @@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
init_method,
seed,
do_validation,
init_sink_value,
stream_config,
json);
}
Expand Down
29 changes: 21 additions & 8 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr

const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
const void* sink_ptr;

ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
Expand Down Expand Up @@ -519,6 +523,7 @@ struct fmha_batch_prefill_args
// 1) +
// kargs.kv_last_page_lens[b]
const void* seqstart_q_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -627,7 +632,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
args.cu_seqlen_k_ptr,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -677,7 +683,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
args.cu_seqlen_k_ptr,
args.sink_ptr);
}
}();

Expand Down Expand Up @@ -837,7 +844,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q);
args.min_seqlen_q,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -882,7 +890,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
}();

Expand Down Expand Up @@ -949,7 +958,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -997,7 +1007,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
}();

Expand Down Expand Up @@ -1164,7 +1175,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -1220,7 +1232,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.sink_ptr);
}
}();

Expand Down
78 changes: 69 additions & 9 deletions example/ck_tile/01_fmha/fmha_fwd_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::string init_method,
uint32_t seed,
int do_validation,
int init_sink_value,
const ck_tile::stream_config& stream_config,
std::optional<std::string> json = std::nullopt)
{
Expand Down Expand Up @@ -527,6 +528,7 @@ fwd_result fmha_fwd_run(mode_enum mode,

ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
ck_tile::HostTensor<KDataType> k_host(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
Expand Down Expand Up @@ -609,6 +611,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
bias_host);
}

else if(init_method == "ni")
{
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
Expand Down Expand Up @@ -695,10 +698,15 @@ fwd_result fmha_fwd_run(mode_enum mode,

iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);

if(init_sink_value != 0)
{
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 100.f, next_seed()}(
sink_host);
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
Expand Down Expand Up @@ -743,6 +751,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
sink_buf.ToDevice(sink_host.data());
knew_buf.ToDevice(knew_host.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
Expand Down Expand Up @@ -971,7 +980,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();

if(init_sink_value != 0)
args.sink_ptr = sink_buf.GetDeviceBuffer();
else
args.sink_ptr = nullptr;
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // unused in group mode
args.hdim_q = hdim_q;
Expand Down Expand Up @@ -1675,19 +1687,67 @@ fwd_result fmha_fwd_run(mode_enum mode,
mask.type == mask_enum::mask_top_left));
}
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
if(lse)
if(init_sink_value != 0)
{
ck_tile::
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
// Create extended tensor with sink token
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
{nhead, real_seqlen_q, real_seqlen_k + 1});

// Copy original attention scores and append sink values
for(auto i_h = 0; i_h < nhead; i_h++)
{
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
{
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
}
// Append sink token at the end of each row
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = scale_s_host * sink_host(i_h);
}
}

// Compute softmax on extended tensor
ck_tile::HostTensor<PDataType> p_extended(
{nhead, real_seqlen_q, real_seqlen_k + 1});

if(lse)
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
}
else
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_with_sinks_ref, p_extended, p_compute_element_func);
}

// Extract only the original columns (exclude sink token column)
p_host_ref.ForEach(
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
}
else
{
ck_tile::
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
// No sink tokens - compute softmax directly
if(lse)
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
}
else
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_host_ref, p_host_ref, p_compute_element_func);
}
}

if(p_drop > 0)
{
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
Expand Down
9 changes: 9 additions & 0 deletions example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,12 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
# 1 1 1 1 1 1 1 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)

$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1

$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0

$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1

$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1

$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
18 changes: 14 additions & 4 deletions include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
const void* k_ptr;
const void* v_ptr;
void* o_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -332,12 +333,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
seqlen_q,
-1,
hdim_q,
Expand Down Expand Up @@ -485,12 +488,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
drop_seed_offset,
const void* sink_ptr = nullptr)
{
Kargs kargs{{q_ptr,
k_ptr,
v_ptr,
o_ptr,
sink_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
Expand Down Expand Up @@ -701,6 +706,9 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
long_index_t batch_offset_o = 0;

const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
const float sink_value = kargs.sink_ptr != nullptr
? *(static_cast<const float*>(kargs.sink_ptr) + i_nhead)
: static_cast<float>(-numeric<half_t>::infinity());
#if 0 // we assume page_block_size=1 for now
const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
#endif
Expand Down Expand Up @@ -1111,7 +1119,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
dropout);
dropout,
sink_value);
}
else
{
Expand All @@ -1131,7 +1140,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
kargs.kv_page_indices,
kargs.stride_k,
kargs.stride_v,
dropout);
dropout,
sink_value);
}
}();

Expand Down
Loading