PitchHut logo
AXS-6: Adaptive eXponent Sharing
Efficient 6-bit format with 21% memory savings and enhanced precision.
Pitch

AXS-6 introduces an innovative 6-bit training format that offers significant advantages over FP8. With 21% memory savings and four times the precision, it boasts a custom Triton kernel for enhanced performance. Ideal for modern AI training, AXS-6 simplifies the process without requiring complex scaling techniques.

Description

AXS-6: Adaptive eXponent Sharing is an innovative 6-bit numerical format designed to improve efficiency in AI training systems. This format achieves a remarkable 21% reduction in memory usage compared to FP8 while offering 4x the mantissa precision, making it an ideal choice for low-precision training in neural networks.

Overview of AXS-6

The AXS-6 format distinguishes itself by utilizing a shared exponent across blocks of 32 values, enabling all remaining bits to be allocated for mantissa precision. This design not only improves training efficiency but also ensures compatibility with straightforward gradient scaling methods (STE), eliminating the need for complex adaptations typical of traditional formats.

┌───────────────────────────────────────────────────────┐
│  Standard FP8 E4M3 (per value)                        │
│  [sign:1][exponent:4][mantissa:3] = 8 bits, 3-bit     │
│                                                       │
│  AXS-6 Block (shared across 32 values)                │
│  [shared_exponent:8][config:2]  ← 10 bits, once       │
│  [sign:1][mantissa:5] × 32     ← 6 bits each          │
│                                                       │
│  Effective: 6.31 bits/value, 5-bit precision           │
└───────────────────────────────────────────────────────┘

Key Advantages of AXS-6

  • Reduction in Memory Use: Each value requires only 6.31 bits, leading to less memory consumption.
  • Enhanced Precision: Provides 5 bits of mantissa precision, which is beneficial for minimizing quantization noise in gradients.
  • Convergence Capability: Converges reliably using simple STE, making it robust for software-level quantization.

Technical Innovations

AXS-6 incorporates several cutting-edge technologies that enhance its performance:

  • Fused NF5 Warp Table designed to optimize quantization processes with faster lookup operations.
  • A Custom Triton Kernel that fuses multiple operations into singular GPU passes, achieving significant speed improvements in fake-quantize processes when compared to standard PyTorch functions.

Performance Metrics

Performance comparisons demonstrate AXS-6's efficacy:

FormatBitsms/stepFinal LossPerplexityConverges?
FP32329.220.05331.05Yes
AXS-6 (Triton)6.3111.700.05371.06Yes
NF4418.576.8432937No
FP4 E2M1424.746.8645958No
FP8 E4M3 (naive)830.017.13991261No

Ease of Use

AXS-6 is designed for seamless integration into existing workflows with drop-in replacements for standard layers such as:

  • AXSLinearUnified for nn.Linear
  • AXSLayerNormUnified for nn.LayerNorm
  • AXSEmbeddingUnified for nn.Embedding

Quick Start Example

To quantize a tensor, the following Python code demonstrates the capabilities of AXS-6:

import torch
from axs.unified import fused_fake_quantize

x = torch.randn(128, 256)

# Fast training path (no intermediate allocation)
x_q = fused_fake_quantize(x, block_size=32)

For complete integration and setup guidance, check the detailed README.

AXS-6 primes the path towards efficient, low-memory deep learning applications while maintaining high precision and performance.

0 comments

No comments yet.

Sign in to be the first to comment.