掘金 人工智能 11月11日 06:56
MemGen: 动态生成式隐式记忆框架解析
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

MemGen 提出了一种新颖的动态生成式隐式记忆框架,旨在解决现有智能体记忆范式的局限性。该框架由轻量级的记忆触发器和记忆编织器协同工作,实现了记忆与推理过程的无缝耦合。与参数化记忆易遗忘或基于检索记忆的静态特性不同,MemGen 能够动态生成和重构记忆,以满足当前任务需求。其核心在于 LatentMemoryModel,它整合了推理器、编织器和触发器,支持模块化设计、动态记忆增强、精度与效率优化,以及灵活的配置。源码解析揭示了其在训练和推理阶段的精妙流程,通过精确的损失计算和高效的生成策略,实现了智能体的“自进化”。

💡 **动态生成式记忆框架**:MemGen 提出了第三种记忆探索路径,与参数化记忆(易遗忘)和基于检索的记忆(静态)不同,它能够在每一步思考中动态生成和重构记忆,实现记忆与推理的无缝耦合,从而构建一个动态、生成式的记忆框架。

🧩 **模块化协同设计**:LatentMemoryModel 核心框架由推理器(Reasoner)、记忆编织器(Weaver)和记忆触发器(Trigger)三大模块构成。它们通过投影层实现嵌入空间的映射,结构清晰且解耦,协同工作以实现动态记忆增强。

🚀 **高效训练与推理**:该框架采用 bfloat16 精度,推理器使用 Flash Attention 2 提升计算效率。训练时冻结推理器参数,仅训练编织器和触发器,实现参数高效学习。推理时,通过 @torch.no_grad() 和 use_cache=True 进一步优化效率,并支持可选的记忆增强位置掩码输出,增强可解释性。

🔄 **动态记忆增强机制**:MemGen 在推理过程中,能够识别分隔符位置作为记忆增强点,动态插入编织器生成的潜在记忆,突破了静态记忆注入的局限,更贴合人类认知中记忆与推理的动态交互特性。这在 `forward` 和 `generate` 函数中得到了体现,前者用于训练计算损失,后者用于推理生成文本。

🎯 **精准损失计算与灵活配置**:在训练阶段,通过潜在记忆掩码排除记忆嵌入对应的位置,仅对原始输入位置计算损失,确保训练目标聚焦于核心任务性能,避免记忆生成过程干扰主任务学习。同时,框架支持自定义触发器模型、PEFT 微调配置等参数,提升了其灵活性和跨场景兼容性。

【Agent】生成式隐式记忆 MemGen 源码解读

[toc]

0x00 概要

MemGen旨在构建一个动态、生成式的记忆框架,其核心由两个协同工作的轻量级模块构成:一个基于强化学习(RL)训练的记忆触发器(Memory Trigger)和一个记忆编织器(Memory Weaver)。

论文:MemGen: Weaving Generative Latent Memory for Self-Evolving Agents

链接:arxiv.org/abs/2509.24…

代码:github.com/KANABOON1/M…

0x01 背景

MemGen 提出动态生成式记忆框架,由记忆触发器与记忆编织器两个轻量模块协同构成,旨在突破现有智能体记忆范式的局限。

当前主流的记忆实现路径为:

这一现状引出两大核心问题:如何实现记忆与推理在每一步思考中的无缝耦合,以及如何让记忆从提取式升级为满足当前需求的生成式重构,而动态生成式隐式记忆正是应对这些挑战的第三种探索路径。

0x02 源码解析

MemGen项目旨在创建一个动态且自生成的记忆框架,该框架由两个协同工作的轻量级模块组成:一个基于强化学习训练的记忆触发器和一个记忆编织器。这一框架的核心思想是解决大型语言模型(LLM)智能体能力涌现时对“自进化”机制的探索需求,其中记忆扮演关键角色。

2.1 模型

LatentMemoryModel 是 MemGen 框架的核心实现,旨在构建动态生成式隐式记忆系统,解决传统记忆范式的局限性。通过整合推理器(Reasoner)、记忆编织器(Weaver)和记忆触发器(Trigger),实现记忆与推理过程的无缝耦合,让智能体在任务执行中动态生成、使用记忆,而非依赖静态检索或参数化存储。

2.1.1 核心特色

模型的核心特色如下:

2.1.2 网络结构

