AXS-6 introduces an innovative 6-bit training format that offers significant advantages over FP8. With 21% memory savings and four times the precision, it boasts a custom Triton kernel for enhanced performance. Ideal for modern AI training, AXS-6 simplifies the process without requiring complex scaling techniques.
AXS-6: Adaptive eXponent Sharing is an innovative 6-bit numerical format designed to improve efficiency in AI training systems. This format achieves a remarkable 21% reduction in memory usage compared to FP8 while offering 4x the mantissa precision, making it an ideal choice for low-precision training in neural networks.
Overview of AXS-6
The AXS-6 format distinguishes itself by utilizing a shared exponent across blocks of 32 values, enabling all remaining bits to be allocated for mantissa precision. This design not only improves training efficiency but also ensures compatibility with straightforward gradient scaling methods (STE), eliminating the need for complex adaptations typical of traditional formats.
┌───────────────────────────────────────────────────────┐
│ Standard FP8 E4M3 (per value) │
│ [sign:1][exponent:4][mantissa:3] = 8 bits, 3-bit │
│ │
│ AXS-6 Block (shared across 32 values) │
│ [shared_exponent:8][config:2] ← 10 bits, once │
│ [sign:1][mantissa:5] × 32 ← 6 bits each │
│ │
│ Effective: 6.31 bits/value, 5-bit precision │
└───────────────────────────────────────────────────────┘
Key Advantages of AXS-6
- Reduction in Memory Use: Each value requires only 6.31 bits, leading to less memory consumption.
- Enhanced Precision: Provides 5 bits of mantissa precision, which is beneficial for minimizing quantization noise in gradients.
- Convergence Capability: Converges reliably using simple STE, making it robust for software-level quantization.
Technical Innovations
AXS-6 incorporates several cutting-edge technologies that enhance its performance:
- Fused NF5 Warp Table designed to optimize quantization processes with faster lookup operations.
- A Custom Triton Kernel that fuses multiple operations into singular GPU passes, achieving significant speed improvements in fake-quantize processes when compared to standard PyTorch functions.
Performance Metrics
Performance comparisons demonstrate AXS-6's efficacy:
| Format | Bits | ms/step | Final Loss | Perplexity | Converges? |
|---|---|---|---|---|---|
| FP32 | 32 | 9.22 | 0.0533 | 1.05 | Yes |
| AXS-6 (Triton) | 6.31 | 11.70 | 0.0537 | 1.06 | Yes |
| NF4 | 4 | 18.57 | 6.8432 | 937 | No |
| FP4 E2M1 | 4 | 24.74 | 6.8645 | 958 | No |
| FP8 E4M3 (naive) | 8 | 30.01 | 7.1399 | 1261 | No |
Ease of Use
AXS-6 is designed for seamless integration into existing workflows with drop-in replacements for standard layers such as:
- AXSLinearUnified for
nn.Linear - AXSLayerNormUnified for
nn.LayerNorm - AXSEmbeddingUnified for
nn.Embedding
Quick Start Example
To quantize a tensor, the following Python code demonstrates the capabilities of AXS-6:
import torch
from axs.unified import fused_fake_quantize
x = torch.randn(128, 256)
# Fast training path (no intermediate allocation)
x_q = fused_fake_quantize(x, block_size=32)
For complete integration and setup guidance, check the detailed README.
AXS-6 primes the path towards efficient, low-memory deep learning applications while maintaining high precision and performance.
No comments yet.
Sign in to be the first to comment.