diff --git a/include/infinicore/ops/flash_attention.hpp b/include/infinicore/ops/flash_attention.hpp index 33084aa45..24e33cfb6 100644 --- a/include/infinicore/ops/flash_attention.hpp +++ b/include/infinicore/ops/flash_attention.hpp @@ -5,8 +5,8 @@ namespace infinicore::op { -INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, Tensor, Tensor, Tensor, std::size_t, float, bool); +INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool); -Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal); -void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal); +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); } // namespace infinicore::op diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h index 2bcb9fe77..5ea71335b 100644 --- a/include/infiniop/ops/flash_attention.h +++ b/include/infiniop/ops/flash_attention.h @@ -12,7 +12,7 @@ __C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, float scale, char is_causal); @@ -28,6 +28,7 @@ __C __export infiniStatus_t infiniopFlashAttention( const void *q, const void *k, const void *v, + const void *total_kv_len, void *stream); __C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor( diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index f8c7d6ef0..d34490365 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,5 +1,6 @@ from .causal_softmax import causal_softmax from .embedding import embedding +from .flash_attention import flash_attention from .linear import linear from .random_sample import random_sample from .rms_norm import rms_norm @@ -10,13 +11,14 @@ __all__ = [ "causal_softmax", + "embedding", + "flash_attention", + "linear", "random_sample", "rms_norm", + "rope", "scaled_dot_product_attention", "silu", "swiglu", - "linear", - "embedding", - "rope", "RopeAlgo", ] diff --git a/python/infinicore/nn/functional/flash_attention.py b/python/infinicore/nn/functional/flash_attention.py new file mode 100644 index 000000000..8f42e865f --- /dev/null +++ b/python/infinicore/nn/functional/flash_attention.py @@ -0,0 +1,34 @@ +import math + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def flash_attention( + query, + key, + value, + total_kv_len, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, +): + assert attn_mask is None and dropout_p == 0 and not enable_gqa + + emb_dim = query.shape[-1] + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + return Tensor( + _infinicore.flash_attention( + query._underlying, + key._underlying, + value._underlying, + total_kv_len._underlying, + scale, + is_causal, + ) + ) diff --git a/python/infinicore/nn/functional/scaled_dot_product_attention.py b/python/infinicore/nn/functional/scaled_dot_product_attention.py index cc43e890f..0b780e562 100644 --- a/python/infinicore/nn/functional/scaled_dot_product_attention.py +++ b/python/infinicore/nn/functional/scaled_dot_product_attention.py @@ -14,6 +14,8 @@ def scaled_dot_product_attention( scale=None, enable_gqa=False, ): + raise NotImplementedError("Scaled Dot Product Attention is not yet supported.") + assert attn_mask is None and dropout_p == 0 and not enable_gqa emb_dim = query.shape[-1] diff --git a/src/infinicore/ops/flash_attention/flash_attention.cc b/src/infinicore/ops/flash_attention/flash_attention.cc index 92a854710..70fa8125d 100644 --- a/src/infinicore/ops/flash_attention/flash_attention.cc +++ b/src/infinicore/ops/flash_attention/flash_attention.cc @@ -6,24 +6,24 @@ namespace infinicore::op { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention); -FlashAttention::FlashAttention(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v); INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, q, k, v, total_kv_len, scale, is_causal); } -void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal); } -Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { Shape shape = q->shape(); auto out = Tensor::empty(shape, q->dtype(), q->device()); flash_attention_(out, q, k, v, total_kv_len, scale, is_causal); return out; } -void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal); } } // namespace infinicore::op diff --git a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc index b714744f0..f5207f0ee 100644 --- a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc +++ b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc @@ -11,18 +11,17 @@ INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100); struct PlannedMeta { std::shared_ptr descriptor; - graph::GraphTensor workspace, out, q, k, v; - std::size_t total_kv_len; + graph::GraphTensor workspace, out, q, k, v, total_kv_len; float scale; bool is_causal; }; -void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( Descriptor, descriptor, FlashAttention, - seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len, scale, is_causal); + seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal); INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor); @@ -33,7 +32,7 @@ void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, f graph::GraphTensor(q), graph::GraphTensor(k), graph::GraphTensor(v), - total_kv_len, scale, is_causal}; + graph::GraphTensor(total_kv_len), scale, is_causal}; return planned; } @@ -43,7 +42,7 @@ void run(void *planned_meta) { INFINICORE_CHECK_ERROR(infiniopFlashAttention( planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), - planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), context::getStream())); + planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream())); } void cleanup(void **planned_meta_ptr) { diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py index dfcce6910..abb75a1b0 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/build.py +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -1,6 +1,6 @@ import ninetoothed -from ntops.kernels import scaled_dot_product_attention -from ntops.kernels.scaled_dot_product_attention import CausalVariant +from . import flash_attention +from .flash_attention import CausalVariant import infiniop.ninetoothed.build @@ -27,7 +27,7 @@ def build(): } infiniop.ninetoothed.build.build( - scaled_dot_product_attention.premake, + flash_attention.premake, constexpr_param_grid, caller="cuda", op_name="flash_attention", diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h index f39d9d045..d47a347e1 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -17,7 +17,7 @@ class Descriptor final : public InfiniopDescriptor { infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, double scale, char is_causal) : InfiniopDescriptor{handle->device, handle->device_id}, _query_shape{q_desc->shape()}, @@ -26,12 +26,12 @@ class Descriptor final : public InfiniopDescriptor { _key_strides{k_desc->strides()}, _value_shape{v_desc->shape()}, _value_strides{v_desc->strides()}, + _total_kv_shape{total_kv_len->shape()}, + _total_kv_strides{total_kv_len->strides()}, _output_strides{out_desc->strides()}, _dtype{q_desc->dtype()}, _scale{scale}, _is_causal{is_causal} { - _key_shape[_key_shape.size() - 2] = total_kv_len; - _value_shape[_key_shape.size() - 2] = total_kv_len; } ~Descriptor() = default; @@ -46,6 +46,7 @@ class Descriptor final : public InfiniopDescriptor { const void *q, const void *k, const void *v, + const void *total_kv_len, void *stream) const { uint64_t empty_shape[4]; int64_t empty_strides[4]; @@ -53,6 +54,7 @@ class Descriptor final : public InfiniopDescriptor { auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}}; auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}}; auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}}; + auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}}; NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides}; NineToothedTensor is_causal; @@ -75,6 +77,7 @@ class Descriptor final : public InfiniopDescriptor { query, key, value, + total_kv_length, attn_mask, is_causal, scale, @@ -101,7 +104,7 @@ class Descriptor final : public InfiniopDescriptor { infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, double scale, char is_causal) { *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal}; @@ -126,6 +129,10 @@ class Descriptor final : public InfiniopDescriptor { std::vector _value_strides; + std::vector _total_kv_shape; + + std::vector _total_kv_strides; + std::vector _output_strides; infiniDtype_t _dtype; diff --git a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py new file mode 100644 index 000000000..965408408 --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py @@ -0,0 +1,281 @@ +import enum +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +BLOCK_SIZE_M = ninetoothed.block_size() +BLOCK_SIZE_N = ninetoothed.block_size() + + +class CausalVariant(enum.IntEnum): + """Please refer to ``_.""" + + UPPER_LEFT = enum.auto() + + LOWER_RIGHT = enum.auto() + + +def arrangement( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache, + block_size_m=None, + block_size_n=None, +): + def arrange_query_or_output(input): + arranged = input.tile((1, 1, block_size_m, -1)).tile( + (1, query.shape[-3] // key.shape[-3], 1, 1) + ) + arranged.dtype = arranged.dtype.squeeze((0, 2, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_key_or_value(input): + arranged = ( + input.tile((1, 1, block_size_n, -1)) + .tile((1, 1, -1, -1)) + .expand((-1, -1, query_arranged.shape[-2], -1)) + ) + arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_total_kv_len(input, shape): + arranged = input.tile((1,)) + arranged = arranged.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(shape) + return arranged + + def arrange_present_key_or_present_value(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + def arrange_attn_mask(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)).tile((1, 1, 1, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1, 2)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + if block_size_m is None: + block_size_m = BLOCK_SIZE_M + + if block_size_n is None: + block_size_n = BLOCK_SIZE_N + + query_arranged = arrange_query_or_output(query) + key_arranged = arrange_key_or_value(key) + value_arranged = arrange_key_or_value(value) + total_kv_len_arranged = arrange_total_kv_len(total_kv_len, query_arranged.shape) + present_key_arranged = arrange_present_key_or_present_value(present_key) + present_value_arranged = arrange_present_key_or_present_value(present_value) + present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot) + present_value_slot_arranged = arrange_present_key_or_present_value( + present_value_slot + ) + attn_mask_arranged = arrange_attn_mask(attn_mask) + is_causal_arranged = is_causal + scale_arranged = scale + output_arranged = arrange_query_or_output(output) + with_attn_mask_arranged = with_attn_mask + causal_variant_arranged = causal_variant + + if with_kv_cache: + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + present_key_arranged, + present_value_arranged, + present_key_slot_arranged, + present_value_slot_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + +def application_with_kv_cache( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + present_key_slot = present_key # noqa: F841 + present_value_slot = present_value # noqa: F841 + + application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + +def application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + actual_kv_len = total_kv_len[0] + + for i in range(query.shape[0]): + query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype) + + acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32) + lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) + max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) + + for j in range(min(key.shape[0], actual_kv_len)): + + qk = ntl.dot(query_i, ntl.trans(key[j])) + + key_pos = key[j].offsets(-2) + qk = ntl.where(key_pos < actual_kv_len, qk, float("-inf")) + + if with_attn_mask: + qk += attn_mask[j] + + if is_causal: + query_pos = query[i].offsets(-2) + + if causal_variant == 2: + mask = ( + query_pos[:, None] + actual_kv_len - query.source.shape[-2] + >= key_pos[None, :] + ) + else: + mask = query_pos[:, None] >= key_pos[None, :] + + qk = ntl.where(mask, qk, float("-inf")) + + next_max = ntl.maximum(max, ntl.max(qk, 1)) + stable_qk = ntl.exp2(qk - next_max[:, None]) + + alpha = ntl.exp2(max - next_max) + acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j]) + max = next_max + lse = lse * alpha + ntl.sum(stable_qk, 1) + + acc /= lse[:, None] + output[i] = acc # noqa: F841 + + +def premake( + with_kv_cache, + emb_dim=None, + is_causal=None, + with_attn_mask=None, + causal_variant=None, + dtype=None, + block_size_m=None, + block_size_n=None, +): + arrangement_ = functools.partial( + arrangement, + with_kv_cache=with_kv_cache, + block_size_m=block_size_m, + block_size_n=block_size_n, + ) + + query, key, value, attn_mask, output = ( + Tensor( + 4, + dtype=dtype, + shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}), + ) + for _ in range(5) + ) + total_kv_len = Tensor(1, dtype=ninetoothed.int32) + present_key, present_value, present_key_slot, present_value_slot = ( + Tensor(4, dtype=dtype) for _ in range(4) + ) + scale = Tensor(0, dtype=ninetoothed.float64) + is_causal = Tensor(0, constexpr=True, value=is_causal) + with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask) + causal_variant = Tensor(0, constexpr=True, value=causal_variant) + + if emb_dim is not None: + for tensor in (query, key, value, attn_mask, output): + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + if with_kv_cache: + application = application_with_kv_cache + else: + application = application_without_kv_cache + + tensors = ( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc index 6ce530fd3..ddccf9836 100644 --- a/src/infiniop/ops/flash_attention/operator.cc +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -5,11 +5,9 @@ #ifdef ENABLE_CPU_API // #include "cpu/flash_attention_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) #include "ninetoothed/descriptor.h" -#else -// #include "nvidia/flash_attention_nvidia.cuh" #endif #endif @@ -20,7 +18,7 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, float scale, char is_causal) { @@ -39,14 +37,9 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( switch (handle->device) { -#ifdef ENABLE_CPU_API - // CREATE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: @@ -66,14 +59,10 @@ __C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // GET_SIZE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: @@ -91,23 +80,19 @@ __C infiniStatus_t infiniopFlashAttention( const void *q, const void *k, const void *v, + const void *total_kv_len, void *stream) { #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(workspace, workspace_size, out, q, k, v, stream); + ->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream); switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // CALCULATE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: @@ -126,14 +111,10 @@ __C infiniStatus_t infiniopDestroyFlashAttentionDescriptor( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // DESTROY(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // DESTROY(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: diff --git a/test/infinicore/ops/flash_attention.py b/test/infinicore/ops/flash_attention.py new file mode 100644 index 000000000..320a7f13d --- /dev/null +++ b/test/infinicore/ops/flash_attention.py @@ -0,0 +1,113 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TensorInitializer, + TestCase, + GenericTestRunner, +) + +# Test cases format: (q_shape, k_shape, v_shape, attn_mask_or_None, dropout_p, is_causal) +# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) + +_TEST_CASES_DATA = [ + ((1, 1, 2, 16), (1, 1, 8, 16), (1, 1, 8, 16), None, 0.0, False), + ((1, 2, 8, 16), (1, 2, 16, 16), (1, 2, 16, 16), None, 0.0, False), + ((1, 1, 4, 32), (1, 1, 32, 32), (1, 1, 32, 32), None, 0.0, False), + ((1, 8, 4, 16), (1, 8, 64, 16), (1, 8, 64, 16), None, 0.0, True), + ((1, 8, 4, 16), (1, 8, 64, 16), (1, 8, 64, 16), None, 0.0, False), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, +} +_TENSOR_DTYPES = [infinicore.float16, infinicore.float32] + + +def parse_test_cases(): + import random + + cases = [] + for q_shape, k_shape, v_shape, attn_mask, dropout_p, is_causal in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + q_spec = TensorSpec.from_tensor(q_shape, None, dtype) + k_spec = TensorSpec.from_tensor(k_shape, None, dtype) + v_spec = TensorSpec.from_tensor(v_shape, None, dtype) + + len_shape = (q_shape[0],) + total_len = random.randint(1, k_shape[2]) + total_kv_len_spec = TensorSpec.from_tensor( + len_shape, + None, + infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=total_len, + high=total_len + 1, + ) + + kwargs = { + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + } + # remove None keys + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + cases.append( + TestCase( + inputs=[q_spec, k_spec, v_spec, total_kv_len_spec, total_len], + kwargs=kwargs, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="Flash Attention", + ) + ) + + return cases + + +def torch_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + k_slice = k[:, :, :cheat, :] + v_slice = v[:, :, :cheat, :] + return torch.nn.functional.scaled_dot_product_attention( + q, k_slice, v_slice, **kwargs + ) + + +def infini_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + return infinicore.nn.functional.flash_attention(q, k, v, total_kv_len, **kwargs) + + +class OpTest(BaseOperatorTest): + """ScaledDotProductAttention operator test with simplified implementation""" + + def __init__(self): + super().__init__("ScaledDotProductAttention") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_flash_attn(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infini_flash_attn(*args, **kwargs) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()