关键说明(核心设计亮点)

    三大模块协同逻辑
      推理器(Reasoner):核心推理组件,权重冻结以保留基础能力,仅通过潜在记忆调整解码路径。触发器(MemGenTrigger):动态判断记忆插入时机,输出二分类触发概率,决定是否调用编织器。编织器(MemGenWeaver):生成针对性潜在记忆,分提示词 / 推理两阶段设计,支持 PEFT 高效微调。
    核心流程闭环:输入 → 推理器生成原始嵌入 → 触发器 + 增强点选择模块确定插入位置 → 编织器生成潜在记忆 → 投影层适配维度 → 重组增强序列 → 推理器完成最终推理 → 过滤无效位置输出。关键技术细节
      跨模块投影:通过 reasoner_to_weaverweaver_to_reasoner 解决推理器与编织器嵌入维度不匹配问题。动态记忆增强:按分隔符拆分序列,逐段插入记忆,避免长序列冗余,贴合人类 “思考 - 记忆” 交互模式。精度与效率:全流程采用 bfloat16 精度,推理器 / 编织器启用 Flash Attention 2,平衡性能与速度。
    训练与推理适配
      训练时:通过 labelsvalid_logits 计算损失,仅优化编织器、触发器及投影层参数。推理时:无需 labels,自动完成 “触发判断 - 记忆生成 - 推理增强” 全流程,实现动态自进化。

具体网络结构如下

2.1.3 代码

LatentMemoryModel 的代码如下:

@registry.register_model("latmem")class LatentMemoryModel(BaseModel):  # 定义了一个名为 LatentMemoryModel 的类,继承自 BaseModel    def __init__(        self,         reasoner_model_name: str,  # 推理模型名称        weaver_model_name: str,  # 记忆编织器模型名称        prompt_latents_len: int,  # 提示长度        inference_latents_len: int,  # 推理长度        weaver_peft_config: Optional[PeftConfig] = None,  # 记忆编织器配置,可选        trigger_model_name: str = None,  # 触发模型名称,可选        trigger_peft_config: Optional[PeftConfig] = None,  # 触发器配置,可选        max_prompt_aug_num: int = 1,  # 最大提示增强数量        max_inference_aug_num: int = 5,  # 最大推理增强数量    ):           super().__init__()  # 调用父类构造函数        # 构建推理模型        self.model = AutoModelForCausalLM.from_pretrained(  # 从预训练模型加载推理模型            reasoner_model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")        self.tokenizer = AutoTokenizer.from_pretrained(reasoner_model_name)  # 加载入分词器        self.config = self.model.config  # 获取模型配置                # 构建记忆编织器        self.weaver = MemGenWeaver(  # 初始化记忆编织器            weaver_model_name, prompt_latents_len, inference_latents_len, weaver_peft_config        )                # 构建触发器        self.trigger = NanoTrigger()  # 默认触发器,始终返回 true        if trigger_model_name is not None:            self.trigger = MemGenTrigger(  # 如果指定了触发模型,则加载相应的触发器                trigger_model_name, trigger_peft_config            )            logging.info(f"Use Trigger: {trigger_model_name}")  # 记录日志                # 投影层,用于在推理模型和记忆编织器之间映射嵌入        # 将推理模型输入嵌入映射到记忆编织器输入嵌入        self.reasoner_to_weaver = nn.Linear(  # 线性层,从推理模型隐藏层到记忆编织器隐藏层            self.model.config.hidden_size, self.weaver.config.hidden_size, dtype=torch.bfloat16        )        # 将记忆编织器隐藏状态映射回推理模型输入嵌入        self.weaver_to_reasoner = nn.Linear(  # 线性层,从记忆编织器隐藏层到推理模型隐藏层            self.weaver.config.hidden_size, self.model.config.hidden_size, dtype=torch.bfloat16        )                self.delimiters: List[str] = [",", ".", "\n"]  # 用于检测增强点的分隔符        self.max_prompt_aug_num = max_prompt_aug_num  # 提示后提示中插入潜在数量        self.max_inference_aug_num = max_inference_aug_num  # 指定分隔符后插入潜在数量        # 后处理        self._postprocess_models()  # 后处理模型        self.warnings_issued = {}  # 存储发出的警告        self.model_tags = None  # 存储模型标签        log_trainable_params(self)  # 记录可训练参数    def add_model_tags(self, tags: Union[list[str], str]) -> None:  # 添加模型标签        r"""        向模型添加自定义标签,这些标签将被推送到 Hugging Face Hub。不会覆盖模型中现有的标签。        参数:            tags (`Union[list[str], str]`):                要添加到模型的标签        例子:        ```python        from transformers import AutoModel        model = AutoModel.from_pretrained("google-bert/bert-base-cased")        model.add_model_tags(["custom", "custom-bert"])        # 将模型推送到您的命名空间,名称为 "my-custom-bert"。        model.push_to_hub("my-custom-bert")        """        if isinstance(tags, str):            tags = [tags]        if self.model_tags is None:            self.model_tags = []        for tag in tags:            if tag not in self.model_tags:                self.model_tags.append(tag)        def _postprocess_models(self):        """        后处理记忆模型的组件:推理模型、记忆编织器、触发器和分词器。        步骤:            1. 冻结推理模型的所有参数(不更新梯度)。            2. 将所有模型转换为 bfloat16 以提高内存和计算效率。            3. 确保分词器有一个有效的填充符:                - 如果缺少填充符,使用 EOS 符作为填充符。                - 设置 `padding_side` 为 "left" 以兼容生成任务。            4. 标准化分词器的模板为 `CONVERSATION_TEMPLATE`。        """        # 默认冻结推理模型的所有参数        fix_model_parameters(self.model)        # 将所有子模型转换为 bfloat16        self.model = self.model.bfloat16()        self.weaver = self.weaver.bfloat16()        self.trigger = self.trigger.bfloat16()        # 确保分词器有一个填充符        if self.tokenizer.pad_token is None:            self.tokenizer.pad_token = self.tokenizer.eos_token            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id            self.tokenizer.padding_side = "left"            logging.info(                f"Tokenizer has no pad token. Using EOS token ({self.tokenizer.eos_token}) as pad token."            )        # 标准化分词器的模板        self.tokenizer.chat_template = CONVERSATION_TEMPLATE
2.1.4 插入阶段

