Reduce sched_yield overhead with time-based spinning#2017
Reduce sched_yield overhead with time-based spinning#2017hodgesds wants to merge 1 commit intoNVIDIA:masterfrom
Conversation
Under system saturation, sched_yield has multi-millisecond tail latency (up to 4ms at 100% CPU). Replace immediate sched_yield with time-based spinning: - Spin with CPU pause instruction for configurable duration (default 1µs) - Only yield after spin timeout, then reset timer - Configurable via environment variables Changes: - utils.h: Add ncclCpuRelax() cross-platform CPU pause helper - proxy.cc: Time-based spin in freeOps pool wait loop - proxy.cc: Time-based spin in progress loop idle path - doca_gpunetio.cpp: Replace yield with pause in service mainloop Environment variables: - NCCL_PROXY_SPIN_TIME_NS: freeOps wait spin duration (default 1000) - NCCL_PROXY_PROGRESS_SPIN_TIME_NS: progress loop spin duration (default 1000) - Set to 0 to restore original always-yield behavior The pause instruction (~43 cycles on x86) allows hyperthreads to run while avoiding syscall overhead. Signed-off-by: Daniel Hodges <hodgesd@meta.com>
NCCL sched_yield Optimization BenchmarksTest EnvironmentCPU
NUMA LayoutGPU Topology
Software
1. Per-Operation OverheadMeasured cost of individual operations:
2. sched_yield Under CPU Load
Under full CPU load,
3. Wake-up Latency ComparisonTime from flag set to consumer thread noticing (cycles):
Timed spinning provides consistent latency with bounded max. 4. Spin Instruction ComparisonComparing different spin-wait approaches:
5. NCCL AllReduce Benchmark8x H100 GPUs, 64MB buffers, 500 iterations:
Performance is nearly identical on NVLink single host without additional load
The optimization benefits are expected on:
6. Yield Contention BenchmarkSimulating NCCL proxy free-ops pool contention pattern (64 threads, 10 seconds):
This user-space simulation shows modest improvement. Real NCCL workloads with ConfigurationEnvironment variables for tuning: # Spin duration for freeOps pool wait (default: 1000ns)
export NCCL_PROXY_SPIN_TIME_NS=1000
# Spin duration for progress loop idle (default: 1000ns)
export NCCL_PROXY_PROGRESS_SPIN_TIME_NS=1000
# Disable spinning (original behavior)
export NCCL_PROXY_SPIN_TIME_NS=0
export NCCL_PROXY_PROGRESS_SPIN_TIME_NS=0Benchmark Source CodeLatency Microbenchmark (latency_bench.cpp)Measures per-operation overhead and wake-up latency. /*
* Microbenchmark to measure:
* 1. Per-operation overhead: pause, sched_yield, clock_gettime
* 2. Wake-up latency: time from flag set to consumer noticing
*
* Build:
* g++ -O2 -pthread -o latency_bench.cpp
*
* Run:
* ./latency_bench
*
*/
#include <atomic>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <thread>
#include <vector>
#include <sched.h>
#include <time.h>
#include <x86intrin.h> // For __rdtsc
// Number of iterations for averaging
static constexpr int WARMUP_ITERS = 1000;
static constexpr int MEASURE_ITERS = 100000;
static constexpr int LATENCY_ITERS = 10000;
// CPU pause instruction
static inline void cpu_pause() {
#if defined(__x86_64__) || defined(__i386__)
__asm__ __volatile__("pause" ::: "memory");
#elif defined(__aarch64__)
__asm__ __volatile__("yield" ::: "memory");
#else
__asm__ __volatile__("" ::: "memory");
#endif
}
// Get TSC cycles
static inline uint64_t rdtsc() {
return __rdtsc();
}
// Get monotonic time in nanoseconds
static inline uint64_t clock_nano() {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return static_cast<uint64_t>(ts.tv_sec) * 1000000000ULL + ts.tv_nsec;
}
// Measure CPU frequency
double measure_cpu_freq_ghz() {
uint64_t t0 = clock_nano();
uint64_t c0 = rdtsc();
// Busy wait for 100ms
while (clock_nano() - t0 < 100000000ULL) {
cpu_pause();
}
uint64_t t1 = clock_nano();
uint64_t c1 = rdtsc();
double elapsed_ns = static_cast<double>(t1 - t0);
double cycles = static_cast<double>(c1 - c0);
return cycles / elapsed_ns; // GHz
}
struct OpStats {
double min_cycles;
double max_cycles;
double avg_cycles;
double p50_cycles;
double p99_cycles;
};
OpStats compute_stats(std::vector<uint64_t>& samples) {
std::sort(samples.begin(), samples.end());
OpStats stats;
stats.min_cycles = static_cast<double>(samples.front());
stats.max_cycles = static_cast<double>(samples.back());
uint64_t sum = 0;
for (auto s : samples) sum += s;
stats.avg_cycles = static_cast<double>(sum) / samples.size();
stats.p50_cycles = static_cast<double>(samples[samples.size() / 2]);
stats.p99_cycles = static_cast<double>(samples[samples.size() * 99 / 100]);
return stats;
}
void measure_pause() {
printf("\n=== Measuring pause instruction ===\n");
std::vector<uint64_t> samples;
samples.reserve(MEASURE_ITERS);
for (int i = 0; i < WARMUP_ITERS; i++) {
cpu_pause();
}
for (int i = 0; i < MEASURE_ITERS; i++) {
uint64_t c0 = rdtsc();
cpu_pause();
uint64_t c1 = rdtsc();
samples.push_back(c1 - c0);
}
OpStats stats = compute_stats(samples);
printf(" min: %.0f cycles\n", stats.min_cycles);
printf(" avg: %.0f cycles\n", stats.avg_cycles);
printf(" p50: %.0f cycles\n", stats.p50_cycles);
printf(" p99: %.0f cycles\n", stats.p99_cycles);
printf(" max: %.0f cycles\n", stats.max_cycles);
}
void measure_clock_gettime() {
printf("\n=== Measuring clock_gettime(CLOCK_MONOTONIC) ===\n");
std::vector<uint64_t> samples;
samples.reserve(MEASURE_ITERS);
for (int i = 0; i < WARMUP_ITERS; i++) {
clock_nano();
}
for (int i = 0; i < MEASURE_ITERS; i++) {
uint64_t c0 = rdtsc();
clock_nano();
uint64_t c1 = rdtsc();
samples.push_back(c1 - c0);
}
OpStats stats = compute_stats(samples);
printf(" min: %.0f cycles\n", stats.min_cycles);
printf(" avg: %.0f cycles\n", stats.avg_cycles);
printf(" p50: %.0f cycles\n", stats.p50_cycles);
printf(" p99: %.0f cycles\n", stats.p99_cycles);
printf(" max: %.0f cycles\n", stats.max_cycles);
}
void measure_sched_yield() {
printf("\n=== Measuring sched_yield() ===\n");
std::vector<uint64_t> samples;
samples.reserve(MEASURE_ITERS);
for (int i = 0; i < WARMUP_ITERS; i++) {
sched_yield();
}
for (int i = 0; i < MEASURE_ITERS; i++) {
uint64_t c0 = rdtsc();
sched_yield();
uint64_t c1 = rdtsc();
samples.push_back(c1 - c0);
}
OpStats stats = compute_stats(samples);
printf(" min: %.0f cycles\n", stats.min_cycles);
printf(" avg: %.0f cycles\n", stats.avg_cycles);
printf(" p50: %.0f cycles\n", stats.p50_cycles);
printf(" p99: %.0f cycles\n", stats.p99_cycles);
printf(" max: %.0f cycles\n", stats.max_cycles);
}
struct alignas(64) SharedState {
std::atomic<uint64_t> flag{0};
std::atomic<bool> ready{false};
std::atomic<bool> done{false};
uint64_t consumer_noticed;
char padding[64];
};
enum class WaitMode { SPIN_PAUSE, SPIN_YIELD, SPIN_TIMED };
void consumer_thread(SharedState* state, WaitMode mode, int64_t spin_time_ns) {
while (!state->done.load(std::memory_order_relaxed)) {
while (!state->ready.load(std::memory_order_acquire)) {
cpu_pause();
}
uint64_t flag_value = 0;
if (mode == WaitMode::SPIN_PAUSE) {
while ((flag_value = state->flag.load(std::memory_order_acquire)) == 0) {
cpu_pause();
}
} else if (mode == WaitMode::SPIN_YIELD) {
while ((flag_value = state->flag.load(std::memory_order_acquire)) == 0) {
sched_yield();
}
} else {
uint64_t t0 = clock_nano();
while ((flag_value = state->flag.load(std::memory_order_acquire)) == 0) {
if (clock_nano() - t0 < static_cast<uint64_t>(spin_time_ns)) {
cpu_pause();
} else {
sched_yield();
t0 = clock_nano();
}
}
}
state->consumer_noticed = rdtsc();
state->ready.store(false, std::memory_order_release);
}
}
void measure_wakeup_latency(WaitMode mode, const char* mode_name, int64_t spin_time_ns = 0) {
printf("\n=== Measuring wake-up latency: %s ===\n", mode_name);
SharedState state;
std::thread consumer(consumer_thread, &state, mode, spin_time_ns);
std::vector<uint64_t> latencies;
latencies.reserve(LATENCY_ITERS);
for (int i = 0; i < 100; i++) {
state.ready.store(true, std::memory_order_release);
for (int j = 0; j < 100; j++) cpu_pause();
uint64_t t0 = rdtsc();
state.flag.store(t0, std::memory_order_release);
while (state.ready.load(std::memory_order_acquire)) {
cpu_pause();
}
state.flag.store(0, std::memory_order_relaxed);
}
for (int i = 0; i < LATENCY_ITERS; i++) {
state.ready.store(true, std::memory_order_release);
int delay = (i * 7) % 1000;
for (int j = 0; j < delay; j++) cpu_pause();
uint64_t t0 = rdtsc();
state.flag.store(t0, std::memory_order_release);
while (state.ready.load(std::memory_order_acquire)) {
cpu_pause();
}
uint64_t latency = state.consumer_noticed - t0;
latencies.push_back(latency);
state.flag.store(0, std::memory_order_relaxed);
}
state.done.store(true, std::memory_order_release);
state.ready.store(true, std::memory_order_release);
state.flag.store(1, std::memory_order_release);
consumer.join();
OpStats stats = compute_stats(latencies);
printf(" min: %.0f cycles\n", stats.min_cycles);
printf(" avg: %.0f cycles\n", stats.avg_cycles);
printf(" p50: %.0f cycles\n", stats.p50_cycles);
printf(" p99: %.0f cycles\n", stats.p99_cycles);
printf(" max: %.0f cycles\n", stats.max_cycles);
}
std::atomic<bool> g_stress_running{true};
void stress_worker() {
volatile uint64_t x = 0;
while (g_stress_running.load(std::memory_order_relaxed)) {
x = x * 7 + 13;
}
}
void measure_sched_yield_under_load(int num_stress_threads) {
printf("\n=== Measuring sched_yield() under load (%d stress threads) ===\n",
num_stress_threads);
std::vector<std::thread> stress_threads;
g_stress_running.store(true);
for (int i = 0; i < num_stress_threads; i++) {
stress_threads.emplace_back(stress_worker);
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::vector<uint64_t> samples;
samples.reserve(MEASURE_ITERS);
for (int i = 0; i < MEASURE_ITERS; i++) {
uint64_t c0 = rdtsc();
sched_yield();
uint64_t c1 = rdtsc();
samples.push_back(c1 - c0);
}
g_stress_running.store(false);
for (auto& t : stress_threads) {
t.join();
}
OpStats stats = compute_stats(samples);
printf(" min: %.0f cycles\n", stats.min_cycles);
printf(" avg: %.0f cycles\n", stats.avg_cycles);
printf(" p50: %.0f cycles\n", stats.p50_cycles);
printf(" p99: %.0f cycles\n", stats.p99_cycles);
printf(" max: %.0f cycles\n", stats.max_cycles);
}
int main(int argc, char* argv[]) {
printf("Latency Microbenchmark\n");
printf("======================\n");
double cpu_ghz = measure_cpu_freq_ghz();
printf("\nCPU frequency: %.2f GHz\n", cpu_ghz);
printf("(1 cycle = %.2f ns)\n", 1.0 / cpu_ghz);
printf("\n\n### PART 1: Per-operation overhead ###\n");
measure_pause();
measure_clock_gettime();
measure_sched_yield();
printf("\n\n### PART 2: Wake-up latency ###\n");
measure_wakeup_latency(WaitMode::SPIN_PAUSE, "spin with pause");
measure_wakeup_latency(WaitMode::SPIN_YIELD, "immediate yield (old NCCL)");
measure_wakeup_latency(WaitMode::SPIN_TIMED, "timed spin 1us (new NCCL)", 1000);
measure_wakeup_latency(WaitMode::SPIN_TIMED, "timed spin 5us", 5000);
printf("\n\n### PART 3: sched_yield under CPU load ###\n");
int num_cpus = std::thread::hardware_concurrency();
printf("System has %d CPUs\n", num_cpus);
measure_sched_yield_under_load(num_cpus / 2);
measure_sched_yield_under_load(num_cpus);
measure_sched_yield_under_load(num_cpus * 2);
printf("\n\nDone.\n");
return 0;
}NCCL AllReduce Benchmark (nccl_yield_bench.cu)Tests actual NCCL collective operations. /*
* NCCL AllReduce Benchmark to reproduce sched_yield contention
*
* Build (from NCCL root):
* nvcc -O2 -o nccl_yield_bench nccl_yield_bench.cu \
* -I./build/include -L./build/lib -lnccl -lcudart -lpthread
*
* Run:
* LD_LIBRARY_PATH=./build/lib ./nccl_yield_bench
*
* Profile:
* perf stat -e syscalls:sys_enter_sched_yield \
* LD_LIBRARY_PATH=./build/lib ./nccl_yield_bench
*/
#include <cuda_runtime.h>
#include <nccl.h>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <thread>
#include <vector>
#define CUDACHECK(cmd) do { \
cudaError_t err = cmd; \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA error %s:%d '%s'\n", \
__FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
} while(0)
#define NCCLCHECK(cmd) do { \
ncclResult_t res = cmd; \
if (res != ncclSuccess) { \
fprintf(stderr, "NCCL error %s:%d '%s'\n", \
__FILE__, __LINE__, ncclGetErrorString(res)); \
exit(EXIT_FAILURE); \
} \
} while(0)
struct BenchConfig {
size_t buffer_size = 64 * 1024 * 1024; // 64MB per GPU
int num_iterations = 1000;
int warmup_iterations = 100;
bool sync_each_iter = true;
};
void run_benchmark(int num_gpus, const BenchConfig& config) {
printf("Running NCCL benchmark with %d GPUs\n", num_gpus);
printf(" Buffer size: %zu MB\n", config.buffer_size / (1024 * 1024));
printf(" Iterations: %d\n", config.num_iterations);
printf(" Sync each iteration: %s\n", config.sync_each_iter ? "yes" : "no");
const char* spin_env = getenv("NCCL_PROXY_SPIN_TIME_NS");
if (spin_env) {
printf(" NCCL_PROXY_SPIN_TIME_NS: %s\n", spin_env);
} else {
printf(" NCCL_PROXY_SPIN_TIME_NS: (default)\n");
}
printf("\n");
std::vector<ncclComm_t> comms(num_gpus);
std::vector<cudaStream_t> streams(num_gpus);
std::vector<float*> send_buffers(num_gpus);
std::vector<float*> recv_buffers(num_gpus);
ncclUniqueId id;
NCCLCHECK(ncclGetUniqueId(&id));
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
CUDACHECK(cudaStreamCreate(&streams[i]));
CUDACHECK(cudaMalloc(&send_buffers[i], config.buffer_size));
CUDACHECK(cudaMalloc(&recv_buffers[i], config.buffer_size));
CUDACHECK(cudaMemset(send_buffers[i], i + 1, config.buffer_size));
}
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
NCCLCHECK(ncclCommInitRank(&comms[i], num_gpus, id, i));
}
NCCLCHECK(ncclGroupEnd());
size_t count = config.buffer_size / sizeof(float);
printf("Warming up...\n");
for (int iter = 0; iter < config.warmup_iterations; iter++) {
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
NCCLCHECK(ncclAllReduce(send_buffers[i], recv_buffers[i], count,
ncclFloat, ncclSum, comms[i], streams[i]));
}
NCCLCHECK(ncclGroupEnd());
if (config.sync_each_iter) {
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
CUDACHECK(cudaStreamSynchronize(streams[i]));
}
}
}
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
CUDACHECK(cudaDeviceSynchronize());
}
printf("Running benchmark...\n");
auto start = std::chrono::high_resolution_clock::now();
for (int iter = 0; iter < config.num_iterations; iter++) {
NCCLCHECK(ncclGroupStart());
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
NCCLCHECK(ncclAllReduce(send_buffers[i], recv_buffers[i], count,
ncclFloat, ncclSum, comms[i], streams[i]));
}
NCCLCHECK(ncclGroupEnd());
if (config.sync_each_iter) {
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
CUDACHECK(cudaStreamSynchronize(streams[i]));
}
}
}
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
CUDACHECK(cudaDeviceSynchronize());
}
auto end = std::chrono::high_resolution_clock::now();
double elapsed_ms = std::chrono::duration<double, std::milli>(end - start).count();
double algo_bw = 2.0 * (num_gpus - 1.0) / num_gpus * config.buffer_size;
double total_bytes = algo_bw * config.num_iterations;
double bw_gbps = (total_bytes / (elapsed_ms / 1000.0)) / (1024.0 * 1024.0 * 1024.0);
printf("\nResults:\n");
printf(" Total time: %.2f ms\n", elapsed_ms);
printf(" Time per iteration: %.3f ms\n", elapsed_ms / config.num_iterations);
printf(" Throughput: %.2f iterations/sec\n", config.num_iterations / (elapsed_ms / 1000.0));
printf(" Algorithm bandwidth: %.2f GB/s\n", bw_gbps);
for (int i = 0; i < num_gpus; i++) {
CUDACHECK(cudaSetDevice(i));
NCCLCHECK(ncclCommDestroy(comms[i]));
CUDACHECK(cudaStreamDestroy(streams[i]));
CUDACHECK(cudaFree(send_buffers[i]));
CUDACHECK(cudaFree(recv_buffers[i]));
}
}
int main(int argc, char* argv[]) {
int num_gpus;
CUDACHECK(cudaGetDeviceCount(&num_gpus));
if (num_gpus < 1) {
fprintf(stderr, "No GPUs found\n");
return 1;
}
printf("Found %d GPUs\n", num_gpus);
for (int i = 0; i < num_gpus; i++) {
cudaDeviceProp prop;
CUDACHECK(cudaGetDeviceProperties(&prop, i));
printf(" GPU %d: %s\n", i, prop.name);
}
printf("\n");
BenchConfig config;
for (int i = 1; i < argc; i++) {
if (strncmp(argv[i], "--size=", 7) == 0) {
config.buffer_size = atoll(argv[i] + 7) * 1024 * 1024;
} else if (strncmp(argv[i], "--iters=", 8) == 0) {
config.num_iterations = atoi(argv[i] + 8);
} else if (strncmp(argv[i], "--warmup=", 9) == 0) {
config.warmup_iterations = atoi(argv[i] + 9);
} else if (strcmp(argv[i], "--no-sync") == 0) {
config.sync_each_iter = false;
} else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
printf("Usage: %s [options]\n", argv[0]);
printf("Options:\n");
printf(" --size=N Buffer size in MB (default: 64)\n");
printf(" --iters=N Number of iterations (default: 1000)\n");
printf(" --warmup=N Warmup iterations (default: 100)\n");
printf(" --no-sync Don't sync between iterations\n");
return 0;
}
}
run_benchmark(num_gpus, config);
return 0;
} |
|
bpftrace script for tracking yield latency: and output: |
| } | ||
| pthread_rwlock_unlock(&service->service_lock); | ||
| sched_yield(); | ||
| cpuRelax(); |
There was a problem hiding this comment.
fyi, the pause instruction has variable latency per micro-architecture. For example, the Intel Skylake has 10x more latency per pause instruction compared to the previous generation. A more robust approach might be to calibrate the pause latency during init and insert an approximate known timed delay.
xerothermic
left a comment
There was a problem hiding this comment.
Watch out for the variable pause latency, otherwise looks good.
Under system saturation, sched_yield has multi-millisecond tail latency (up to 4ms at 100% CPU).
Replace immediate sched_yield with time-based spinning:
Changes:
Environment variables:
The pause instruction (~43 cycles on x86) allows hyperthreads to run while avoiding syscall overhead.