少点错误 2024年07月20日
BatchTopK: A Simple Improvement for TopK-SAEs
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

BatchTopK 是一种改进的稀疏自动编码器 (SAE) 训练方法,它通过对整个批次进行 TopK 操作,而不是对每个样本独立进行操作,来提高模型的性能。与标准 TopK SAE 相比,BatchTopK 允许每个样本使用不同数量的特征,同时保持批次中平均激活特征数量的稳定。实验证明,BatchTopK 在不同字典大小和激活特征数量下,都比标准 TopK 实现了更好的性能。

🤔 BatchTopK 通过对整个批次进行 TopK 操作,而不是对每个样本独立进行操作,来实现更灵活的特征激活,从而提高 SAE 的性能。这种方法允许每个样本使用不同数量的特征,同时保持批次中平均激活特征数量的稳定。

📈 实验证明,BatchTopK 在不同字典大小和激活特征数量下,都比标准 TopK 实现了更好的性能。这表明 BatchTopK 能够更好地利用样本之间的信息,从而提高模型的泛化能力。

💡 BatchTopK 的优势在于其自适应稀疏性,它能够根据样本的信息密度来调整激活特征的数量。例如,在处理句子开头 (BOS) 的 Token 时,BatchTopK 可以使用更少的特征,而在处理信息密度更高的 Token 时,则可以使用更多的特征。

⚠️ 虽然 BatchTopK 在训练中表现出色,但它可能不适合在推理中使用。因为在推理过程中,特征的激活取决于批次中的其他样本,这可能导致结果不稳定。

✅ 为了解决这个问题,可以使用一个全局阈值来代替 BatchTopK,该阈值是所有批次中最小非零激活值的平均值。这种方法能够在推理过程中保持模型的稳定性,同时还能进一步提升性能。

💡 BatchTopK 的应用场景包括:自然语言处理、机器学习、图像识别等领域。它可以帮助研究人员更好地理解大型语言模型的内部表示,并开发出更有效的模型。

⚠️ 尽管 BatchTopK 是一种很有潜力的方法,但它还需要进一步的研究和验证。例如,需要测试它在不同模型、不同数据集上的泛化能力,以及它与其他 SAE 架构的比较。

💡 Future Work: 可以尝试使用目标分位数来代替 TopK 激活,并使用运行平均值来估计分位数。还可以研究批次大小对 BatchTopK 的影响,以及 BatchTopK 与其他 SAE 架构 (例如 GatedSAEs 和 JumpReLU SAEs) 的性能比较。

💡 未来的研究方向可以包括:探索更有效的辅助损失函数、研究 BatchTopK 对不同模型和数据集的泛化能力、以及与其他 SAE 架构的比较。

Published on July 20, 2024 2:20 AM GMT

Work done in Neel Nanda’s stream of MATS 6.0.

Epistemic status: Tried this on a single sweep and seems to work well, but it might definitely be a fluke of something particular to our implementation or experimental set-up. As there are also some theoretical reasons to expect this technique to work (adaptive sparsity), it seems probable that for many TopK SAE set-ups it could be a good idea to also try BatchTopK. As we’re not planning to investigate this much further and it might be useful to others, we’re just sharing what we’ve found so far. 

TL;DR: Instead of taking the TopK feature activations per token during training, taking the Top(Kbatch_size) for every batch seems to improve SAE performance. During inference, this activation can be replaced with a single global threshold for all features.

Introduction

Sparse autoencoders (SAEs) have emerged as a promising tool for interpreting the internal representations of large language models. By learning to reconstruct activations using only a small number of features, SAEs can extract monosemantic concepts from the representations inside transformer models. Recently, OpenAI published a paper exploring the use of TopK activation functions in SAEs. This approach directly enforces sparsity by only keeping the K largest activations per sample. 

While effective, TopK forces every token to use exactly k features, which is likely suboptimal.  We came up with a simple modification that solves this and seems to improve its performance.

BatchTopK