LatentMemoryModel 的两个关键函数 forward 和 generate 区别如下:

forward

forward 函数的主体如下:

        def _forward(        self,         input_ids: torch.Tensor,        attention_mask: torch.Tensor,        labels: torch.Tensor,           **kwargs    ) -> torch.Tensor:        # 预处理输入        assert input_ids.shape == attention_mask.shape == labels.shape                tokenizer = self.tokenizer        reasoner = self.model        weaver = self.weaver        delimiters = self.delimiters        max_augment_num = self.max_inference_aug_num  # 限制推理增强点的数量以避免过度增强        device = self.device        embeds_dtype = reasoner.get_input_embeddings().weight.dtype        B, _ = input_ids.shape        hidden_size = reasoner.config.hidden_size        # 选择增强索引        augmentation_indices = self._select_augment_points_after_delimiter(            input_ids, labels, delimiters, tokenizer, max_augment_num        )                # 输入嵌入        inputs_embeds = reasoner.get_input_embeddings()(input_ids)                         # 初始化开始索引和空张量以累积处理的段        current_start_idx = 0        current_inputs_embeds = torch.empty(B, 0, hidden_size).to(device, dtype=embeds_dtype)        current_attention_mask = torch.empty(B, 0).to(device, dtype=attention_mask.dtype)        current_latents_mask = torch.empty(B, 0).to(device, dtype=torch.bool)        # 遍历所选增强点        for aug_idx in augmentation_indices:            # 切片原始嵌入和注意力掩码            segment_inputs_embeds = inputs_embeds[:, current_start:aug_idx]            segment_attention_mask = attention_mask[:, current_start:aug_idx]            segment_latents_mask = torch.zeros(B, segment_inputs_embeds.size(1).to(device, dtype=torch.bool)            # 连接当前段到累积嵌入和掩码            current_inputs_embeds = torch.cat([current_inputs_embeds, segment_inputs_embeds], dim=1)            current_mask = torch.cat([current_mask, segment_attention_mask], dim=1)            current_position_ids = generate_position_ids(current_mask)            current_latents = torch.cat([current_latents, segment_latents], dim=1)            # 将推理模型嵌入映射到记忆编织器嵌入            weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds)            # 确定此点是否为提示(增强)的结束            is_prompt_end_aug = (labels[:, aug_idx] != -100).all() and (labels[:, aug_idx-1] == -100).all().item()            # 根据类型,使用记忆编织器增强提示或推理            if is_prompt_end_aug:                weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt(                    weaver_inputs, current_attention_mask, current_position_ids                )            else:                weaver_hidden_states, attn_mask, pos_ids = weaver.augment_inference(                    weaver_inputs, current_attention_mask, current_position_ids                )             # 将记忆编织器隐藏状态映射回推理模型嵌入            latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states)            # 更新累积嵌入和掩码与新增强段            current_inputs_embeds = torch.cat
