Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,11 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = 16 * packed_size; // fp4:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
int KPack =
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);

ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));

Expand Down
2 changes: 1 addition & 1 deletion include/ck/utility/amd_xdlops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 ; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
0,
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/pk_fp6.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
Expand Down
11 changes: 10 additions & 1 deletion include/ck_tile/core/arch/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");

using rtn_type = thread_buffer<int8_t, N>;
Expand Down Expand Up @@ -1457,6 +1457,15 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,

return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 12)
{
auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));

return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
Expand Down
54 changes: 48 additions & 6 deletions include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,23 @@ llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");

CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x3_(int32x3_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v3i32");

CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x3(
dwordx3_union vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc)
{
int32x3_t v_reg;
v_reg[0] = vdata.as_i32[0];
v_reg[1] = vdata.as_i32[1];
v_reg[2] = vdata.as_i32[2];
llvm_amdgcn_raw_buffer_store_i32x3_(v_reg, rsrc, voffset, soffset, glc_slc);
};

CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
Expand Down Expand Up @@ -1285,7 +1302,7 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 | N == 16 || N == 32 || N == 64,
"wrong! not implemented");

using rtn_type = thread_buffer<int8_t, N>;
Expand Down Expand Up @@ -1325,6 +1342,18 @@ amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource,

return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 12)
{
auto tmp = llvm_amdgcn_raw_buffer_load_i32x3(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
dwordx3_union ret;
ret.as_i32[0] = tmp[0];
ret.as_i32[1] = tmp[1];
ret.as_i32[2] = tmp[2];
return bit_cast<rtn_type>(ret);
}
else if constexpr(N == 16)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
Expand Down Expand Up @@ -1406,15 +1435,17 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
(std::is_same<T, e8m0_bexp_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_fp4_raw_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))) ||
(std::is_same<T, pk_fp6x16_t>::value && (N == 1)),
"wrong! not implemented");

using rtn_type = thread_buffer<T, N>;
Expand Down Expand Up @@ -1745,7 +1776,7 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");

if constexpr(N == 1)
Expand Down Expand Up @@ -1781,6 +1812,14 @@ CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 12)
{
llvm_amdgcn_raw_buffer_store_i32x3(bit_cast<dwordx3_union>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
Expand Down Expand Up @@ -1854,10 +1893,13 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 12 || N == 16)) ||
(std::is_same<T, uint16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(std::is_same<T, uint8_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
std::is_same<T, pk_fp6x16_t>::value && (N == 1),
"wrong! not implemented");

if constexpr(std::is_same<T, float>::value) // fp32
Expand Down
108 changes: 108 additions & 0 deletions include/ck_tile/core/numeric/pk_fp6.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <cmath>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/mxfp_convert.hpp"

namespace ck_tile {
template <index_t pk_size>
struct pk_fp6_t
{
static constexpr index_t num_bits_elem = 6;
using element_type = uint32_t; // element storage fundamental type
static constexpr index_t packed_size = pk_size;
static constexpr index_t num_bits_vec_elem =
sizeof(element_type) * 8; // 32-bit uint for storage
static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0,
"Packed elements must fit exactly into the element storage.");
static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem;
element_type data_[vector_size]; // packed data
using type = pk_fp6_t<packed_size>;
CK_TILE_HOST_DEVICE pk_fp6_t() {};
CK_TILE_HOST_DEVICE explicit pk_fp6_t(int value)
{
for(size_t i = 0; i < vector_size; ++i)
{
data_[i] = value;
}
}
void pack(const uint32_t x, const index_t i)
{
uint32_t bits = static_cast<uint32_t>(x) & 0x3F;
const int bit_pos = i * num_bits_elem;
const int arr_index = bit_pos / num_bits_vec_elem;
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = data_[arr_index];

// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
data_[arr_index] = old_value;

// if it crosses into the next block, shift the remainder
if(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = data_[arr_index + 1];
next_value |= (bits >> (num_bits_elem - overhang));
data_[arr_index + 1] = next_value;
}
}

template <typename T>
static inline uint32_t unpack(const T& pk, const index_t i)
{
const int bit_pos = i * num_bits_elem;
const int arr_idx = bit_pos / num_bits_vec_elem;
const int bit_offset = bit_pos % num_bits_vec_elem;
const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;

uint32_t bits = pk.data_[arr_idx] >> bit_offset;
if(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang);
}

return bits & 0x3F;
}

inline uint32_t unpack(const index_t i) const { return unpack(*this, i); }

float fp6_e2m3_to_float(uint32_t fp6_bits)
{
fp6_bits = fp6_bits & 0x3F;

uint32_t sign = (fp6_bits >> 5) & 0x1; // bit 5
uint32_t exponent = (fp6_bits >> 3) & 0x3; // bits 4-3
uint32_t mantissa = fp6_bits & 0x7; // bits 2-0

float result;
if(exponent == 0 && mantissa == 0)
{
result = 0.f;
}
else if(exponent != 0)
{
result = std::pow(2, exponent - 1);
float mantissa_value = 1.0f + mantissa / 8.0f;
result *= mantissa_value;
}
else
{
result = mantissa / 8.0f;
}
return sign == 1 ? -1 * result : result;
}
};

using pk_fp6x16_t = pk_fp6_t<16>;
using pk_fp6x32_t = pk_fp6_t<32>;
template <>
struct numeric_traits<pk_fp6x16_t>
{
static constexpr int PackedSize = 16;
};
} // namespace ck_tile
1 change: 1 addition & 0 deletions include/ck_tile/core/numeric/type_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
} // namespace ck_tile

#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/numeric/pk_fp6.hpp"

namespace ck_tile {

Expand Down
21 changes: 21 additions & 0 deletions include/ck_tile/core/numeric/vector_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};

struct int32x3_t
{
int data[3];
};

template <>
struct ext_vector<int8_t, 12>
{
static constexpr index_t N = 12;
using value_type = int32x3_t;
using type = int32x3_t;
};

template <index_t N_>
struct ext_vector<pk_fp6x16_t, N_>
{
static constexpr index_t N = N_;
using value_type = pk_fp6x16_t;
using type = pk_fp6x16_t; // this is danguous
};

template <typename T_, index_t N_>
struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
{
Expand Down
12 changes: 12 additions & 0 deletions include/ck_tile/core/tensor/buffer_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ struct buffer_view<address_space_enum::lds,
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// int8 on thread buffer
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 12>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
Expand Down Expand Up @@ -1033,6 +1034,11 @@ struct buffer_view<address_space_enum::lds,
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 12>>)
{
*c_style_pointer_cast<dwordx3_union*>(&p_data_[i]) =
*c_style_pointer_cast<const dwordx3_union*>(&x);
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
Expand Down Expand Up @@ -1075,6 +1081,12 @@ struct buffer_view<address_space_enum::lds,
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
else
{
static_assert(false,
"wrong! not implemented for this combination, please add "
"implementation");
}
}
}
else
Expand Down
Loading