diff --git a/profiler/include/profiler/gpu_verification.hpp b/profiler/include/profiler/gpu_verification.hpp new file mode 100644 index 0000000000..808dc58c2f --- /dev/null +++ b/profiler/include/profiler/gpu_verification.hpp @@ -0,0 +1,313 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/type.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/library/utility/check_err.hpp" + +namespace ck { +namespace profiler { + +// Compute relative tolerance for GPU verification +// Matches the logic of ck::utils::get_relative_threshold but handles all types +template +inline float compute_relative_tolerance(const int number_of_accumulations = 1) +{ + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + using I8 = int8_t; + using I16 = int16_t; + using I32 = int32_t; + + // For integer types, tolerance is 0 + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + return 0.0f; + } + // For types supported by get_relative_threshold, use it + else if constexpr((std::is_same_v || + std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v)) + { + return static_cast( + ck::utils::get_relative_threshold( + number_of_accumulations)); + } + // For unsupported types (FP8, BF8, etc.), use default tolerances based on output type + else + { + if constexpr(std::is_same_v) + { + return 1e-3f; + } + else if constexpr(std::is_same_v) + { + return 1e-1f; + } + else + { + // For FP8/BF8 and other types, use conservative tolerance + return 1e-1f; + } + } +} + +// GPU verification kernel - compares device result against reference using relative and absolute +// tolerance Returns 1 in passed if all elements match within tolerance, 0 otherwise +template +__global__ void gpu_verify_kernel(const T* __restrict__ device_result, + const T* __restrict__ reference_result, + float rtol, + float atol, + long long size, + int* passed) +{ + // Grid-stride loop to handle any tensor size + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + long long stride = blockDim.x * gridDim.x; + + for(long long i = idx; i < size; i += stride) + { + // Convert to float for comparison + float dev_val = type_convert(device_result[i]); + float ref_val = type_convert(reference_result[i]); + + // Compute absolute difference + float abs_diff = fabsf(dev_val - ref_val); + + // Check tolerance (matches CPU check_err logic: err > atol + rtol * abs(ref)) + if(abs_diff > atol + rtol * fabsf(ref_val)) + { + atomicMin(passed, 0); // Mark as failed + return; // Early exit on first failure + } + } +} + +// Host-side wrapper for GPU verification with explicit tolerances +// Returns true if verification passed, false otherwise +template +bool gpu_verify(const void* device_result, + const void* reference_result, + float rtol, + float atol, + std::size_t size, + hipStream_t stream = nullptr) +{ + // Allocate result buffer on device + int* passed_dev; + hip_check_error(hipMalloc(&passed_dev, sizeof(int))); + + // Initialize to passed (1) + int passed_host = 1; + hip_check_error(hipMemcpy(passed_dev, &passed_host, sizeof(int), hipMemcpyHostToDevice)); + + // Launch kernel with grid-stride loop + // Use 65535 as max grid size (hardware limit for grid dimension in x) + // Grid-stride loop handles any tensor size regardless of grid dimensions + constexpr int block_size = 256; + int grid_size = std::min(65535, (size + block_size - 1) / block_size); + + gpu_verify_kernel + <<>>(static_cast(device_result), + static_cast(reference_result), + rtol, + atol, + static_cast(size), + passed_dev); + + hip_check_error(hipGetLastError()); + + // Synchronize the stream to ensure kernel completion before reading results + hip_check_error(hipStreamSynchronize(stream)); + + // Get result + hip_check_error(hipMemcpy(&passed_host, passed_dev, sizeof(int), hipMemcpyDeviceToHost)); + + // Free device memory + hip_check_error(hipFree(passed_dev)); + + return passed_host == 1; +} + +// Forward declaration of gpu_reduce_max +template +float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream = nullptr); + +// Host-side wrapper for GPU verification with automatic tolerance computation +// Computes max value on GPU, then computes tolerances and verifies +// Returns true if verification passed, false otherwise +template +bool gpu_verify(const void* device_result, + const void* reference_result, + int number_of_accumulations, + std::size_t size, + hipStream_t stream = nullptr) +{ + // Compute max absolute value on GPU (only 4 bytes transferred!) + double max_abs_value = + static_cast(gpu_reduce_max(reference_result, size, stream)); + + // Compute tolerances based on data types and accumulation count + float rtol = compute_relative_tolerance( + number_of_accumulations); + + float atol = 0.0f; + // Only compute absolute tolerance for supported types + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + + if constexpr((std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v)) + { + atol = static_cast( + ck::utils::get_absolute_threshold( + max_abs_value, number_of_accumulations)); + } + + // Call the explicit tolerance version + return gpu_verify(device_result, reference_result, rtol, atol, size, stream); +} + +// +// Helper function for atomic float max (using compare-and-swap) +__device__ __forceinline__ float atomicMaxFloat(float* address, float val) +{ + int* address_as_int = reinterpret_cast(address); + int old = *address_as_int; + int assumed; + + do + { + assumed = old; + old = + atomicCAS(address_as_int, assumed, __float_as_int(fmaxf(val, __int_as_float(assumed)))); + } while(assumed != old); + + return __int_as_float(old); +} + +// GPU reduction kernel for computing max(abs(data)) +// This is an internal kernel called only by gpu_reduce_max() wrapper. +// +// Assumption: Block size is 256 +template +__global__ void +gpu_reduce_max_kernel(const T* __restrict__ data, long long size, float* __restrict__ max_val) +{ + constexpr int block_size = 256; + __shared__ float shared_max[block_size]; + + long long idx = blockIdx.x * blockDim.x + threadIdx.x; + long long stride = blockDim.x * gridDim.x; + + float local_max = 0.0f; + + for(long long i = idx; i < size; i += stride) + { + float val = fabsf(type_convert(data[i])); + local_max = fmaxf(local_max, val); + } + + shared_max[threadIdx.x] = local_max; + __syncthreads(); + + // Block-level reduction: 256 -> 128 -> 64 -> 32 + for(unsigned int s = block_size / 2; s > 32; s >>= 1) + { + if(threadIdx.x < s) + { + shared_max[threadIdx.x] = fmaxf(shared_max[threadIdx.x], shared_max[threadIdx.x + s]); + } + __syncthreads(); + } + + // Warp-level reduction: 32 -> 16 -> 8 -> 4 -> 2 -> 1 + // No sync needed within a warp + if(threadIdx.x < 32) + { + volatile float* smem = shared_max; + smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 32]); + smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 16]); + smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 8]); + smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 4]); + smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 2]); + smem[threadIdx.x] = fmaxf(smem[threadIdx.x], smem[threadIdx.x + 1]); + } + + // Two-phase reduction pattern minimizes atomic contention: + // 1. Each block reduces to shared memory (above) + // 2. Single thread per block updates global max (below) + // This limits atomic operations to O(grid_size) rather than O(total_threads) + if(threadIdx.x == 0) + { + atomicMaxFloat(max_val, shared_max[0]); + } +} + +// Host-side wrapper for GPU max reduction +// Computes max(abs(data)) and returns as float +// Only transfers 4 bytes (the final max value) instead of entire tensor +template +float gpu_reduce_max(const void* device_buffer, std::size_t size, hipStream_t stream) +{ + if(size == 0) + { + return 0.0f; + } + + // Allocate device memory for result + float* max_dev; + hip_check_error(hipMalloc(&max_dev, sizeof(float))); + + // Initialize to zero + float init_val = 0.0f; + hip_check_error(hipMemcpy(max_dev, &init_val, sizeof(float), hipMemcpyHostToDevice)); + + // Launch reduction kernel + // Use 1024 blocks max for reduction to balance occupancy vs. grid-stride iterations + // For very large tensors (>256M elements), grid-stride loop handles the remainder + constexpr int block_size = 256; + int grid_size = std::min(1024, (size + block_size - 1) / block_size); + + gpu_reduce_max_kernel<<>>( + static_cast(device_buffer), static_cast(size), max_dev); + + hip_check_error(hipGetLastError()); + + // Synchronize if using default stream + if(stream == nullptr) + { + hip_check_error(hipDeviceSynchronize()); + } + + // Copy result to host (only 4 bytes!) + float max_host; + hip_check_error(hipMemcpy(&max_host, max_dev, sizeof(float), hipMemcpyDeviceToHost)); + + // Free device memory + hip_check_error(hipFree(max_dev)); + + return max_host; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 67d082d07b..d74cf57649 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" #include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "profiler/gpu_verification.hpp" namespace ck { namespace profiler { @@ -89,14 +90,15 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, out_device_buf.ToDevice(out.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); + // Allocate GPU reference buffer (used only if do_verification == 2) + DeviceMem gpu_ref_in_buf( + do_verification == 2 ? sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize() : 0); + float max_accumulated_value = 0; if(do_verification == 2) { - // Use GPU reference for verification - std::cout << "Using GPU reference for verification" << std::endl; - - // Allocate GPU reference output buffer - DeviceMem gpu_ref_in_buf(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + // Use GPU reference with GPU verification + std::cout << "Using GPU reference with GPU verification" << std::endl; // Call GPU reference with ConvParam directly ref::naive_conv_bwd_data( + gpu_ref_in_buf.GetDeviceBuffer(), in_host.mDesc.GetElementSpaceSize()); } else if(do_verification == 1) { @@ -204,8 +206,96 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, best_split_k = split_k_for_run; } - if(do_verification) + // Synchronize before verification to ensure kernel has completed + if(do_verification > 0 && !time_kernel) + { + hip_check_error(hipStreamSynchronize(nullptr)); + } + + if(do_verification == 2) + { + // GPU verification path + using ComputeType_ = std::conditional_t; + using ComputeType = + std::conditional_t; + using AccDataType = + std::conditional_t, int32_t, float>; + + // Calculate number of accumulations accounting for split_k + const int num_accums = static_cast(conv_param.K_ / split_k_for_run); + + // Additional tolerance for split_k accumulation if needed + int total_accums = num_accums; + if(split_k_for_run > 1) + { + total_accums = std::max(num_accums, static_cast(split_k_for_run)); + } + + // Perform GPU verification (max value computed internally on GPU) + const std::size_t tensor_size = in_device.mDesc.GetElementSpaceSize(); + bool gpu_passed = ck::profiler::gpu_verify( + in_device_buf.GetDeviceBuffer(), + gpu_ref_in_buf.GetDeviceBuffer(), + total_accums, + tensor_size); + + if(!gpu_passed) + { + // GPU verification failed - fall back to CPU for detailed diagnostics + std::cout << "GPU verification failed, running CPU verification for details..." + << std::endl; + + // Copy both buffers to host + in_device_buf.FromDevice(in_device.mData.data()); + gpu_ref_in_buf.FromDevice(in_host.mData.data()); + + // Recalculate tolerances for CPU verification with original logic + auto rtol = + ck::utils::get_relative_threshold( + num_accums); + auto atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / split_k_for_run, num_accums); + + if(split_k_for_run > 1) + { + auto rtol_split_k = + ck::utils::get_relative_threshold( + split_k_for_run); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, split_k_for_run); + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + } + + // Run CPU verification for detailed error messages + ck::utils::check_err( + in_device, in_host, "Error: Incorrect results!", rtol, atol); + pass = false; + + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + + if(do_log) + { + LogRangeAsType(std::cout << "output : ", out.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "weight: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") + << std::endl; + } + } + } + else if(do_verification == 1) { + // CPU verification path (original behavior) in_device_buf.FromDevice(in_device.mData.data()); using ComputeType_ = std::conditional_t( - static_cast(in_ref_buf.GetDeviceBuffer()), - static_cast(wei_ref_buf.GetDeviceBuffer()), - static_cast(out_ref_buf.GetDeviceBuffer()), + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_wei_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), conv_param, in_element_op, wei_element_op, out_element_op); - - // Copy result back to host - wei_ref_buf.FromDevice(weight_host_result.mData.data()); } - - max_accumulated_value = - *std::max_element(weight_host_result.mData.begin(), weight_host_result.mData.end()); } using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight 0 && !time_kernel) + { + hip_check_error(hipStreamSynchronize(nullptr)); + } + + if(do_verification == 2) + { + // GPU verification path + using ComputeType = + std::conditional_t; + using AccDataType = + std::conditional_t, int32_t, float>; + + // Calculate number of accumulations accounting for split_k + const int num_accums = + static_cast(output.GetElementSize() / conv_param.K_ / split_k_value); + + // Additional tolerance for split_k accumulation if needed + int total_accums = num_accums; + if(split_k_value > 1) + { + total_accums = std::max(num_accums, static_cast(split_k_value)); + } + + // Perform GPU verification (max value computed internally on GPU) + const std::size_t tensor_size = + weight_device_result.mDesc.GetElementSpaceSize(); + bool gpu_passed = + ck::profiler::gpu_verify( + wei_device_buf.GetDeviceBuffer(), + gpu_ref_wei_buf.GetDeviceBuffer(), + total_accums, + tensor_size); + + if(!gpu_passed) + { + // GPU verification failed - fall back to CPU for detailed diagnostics + std::cout + << "GPU verification failed, running CPU verification for details..." + << std::endl; + + // Copy both buffers to host + wei_device_buf.FromDevice(weight_device_result.mData.data()); + gpu_ref_wei_buf.FromDevice(weight_host_result.mData.data()); + + // Recalculate tolerances for CPU verification with original logic + const index_t num_accums_full = output.GetElementSize() / conv_param.K_; + const index_t num_accums_split_k = split_k_value; + auto rtol = ck::utils:: + get_relative_threshold( + num_accums_full / num_accums_split_k); + auto atol = ck::utils:: + get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums_full / num_accums_split_k); + + if(split_k_value > 1) + { + auto rtol_split_k = + ck::utils::get_relative_threshold(num_accums_split_k); + auto atol_split_k = ck::utils:: + get_absolute_threshold( + max_accumulated_value, num_accums_split_k); + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + } + + // Run CPU verification for detailed error messages + ck::utils::check_err(weight_device_result, + weight_host_result, + "Error: Incorrect results!", + rtol, + atol); + all_pass = false; + + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + std::cout << "Fail info: splitK: " << split_k_value << " " + << op_ptr->GetTypeString() << std::endl; + + if(do_log) + { + LogRangeAsType(std::cout << "output : ", output.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "weight (device): ", weight_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "weight (host): ", weight_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "input: ", input.mData, ",") + << std::endl; + } + } + } + else if(do_verification == 1) { + // CPU verification path (original behavior) wei_device_buf.FromDevice(weight_device_result.mData.data()); using ComputeType = diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index 586f9aa4ac..874d1e115c 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -23,6 +23,7 @@ #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" +#include "profiler/gpu_verification.hpp" namespace ck { namespace profiler { @@ -113,14 +114,15 @@ bool profile_grouped_conv_fwd_impl(int do_verification, in_device_buf.ToDevice(input.mData.data()); wei_device_buf.ToDevice(weight.mData.data()); + // Allocate GPU reference buffer (used only if do_verification == 2) + DeviceMem gpu_ref_out_buf( + do_verification == 2 ? sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize() : 0); + // run reference op if(do_verification == 2) { - // Use GPU reference for verification - std::cout << "Using GPU reference for verification" << std::endl; - - // Allocate GPU reference output buffer - DeviceMem gpu_ref_out_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + // Use GPU reference with GPU verification + std::cout << "Using GPU reference with GPU verification" << std::endl; // Call GPU reference with ConvParam directly ref::naive_conv_fwd 0 && !time_kernel) + { + hip_check_error(hipStreamSynchronize(nullptr)); + } + + if(do_verification == 2) + { + // GPU verification path + // Calculate number of accumulations (C * filter spatial dimensions) + std::size_t filter_spatial_size = 1; + for(auto len : conv_param.filter_spatial_lengths_) + { + filter_spatial_size *= len; + } + const int num_accums = static_cast(conv_param.C_ * filter_spatial_size); + + // Perform GPU verification (max value computed internally on GPU) + const std::size_t tensor_size = device_output.mDesc.GetElementSpaceSize(); + bool gpu_passed = ck::profiler::gpu_verify( + out_device_buf.GetDeviceBuffer(), + gpu_ref_out_buf.GetDeviceBuffer(), + num_accums, + tensor_size); + + if(!gpu_passed) + { + // GPU verification failed - fall back to CPU for detailed diagnostics + std::cout << "GPU verification failed, running CPU verification for details..." + << std::endl; + + // Copy both buffers to host + out_device_buf.FromDevice(device_output.mData.data()); + gpu_ref_out_buf.FromDevice(host_output.mData.data()); + + // Run CPU verification for detailed error messages + ck::utils::check_err(device_output, host_output); + pass = false; + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else if(do_verification == 1) { + // CPU verification path (original behavior) out_device_buf.FromDevice(device_output.mData.data()); pass = pass & ck::utils::check_err(device_output, host_output); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7521aebc74..46bb606765 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -319,3 +319,4 @@ add_subdirectory(position_embedding) add_subdirectory(scatter_gather) add_subdirectory(gpu_reference) add_subdirectory(util) +add_subdirectory(gpu_verification) diff --git a/test/gpu_verification/CMakeLists.txt b/test/gpu_verification/CMakeLists.txt new file mode 100644 index 0000000000..76c2bff8d4 --- /dev/null +++ b/test/gpu_verification/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# GPU verification unit tests +add_gtest_executable(test_gpu_verification test_gpu_verification.cpp) + +target_link_libraries(test_gpu_verification + PRIVATE + utility + device_other_operations +) diff --git a/test/gpu_verification/test_gpu_verification.cpp b/test/gpu_verification/test_gpu_verification.cpp new file mode 100644 index 0000000000..977475f064 --- /dev/null +++ b/test/gpu_verification/test_gpu_verification.cpp @@ -0,0 +1,736 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" +#include "profiler/gpu_verification.hpp" + +using namespace ck::profiler; +using ck::ref::SimpleDeviceMem; + +// Test fixture for GPU verification tests +class GPUVerificationTest : public ::testing::Test +{ + protected: + // Random number generator - initialized once per test for reproducibility + std::mt19937 rng_; + + void SetUp() override + { + // Ensure HIP is initialized + hipDeviceProp_t prop; + [[maybe_unused]] hipError_t err = hipGetDeviceProperties(&prop, 0); + + // Initialize RNG with fixed seed for reproducibility + // Can be overridden with CK_TEST_SEED environment variable + unsigned int seed = 12345; + if(const char* env_seed = std::getenv("CK_TEST_SEED")) + { + seed = std::stoul(env_seed); + } + rng_.seed(seed); + } + + void TearDown() override + { + // Cleanup handled automatically + } + + // Helper to upload data to device using SimpleDeviceMem + template + std::unique_ptr CreateDeviceBuffer(const std::vector& host_data) + { + auto device_buf = std::make_unique(host_data.size() * sizeof(T)); + HIP_CHECK_ERROR(hipMemcpy(device_buf->GetDeviceBuffer(), + host_data.data(), + host_data.size() * sizeof(T), + hipMemcpyHostToDevice)); + return device_buf; + } + + // Helper to compare CPU max reduction with GPU + template + float ComputeCPUMaxAbs(const std::vector& data) + { + if(data.empty()) + return 0.0f; + + float max_val = 0.0f; + for(const auto& val : data) + { + float abs_val = std::abs(ck::type_convert(val)); + max_val = std::max(max_val, abs_val); + } + return max_val; + } + + // Helper to generate random data + template + std::vector GenerateRandomData(size_t size, float min_val = -10.0f, float max_val = 10.0f) + { + std::vector data(size); + + // Use test fixture's RNG (rng_) for reproducibility + // RNG is seeded in SetUp() with fixed seed or CK_TEST_SEED environment variable + if constexpr(std::is_integral::value) + { + std::uniform_int_distribution dis(static_cast(min_val), + static_cast(max_val)); + for(auto& val : data) + val = static_cast(dis(rng_)); + } + else + { + std::uniform_real_distribution dis(min_val, max_val); + for(auto& val : data) + val = ck::type_convert(dis(rng_)); + } + return data; + } +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +TEST_F(GPUVerificationTest, FP32_ExactMatch_ShouldPass) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + // Identical data should pass with zero tolerance + bool result = gpu_verify(device_buf1->GetDeviceBuffer(), + device_buf2->GetDeviceBuffer(), + 0.0f, // rtol + 0.0f, // atol + size); + + EXPECT_TRUE(result) << "Identical FP32 tensors should pass verification"; +} + +TEST_F(GPUVerificationTest, FP32_Different_ShouldFail) +{ + constexpr size_t size = 1024; + std::vector data1 = GenerateRandomData(size); + std::vector data2 = GenerateRandomData(size); + + auto device_buf1 = CreateDeviceBuffer(data1); + auto device_buf2 = CreateDeviceBuffer(data2); + + // Different random data should fail with zero tolerance + bool result = gpu_verify(device_buf1->GetDeviceBuffer(), + device_buf2->GetDeviceBuffer(), + 0.0f, // rtol + 0.0f, // atol + size); + + EXPECT_FALSE(result) << "Different FP32 tensors should fail with zero tolerance"; +} + +TEST_F(GPUVerificationTest, FP32_WithinTolerance_ShouldPass) +{ + constexpr size_t size = 1024; + std::vector data1(size, 1.0f); + std::vector data2(size, 1.01f); + + auto device_buf1 = CreateDeviceBuffer(data1); + auto device_buf2 = CreateDeviceBuffer(data2); + + // 1% relative difference should pass with 2% tolerance + bool result = gpu_verify(device_buf1->GetDeviceBuffer(), + device_buf2->GetDeviceBuffer(), + 0.02f, // rtol + 0.02f, // atol + size); + + EXPECT_TRUE(result) << "Data within tolerance should pass"; +} + +TEST_F(GPUVerificationTest, FP32_OutsideTolerance_ShouldFail) +{ + constexpr size_t size = 1024; + std::vector data1(size, 1.0f); + std::vector data2(size, 1.1f); + + auto device_buf1 = CreateDeviceBuffer(data1); + auto device_buf2 = CreateDeviceBuffer(data2); + + // 10% relative difference should fail with 1% tolerance + bool result = gpu_verify(device_buf1->GetDeviceBuffer(), + device_buf2->GetDeviceBuffer(), + 0.01f, // rtol + 0.01f, // atol + size); + + EXPECT_FALSE(result) << "Data outside tolerance should fail"; +} + +// ============================================================================ +// Data Type Coverage Tests +// ============================================================================ + +TEST_F(GPUVerificationTest, FP16_ExactMatch_ShouldPass) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Identical FP16 tensors should pass verification"; +} + +TEST_F(GPUVerificationTest, BF16_ExactMatch_ShouldPass) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Identical BF16 tensors should pass verification"; +} + +TEST_F(GPUVerificationTest, INT8_ExactMatch_ShouldPass) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size, int8_t{-100}, int8_t{100}); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Identical INT8 tensors should pass verification"; +} + +TEST_F(GPUVerificationTest, INT16_ExactMatch_ShouldPass) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size, int16_t{-1000}, int16_t{1000}); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Identical INT16 tensors should pass verification"; +} + +TEST_F(GPUVerificationTest, INT32_ExactMatch_ShouldPass) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size, -10000, 10000); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Identical INT32 tensors should pass verification"; +} + +// ============================================================================ +// Tolerance Validation Tests +// ============================================================================ + +TEST_F(GPUVerificationTest, RelativeTolerance_ScalesWithReferenceValue) +{ + constexpr size_t size = 100; + std::vector reference(size); + std::vector result(size); + + // Test that relative tolerance scales correctly + // For reference = 100, result = 101, relative error = 1% + for(size_t i = 0; i < size; ++i) + { + reference[i] = 100.0f; + result[i] = 101.0f; + } + + auto device_ref = CreateDeviceBuffer(reference); + auto device_res = CreateDeviceBuffer(result); + + // Should pass with 2% relative tolerance + bool pass = gpu_verify(device_res->GetDeviceBuffer(), + device_ref->GetDeviceBuffer(), + 0.02f, // rtol + 0.0f, // atol + size); + + EXPECT_TRUE(pass) << "Should pass with sufficient relative tolerance"; + + // Should fail with 0.5% relative tolerance + bool fail = gpu_verify(device_res->GetDeviceBuffer(), + device_ref->GetDeviceBuffer(), + 0.005f, // rtol + 0.0f, // atol + size); + + EXPECT_FALSE(fail) << "Should fail with insufficient relative tolerance"; +} + +TEST_F(GPUVerificationTest, AbsoluteTolerance_CriticalForSmallValues) +{ + constexpr size_t size = 100; + std::vector reference(size, 0.0f); + std::vector result(size, 0.001f); + + auto device_ref = CreateDeviceBuffer(reference); + auto device_res = CreateDeviceBuffer(result); + + // For values near zero, relative tolerance doesn't help - need absolute + bool pass = gpu_verify(device_res->GetDeviceBuffer(), + device_ref->GetDeviceBuffer(), + 0.0f, // rtol + 0.002f, // atol (larger than difference) + size); + + EXPECT_TRUE(pass) << "Should pass with sufficient absolute tolerance"; + + bool fail = gpu_verify(device_res->GetDeviceBuffer(), + device_ref->GetDeviceBuffer(), + 0.0f, // rtol + 0.0005f, // atol (smaller than difference) + size); + + EXPECT_FALSE(fail) << "Should fail with insufficient absolute tolerance"; +} + +TEST_F(GPUVerificationTest, AutomaticToleranceComputation_FP32) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + // Use automatic tolerance computation (3-template parameter version) + bool result = gpu_verify(device_buf1->GetDeviceBuffer(), + device_buf2->GetDeviceBuffer(), + 1, // number_of_accumulations + size); + + EXPECT_TRUE(result) << "Identical data should pass with automatic tolerances"; +} + +TEST_F(GPUVerificationTest, AutomaticToleranceComputation_FP16) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 1, size); + + EXPECT_TRUE(result) << "Identical FP16 data should pass with automatic tolerances"; +} + +TEST_F(GPUVerificationTest, ToleranceScalesWithAccumulations) +{ + // Verify that tolerance increases with number of accumulations + constexpr size_t size = 100; + std::vector reference(size, 1.0f); + std::vector result(size); + + // Create result with small accumulated error + for(size_t i = 0; i < size; ++i) + { + result[i] = 1.0f + 1e-6f; // Small error + } + + auto device_ref = CreateDeviceBuffer(reference); + auto device_res = CreateDeviceBuffer(result); + + // With more accumulations, tolerance should be larger, so this should pass + bool result_many_accums = gpu_verify(device_res->GetDeviceBuffer(), + device_ref->GetDeviceBuffer(), + 1000, // Many accumulations + size); + + // With fewer accumulations, tolerance is tighter + bool result_few_accums = gpu_verify(device_res->GetDeviceBuffer(), + device_ref->GetDeviceBuffer(), + 1, // Few accumulations + size); + + // Note: The actual behavior depends on the error magnitude and tolerance formulas + // This test documents the expected behavior + EXPECT_TRUE(result_many_accums || result_few_accums) + << "At least one configuration should pass for small errors"; +} + +// ============================================================================ +// Edge Cases Tests +// ============================================================================ + +TEST_F(GPUVerificationTest, SingleElement_ExactMatch) +{ + constexpr size_t size = 1; + std::vector data{42.0f}; + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Single element exact match should pass"; +} + +TEST_F(GPUVerificationTest, LargeTensor_Performance) +{ + constexpr size_t size = 10 * 1024 * 1024; // 10M elements + std::vector data(size, 1.0f); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Large tensor verification should complete successfully"; +} + +TEST_F(GPUVerificationTest, VeryLargeValues_NearTypeLimit) +{ + constexpr size_t size = 100; + float large_val = 1e36f; // Close to FP32 limit but not overflow + std::vector data(size, large_val); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Very large values should be handled correctly"; +} + +TEST_F(GPUVerificationTest, VerySmallValues_NearZero) +{ + constexpr size_t size = 100; + float small_val = 1e-36f; // Very small but not denormal + std::vector data(size, small_val); + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify(device_buf1->GetDeviceBuffer(), + device_buf2->GetDeviceBuffer(), + 0.0f, + 1e-38f, // Very small absolute tolerance + size); + + EXPECT_TRUE(result) << "Very small values should be handled correctly"; +} + +TEST_F(GPUVerificationTest, MixedPositiveNegative_Values) +{ + constexpr size_t size = 100; + std::vector data(size); + for(size_t i = 0; i < size; ++i) + { + data[i] = (i % 2 == 0) ? static_cast(i) : -static_cast(i); + } + + auto device_buf1 = CreateDeviceBuffer(data); + auto device_buf2 = CreateDeviceBuffer(data); + + bool result = gpu_verify( + device_buf1->GetDeviceBuffer(), device_buf2->GetDeviceBuffer(), 0.0f, 0.0f, size); + + EXPECT_TRUE(result) << "Mixed positive/negative values should work correctly"; +} + +// ============================================================================ +// GPU Max Reduction Tests +// ============================================================================ + +TEST_F(GPUVerificationTest, GPUReduceMax_FP32_Correctness) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf = CreateDeviceBuffer(data); + + float cpu_max = ComputeCPUMaxAbs(data); + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + EXPECT_FLOAT_EQ(cpu_max, gpu_max) << "GPU max reduction should match CPU for FP32"; +} + +TEST_F(GPUVerificationTest, GPUReduceMax_FP16_Correctness) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf = CreateDeviceBuffer(data); + + float cpu_max = ComputeCPUMaxAbs(data); + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + // FP16 might have small precision differences + EXPECT_NEAR(cpu_max, gpu_max, 1e-3f) + << "GPU max reduction should match CPU for FP16 within precision"; +} + +TEST_F(GPUVerificationTest, GPUReduceMax_BF16_Correctness) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size); + + auto device_buf = CreateDeviceBuffer(data); + + float cpu_max = ComputeCPUMaxAbs(data); + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + // BF16 has lower precision + EXPECT_NEAR(cpu_max, gpu_max, 1e-2f) + << "GPU max reduction should match CPU for BF16 within precision"; +} + +TEST_F(GPUVerificationTest, GPUReduceMax_INT8_Correctness) +{ + constexpr size_t size = 1024; + std::vector data = GenerateRandomData(size, int8_t{-100}, int8_t{100}); + + auto device_buf = CreateDeviceBuffer(data); + + float cpu_max = ComputeCPUMaxAbs(data); + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + EXPECT_FLOAT_EQ(cpu_max, gpu_max) << "GPU max reduction should match CPU for INT8"; +} + +TEST_F(GPUVerificationTest, GPUReduceMax_SingleElement) +{ + constexpr size_t size = 1; + std::vector data{-42.5f}; + + auto device_buf = CreateDeviceBuffer(data); + + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + EXPECT_FLOAT_EQ(42.5f, gpu_max) << "Max of single element should be its absolute value"; +} + +TEST_F(GPUVerificationTest, GPUReduceMax_LargeBuffer) +{ + constexpr size_t size = 10 * 1024 * 1024; // 10M elements + std::vector data = GenerateRandomData(size); + + auto device_buf = CreateDeviceBuffer(data); + + float cpu_max = ComputeCPUMaxAbs(data); + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + EXPECT_FLOAT_EQ(cpu_max, gpu_max) << "GPU max reduction should handle large buffers correctly"; +} + +TEST_F(GPUVerificationTest, GPUReduceMax_AllNegative) +{ + constexpr size_t size = 100; + std::vector data(size); + for(size_t i = 0; i < size; ++i) + { + data[i] = -static_cast(i + 1); + } + + auto device_buf = CreateDeviceBuffer(data); + + float cpu_max = ComputeCPUMaxAbs(data); + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + EXPECT_FLOAT_EQ(cpu_max, gpu_max) + << "GPU max reduction should handle all negative values (absolute)"; +} + +TEST_F(GPUVerificationTest, GPUReduceMax_MixedPositiveNegative) +{ + constexpr size_t size = 100; + std::vector data(size); + for(size_t i = 0; i < size; ++i) + { + data[i] = (i % 2 == 0) ? static_cast(i) : -static_cast(i); + } + + auto device_buf = CreateDeviceBuffer(data); + + float cpu_max = ComputeCPUMaxAbs(data); + float gpu_max = gpu_reduce_max(device_buf->GetDeviceBuffer(), size); + + EXPECT_FLOAT_EQ(cpu_max, gpu_max) << "GPU max reduction should handle mixed signs correctly"; +} + +// ============================================================================ +// Tolerance Computation Tests +// ============================================================================ + +TEST_F(GPUVerificationTest, ComputeRelativeTolerance_IntegerTypes_ReturnsZero) +{ + // Integer types should have zero relative tolerance + float rtol_int8 = compute_relative_tolerance(); + float rtol_int16 = compute_relative_tolerance(); + float rtol_int32 = compute_relative_tolerance(); + + EXPECT_FLOAT_EQ(0.0f, rtol_int8) << "INT8 should have zero relative tolerance"; + EXPECT_FLOAT_EQ(0.0f, rtol_int16) << "INT16 should have zero relative tolerance"; + EXPECT_FLOAT_EQ(0.0f, rtol_int32) << "INT32 should have zero relative tolerance"; +} + +TEST_F(GPUVerificationTest, ComputeRelativeTolerance_FP32_NonZero) +{ + // FP32 should have non-zero relative tolerance + float rtol = compute_relative_tolerance(); + + EXPECT_GT(rtol, 0.0f) << "FP32 should have non-zero relative tolerance"; + EXPECT_LT(rtol, 1.0f) << "FP32 tolerance should be reasonable (< 1.0)"; +} + +TEST_F(GPUVerificationTest, ComputeRelativeTolerance_FP16_NonZero) +{ + // FP16 should have non-zero relative tolerance + float rtol = compute_relative_tolerance(); + + EXPECT_GT(rtol, 0.0f) << "FP16 should have non-zero relative tolerance"; + EXPECT_LT(rtol, 1.0f) << "FP16 tolerance should be reasonable (< 1.0)"; +} + +TEST_F(GPUVerificationTest, ComputeRelativeTolerance_BF16_NonZero) +{ + // BF16 should have non-zero relative tolerance + float rtol = compute_relative_tolerance(); + + EXPECT_GT(rtol, 0.0f) << "BF16 should have non-zero relative tolerance"; + EXPECT_LT(rtol, 1.0f) << "BF16 tolerance should be reasonable (< 1.0)"; +} + +TEST_F(GPUVerificationTest, ComputeRelativeTolerance_ScalesWithAccumulations) +{ + // Tolerance should increase with more accumulations + float rtol_1 = compute_relative_tolerance(1); + float rtol_10 = compute_relative_tolerance(10); + float rtol_100 = compute_relative_tolerance(100); + float rtol_1000 = compute_relative_tolerance(1000); + + // More accumulations should give larger tolerance (or equal, but not smaller) + EXPECT_GE(rtol_10, rtol_1) << "10 accums should have >= tolerance than 1"; + EXPECT_GE(rtol_100, rtol_10) << "100 accums should have >= tolerance than 10"; + EXPECT_GE(rtol_1000, rtol_100) << "1000 accums should have >= tolerance than 100"; +} + +TEST_F(GPUVerificationTest, ComputeRelativeTolerance_MixedPrecision) +{ + // Test mixed precision scenarios common in ML + float rtol_fp16_fp32 = compute_relative_tolerance(); + float rtol_fp32_fp32 = compute_relative_tolerance(); + + // FP16 compute with FP32 output should have reasonable tolerance + EXPECT_GT(rtol_fp16_fp32, 0.0f) << "Mixed precision should have non-zero tolerance"; + + // Mixed precision might need larger tolerance than pure FP32 + // (This is implementation-dependent, just document the behavior) + EXPECT_GT(rtol_fp16_fp32, 0.0f); + EXPECT_GT(rtol_fp32_fp32, 0.0f); +} + +// ============================================================================ +// Integration Tests (End-to-End) +// ============================================================================ + +TEST_F(GPUVerificationTest, EndToEnd_ConvolutionLikeWorkload_FP32) +{ + // Simulate a convolution output verification scenario + constexpr size_t size = 256 * 256; // Realistic output size + std::vector kernel_output = GenerateRandomData(size); + std::vector reference_output = kernel_output; // Start identical + + // Add small numerical errors like real kernels might have + for(size_t i = 0; i < size; i += 100) + { + reference_output[i] += 1e-5f; + } + + auto device_kernel = CreateDeviceBuffer(kernel_output); + auto device_ref = CreateDeviceBuffer(reference_output); + + // Should pass with automatic tolerance for FP32 compute + bool result = gpu_verify(device_kernel->GetDeviceBuffer(), + device_ref->GetDeviceBuffer(), + 1000, // Typical number of accumulations in conv + size); + + EXPECT_TRUE(result) << "Realistic convolution output should pass verification"; +} + +TEST_F(GPUVerificationTest, EndToEnd_ConvolutionLikeWorkload_FP16) +{ + // FP16 computation scenario + constexpr size_t size = 128 * 128; + std::vector kernel_output = GenerateRandomData(size); + std::vector reference_output = kernel_output; + + // Add errors within FP16 precision + for(size_t i = 0; i < size; i += 50) + { + float val = ck::type_convert(reference_output[i]); + reference_output[i] = ck::type_convert(val + 1e-3f); + } + + auto device_kernel = CreateDeviceBuffer(kernel_output); + auto device_ref = CreateDeviceBuffer(reference_output); + + bool result = gpu_verify( + device_kernel->GetDeviceBuffer(), device_ref->GetDeviceBuffer(), 1000, size); + + EXPECT_TRUE(result) << "FP16 convolution output should pass verification"; +} + +TEST_F(GPUVerificationTest, EndToEnd_DetectsActualErrors) +{ + // Verify that the system catches real errors + constexpr size_t size = 1024; + std::vector kernel_output = GenerateRandomData(size); + std::vector reference_output = GenerateRandomData(size); // Completely different + + auto device_kernel = CreateDeviceBuffer(kernel_output); + auto device_ref = CreateDeviceBuffer(reference_output); + + // Should fail when data is truly different + bool result = gpu_verify( + device_kernel->GetDeviceBuffer(), device_ref->GetDeviceBuffer(), 1, size); + + EXPECT_FALSE(result) << "System should detect actual errors"; +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}