generate
核心作用

generate 方法是 MemGen 模型的推理核心,实现了动态记忆增强与序列生成的无缝融合。通过迭代生成新 token,每步自适应判断是否插入编织器生成的潜在记忆,让推理器在生成过程中实时利用动态记忆调整解码路径,最终输出增强后的序列(可选返回记忆增强位置掩码)。

核心特色
推理生成流程图

潜在记忆插入的完整流程:

具体流程如下图所示:

代码如下:

@torch.no_grad()  # 禁用梯度计算,适用于推理阶段,提升效率并节省内存def generate(    self,     input_ids: torch.Tensor,  # 输入token ID序列,形状[batch_size, prompt_len]    attention_mask: torch.Tensor,  # 注意力掩码,形状与input_ids一致    generation_config: GenerationConfig = None,  # 生成配置(如最大新token数、采样策略等)    return_augmentation_mask: bool = False,  # 是否返回记忆增强位置掩码    **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:     """    执行MemGen模型的推理生成流程:动态融合潜在记忆与推理器,生成增强后的输出序列。        核心逻辑:    1. 初始化提示词阶段的记忆增强    2. 迭代生成新token,每步判断是否触发推理阶段记忆增强    3. 对需增强的序列插入编织器生成的潜在记忆,非增强序列左填充对齐维度    4. 生成完成后返回结果(可选返回增强位置掩码)    """    tokenizer = self.tokenizer    reasoner = self.model    weaver = self.weaver    trigger = self.trigger    delimiters = self.delimiters    max_augment_num = self.max_inference_aug_num  # 单序列最大推理阶段增强次数    invalid_token_id = -100  # 无效位置标记(用于增强位置掩码)    # 预处理输入:转移到模型所在设备    input_ids = input_ids.to(self.device)    attention_mask = attention_mask.to(self.device)    # 提取生成配置关键参数    max_new_tokens = generation_config.max_new_tokens  # 最大生成新token数    do_sample = generation_config.do_sample  # 是否启用采样生成    temperature = generation_config.temperature  # 采样温度(控制随机性)    pad_token_id = tokenizer.pad_token_id  # pad token ID    eos_token_id = tokenizer.eos_token_id  # 结束token ID    prompt_len = input_ids.size(1)  # 提示词长度    # 重构生成配置(固定必要参数,确保生成稳定性)    generation_config = GenerationConfig(        do_sample=do_sample,        temperature=temperature,        pad_token_id=pad_token_id,        eos_token_id=eos_token_id,        use_cache=True  # 启用缓存加速生成    )    # 将输入token ID转换为嵌入向量    inputs_embeds = reasoner.get_input_embeddings()(input_ids)    B, _, hidden_size = inputs_embeds.shape  # B=batch_size,hidden_size=推理器隐藏层维度    device = inputs_embeds.device  # 模型所在设备(CPU/GPU)    # 初始化生成过程中的关键张量    current_inputs_embeds = inputs_embeds  # 当前输入嵌入(含原始提示词+潜在记忆)    current_attention_mask = attention_mask  # 当前注意力掩码    current_position_ids = generate_position_ids(current_attention_mask)  # 当前位置ID    current_input_ids = input_ids  # 当前已生成的token ID序列        # 提示词阶段记忆增强:生成并插入提示词专用潜在记忆    weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds)  # 映射到编织器嵌入空间    weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt(        weaver_inputs_embeds, current_attention_mask, current_position_ids    )    latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states)  # 映射回推理器嵌入空间    # 拼接提示词与增强记忆    current_inputs_embeds = torch.cat([current_inputs_embeds, latent_inputs_embeds], dim=1)    current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1)    current_position_ids = torch.cat([current_position_ids, pos_ids], dim=1)    # 生成循环初始化    sentence_augment_count = torch.zeros(B, dtype=torch.int, device=device)  # 各序列已增强次数    augmentation_pos = torch.full((B, max_new_tokens), fill_value=invalid_token_id, device=device)  # 增强位置掩码    inserted_embeds: List[List[torch.Tensor]] = [[] for _ in range(B)]  # 记录插入的潜在记忆(用于后处理)        for i in range(max_new_tokens):        # 若所有序列均已生成EOS token,提前终止        if (current_input_ids[:, -1] == eos_token_id).all():            break           # 若所有序列均已达到最大增强次数,一次性生成剩余token        if (sentence_augment_count >= max_augment_num).all():            # 调整剩余生成长度            generation_config.max_new_tokens = max_new_tokens - i            # 推理器生成剩余token            generated = reasoner.generate(                inputs_embeds=current_inputs_embeds,                attention_mask=current_attention_mask,                generation_config=generation_config,            )            current_input_ids = torch.cat([current_input_ids, generated], dim=1)            break        # 推理器前向传播,获取当前步输出        outputs = reasoner(            inputs_embeds=current_inputs_embeds,            attention_mask=current_attention_mask,            position_ids=current_position_ids,            output_hidden_states=False,  # 推理阶段无需输出隐藏状态,提升效率        )        # 生成并追加一个新token,更新关键张量        current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids = self._append_one_step(            outputs, current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids, do_sample, temperature        )         # 若为最后一步生成,终止循环        if i == max_new_tokens - 1:              break         # 判断当前批次中哪些序列需要进行推理阶段记忆增强        augment_decision = self._should_augment(            current_input_ids, current_attention_mask, sentence_augment_count=sentence_augment_count,             do_sample=do_sample, temperature=temperature          )        augmentation_pos[:, i + 1] = augment_decision  # 记录增强位置(1=增强,0=不增强,-100=无效)        augment_indices = torch.where(augment_decision == 1)[0]  # 需增强的序列索引        # 对需增强的序列执行记忆增强,非增强序列左填充对齐维度        if len(augment_indices) > 0:            # 递增需增强序列的增强次数计数            sentence_augment_count[augment_indices] += 1            # 提取需增强序列的嵌入、掩码和位置ID            candidate_inputs_embeds = current_inputs_embeds[augment_indices]            candidate_attention_mask = current_attention_mask[augment_indices]            candidate_position_ids = current_position_ids[augment_indices]                        # 编织器生成推理阶段潜在记忆            weaver_inputs_embeds = self.reasoner_to_weaver(candidate_inputs_embeds)            weaver_hidden_states, attn_mask, _ = weaver.augment_inference(                weaver_inputs_embeds, candidate_attention_mask, candidate_position_ids            )            latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states)  # 映射回推理器空间                        # 拼接原始嵌入与潜在记忆            candidate_inputs_embeds = torch.cat([candidate_inputs_embeds, latent_inputs_embeds], dim=1)            candidate_attention_mask = torch.cat([candidate_attention_mask, attn_mask], dim=1)                        # 构建合并张量(适配所有序列,包括增强和非增强)            new_len = candidate_inputs_embeds.size(1)  # 增强后序列长度            merged_inputs_embeds = torch.zeros((B, new_len, hidden_size), device=device, dtype=current_inputs_embeds.dtype)            merged_attention_mask = torch.zeros((B, new_len), device=device, dtype=current_attention_mask.dtype)                        # 填充增强序列            merged_inputs_embeds[augment_indices] = candidate_inputs_embeds            merged_attention_mask[augment_indices] = candidate_attention_mask                        # 填充非增强序列(左填充对齐长度)            non_augment_indices = torch.where(augment_decision != 1)[0]            if len(non_augment_indices) > 0:                non_aug_inputs_embeds = current_inputs_embeds[non_augment_indices]                non_aug_attention_mask = current_attention_mask[non_augment_indices]                non_aug_inputs_embeds, non_aug_attention_mask, _ = self._left_pad(                    non_aug_inputs_embeds, non_aug_attention_mask, None, weaver.inference_latents_num                )                merged_inputs_embeds[non_augment_indices] = non_aug_inputs_embeds                merged_attention_mask[non_augment_indices] = non_aug_attention_mask                        # 更新当前关键张量            current_inputs_embeds = merged_inputs_embeds            current_attention_mask = merged_attention_mask            current_position_ids = generate_position_ids(current_attention_mask)  # 重新生成位置ID                        # 记录插入的潜在记忆(用于后处理或可解释性分析)            for idx, embed in zip(augment_indices, latent_inputs_embeds):                inserted_embeds[idx].append(embed.clone().detach().cpu())                # 后处理:调整增强位置掩码长度与生成结果一致        new_generated_len = current_input_ids.size(1) - prompt_len        augmentation_pos = augmentation_pos[:, :new_generated_len]                 # 根据配置返回结果:仅生成序列 或 序列+增强位置掩码        if not return_augmentation_mask:            return current_input_ids        else:            return current_input_ids, augmentation_pos

