Kernels
sae
File size: 7,114 Bytes
a262a48
 
 
6147147
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f19e61
a262a48
 
 
 
 
 
 
8f19e61
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38f7ad
 
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38f7ad
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38f7ad
a262a48
 
 
ed21233
 
a262a48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
---
license: apache-2.0
tags:
  - kernels
  - sae
---
# Flex SAE Kernels

[![ArXiv](https://img.shields.io/badge/arXiv-2505.24473-b31b1b.svg)](https://arxiv.org/abs/2505.24473)

Fused Triton implementations of the TopK and HierarchicalTopK sparse autoencoder (SAE) decoder losses described in *Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy*.

**This work has been accepted to [EMNLP 2025](https://2025.emnlp.org/).**

## What is released?

 - Fast TopK kernel for SAE (slightly modified version from xformers) `torch-ext/flex_sae/topk_kernels.py`
 - Fast HierarchicalTopK kernels (see our [paper](https://arxiv.org/abs/2505.24473)) `torch-ext/flex_sae/hierarchical_kernels.py`.


## Quickstart

Kernels are available via loading from hub, they have the following signature:
```python
from kernels import get_kernel


flex = get_kernel('t-tech/flex-sae')

top_k_kernel = flex.triton_topk_sae_loss
hierarchical_top_k_kernel = flex.triton_hierarchical_sae_loss

"B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"

loss: torch.Tensor = top_k_kernel(
    indices: torch.Tensor,  # [B, K]
    weight: torch.Tensor,  # [F, D]
    vals: torch.Tensor,  # [B, K]
    bias: torch.Tensor,  # [D]
    target: torch.Tensor,  # [B, D]
)

loss: torch.Tensor = hierarchical_top_k_kernel(
    indices: torch.Tensor,  # [B, K]
    weight: torch.Tensor,  # [F, D]
    vals: torch.Tensor,  # [B, K]
    bias: torch.Tensor,  # [D]
    target: torch.Tensor,  # [B, D]
)
```

## Overview
- `torch-ext/flex_sae/` contains the Triton kernels alongside torch reference implementations.
- `tests/` hosts CUDA-backed property tests that ensure numerical parity across dtypes and kernels.
- `build.toml`, `flake.nix` integrate the project with [Hugging Face kernel-builder](https://github.com/huggingface/kernel-builder).

The Triton kernels target CUDA GPUs and focus on reducing the latency gap between TopK and HierarchicalTopK decoders while keeping memory usage flat.

## Example

You can find example usage in [example.py](https://huggingface.co/t-tech/flex-sae/blob/main/example.py). 
```python
# /// script
# dependencies = [
#   "torch",
#   "numpy",
#   "kernels",
# ]
# ///

import torch
import numpy as np
from kernels import get_kernel

flex = get_kernel("t-tech/flex-sae")  #Fast Kernels

@torch.compile(fullgraph=True)
def hierarchical_sae_loss(
    indices: torch.Tensor,  # [B, K]
    weight: torch.Tensor,  # [F, D]
    vals: torch.Tensor,  # [B, K]
    bias: torch.Tensor,  # [D]
    target: torch.Tensor,  # [B, D]
) -> torch.Tensor:
    emb = weight[indices].to(torch.float32)  # [K, D]
    recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1)
    diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1)
    loss = diff.pow(2).mean()
    return loss


B = 2048
K = 256
F = 1024 * 128
D = 1024
WARMUP = 5
NUM_ITER = 100
dtype = torch.float32

vals = None
decoder = None
bias = None
target = None
indices = None


def init_parameters():
    global vals, decoder, bias, target, indices
    vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_()
    decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True)
    bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True)
    target = torch.randn(B, D, dtype=dtype, device="cuda")
    indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda")


timing_kernel = []
timing_vanilla = []
torch.cuda.reset_peak_memory_stats()
loss_kernel_list = torch.zeros((100,))
loss_vanilla_list = torch.zeros((100,))


def zero_grad():
    vals.grad = None
    decoder.grad = None
    bias.grad = None
    torch.cuda.empty_cache()


for i in range(NUM_ITER + WARMUP):
    init_parameters()
    start_kernel = torch.cuda.Event(enable_timing=True)
    end_kernel = torch.cuda.Event(enable_timing=True)
    start_vanilla = torch.cuda.Event(enable_timing=True)
    end_vanilla = torch.cuda.Event(enable_timing=True)

    start_kernel.record()
    loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target)
    loss_kernel.backward()
    end_kernel.record()

    zero_grad()
    start_vanilla.record()
    loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
    loss_vanilla.backward()
    end_vanilla.record()
    if i >= WARMUP:
        torch.cuda.synchronize()
        timing_kernel.append(start_kernel.elapsed_time(end_kernel))
        timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
        loss_kernel_list[i-WARMUP] = loss_kernel.detach()
        loss_vanilla_list[i-WARMUP] = loss_vanilla.detach()
    zero_grad()

if torch.allclose(loss_kernel, loss_vanilla):
    print("βœ… Outputs are close! Everything is good! πŸŽ‰")
else:
    print("❌ Outputs mismatch... βš οΈπŸ€”")


print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
print(f"πŸ”₯ Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} Β± {np.std(timing_vanilla):.4f} ms")
print(f"πŸš€ Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
```

Run it with `uv run https://huggingface.co/t-tech/flex-sae/resolve/main/example.py`.

## Performance
Benchmarks were collected on a workload with dictionary size $F = 65 536$, embedding dimension $D = 2304$, and sparsity budgets $K \in \{32, 64, 128\}$. Latency is reported as time per training step (milliseconds) and memory as peak device usage (GiB).

| Decoder backend | K=32 (ms / GiB) | K=64 (ms / GiB) | K=128 (ms / GiB) |
| --- | --- | --- | --- |
| **Pure torch-compiled** | | | |
| TopK | 8.787 / 2.92 | 11.746 / 2.92 | 18.877 / 2.93 |
| HierarchicalTopK | 12.824 / 6.29 | 23.379 / 10.79 | 43.851 / 19.80 |
| **Triton kernels** | | | |
| TopK | 5.576 / 2.92 | 6.339 / 2.92 | 7.961 / 2.93 |
| HierarchicalTopK | **6.696 / 2.92** | **7.995 / 2.92** | **10.609 / 2.93** |

Across the evaluated sparsity budgets the fused Triton HierarchicalTopK kernel matches TopK kernels on memory use while remaining consistently faster than the reference torch implementation.

## License & Attribution
- All files except `torch-ext/flex_sae/topk_kernels.py` are released under the [Apache License 2.0](LICENSE).
- `torch-ext/flex_sae/topk_kernels.py` includes code adapted from Facebook Research's [memory](https://github.com/facebookresearch/memory) project, originally published under the Creative Commons Attribution-NonCommercial 4.0 International License. That component therefore remains available for non-commercial use only; see [NOTICE](NOTICE) for details.

## Citation
```bibtex
@misc{balagansky2025trainsparseautoencodermultiple,
      title={Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy},
      author={Nikita Balagansky and Yaroslav Aksenov and Daniil Laptev and Vadim Kurochkin and Gleb Gerasimov and Nikita Koryagin and Daniil Gavrilov},
      year={2025},
      eprint={2505.24473},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2505.24473},
}
```