File size: 7,114 Bytes
a262a48 6147147 a262a48 8f19e61 a262a48 8f19e61 a262a48 a38f7ad a262a48 a38f7ad a262a48 a38f7ad a262a48 ed21233 a262a48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
---
license: apache-2.0
tags:
- kernels
- sae
---
# Flex SAE Kernels
[](https://arxiv.org/abs/2505.24473)
Fused Triton implementations of the TopK and HierarchicalTopK sparse autoencoder (SAE) decoder losses described in *Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy*.
**This work has been accepted to [EMNLP 2025](https://2025.emnlp.org/).**
## What is released?
- Fast TopK kernel for SAE (slightly modified version from xformers) `torch-ext/flex_sae/topk_kernels.py`
- Fast HierarchicalTopK kernels (see our [paper](https://arxiv.org/abs/2505.24473)) `torch-ext/flex_sae/hierarchical_kernels.py`.
## Quickstart
Kernels are available via loading from hub, they have the following signature:
```python
from kernels import get_kernel
flex = get_kernel('t-tech/flex-sae')
top_k_kernel = flex.triton_topk_sae_loss
hierarchical_top_k_kernel = flex.triton_hierarchical_sae_loss
"B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"
loss: torch.Tensor = top_k_kernel(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
bias: torch.Tensor, # [D]
target: torch.Tensor, # [B, D]
)
loss: torch.Tensor = hierarchical_top_k_kernel(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
bias: torch.Tensor, # [D]
target: torch.Tensor, # [B, D]
)
```
## Overview
- `torch-ext/flex_sae/` contains the Triton kernels alongside torch reference implementations.
- `tests/` hosts CUDA-backed property tests that ensure numerical parity across dtypes and kernels.
- `build.toml`, `flake.nix` integrate the project with [Hugging Face kernel-builder](https://github.com/huggingface/kernel-builder).
The Triton kernels target CUDA GPUs and focus on reducing the latency gap between TopK and HierarchicalTopK decoders while keeping memory usage flat.
## Example
You can find example usage in [example.py](https://huggingface.co/t-tech/flex-sae/blob/main/example.py).
```python
# /// script
# dependencies = [
# "torch",
# "numpy",
# "kernels",
# ]
# ///
import torch
import numpy as np
from kernels import get_kernel
flex = get_kernel("t-tech/flex-sae") #Fast Kernels
@torch.compile(fullgraph=True)
def hierarchical_sae_loss(
indices: torch.Tensor, # [B, K]
weight: torch.Tensor, # [F, D]
vals: torch.Tensor, # [B, K]
bias: torch.Tensor, # [D]
target: torch.Tensor, # [B, D]
) -> torch.Tensor:
emb = weight[indices].to(torch.float32) # [K, D]
recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
loss = diff.pow(2).mean()
return loss
B = 2048
K = 256
F = 1024 * 128
D = 1024
WARMUP = 5
NUM_ITER = 100
dtype = torch.float32
vals = None
decoder = None
bias = None
target = None
indices = None
def init_parameters():
global vals, decoder, bias, target, indices
vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
target = torch.randn(B, D, dtype=dtype, device="cuda")
indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")
timing_kernel = []
timing_vanilla = []
torch.cuda.reset_peak_memory_stats()
loss_kernel_list = torch.zeros((100,))
loss_vanilla_list = torch.zeros((100,))
def zero_grad():
vals.grad = None
decoder.grad = None
bias.grad = None
torch.cuda.empty_cache()
for i in range(NUM_ITER + WARMUP):
init_parameters()
start_kernel = torch.cuda.Event(enable_timing=True)
end_kernel = torch.cuda.Event(enable_timing=True)
start_vanilla = torch.cuda.Event(enable_timing=True)
end_vanilla = torch.cuda.Event(enable_timing=True)
start_kernel.record()
loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
loss_kernel.backward()
end_kernel.record()
zero_grad()
start_vanilla.record()
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
loss_vanilla.backward()
end_vanilla.record()
if i >= WARMUP:
torch.cuda.synchronize()
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
loss_kernel_list[i-WARMUP] = loss_kernel.detach()
loss_vanilla_list[i-WARMUP] = loss_vanilla.detach()
zero_grad()
if torch.allclose(loss_kernel, loss_vanilla):
print("β
Outputs are close! Everything is good! π")
else:
print("β Outputs mismatch... β οΈπ€")
print(f"π¦ Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} Β± {np.std(timing_kernel):.4f} ms")
print(f"π₯ Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} Β± {np.std(timing_vanilla):.4f} ms")
print(f"π Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
```
Run it with `uv run https://huggingface.co/t-tech/flex-sae/resolve/main/example.py`.
## Performance
Benchmarks were collected on a workload with dictionary size $F = 65 536$, embedding dimension $D = 2304$, and sparsity budgets $K \in \{32, 64, 128\}$. Latency is reported as time per training step (milliseconds) and memory as peak device usage (GiB).
| Decoder backend | K=32 (ms / GiB) | K=64 (ms / GiB) | K=128 (ms / GiB) |
| --- | --- | --- | --- |
| **Pure torch-compiled** | | | |
| TopK | 8.787 / 2.92 | 11.746 / 2.92 | 18.877 / 2.93 |
| HierarchicalTopK | 12.824 / 6.29 | 23.379 / 10.79 | 43.851 / 19.80 |
| **Triton kernels** | | | |
| TopK | 5.576 / 2.92 | 6.339 / 2.92 | 7.961 / 2.93 |
| HierarchicalTopK | **6.696 / 2.92** | **7.995 / 2.92** | **10.609 / 2.93** |
Across the evaluated sparsity budgets the fused Triton HierarchicalTopK kernel matches TopK kernels on memory use while remaining consistently faster than the reference torch implementation.
## License & Attribution
- All files except `torch-ext/flex_sae/topk_kernels.py` are released under the [Apache License 2.0](LICENSE).
- `torch-ext/flex_sae/topk_kernels.py` includes code adapted from Facebook Research's [memory](https://github.com/facebookresearch/memory) project, originally published under the Creative Commons Attribution-NonCommercial 4.0 International License. That component therefore remains available for non-commercial use only; see [NOTICE](NOTICE) for details.
## Citation
```bibtex
@misc{balagansky2025trainsparseautoencodermultiple,
title={Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy},
author={Nikita Balagansky and Yaroslav Aksenov and Daniil Laptev and Vadim Kurochkin and Gleb Gerasimov and Nikita Koryagin and Daniil Gavrilov},
year={2025},
eprint={2505.24473},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.24473},
}
```
|