2.2 Trigger

2.2.1. 核心作用

该模块定义了 MemGen 框架中记忆触发器的核心接口与两种具体实现,核心作用是动态决策记忆增强的时机—— 即在推理过程中判断何时插入编织器生成的潜在记忆,实现记忆与推理的动态耦合,突破传统静态记忆注入的局限。

2.2.2. 核心特色
2.2.3 网络架构

网络架构图如下。

说明如下:

    模型支持PEFT参数高效微调(如LoRA),适配于Transformer Blocks层整体精度采用bfloat16,平衡计算效率与数值稳定性注意力计算通过Flash Attention 2优化,提升长序列处理速度

2.2.4 代码
class Trigger(torch.nn.Module, ABC):    """    记忆触发器的抽象基类(Trigger)。    定义了触发器的核心接口,用于决定在推理过程中何时触发记忆增强(插入潜在记忆)。    所有具体触发器实现都需继承此类并实现forward方法。    """    def __init__(self):        super().__init__()  # 调用父类Module的初始化方法        @abstractmethod    def forward(self, **kwargs) -> bool:        """        抽象前向传播方法:接收输入数据,返回是否触发记忆增强的决策。        子类必须实现此方法,定义具体的触发逻辑。                Args:            **kwargs: 可变关键字参数,包含输入序列、注意力掩码等模型所需数据                    Returns:            bool: 触发决策(True表示触发记忆增强,False表示不触发)        """        ...class NanoTrigger(torch.nn.Module):    """    极简触发器(NanoTrigger):始终触发记忆增强的基础实现。    无需复杂逻辑,固定返回触发决策,适用于基础测试或无需动态控制的场景。    """    def __init__(self):        super().__init__()          # 注册一个缓冲区张量,用于获取模型所在设备(无实际计算意义)        self.register_buffer("_device", torch.tensor(0.0))        @property    def device(self):        """获取模型所在设备(CPU/GPU)"""        return self._device.device        def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> bool:        # 该"极简触发器"始终预测需要插入记忆        # 输出logits张量,其中插入决策(索引=1)的概率被设为1.0        # 适用于批次中的每个token位置        batch_size, seq_len = input_ids.shape        # 初始化logits张量:形状为[batch_size, seq_len, 2],2表示"不插入"(0)和"插入"(1)两类        logits = torch.zeros(batch_size, seq_len, 2, device=input_ids.device)        logits[..., 1] = 1.0  # 将所有位置的"插入"决策概率设为1.0        return logitsclass MemGenTrigger(torch.nn.Module):    """    MemGen框架的专用触发器模块(MemGenTrigger)。    - 输入:接收推理器模型当前解码序列的`inputs_embeds`(或input_ids)    - 输出:生成形状为[batch_size, seq_len, 2]的logits张量,      表示每个位置"不插入"(0)和"插入"(1)记忆的概率,用于动态决策记忆增强时机。    """    def __init__(        self,         pretrained_model_name_or_path: str,  # 预训练模型名称或路径(用于初始化触发器LLM)        peft_config: Optional[PeftConfig] = None  # PEFT配置(可选,用于参数高效微调)    ):        super().__init__()                # 构建基础LLM模型(作为触发器的核心推理组件)        self.model = AutoModelForCausalLM.from_pretrained(            pretrained_model_name_or_path,             torch_dtype=torch.bfloat16,  # 使用bfloat16精度提升效率            attn_implementation="flash_attention_2"  # 启用Flash Attention 2优化注意力计算        )        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)  # 对应的Tokenizer                # 对基础模型进行后处理(设置可训练、替换输出头)        self.model = self._postprocess(self.model)        # 若提供PEFT配置,应用参数高效微调        if peft_config is not None:            self.model = get_peft_model(self.model, peft_config)                self.config = self.model.config  # 保存模型配置    @property    def device(self):        """获取模型所在设备(CPU/GPU)"""        return self.model.device        def _postprocess(self, model: PreTrainedModel):        """        对基础模型进行后处理,适配触发器的二分类任务需求。                Args:            model: 原始预训练LLM模型                    Returns:            处理后的模型(可训练、替换为二分类输出头)        """        # 设置所有模型参数为可训练        for parameter in model.parameters():            parameter.requires_grad = True                # 将原始语言模型的输出头(lm_head)替换为二分类头        hidden_size = model.config.hidden_size  # 模型隐藏层维度        classification_head = nn.Linear(hidden_size, 2)  # 输出维度为2(不插入/插入)        model.lm_head = classification_head                # 确保新的二分类头参数可训练        for param in model.lm_head.parameters():            param.requires_grad = True        return model    def forward(        self,         input_ids: Optional[torch.LongTensor] = None,  # 生成序列的token ID,形状[batch_size, seq_len]        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,避免关注填充token        **kwargs: Unpack[TransformersKwargs],  # 传递给底层模型的额外参数    ) -> torch.Tensor:        """        序列生成的触发决策机制。        触发器基于已生成的`input_ids`做出决策,受数据分布影响,但独立于编织器模块。        Args:            input_ids (Optional[torch.LongTensor]): 生成序列的token ID张量            attention_mask (Optional[torch.Tensor]): 注意力掩码,默认None            **kwargs: 传递给底层模型的额外关键字参数        Returns:            torch.Tensor: Logits张量,形状为`(batch_size, seq_len, num_classes)`                        num_classes=2,分别对应"不插入"(索引0)和"插入"(索引1)的概率        """           # 调用基础模型前向传播,返回二分类logits        return self.model(            input_ids=input_ids,             attention_mask=attention_mask,             **kwargs        ).logits

