Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/infinicore/ops/flash_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, Tensor, Tensor, Tensor, std::size_t, float, bool);
INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool);

Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal);
void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal);
Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
} // namespace infinicore::op
3 changes: 2 additions & 1 deletion include/infiniop/ops/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ __C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
std::size_t total_kv_len,
infiniopTensorDescriptor_t total_kv_len,
float scale,
char is_causal);

Expand All @@ -28,6 +28,7 @@ __C __export infiniStatus_t infiniopFlashAttention(
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
Expand Down
8 changes: 5 additions & 3 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .causal_softmax import causal_softmax
from .embedding import embedding
from .flash_attention import flash_attention
from .linear import linear
from .random_sample import random_sample
from .rms_norm import rms_norm
Expand All @@ -10,13 +11,14 @@

__all__ = [
"causal_softmax",
"embedding",
"flash_attention",
"linear",
"random_sample",
"rms_norm",
"rope",
"scaled_dot_product_attention",
"silu",
"swiglu",
"linear",
"embedding",
"rope",
"RopeAlgo",
]
34 changes: 34 additions & 0 deletions python/infinicore/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import math

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def flash_attention(
query,
key,
value,
total_kv_len,
attn_mask=None,
dropout_p=0,
is_causal=False,
scale=None,
enable_gqa=False,
):
assert attn_mask is None and dropout_p == 0 and not enable_gqa

emb_dim = query.shape[-1]

if scale is None:
scale = 1 / math.sqrt(emb_dim)

return Tensor(
_infinicore.flash_attention(
query._underlying,
key._underlying,
value._underlying,
total_kv_len._underlying,
scale,
is_causal,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def scaled_dot_product_attention(
scale=None,
enable_gqa=False,
):
raise NotImplementedError("Scaled Dot Product Attention is not yet supported.")

assert attn_mask is None and dropout_p == 0 and not enable_gqa

emb_dim = query.shape[-1]
Expand Down
8 changes: 4 additions & 4 deletions src/infinicore/ops/flash_attention/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@ namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention);

FlashAttention::FlashAttention(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k, v, total_kv_len, scale, is_causal);
}

void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal);
}

Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
Shape shape = q->shape();
auto out = Tensor::empty(shape, q->dtype(), q->device());
flash_attention_(out, q, k, v, total_kv_len, scale, is_causal);
return out;
}

void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal);
}
} // namespace infinicore::op
11 changes: 5 additions & 6 deletions src/infinicore/ops/flash_attention/flash_attention_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@ INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, out, q, k, v;
std::size_t total_kv_len;
graph::GraphTensor workspace, out, q, k, v, total_kv_len;
float scale;
bool is_causal;
};

void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) {
void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal);

INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, FlashAttention,
seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len, scale, is_causal);
seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal);

INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor);

Expand All @@ -33,7 +32,7 @@ void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, f
graph::GraphTensor(q),
graph::GraphTensor(k),
graph::GraphTensor(v),
total_kv_len, scale, is_causal};
graph::GraphTensor(total_kv_len), scale, is_causal};

return planned;
}
Expand All @@ -43,7 +42,7 @@ void run(void *planned_meta) {

INFINICORE_CHECK_ERROR(infiniopFlashAttention(
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), context::getStream()));
planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
Expand Down
6 changes: 3 additions & 3 deletions src/infiniop/ops/flash_attention/ninetoothed/build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ninetoothed
from ntops.kernels import scaled_dot_product_attention
from ntops.kernels.scaled_dot_product_attention import CausalVariant
from . import flash_attention
from .flash_attention import CausalVariant

import infiniop.ninetoothed.build

Expand All @@ -27,7 +27,7 @@ def build():
}

infiniop.ninetoothed.build.build(
scaled_dot_product_attention.premake,
flash_attention.premake,
constexpr_param_grid,
caller="cuda",
op_name="flash_attention",
Expand Down
15 changes: 11 additions & 4 deletions src/infiniop/ops/flash_attention/ninetoothed/descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Descriptor final : public InfiniopDescriptor {
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
std::size_t total_kv_len,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) : InfiniopDescriptor{handle->device, handle->device_id},
_query_shape{q_desc->shape()},
Expand All @@ -26,12 +26,12 @@ class Descriptor final : public InfiniopDescriptor {
_key_strides{k_desc->strides()},
_value_shape{v_desc->shape()},
_value_strides{v_desc->strides()},
_total_kv_shape{total_kv_len->shape()},
_total_kv_strides{total_kv_len->strides()},
_output_strides{out_desc->strides()},
_dtype{q_desc->dtype()},
_scale{scale},
_is_causal{is_causal} {
_key_shape[_key_shape.size() - 2] = total_kv_len;
_value_shape[_key_shape.size() - 2] = total_kv_len;
}

~Descriptor() = default;
Expand All @@ -46,13 +46,15 @@ class Descriptor final : public InfiniopDescriptor {
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream) const {
uint64_t empty_shape[4];
int64_t empty_strides[4];

auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}};
auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}};
auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}};
auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}};

NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides};
NineToothedTensor is_causal;
Expand All @@ -75,6 +77,7 @@ class Descriptor final : public InfiniopDescriptor {
query,
key,
value,
total_kv_length,
attn_mask,
is_causal,
scale,
Expand All @@ -101,7 +104,7 @@ class Descriptor final : public InfiniopDescriptor {
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
std::size_t total_kv_len,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) {
*desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal};
Expand All @@ -126,6 +129,10 @@ class Descriptor final : public InfiniopDescriptor {

std::vector<Stride> _value_strides;

std::vector<Size> _total_kv_shape;

std::vector<Stride> _total_kv_strides;

std::vector<Stride> _output_strides;

infiniDtype_t _dtype;
Expand Down
Loading