智源社区 09月04日
“金鱼损失”:让大模型“记性差一点”更聪明
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

训练大模型时,一个新提出的“金鱼损失”(Goldfish Loss)方法,旨在解决模型过度记忆训练数据的问题。该方法借鉴金鱼“记性差”的特点,在损失计算时随机剔除部分token,阻止模型死记硬背。实验表明,使用金鱼损失的LLaMA-2模型,记忆化内容显著减少,但下游任务性能几乎不受影响。与Dropout不同,金鱼损失通过哈希掩码确保对同一内容的遮蔽位置一致,从根本上阻止模型复现完整训练文本。尽管可能需要更多数据补偿,但此方法有效提升了模型的泛化能力。

💡 **创新“金鱼损失”机制,缓解模型记忆化问题:** 研究团队提出“金鱼损失”(Goldfish Loss)方法,旨在解决大型语言模型(LLMs)在训练过程中过度记忆并复现训练数据的问题。该方法的核心思想是让模型在学习时“选择性遗忘”,模仿金鱼的短期记忆特性,通过在损失函数计算时随机剔除一小部分token,从而阻止模型死记硬背训练内容,但仍能学习语言规律。

🔑 **区分于Dropout,从根源上阻止模型复现:** 与Dropout等正则化方法不同,金鱼损失通过一种基于哈希(hashing)的掩码策略,确保当模型遇到相同的文本片段时,被剔除(不参与损失计算)的token位置是保持一致的。这从根本上阻止了模型通过累积多次“遗忘”来拼凑出完整训练文本的可能性,实现了更有效的防记忆化。

🚀 **实验验证有效性:** 在极端和标准训练场景下进行的实验表明,使用金鱼损失训练的LLaMA-2模型,在记忆化内容方面显著优于标准训练模型,几乎不复现训练数据。同时,模型的下游任务性能几乎不受影响,证明了该方法在提高模型泛化能力的同时,并未损害其语言生成能力。

⏳ **潜在的计算效率考量:** 虽然金鱼损失在防止记忆化方面表现出色,但由于模型需要通过更多数据来补偿被剔除token的缺失,这可能导致在训练过程中计算效率的下降。研究人员采用了静态掩码和局部化哈希掩码等策略来优化实现。

训练大模型时,有时让它“记性差一点”,反而更聪明!

大语言模型如果不加约束,很容易把训练数据原封不动地复刻出来。为解决这个问题,来自马里兰大学、图宾根大学和马普所的研究团队提出了一个新方法——金鱼损失(Goldfish Loss)

顾名思义,金鱼损失就是让模型像金鱼一样,不去死记每一个细节,而是在损失函数计算时随机剔除一小部分token。

由此,模型不再逐字记住训练集内容,但仍能学会语言规律。

实验显示,LLaMA-2在使用金鱼损失后:

    记忆化内容显著减少:模型不再复现训练数据下游任务性能几乎不受影响:仍然能流畅生成文本

用网友的精辟评论概括就是:dropout,但损失函数!

在梯度计算中随机屏蔽部分token

金鱼损失的核心理念非常简单,就是在模型训练过程中随机剔除一部分训练文本中的tokens,使其不参与损失计算。

这样一来,当模型在推理阶段遇到这些位置时,就只能“猜测”,而不是逐字逐句复现训练数据的完整序列。

此外,为了保证被剔除token的一致性,研究人员设计了一种基于哈希(hashing)的掩码策略。

那么,这和同样是防止模型背会的正则化方法有什么不同呢?

Dropout这样的正则化方法为例,它通过在训练时“加噪声”来防止模型过度依赖某些参数,从而提高模型举一反三的能力。

但这样做的问题在于:如果只是随机丢token,那么,每次看到同一段落时,丢掉的地方不一样,模型累计几次就能拼凑出完整段落。

所以,说到底,模型还是靠死记硬背,记住了答案。

相比之下,金鱼损失则用哈希掩码确保每次遇到同一段落,掩盖位置都一样,这就从根本上阻止了模型复现完整训练文本。

接下来,我们来看金鱼损失具体是怎么做的。

在传统的next-token prediction中,模型以序列中的下一个真实token作为目标,输出预测分布,并基于该分布计算交叉熵损失。

在金鱼损失下,模型虽然也在前向传播中预测序列里下一个 token。但在计算损失时,会以一定的概率将某些位置的token从损失计算里“抹掉”。

也就是说,有些真实的下一个token不会作为目标来训练。

在这里,研究人员采用了简单的静态掩码(static mask),剔除每序列中的第4个token。

更进一步,为了确保模型不会从其他地方学到被掩码的数据(例如不同的文档会在不同的网页中反复出现),研究团队还提出了一种局部化哈希掩码(localized hashed mask),使得当相同的前h个token出现时,掩盖模式是相同的(可重复)。

实验测试与结果

为了验证金鱼损失确实能防止记忆化,研究团队设计了两种实验场景:

一种是极端场景,通过对少量样本进行多个训练周期(即重复)来强烈促使记忆化;

另一种是标准场景,模拟现实模型训练中使用的批次处理方式 。

同时,为了评估模型的记忆化程度,研究采用了以下指标:

    RougeL得分:该指标衡量最长公共(非连续)子序列的长度 。得分为1.0表示完美记忆 。

    精确匹配率(Exact Match):该指标衡量正确预测的序列占真实序列的百分比.

实验表明,在极端场景下,标准训练导致模型逐字记忆了100篇文章中的84篇,而金鱼损失没有记忆任何文章

(注:实验让LLaMA-2-7B在《哈利·波特》第一章或100篇维基百科文档上进一步训练了100个epoch)

此外,在标准训练场景下,金鱼损失也明显减少了模型逐字复现训练语料库中目标序列的情况。

但这里可能有个直觉式的反应——如果让模型“随机漏学”一些token,它的能力会不会也随之降低呢?

对此,研究人员进行了测试:研究表明,金鱼损失模型、标准损失模型和对照模型之间的总体性能没有系统性差异。

需要注意的是,金鱼损失的核心在于忽略部分token的梯度计算。因此,为了学到足够的语言模式,模型必须通过更多数据来补偿这些空缺,这可能导致计算效率的下降。

参考链接

[1]https://arxiv.org/pdf/2406.10209

一键三连「点赞」「转发」「小心心」

欢迎在评论区留下你的想法!

—  —

专属AI产品从业者的实名社群,只聊AI产品最落地的真问题  扫码添加小助手,发送「姓名+公司+职位」申请入群~
进群后,你将直接获得:
 👉 最新最专业的AI产品信息及分析 🔍 
 👉 不定期发放的热门产品内测码 🔥
 👉 内部专属内容与专业讨论 👂

🌟 点亮星标 🌟

科技前沿进展每日见

内容中包含的图片若涉及版权问题,请及时与我们联系删除

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

大模型 LLM 金鱼损失 Goldfish Loss 模型训练 防止记忆化 正则化 AI 机器学习
相关文章