2.3 MemGenWeaver

2.3.1 核心作用

MemGenWeaver 是 MemGen 框架的核心组件之一,负责生成动态潜在记忆并将其与推理器的输入序列融合,从而实现记忆与推理过程的无缝交织。它通过可学习的潜在记忆查询向量,在提示词阶段和推理阶段分别生成针对性的记忆表示,引导推理器调整解码路径,提升智能体的动态决策能力。

2.3.2 核心特色
2.3.3 网络架构

网络架构图如下。

说明如下:

    核心组件:

      可学习潜在记忆向量:分阶段设计(P=提示词阶段数量,I=推理阶段数量),支持动态生成记忆预训练LLM:作为记忆生成核心,默认启用bfloat16精度和Flash Attention 2优化序列融合层:确保输入与记忆在语义、掩码、时序上的一致性

    核心流程:

      输入 → 选择对应阶段的潜在记忆 → 融合序列 → LLM生成隐藏状态 → 提取潜在记忆输出支持PEFT参数高效微调(如LoRA),适配于Transformer Blocks层

    输出用途:

      生成的潜在记忆将通过投影层映射到推理器的嵌入空间,与原始输入融合以引导解码

2.3.4 代码

两个关键变量如下:

这两个变量都通过_augment 方法获得(获取学习到的潜在向量,并将其附加到输入嵌入中)。其流程如下:

