Skip to content

AMD-AGI/Primus-Turbo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Primus-Turbo

Primus-Turbo is a high-performance acceleration library dedicated to large-scale model training on AMD GPUs. Built and optimized for the AMD ROCm platform, it covers the full training stack โ€” including core compute operators (GEMM, Attention, GroupedGEMM), communication primitives, optimizer modules, low-precision computation (FP8), and computeโ€“communication overlap kernels.

With High Performance, Full-Featured, and Developer-Friendly as its guiding principles, Primus-Turbo is designed to fully unleash the potential of AMD GPUs for large-scale training workloads, offering a robust and complete acceleration foundation for next-generation AI systems.

Note: JAX support is under active development. Optim support is planned but not yet available.

๐Ÿš€ What's New

๐Ÿงฉ Primus Product Matrix

Module Role Key Features
Primus-LM E2E training framework - Supports multiple training backends (Megatron, TorchTitan, etc.)
- Provides high-performance, scalable distributed training
- Deeply integrates with Primus-Turbo and Primus-SaFE
Primus-Turbo High-performance operators & modules - Supports core training operators and modules (FlashAttention, GEMM, GroupedGemm, DeepEP etc.)
- Integrates multiple high-performance backends (e.g., CK, hipBLASLt, AITER)
- High performance and easy to integrate
Primus-SaFE Stability & platform layer - Cluster sanity check and benchmarking
- Kubernetes scheduling with topology awareness
- Fault tolerance
- Stability enhancements

๐Ÿ“ฆ Quick Start

Requirements

Software

  • ROCm >= 6.4
  • Python >= 3.10
  • PyTorch >= 2.6.0 (with ROCm support)
  • rocSHMEM (optional, required for experimental DeepEP). Please refer to our DeepEP Installation Guide for instructions.

Hardware

Architecture Supported GPUs
GFX942 โœ…MI300X, โœ…MI325X
GFX950 โœ…MI350X, โœ…MI355X

See AMD GPU Architecture to find the architecture for your GPU.

1. Installation

Docker (Recommended)

Use the pre-built AMD ROCm image from Docker Hub:

# PyTorch Ecosystem
rocm/primus:v25.10

# JAX Ecosystem
rocm/jax-training:maxtext-v25.9

Install from Source

git clone https://github.com/AMD-AGI/Primus-Turbo.git --recursive
cd Primus-Turbo

pip3 install -r requirements.txt
pip3 install --no-build-isolation .

# (Optional) Set GPU_ARCHS environment variable to specify target AMD GPU architectures.
GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation .

2. Development

For contributors, use editable mode (-e) so that code changes take effect immediately without reinstalling.

git clone https://github.com/AMD-AGI/Primus-Turbo.git --recursive
cd Primus-Turbo

pip3 install -r requirements.txt
pip3 install --no-build-isolation -e . -v

# (Optional) Set GPU_ARCHS environment variable to specify target AMD GPU architectures.
GPU_ARCHS="gfx942;gfx950" pip3 install --no-build-isolation -e . -v

# (Optional) Set PRIMUS_TURBO_FRAMEWORK to compile for a specific framework.
# Supported values: PYTORCH (default), JAX.
# For example, to compile for JAX:
PRIMUS_TURBO_FRAMEWORK="JAX" pip3 install --no-build-isolation -e . -v

3. Testing

Option 1: Single-process mode (slow but simple)

pytest tests/pytorch/    # run all PyTorch tests
pytest tests/jax/        # run all JAX tests

Option 2: Multi-process mode (faster)

# PyTorch tests
pytest tests/pytorch/ -n 8        # single-GPU tests (parallel)
pytest tests/pytorch/ --dist-only # multi-GPU tests

# JAX tests
pytest tests/jax/ -n 8            # single-GPU tests (parallel)
pytest tests/jax/ --dist-only     # multi-GPU tests

4. Packaging

pip3 install -r requirements.txt
python3 -m build --wheel --no-isolation
pip3 install --extra-index-url https://test.pypi.org/simple ./dist/primus_turbo-XXX.whl

5. Minimal Example

import torch
import primus_turbo.pytorch as turbo

dtype = torch.bfloat16
device = "cuda:0"

a = torch.randn((128, 256), dtype=dtype, device=device)
b = torch.randn((256, 512), dtype=dtype, device=device)
c = turbo.ops.gemm(a, b)

print(c)
print(c.shape)

๐Ÿ’ก Example

See Examples for usage examples.

๐Ÿ“Š Performance

See Benchmarks for detailed performance results and comparisons.

๐Ÿ“ Roadmap

Roadmap: Primus-Turbo Roadmap H2 2025

๐Ÿ“œ License

Primus-Turbo is licensed under the MIT License.

ยฉ 2025 Advanced Micro Devices, Inc. All rights reserved.

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published