Standard TopK SAEs apply the TopK operation independently to each sample in a batch. For a target sparsity of K, this means exactly K features are activated for every sample.

BatchTopK instead applies the TopK operation across the entire flattened batch:

    Flatten all feature activations across the batchTake the top (K batch_size) activationsReshape back to the original batch shape

This allows more flexibility in how many features activate per sample, while still maintaining an average of K active features across the batch.

Experimental Set-Up

For both the TopK and the BatchTopK SAEs we train a sweep with the following hyperparameters:

As in the OpenAI paper, the input gets normalized before feeding it into the SAE and calculating the reconstruction loss. We also use the same auxiliary loss function for dead features (features that didn’t activate for 5 batches) that calculates the loss on the residual using the top 512 dead features per sample and gets multiplied by a factor 1/32. 

Results

For a fixed number of active features (L0=32) the BatchTopK SAE has a lower normalized MSE than the TopK SAE and less downstream loss degradation across different dictionary sizes. Similarly, for fixed dictionary size (12288) BatchTopK outperforms TopK for different values of k.

BatchTopK achieves a better NMSE and CE compared to standard TopK across different dictionary sizes, for a fixed number of active features of 32 (Left). BatchTopK outperforms standard TopK for different values of K, with a fixed dictionary size of 12288 (Right).

Our main hypothesis for the improved performance is thanks to adaptive sparsity: some samples contain more highly activating features than others. Let’s have look at the distribution of number of active samples for the BatchTopK model.

Distribution of the number of active features per sample for a BatchTopK model. The peak on the left likely corresponds to BOS tokens, demonstrating BatchTopK's adaptive sparsity.

The BatchTopK model indeed makes use of its possibility to use different sparsities for different inputs. We suspect that the weird peak on the left side are the feature activations on BOS-tokens, given that its frequency is very close to 1 in 128, which is the sequence length. This serves as a great example of why BatchTopK might outperform TopK. At the BOS-token, a sequence has very little information yet, but the TopK SAE still activates 32 features. The BatchTopK model “saves” these activations such that it can use more features on tokens that are more information-dense. 

Inference with BatchTopK

BatchTopK seems to work well as a training method, but might not be ideal to use during inference. Generally, it is a bit icky if during inference the activations of the features depend on whatever else there is present in your batch. Also, the SAE is trained on batches with mixed activations from many different sequences, whereas during inference the features in the batches (or individual sequences) will be correlated in all kinds of ways.

Instead, we can estimate a threshold T, which is the average minimum activation value above zero in a batch:

Where  is the jth feature of the ith sample in a batch B. Now we can simply use this threshold during inference and just set all feature activations below this threshold to zero. Interestingly, the architecture is now equivalent to a ProLU or JumpReLU (published today!), but with a global threshold for all features rather than an individual threshold and trained in a very different fashion. 

Performance comparison of  TopK, original BatchTopK, and BatchTopK with estimated threshold. Using a threshold during inference slightly improves BatchTopK's performance

Using the threshold, the performance of the BatchTopK actually improves a bit further. This can be explained by the fact that without using the threshold, BatchTopK basically relies on using a noisier version of the same threshold.

We also checked whether we can use the BatchTopK activation on a model trained with TopK into a model, and vice versa, but this doesn’t seem to work. This shows that the way that the activation function influences the training process for training these SAEs is actually an important factor, rather than that BatchTopK just selects higher activations in general. 

Applying BatchTopK activation to a TopK-trained model and vice versa results in poor performance, highlighting the importance of the activation function during training

Limitations and Future Work

As stated in the epistemic status, given the limited experiments we have run we are not very confident how much this result will generalize to other models, larger dictionary sizes, different hyperparameters, etc. We encourage others to experiment with this approach, validate the results, and explore further refinements. To this end, we are sharing the training code that we used to run these experiments.

Here are some ideas to further improve upon this work:

Thanks to Joseph Bloom for helpful comments on the experiments.



Discuss

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

稀疏自动编码器 BatchTopK TopK 自然语言处理 人工智能
相关文章