判断是否插入是通过函数 _should_augment 完成的。

class MemGenWeaver(torch.nn.Module):    """    MemGen模型的编织器模块(MemGenWeaver)。    - 输入:接收接收来自推理器模型当前当前解码序列的`inputs_embeds`(输入嵌入入)    - 输出:生成长度为K的隐藏状态序列,这些状态将与原始`inputs_embeds`拼接,以改变推理器的解码路径    """    def __init__(        self,         pretrained_model_name_or_path: str,  # 预训练模型的名称或路径        prompt_latents_num: int,    # 提示词阶段生成的潜在记忆数量        inference_latents_num: int, # 推理阶段生成的潜在记忆数量        peft_config: Optional[PeftConfig] = None  # PEFT配置(可选)    ):        super().__init__()                # 基础模型初始化        self.model = AutoModelForCausalLM.from_pretrained(            pretrained_model_name_or_path,            torch_dtype=torch.bfloat16,  # 使用bfloat16精度以提高效率            attn_implementation="flash_attention_2"  # 启用Flash Attentionention 2优化        )        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)  # 对应的分词器        # 若提供PEFT配置,则应用参数高效微调        if peft_config is not None:            self.model = get_peft_model(self.model, peft_config)                self.config = self.model.config  # 保存模型配置                # 提示词阶段的潜在记忆查询向量(可学习参数)        self.prompt_query_latents = nn.Parameter(            torch.randn(prompt_latents_num, self.config.hidden_size),  # 形状:[prompt_latents_num, hidden_size]            requires_grad=True  # 允许反向传播更新        )        # 推理阶段的潜在记忆查询向量(可学习参数)        self.inference_query_latents = nn.Parameter(            torch.randn(inference_latents_num, self.config.hidden_size),  # 形状:[inference_latents_num, hidden_size]            requires_grad=True  # 允许反向传播更新        )        @property    def prompt_latents_num(self) -> int:        """返回提示词阶段的潜在记忆数量"""        return self.prompt_query_latents.size(0)    @property    def inference_latents_num(self) -> int:        """返回推理阶段的潜在记忆数量"""        return self.inference_query_latents.size(0)    @property    def device(self):        """返回模型所在的设备(CPU/GPU)"""        return self.model.device    def _augment(        self,         latents: torch.Tensor,                # 潜在记忆查询向量,形状:[latents_num, hidden_size]        inputs_embeds: torch.Tensor,          # 输入嵌入,形状:[batch_size, seq_len, hidden_size]        attention_mask: torch.Tensor,         # 注意力掩码,形状:[batch_size, seq_len]        position_ids: torch.Tensor            # 位置ID,形状:[batch_size, seq_len]    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:        """        通用的潜在记忆增强方法:将潜在记忆与输入序列融合,生成增强后的隐藏状态。                参数:            latents: 潜在记忆查询向量            inputs_embeds: 输入序列的嵌入表示            attention_mask: 输入序列的注意力掩码            position_ids: 输入序列的位置ID                返回:            三元组 (latents_hidden_states, latents_mask, latents_position_ids)            - latents_hidden_states: 生成的潜在记忆隐藏状态,形状:[batch_size, latents_num, hidden_size]            - latents_mask: 潜在记忆的注意力掩码,形状:[batch_size, latents_num]            - latents_position_ids: 潜在记忆的位置ID,形状:[batch_size, latents_num]        """        batch_size = attention_mask.shape[0]  # 获取批次大小        latents_num = latents.size(0)         # 获取潜在记忆数量                # 扩展潜在记忆维度以匹配批次大小:[1, latents_num, hidden_size] → [batch_size, latents_num, hidden_size]        latents = latents.unsqueeze(0).repeat(batch_size, 1, 1)                # 将潜在记忆嵌入与输入嵌入拼接:[batch_size, seq_len + latents_num, hidden_size]        inputs_embeds = torch.cat([inputs_embeds, latents], dim=1)        # 构建潜在记忆的注意力掩码(全为1,表示有效)并与输入掩码拼接        latents_mask = torch.ones(latents.shape[:-1], dtype=attention_mask.dtype, device=attention_mask.device)        attention_mask = torch.cat([attention_mask, latents_mask], dim=1)  # 形状:[batch_size, seq_len + latents_num]                # 生成潜在记忆的位置ID(在输入序列最后位置的基础上递增)        last_position_ids = position_ids.max(dim=1)[0]  # 获取输入序列的最大位置ID        latents_relative_positions = torch.arange(latents_num, device=attention_mask.device)  # 潜在记忆的相对位置        # 计算绝对位置:输入序列最大位置 + 相对位置 + 1(避免重叠)        latents_position_ids = last_position_ids.unsqueeze(1) + latents_relative_positions + 1        # 拼接位置ID:[batch_size, seq_len + latents_num]        position_ids = torch.cat([position_ids.long(), latents_position_ids.long()], dim=1)         # 验证拼接后的维度是否一致        assert inputs_embeds.shape[:2] == attention_mask.shape == position_ids.shape        # 模型前向传播,获取隐藏状态        outputs = self.model(            inputs_embeds=inputs_embeds,            attention_mask=attention_mask,            position_ids=position_ids,              output_hidden_states=True,  # 输出所有层的隐藏状态        )        # 取最后一层的隐藏状态,并提取潜在记忆部分(序列末尾的latents_num个位置)        hidden_states = outputs.hidden_states[-1]        latents_hidden_states = hidden_states[:, -latents_num:, :]        return latents_hidden_states, latents_mask, latents_position_ids    def augment_prompt(        self,         inputs_embeds: torch.Tensor,         attention_mask: torch.Tensor,         position_ids: torch.Tensor    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:        """        提示词阶段的潜在记忆增强:使用提示词专用的潜在记忆查询向量。                参数与返回值同_augment方法        """        return self._augment(            latents=self.prompt_query_latents,            inputs_embeds=inputs_embeds,            attention_mask=attention_mask,            position_ids=position_ids        )    def augment_inference(        self,         inputs_embeds: torch.Tensor,         attention_mask: torch.Tensor,         position_ids: torch.Tensor    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:        """        推理阶段的潜在记忆增强:使用推理专用的潜在记忆查询向量。                参数与返回值同_augment方法        """        return self._augment(            latents=self.inference_query_latents,            inputs_embeds=inputs_embeds,            attention_mask=attention_mask,            position_ids=position_ids        )

0xFF 参考

最新成果!Agent记忆的第三种可能:生成式隐式记忆

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

MemGen 生成式记忆 隐式记忆 强化学习 LLM智能体 自进化 源码解析 LatentMemoryModel Generative Memory Implicit Memory Reinforcement Learning LLM Agents Self-Evolution Source Code Analysis
相关文章