diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a9b25b062a..14d91ad195d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". -* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. +* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added FP8 KV cache support for FMHA batch prefill. +* Added gptoss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines. ### Changed diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 6f2616cae56..3d729a272de 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[]) .insert("kv_eff_lens", "", "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" - "Comma-separated list of length 'b'. If empty, no override."); + "Comma-separated list of length 'b'. If empty, no override.") + .insert("init_sink", "1", "value to init the output tensor sink value for validation"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); uint32_t seed = arg_parser.get_uint32("seed"); + int init_sink_value = arg_parser.get_int("init_sink"); ck_tile::stream_config stream_config{nullptr, true, @@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser) init_method, seed, do_validation, + init_sink_value, stream_config, json); } diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ba55d6d722a..ba0615d4a79 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -230,6 +230,7 @@ struct fmha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) + const void* sink_ptr; ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -519,6 +523,7 @@ struct fmha_batch_prefill_args // 1) + // kargs.kv_last_page_lens[b] const void* seqstart_q_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -627,7 +632,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_k_ptr); + args.cu_seqlen_k_ptr, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -677,7 +683,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.s_randval, args.drop_seed_offset, args.cu_seqlen_q_ptr, - args.cu_seqlen_k_ptr); + args.cu_seqlen_k_ptr, + args.sink_ptr); } }(); @@ -837,7 +844,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.window_size_right, args.sink_size, args.mask_type, - args.min_seqlen_q); + args.min_seqlen_q, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -882,7 +890,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } }(); @@ -949,7 +958,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -997,7 +1007,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.window_size_left, args.window_size_right, args.sink_size, - args.mask_type); + args.mask_type, + args.sink_ptr); } }(); @@ -1164,7 +1175,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.sink_ptr); } else { // create batch mode kernel arguments @@ -1220,7 +1232,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.sink_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 536fcb06922..332a078db7b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -184,6 +184,7 @@ fwd_result fmha_fwd_run(mode_enum mode, std::string init_method, uint32_t seed, int do_validation, + int init_sink_value, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { @@ -527,6 +528,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor sink_host({nhead}); ck_tile::HostTensor k_host( 0 < page_block_size ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) @@ -609,6 +611,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}( bias_host); } + else if(init_method == "ni") { ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -695,10 +698,15 @@ fwd_result fmha_fwd_run(mode_enum mode, iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); - + if(init_sink_value != 0) + { + ck_tile::FillUniformDistributionIntegerValue{30.f, 60.f, next_seed()}( + sink_host); + } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); @@ -743,6 +751,7 @@ fwd_result fmha_fwd_run(mode_enum mode, q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); + sink_buf.ToDevice(sink_host.data()); knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); @@ -971,7 +980,10 @@ fwd_result fmha_fwd_run(mode_enum mode, args.q_ptr = q_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer(); args.v_ptr = v_buf.GetDeviceBuffer(); - + if(init_sink_value != 0) + args.sink_ptr = sink_buf.GetDeviceBuffer(); + else + args.sink_ptr = nullptr; args.batch = batch; args.seqlen_q = shape_seqlen_q; // unused in group mode args.hdim_q = hdim_q; @@ -1675,19 +1687,67 @@ fwd_result fmha_fwd_run(mode_enum mode, mask.type == mask_enum::mask_top_left)); } const ck_tile::HostTensor masked_s_host_ref = s_host_ref; - if(lse) + if(init_sink_value != 0) { - ck_tile:: - reference_batched_softmax( - s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + // Create extended tensor with sink token + ck_tile::HostTensor s_with_sinks_ref( + {nhead, real_seqlen_q, real_seqlen_k + 1}); + + // Copy original attention scores and append sink values + for(auto i_h = 0; i_h < nhead; i_h++) + { + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c); + } + // Append sink token at the end of each row + s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h); + } + } + + // Compute softmax on extended tensor + ck_tile::HostTensor p_extended( + {nhead, real_seqlen_q, real_seqlen_k + 1}); + + if(lse) + { + ck_tile::reference_batched_softmax( + s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( + s_with_sinks_ref, p_extended, p_compute_element_func); + } + + // Extract only the original columns (exclude sink token column) + p_host_ref.ForEach( + [&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); }); } else { - ck_tile:: - reference_batched_softmax( + // No sink tokens - compute softmax directly + if(lse) + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( s_host_ref, p_host_ref, p_compute_element_func); + } } - if(p_drop > 0) { ck_tile::HostTensor randval_host_ref( diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh index 664c8254181..746ff8c0e1e 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -84,3 +84,12 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l # 1 1 1 1 1 1 1 1 1 1 # l=2/r=0(br) l=2/r=0/s=2(br) +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0 + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=2 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 \ No newline at end of file diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 73b6a329d18..9796991f37b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -77,6 +77,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const void* k_ptr; const void* v_ptr; void* o_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -332,12 +333,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel float p_drop, bool s_randval, std::variant, std::pair> - drop_seed_offset) + drop_seed_offset, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, seqlen_q, -1, hdim_q, @@ -485,12 +488,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel float p_drop, bool s_randval, std::variant, std::pair> - drop_seed_offset) + drop_seed_offset, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, -1, // seqlen will be updated by another pointer -1, // hdim_q, @@ -701,6 +706,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel long_index_t batch_offset_o = 0; const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); #if 0 // we assume page_block_size=1 for now const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; #endif @@ -1111,7 +1120,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.kv_page_indices, kargs.stride_k, kargs.stride_v, - dropout); + dropout, + sink_value); } else { @@ -1131,7 +1141,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.kv_page_indices, kargs.stride_k, kargs.stride_v, - dropout); + dropout, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 4dd99a6ea96..b117b8fbea8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -89,6 +89,7 @@ struct FmhaFwdKernel const void* k_ptr; const void* v_ptr; void* o_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -343,12 +344,14 @@ struct FmhaFwdKernel std::variant, std::pair> drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, seqlen_q, seqlen_k, hdim_q, @@ -490,7 +493,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -539,7 +543,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -591,7 +596,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -640,7 +646,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } template @@ -688,12 +695,14 @@ struct FmhaFwdKernel std::variant, std::pair> drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, -1, // seqlen will be updated by another pointer -1, // hdim_q, @@ -833,7 +842,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -878,7 +888,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -926,7 +937,8 @@ struct FmhaFwdKernel bool s_randval, const std::tuple& drop_seed_offset, const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr) + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -971,7 +983,8 @@ struct FmhaFwdKernel s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), cu_seqlen_q_ptr, - cu_seqlen_k_ptr); + cu_seqlen_k_ptr, + sink_ptr); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1093,10 +1106,8 @@ struct FmhaFwdKernel { // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); @@ -1107,6 +1118,10 @@ struct FmhaFwdKernel long_index_t batch_offset_randval = 0; long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { @@ -1525,7 +1540,6 @@ struct FmhaFwdKernel }(); BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { @@ -1564,7 +1578,8 @@ struct FmhaFwdKernel variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_value); } else { @@ -1581,7 +1596,8 @@ struct FmhaFwdKernel variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_value); } }(); @@ -1621,6 +1637,10 @@ struct FmhaFwdKernel constexpr bool PrefillCase = FmhaPipeline::kM0 > 64; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; @@ -2273,6 +2293,7 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + sink_value, smem_ptrk0, smem_ptrk1, smem_ptrv0, @@ -2289,7 +2310,8 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, - smem_ptr); + smem_ptr, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index b75b35fc1e8..f078f19dc69 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -123,6 +123,7 @@ struct FmhaFwdPagedKVKernel const void* k_ptr; const void* v_ptr; void* o_ptr; + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -328,12 +329,14 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, seqlen_q, seqlen_k, hdim_q, @@ -457,7 +460,8 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { return MakeKargsImpl(q_ptr, k_ptr, @@ -500,7 +504,8 @@ struct FmhaFwdPagedKVKernel window_size_left, window_size_right, sink_size, - mask_type); + mask_type, + sink_ptr); } template @@ -543,12 +548,14 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q) + ck_tile::index_t min_seqlen_q, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, o_ptr, + sink_ptr, -1, // seqlen will be updated by another pointer -1, // hdim_q, @@ -669,7 +676,8 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t window_size_right, ck_tile::index_t sink_size, ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q) + ck_tile::index_t min_seqlen_q, + const void* sink_ptr = nullptr) { return MakeKargsImpl(q_ptr, k_ptr, @@ -709,7 +717,8 @@ struct FmhaFwdPagedKVKernel window_size_right, sink_size, mask_type, - min_seqlen_q); + min_seqlen_q, + sink_ptr); } CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches) @@ -898,7 +907,6 @@ struct FmhaFwdPagedKVKernel // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); @@ -909,6 +917,10 @@ struct FmhaFwdPagedKVKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; index_t kv_l2p_offset = 0; + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { @@ -1348,7 +1360,8 @@ struct FmhaFwdPagedKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } else { @@ -1366,7 +1379,8 @@ struct FmhaFwdPagedKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index bd5cddb5260..25a8ce9c683 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -124,6 +124,7 @@ struct FmhaFwdSplitKVKernel const void* v_ptr; void* lse_acc_ptr; void* o_acc_ptr; + const void* sink_ptr; ck_tile::index_t batch; @@ -327,13 +328,15 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, lse_acc_ptr, o_acc_ptr, + sink_ptr, batch, seqlen_q, seqlen_k, @@ -455,13 +458,15 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t sink_size, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + const void* sink_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, v_ptr, lse_acc_ptr, o_acc_ptr, + sink_ptr, batch, -1, // seqlen_q will be updated by another pointer -1, // seqlen_k will be updated by another pointer @@ -530,7 +535,6 @@ struct FmhaFwdSplitKVKernel { kargs.init_logits_soft_cap(logits_soft_cap); } - return kargs; } @@ -615,6 +619,10 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_o_acc = 0; index_t kv_l2p_offset = 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache + const float sink_value = + kargs.sink_ptr != nullptr + ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s + : static_cast(-numeric::infinity()); if constexpr(kIsGroupMode) { @@ -698,7 +706,6 @@ struct FmhaFwdSplitKVKernel kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; } } - // for simplicity, batch stride we just modify the pointer const index_t i_nhead_k = (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk); @@ -1082,7 +1089,8 @@ struct FmhaFwdSplitKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } else { @@ -1102,7 +1110,8 @@ struct FmhaFwdSplitKVKernel variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_value); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 2102fe768f1..38f5c2e4559 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -92,6 +92,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; + static constexpr auto LOG2E = log2e_v; #endif static constexpr index_t kBlockPerCu = []() { @@ -196,7 +197,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { static_assert( std::is_same_v> && @@ -282,8 +284,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(__builtin_isinf_sign(sink_v) >= 0) + { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * LOG2E); +#else + set_tile(m, sink_v); +#endif + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); @@ -302,7 +316,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -887,7 +908,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -913,7 +935,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_idx, stride_k, stride_v, - dropout); + dropout, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index d55d0d93427..e471b8ddc49 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -163,7 +163,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + const float sink_v) const { static_assert( std::is_same_v> && @@ -227,8 +228,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(m, sink_v); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { if constexpr(kHasSink) @@ -258,7 +267,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -788,7 +804,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -812,7 +829,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 944d49a8aad..720247e8cf1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -164,7 +164,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { static_assert( std::is_same_v> && @@ -254,8 +255,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { @@ -285,7 +294,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } if(get_thread_local_1d_id() < kM0) { @@ -299,7 +315,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS return o_acc; } } - + if(i_split > 0) + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); + if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + } const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; // make sure the first tile is completely located in page-block (page-block size should be @@ -879,7 +904,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -905,7 +931,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 26a4cc905c7..c04dbee01ee 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -163,7 +163,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { static_assert( std::is_same_v> && @@ -227,8 +228,20 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else + set_tile(m, sink_v); +#endif + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { @@ -260,7 +273,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); @@ -272,6 +292,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } } + if(i_split > 0) + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); + if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + } const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; // make sure the first tile is completely located in page-block (page-block size should be @@ -797,7 +827,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -823,7 +854,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS variant_params, block_indices, kv_l2p_offset, - smem_ptr); + smem_ptr, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index fe825a370a0..b8e02038cf1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -166,7 +166,8 @@ struct BlockFmhaPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { static_assert( std::is_same_v> && @@ -230,8 +231,20 @@ struct BlockFmhaPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(__builtin_isinf_sign(sink_v) >= 0) + { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else + set_tile(m, sink_v); +#endif + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_window.get_window_origin(); @@ -265,7 +278,14 @@ struct BlockFmhaPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -786,7 +806,8 @@ struct BlockFmhaPipelineQRKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -809,7 +830,8 @@ struct BlockFmhaPipelineQRKSVS variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index f57b89cf9dd..bddb6db2cbd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -87,6 +87,7 @@ struct BlockFmhaPipelineQRKSVSAsync #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; + static constexpr auto LOG2E = log2e_v; #endif static constexpr index_t kBlockPerCu = []() { @@ -188,7 +189,8 @@ struct BlockFmhaPipelineQRKSVSAsync const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { static_assert( std::is_same_v> && @@ -274,8 +276,20 @@ struct BlockFmhaPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(__builtin_isinf_sign(sink_v) >= 0) + { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * LOG2E); +#else + set_tile(m, sink_v); +#endif + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); @@ -309,7 +323,14 @@ struct BlockFmhaPipelineQRKSVSAsync auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -880,7 +901,8 @@ struct BlockFmhaPipelineQRKSVSAsync const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { return operator()(q_dram_block_window_tmp, identity{}, @@ -903,7 +925,8 @@ struct BlockFmhaPipelineQRKSVSAsync variant_params, block_indices, smem_ptr, - dropout); + dropout, + sink_v); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 26662dafeb9..95e20d95277 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -148,7 +148,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload FmhaMask mask, PositionEncoding position_encoding, float scale_s, - void* smem_ptr) const + void* smem_ptr, + float sink_v) const { static_assert( std::is_same_v> && @@ -193,8 +194,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(__builtin_isinf_sign(sink_v) >= 0) + { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else + set_tile(m, sink_v); +#endif + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_block_window_tmp.get_window_origin(); const auto [logical_seqlen_k_start, logical_seqlen_k_end] = @@ -212,7 +225,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } store_tile(lse_acc_dram_window_tmp, lse_acc); } @@ -649,6 +669,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload FmhaMask mask, PositionEncoding position_encoding, float scale_s, + float sink_v, void* __restrict__ smem_ptrk0, void* __restrict__ smem_ptrk1, void* __restrict__ smem_ptrv0, @@ -698,8 +719,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); - clear_tile(l); + if(__builtin_isinf_sign(sink_v) >= 0) + { +#if CK_TILE_FMHA_FWD_FAST_EXP2 + set_tile(m, sink_v * C_LOG2E); +#else + set_tile(m, sink_v); +#endif + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } const auto q_origin = q_dram_block_window_tmp.get_window_origin(); const auto [logical_seqlen_k_start, logical_seqlen_k_end] = @@ -717,7 +750,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, -numeric::infinity()); + if(__builtin_isinf_sign(sink_v) >= 0) + { + set_tile(lse_acc, SMPLComputeDataType{sink_v}); + } + else + { + set_tile(lse_acc, -numeric::infinity()); + } store_tile(lse_acc_dram_window_tmp, lse_acc); } diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index b81fa88aa22..c59ee7a67d8 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -120,8 +120,8 @@ const ck_tile::stream_config stream_config{ 1, // rotating_count_ }; -#define COMMON_ARGS \ - init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ +#define COMMON_ARGS \ + init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, 0, \ stream_config auto EnableTestIf(bool condition) @@ -255,6 +255,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, // init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -299,6 +300,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, // init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); } @@ -342,6 +344,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail) init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 0, + 1, // init_sink stream_config); ASSERT_EQ(result, fwd_result::invalid_args); }