找回密码
 立即注册
首页 业界区 业界 【Agent】生成式隐式记忆 MemGen 源码解读

【Agent】生成式隐式记忆 MemGen 源码解读

里豳朝 2025-11-10 21:05:00
【Agent】生成式隐式记忆 MemGen 源码解读


目录

  • 【Agent】生成式隐式记忆 MemGen 源码解读

    • 0x00 概要
    • 0x01 背景
    • 0x02 源码解析

      • 2.1 模型

        • 2.1.1  核心特色
        • 2.1.2  网络结构
        • 2.1.3 代码
        • 2.1.4 插入阶段

          • forward
          • generate

            • 核心作用
            • 核心特色
            • 推理生成流程图



      • 2.2 Trigger

        • 2.2.1. 核心作用
        • 2.2.2. 核心特色
        • 2.2.3 网络架构
        • 2.2.4 代码

      • 2.3  MemGenWeaver

        • 2.3.1 核心作用
        • 2.3.2 核心特色
        • 2.3.3 网络架构
        • 2.3.4 代码


    • 0xFF 参考


0x00 概要

MemGen旨在构建一个动态、生成式的记忆框架,其核心由两个协同工作的轻量级模块构成:一个基于强化学习(RL)训练的记忆触发器(Memory Trigger)和一个记忆编织器(Memory Weaver)。
论文:MemGen: Weaving Generative Latent Memory for Self-Evolving Agents
链接:https://arxiv.org/abs/2509.24704
代码:https://github.com/KANABOON1/MemGen
0x01 背景

MemGen 提出动态生成式记忆框架,由记忆触发器与记忆编织器两个轻量模块协同构成,旨在突破现有智能体记忆范式的局限。
当前主流的记忆实现路径为:

  • 参数化记忆通过微调将经验编码进模型参数,虽能深度内化知识却易引发灾难性遗忘;
  • 基于检索的记忆将经验外化存储,虽规避了遗忘问题,但静态的一次性检索机制无法体现记忆与推理动态交互的认知特性。
这一现状引出两大核心问题:如何实现记忆与推理在每一步思考中的无缝耦合,以及如何让记忆从提取式升级为满足当前需求的生成式重构,而动态生成式隐式记忆正是应对这些挑战的第三种探索路径。
0x02 源码解析

MemGen项目旨在创建一个动态且自生成的记忆框架,该框架由两个协同工作的轻量级模块组成:一个基于强化学习训练的记忆触发器和一个记忆编织器。这一框架的核心思想是解决大型语言模型(LLM)智能体能力涌现时对“自进化”机制的探索需求,其中记忆扮演关键角色。
2.1 模型

LatentMemoryModel 是 MemGen 框架的核心实现,旨在构建动态生成式隐式记忆系统,解决传统记忆范式的局限性。通过整合推理器(Reasoner)、记忆编织器(Weaver)和记忆触发器(Trigger),实现记忆与推理过程的无缝耦合,让智能体在任务执行中动态生成、使用记忆,而非依赖静态检索或参数化存储。
2.1.1  核心特色

模型的核心特色如下:

  • 模块化协同设计:由推理器(核心推理)、编织器(生成潜在记忆)、触发器(控制记忆触发)三大模块构成,模块间通过投影层实现嵌入空间映射,结构清晰且解耦。
  • 动态记忆增强:在推理过程中自动识别分隔符位置作为记忆增强点,动态插入编织器生成的潜在记忆,突破静态记忆注入的局限,贴合人类认知中记忆与推理的动态交互特性。
  • 精度与效率优化:默认使用 bfloat16 精度,推理器采用 Flash Attention 2 提升计算效率;冻结推理器参数,仅训练编织器和触发器,实现参数高效学习。
  • 灵活配置与兼容性:支持自定义触发器模型、PEFT 微调配置、记忆增强次数等参数;自动处理 Tokenizer 缺失 pad token 的问题,标准化对话模板,提升跨场景兼容性。
  • 损失计算精准过滤:通过潜在记忆掩码排除记忆嵌入对应的位置,仅对原始输入位置计算损失,确保训练目标聚焦于核心任务性能,避免记忆生成过程干扰主任务学习。
2.1.2  网络结构

关键说明(核心设计亮点)

  • 三大模块协同逻辑

    • 推理器(Reasoner):核心推理组件,权重冻结以保留基础能力,仅通过潜在记忆调整解码路径。
    • 触发器(MemGenTrigger):动态判断记忆插入时机,输出二分类触发概率,决定是否调用编织器。
    • 编织器(MemGenWeaver):生成针对性潜在记忆,分提示词 / 推理两阶段设计,支持 PEFT 高效微调。

  • 核心流程闭环:输入 → 推理器生成原始嵌入 → 触发器 + 增强点选择模块确定插入位置 → 编织器生成潜在记忆 → 投影层适配维度 → 重组增强序列 → 推理器完成最终推理 → 过滤无效位置输出。
  • 关键技术细节

    • 跨模块投影:通过 reasoner_to_weaver 和 weaver_to_reasoner 解决推理器与编织器嵌入维度不匹配问题。
    • 动态记忆增强:按分隔符拆分序列,逐段插入记忆,避免长序列冗余,贴合人类 “思考 - 记忆” 交互模式。
    • 精度与效率:全流程采用 bfloat16 精度,推理器 / 编织器启用 Flash Attention 2,平衡性能与速度。

  • 训练与推理适配

    • 训练时:通过 labels 和 valid_logits 计算损失,仅优化编织器、触发器及投影层参数。
    • 推理时:无需 labels,自动完成 “触发判断 - 记忆生成 - 推理增强” 全流程,实现动态自进化。

具体网络结构如下
1.png

2.1.3 代码

LatentMemoryModel 的代码如下:
  1. @registry.register_model("latmem")
  2. class LatentMemoryModel(BaseModel):  # 定义了一个名为 LatentMemoryModel 的类,继承自 BaseModel
  3.     def __init__(
  4.         self,
  5.         reasoner_model_name: str,  # 推理模型名称
  6.         weaver_model_name: str,  # 记忆编织器模型名称
  7.         prompt_latents_len: int,  # 提示长度
  8.         inference_latents_len: int,  # 推理长度
  9.         weaver_peft_config: Optional[PeftConfig] = None,  # 记忆编织器配置,可选
  10.         trigger_model_name: str = None,  # 触发模型名称,可选
  11.         trigger_peft_config: Optional[PeftConfig] = None,  # 触发器配置,可选
  12.         max_prompt_aug_num: int = 1,  # 最大提示增强数量
  13.         max_inference_aug_num: int = 5,  # 最大推理增强数量
  14.     ):   
  15.         super().__init__()  # 调用父类构造函数
  16.         # 构建推理模型
  17.         self.model = AutoModelForCausalLM.from_pretrained(  # 从预训练模型加载推理模型
  18.             reasoner_model_name, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
  19.         self.tokenizer = AutoTokenizer.from_pretrained(reasoner_model_name)  # 加载入分词器
  20.         self.config = self.model.config  # 获取模型配置
  21.         
  22.         # 构建记忆编织器
  23.         self.weaver = MemGenWeaver(  # 初始化记忆编织器
  24.             weaver_model_name, prompt_latents_len, inference_latents_len, weaver_peft_config
  25.         )
  26.         
  27.         # 构建触发器
  28.         self.trigger = NanoTrigger()  # 默认触发器,始终返回 true
  29.         if trigger_model_name is not None:
  30.             self.trigger = MemGenTrigger(  # 如果指定了触发模型,则加载相应的触发器
  31.                 trigger_model_name, trigger_peft_config
  32.             )
  33.             logging.info(f"Use Trigger: {trigger_model_name}")  # 记录日志
  34.         
  35.         # 投影层,用于在推理模型和记忆编织器之间映射嵌入
  36.         # 将推理模型输入嵌入映射到记忆编织器输入嵌入
  37.         self.reasoner_to_weaver = nn.Linear(  # 线性层,从推理模型隐藏层到记忆编织器隐藏层
  38.             self.model.config.hidden_size, self.weaver.config.hidden_size, dtype=torch.bfloat16
  39.         )
  40.         # 将记忆编织器隐藏状态映射回推理模型输入嵌入
  41.         self.weaver_to_reasoner = nn.Linear(  # 线性层,从记忆编织器隐藏层到推理模型隐藏层
  42.             self.weaver.config.hidden_size, self.model.config.hidden_size, dtype=torch.bfloat16
  43.         )
  44.         
  45.         self.delimiters: List[str] = [",", ".", "\n"]  # 用于检测增强点的分隔符
  46.         self.max_prompt_aug_num = max_prompt_aug_num  # 提示后提示中插入潜在数量
  47.         self.max_inference_aug_num = max_inference_aug_num  # 指定分隔符后插入潜在数量
  48.         # 后处理
  49.         self._postprocess_models()  # 后处理模型
  50.         self.warnings_issued = {}  # 存储发出的警告
  51.         self.model_tags = None  # 存储模型标签
  52.         log_trainable_params(self)  # 记录可训练参数
  53.     def add_model_tags(self, tags: Union[list[str], str]) -> None:  # 添加模型标签
  54.         r"""
  55.         向模型添加自定义标签,这些标签将被推送到 Hugging Face Hub。不会覆盖模型中现有的标签。
  56.         参数:
  57.             tags (`Union[list[str], str]`):
  58.                 要添加到模型的标签
  59.         例子:
  60.         ```python
  61.         from transformers import AutoModel
  62.         model = AutoModel.from_pretrained("google-bert/bert-base-cased")
  63.         model.add_model_tags(["custom", "custom-bert"])
  64.         # 将模型推送到您的命名空间,名称为 "my-custom-bert"。
  65.         model.push_to_hub("my-custom-bert")
  66.         """
  67.         if isinstance(tags, str):
  68.             tags = [tags]
  69.         if self.model_tags is None:
  70.             self.model_tags = []
  71.         for tag in tags:
  72.             if tag not in self.model_tags:
  73.                 self.model_tags.append(tag)
  74.    
  75.     def _postprocess_models(self):
  76.         """
  77.         后处理记忆模型的组件:推理模型、记忆编织器、触发器和分词器。
  78.         步骤:
  79.             1. 冻结推理模型的所有参数(不更新梯度)。
  80.             2. 将所有模型转换为 bfloat16 以提高内存和计算效率。
  81.             3. 确保分词器有一个有效的填充符:
  82.                 - 如果缺少填充符,使用 EOS 符作为填充符。
  83.                 - 设置 `padding_side` 为 "left" 以兼容生成任务。
  84.             4. 标准化分词器的模板为 `CONVERSATION_TEMPLATE`。
  85.         """
  86.         # 默认冻结推理模型的所有参数
  87.         fix_model_parameters(self.model)
  88.         # 将所有子模型转换为 bfloat16
  89.         self.model = self.model.bfloat16()
  90.         self.weaver = self.weaver.bfloat16()
  91.         self.trigger = self.trigger.bfloat16()
  92.         # 确保分词器有一个填充符
  93.         if self.tokenizer.pad_token is None:
  94.             self.tokenizer.pad_token = self.tokenizer.eos_token
  95.             self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
  96.             self.tokenizer.padding_side = "left"
  97.             logging.info(
  98.                 f"Tokenizer has no pad token. Using EOS token ({self.tokenizer.eos_token}) as pad token."
  99.             )
  100.         # 标准化分词器的模板
  101.         self.tokenizer.chat_template = CONVERSATION_TEMPLATE
复制代码
2.1.4 插入阶段

LatentMemoryModel 的两个关键函数 forward 和  generate 区别如下:

  • forward 函数

    • 训练时候计算损失,由训练循环自动调用。

  • generate 函数

    • 推理时候生成文本,由代码显式调用。

forward

forward 函数的主体如下:
  1.    
  2.     def _forward(
  3.         self,
  4.         input_ids: torch.Tensor,
  5.         attention_mask: torch.Tensor,
  6.         labels: torch.Tensor,   
  7.         **kwargs
  8.     ) -> torch.Tensor:
  9.         # 预处理输入
  10.         assert input_ids.shape == attention_mask.shape == labels.shape
  11.         
  12.         tokenizer = self.tokenizer
  13.         reasoner = self.model
  14.         weaver = self.weaver
  15.         delimiters = self.delimiters
  16.         max_augment_num = self.max_inference_aug_num  # 限制推理增强点的数量以避免过度增强
  17.         device = self.device
  18.         embeds_dtype = reasoner.get_input_embeddings().weight.dtype
  19.         B, _ = input_ids.shape
  20.         hidden_size = reasoner.config.hidden_size
  21.         # 选择增强索引
  22.         augmentation_indices = self._select_augment_points_after_delimiter(
  23.             input_ids, labels, delimiters, tokenizer, max_augment_num
  24.         )
  25.         
  26.         # 输入嵌入
  27.         inputs_embeds = reasoner.get_input_embeddings()(input_ids)
  28.                  
  29.         # 初始化开始索引和空张量以累积处理的段
  30.         current_start_idx = 0
  31.         current_inputs_embeds = torch.empty(B, 0, hidden_size).to(device, dtype=embeds_dtype)
  32.         current_attention_mask = torch.empty(B, 0).to(device, dtype=attention_mask.dtype)
  33.         current_latents_mask = torch.empty(B, 0).to(device, dtype=torch.bool)
  34.         # 遍历所选增强点
  35.         for aug_idx in augmentation_indices:
  36.             # 切片原始嵌入和注意力掩码
  37.             segment_inputs_embeds = inputs_embeds[:, current_start:aug_idx]
  38.             segment_attention_mask = attention_mask[:, current_start:aug_idx]
  39.             segment_latents_mask = torch.zeros(B, segment_inputs_embeds.size(1).to(device, dtype=torch.bool)
  40.             # 连接当前段到累积嵌入和掩码
  41.             current_inputs_embeds = torch.cat([current_inputs_embeds, segment_inputs_embeds], dim=1)
  42.             current_mask = torch.cat([current_mask, segment_attention_mask], dim=1)
  43.             current_position_ids = generate_position_ids(current_mask)
  44.             current_latents = torch.cat([current_latents, segment_latents], dim=1)
  45.             # 将推理模型嵌入映射到记忆编织器嵌入
  46.             weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds)
  47.             # 确定此点是否为提示(增强)的结束
  48.             is_prompt_end_aug = (labels[:, aug_idx] != -100).all() and (labels[:, aug_idx-1] == -100).all().item()
  49.             # 根据类型,使用记忆编织器增强提示或推理
  50.             if is_prompt_end_aug:
  51.                 weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt(
  52.                     weaver_inputs, current_attention_mask, current_position_ids
  53.                 )
  54.             else:
  55.                 weaver_hidden_states, attn_mask, pos_ids = weaver.augment_inference(
  56.                     weaver_inputs, current_attention_mask, current_position_ids
  57.                 )
  58.             # 将记忆编织器隐藏状态映射回推理模型嵌入
  59.             latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states)
  60.             # 更新累积嵌入和掩码与新增强段
  61.             current_inputs_embeds = torch.cat
复制代码
generate

核心作用

该 generate 方法是 MemGen 模型的推理核心,实现了动态记忆增强与序列生成的无缝融合。通过迭代生成新 token,每步自适应判断是否插入编织器生成的潜在记忆,让推理器在生成过程中实时利用动态记忆调整解码路径,最终输出增强后的序列(可选返回记忆增强位置掩码)。
核心特色


  • 双阶段记忆增强:先执行提示词阶段记忆增强(初始化全局记忆),再在迭代生成中动态触发推理阶段增强(补充实时记忆),适配不同生成阶段的记忆需求。
  • 自适应触发机制:通过 _should_augment 结合触发器决策,仅对需要记忆支持的序列执行增强,避免无意义的计算开销。
  • 维度对齐优化:非增强序列采用左填充(_left_pad)方式对齐增强序列维度,确保批次内所有序列格式统一,不影响批量生成效率。
  • 高效推理设计:

    • 禁用梯度计算(@torch.no_grad()),节省内存并加速推理;
    • 启用推理器缓存(use_cache=True),减少重复计算;
    • 仅在必要时输出隐藏状态,降低计算成本。

  • 灵活配置与可解释性:支持控制最大生成 token 数、采样策略等参数;可选返回 augmentation_pos 掩码,标记记忆插入位置,提升模型可解释性。
  • 鲁棒性保障:提前终止机制(所有序列生成 EOS 或达最大增强次数时终止),避免无效迭代;重构生成配置固定关键参数,确保生成稳定性。
推理生成流程图

潜在记忆插入的完整流程:

  • 初始化阶段:对输入提示进行增强,插入初始潜在记忆。
  • 生成循环:逐个生成token。
  • 条件检查:在每个步骤检查是否满足插入条件。
  • 决策判断:使用trigger模型决定是否插入潜在记忆。
  • 潜在记忆生成:通过weaver模型生成潜在记忆表示。
  • 嵌入连接:将潜在记忆嵌入连接到当前输入序列。
  • 继续生成:使用增强后的序列继续生成下一个token。
具体流程如下图所示:
2.png

代码如下:
  1. @torch.no_grad()  # 禁用梯度计算,适用于推理阶段,提升效率并节省内存
  2. def generate(
  3.     self,
  4.     input_ids: torch.Tensor,  # 输入token ID序列,形状[batch_size, prompt_len]
  5.     attention_mask: torch.Tensor,  # 注意力掩码,形状与input_ids一致
  6.     generation_config: GenerationConfig = None,  # 生成配置(如最大新token数、采样策略等)
  7.     return_augmentation_mask: bool = False,  # 是否返回记忆增强位置掩码
  8.     **kwargs
  9. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  10.     """
  11.     执行MemGen模型的推理生成流程:动态融合潜在记忆与推理器,生成增强后的输出序列。
  12.    
  13.     核心逻辑:
  14.     1. 初始化提示词阶段的记忆增强
  15.     2. 迭代生成新token,每步判断是否触发推理阶段记忆增强
  16.     3. 对需增强的序列插入编织器生成的潜在记忆,非增强序列左填充对齐维度
  17.     4. 生成完成后返回结果(可选返回增强位置掩码)
  18.     """
  19.     tokenizer = self.tokenizer
  20.     reasoner = self.model
  21.     weaver = self.weaver
  22.     trigger = self.trigger
  23.     delimiters = self.delimiters
  24.     max_augment_num = self.max_inference_aug_num  # 单序列最大推理阶段增强次数
  25.     invalid_token_id = -100  # 无效位置标记(用于增强位置掩码)
  26.     # 预处理输入:转移到模型所在设备
  27.     input_ids = input_ids.to(self.device)
  28.     attention_mask = attention_mask.to(self.device)
  29.     # 提取生成配置关键参数
  30.     max_new_tokens = generation_config.max_new_tokens  # 最大生成新token数
  31.     do_sample = generation_config.do_sample  # 是否启用采样生成
  32.     temperature = generation_config.temperature  # 采样温度(控制随机性)
  33.     pad_token_id = tokenizer.pad_token_id  # pad token ID
  34.     eos_token_id = tokenizer.eos_token_id  # 结束token ID
  35.     prompt_len = input_ids.size(1)  # 提示词长度
  36.     # 重构生成配置(固定必要参数,确保生成稳定性)
  37.     generation_config = GenerationConfig(
  38.         do_sample=do_sample,
  39.         temperature=temperature,
  40.         pad_token_id=pad_token_id,
  41.         eos_token_id=eos_token_id,
  42.         use_cache=True  # 启用缓存加速生成
  43.     )
  44.     # 将输入token ID转换为嵌入向量
  45.     inputs_embeds = reasoner.get_input_embeddings()(input_ids)
  46.     B, _, hidden_size = inputs_embeds.shape  # B=batch_size,hidden_size=推理器隐藏层维度
  47.     device = inputs_embeds.device  # 模型所在设备(CPU/GPU)
  48.     # 初始化生成过程中的关键张量
  49.     current_inputs_embeds = inputs_embeds  # 当前输入嵌入(含原始提示词+潜在记忆)
  50.     current_attention_mask = attention_mask  # 当前注意力掩码
  51.     current_position_ids = generate_position_ids(current_attention_mask)  # 当前位置ID
  52.     current_input_ids = input_ids  # 当前已生成的token ID序列
  53.    
  54.     # 提示词阶段记忆增强:生成并插入提示词专用潜在记忆
  55.     weaver_inputs_embeds = self.reasoner_to_weaver(current_inputs_embeds)  # 映射到编织器嵌入空间
  56.     weaver_hidden_states, attn_mask, pos_ids = weaver.augment_prompt(
  57.         weaver_inputs_embeds, current_attention_mask, current_position_ids
  58.     )
  59.     latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states)  # 映射回推理器嵌入空间
  60.     # 拼接提示词与增强记忆
  61.     current_inputs_embeds = torch.cat([current_inputs_embeds, latent_inputs_embeds], dim=1)
  62.     current_attention_mask = torch.cat([current_attention_mask, attn_mask], dim=1)
  63.     current_position_ids = torch.cat([current_position_ids, pos_ids], dim=1)
  64.     # 生成循环初始化
  65.     sentence_augment_count = torch.zeros(B, dtype=torch.int, device=device)  # 各序列已增强次数
  66.     augmentation_pos = torch.full((B, max_new_tokens), fill_value=invalid_token_id, device=device)  # 增强位置掩码
  67.     inserted_embeds: List[List[torch.Tensor]] = [[] for _ in range(B)]  # 记录插入的潜在记忆(用于后处理)
  68.    
  69.     for i in range(max_new_tokens):
  70.         # 若所有序列均已生成EOS token,提前终止
  71.         if (current_input_ids[:, -1] == eos_token_id).all():
  72.             break   
  73.         # 若所有序列均已达到最大增强次数,一次性生成剩余token
  74.         if (sentence_augment_count >= max_augment_num).all():
  75.             # 调整剩余生成长度
  76.             generation_config.max_new_tokens = max_new_tokens - i
  77.             # 推理器生成剩余token
  78.             generated = reasoner.generate(
  79.                 inputs_embeds=current_inputs_embeds,
  80.                 attention_mask=current_attention_mask,
  81.                 generation_config=generation_config,
  82.             )
  83.             current_input_ids = torch.cat([current_input_ids, generated], dim=1)
  84.             break
  85.         # 推理器前向传播,获取当前步输出
  86.         outputs = reasoner(
  87.             inputs_embeds=current_inputs_embeds,
  88.             attention_mask=current_attention_mask,
  89.             position_ids=current_position_ids,
  90.             output_hidden_states=False,  # 推理阶段无需输出隐藏状态,提升效率
  91.         )
  92.         # 生成并追加一个新token,更新关键张量
  93.         current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids = self._append_one_step(
  94.             outputs, current_inputs_embeds, current_attention_mask, current_position_ids, current_input_ids, do_sample, temperature
  95.         )
  96.         # 若为最后一步生成,终止循环
  97.         if i == max_new_tokens - 1:  
  98.             break
  99.         # 判断当前批次中哪些序列需要进行推理阶段记忆增强
  100.         augment_decision = self._should_augment(
  101.             current_input_ids, current_attention_mask, sentence_augment_count=sentence_augment_count,
  102.             do_sample=do_sample, temperature=temperature  
  103.         )
  104.         augmentation_pos[:, i + 1] = augment_decision  # 记录增强位置(1=增强,0=不增强,-100=无效)
  105.         augment_indices = torch.where(augment_decision == 1)[0]  # 需增强的序列索引
  106.         # 对需增强的序列执行记忆增强,非增强序列左填充对齐维度
  107.         if len(augment_indices) > 0:
  108.             # 递增需增强序列的增强次数计数
  109.             sentence_augment_count[augment_indices] += 1
  110.             # 提取需增强序列的嵌入、掩码和位置ID
  111.             candidate_inputs_embeds = current_inputs_embeds[augment_indices]
  112.             candidate_attention_mask = current_attention_mask[augment_indices]
  113.             candidate_position_ids = current_position_ids[augment_indices]
  114.             
  115.             # 编织器生成推理阶段潜在记忆
  116.             weaver_inputs_embeds = self.reasoner_to_weaver(candidate_inputs_embeds)
  117.             weaver_hidden_states, attn_mask, _ = weaver.augment_inference(
  118.                 weaver_inputs_embeds, candidate_attention_mask, candidate_position_ids
  119.             )
  120.             latent_inputs_embeds = self.weaver_to_reasoner(weaver_hidden_states)  # 映射回推理器空间
  121.             
  122.             # 拼接原始嵌入与潜在记忆
  123.             candidate_inputs_embeds = torch.cat([candidate_inputs_embeds, latent_inputs_embeds], dim=1)
  124.             candidate_attention_mask = torch.cat([candidate_attention_mask, attn_mask], dim=1)
  125.             
  126.             # 构建合并张量(适配所有序列,包括增强和非增强)
  127.             new_len = candidate_inputs_embeds.size(1)  # 增强后序列长度
  128.             merged_inputs_embeds = torch.zeros((B, new_len, hidden_size), device=device, dtype=current_inputs_embeds.dtype)
  129.             merged_attention_mask = torch.zeros((B, new_len), device=device, dtype=current_attention_mask.dtype)
  130.             
  131.             # 填充增强序列
  132.             merged_inputs_embeds[augment_indices] = candidate_inputs_embeds
  133.             merged_attention_mask[augment_indices] = candidate_attention_mask
  134.             
  135.             # 填充非增强序列(左填充对齐长度)
  136.             non_augment_indices = torch.where(augment_decision != 1)[0]
  137.             if len(non_augment_indices) > 0:
  138.                 non_aug_inputs_embeds = current_inputs_embeds[non_augment_indices]
  139.                 non_aug_attention_mask = current_attention_mask[non_augment_indices]
  140.                 non_aug_inputs_embeds, non_aug_attention_mask, _ = self._left_pad(
  141.                     non_aug_inputs_embeds, non_aug_attention_mask, None, weaver.inference_latents_num
  142.                 )
  143.                 merged_inputs_embeds[non_augment_indices] = non_aug_inputs_embeds
  144.                 merged_attention_mask[non_augment_indices] = non_aug_attention_mask
  145.             
  146.             # 更新当前关键张量
  147.             current_inputs_embeds = merged_inputs_embeds
  148.             current_attention_mask = merged_attention_mask
  149.             current_position_ids = generate_position_ids(current_attention_mask)  # 重新生成位置ID
  150.             
  151.             # 记录插入的潜在记忆(用于后处理或可解释性分析)
  152.             for idx, embed in zip(augment_indices, latent_inputs_embeds):
  153.                 inserted_embeds[idx].append(embed.clone().detach().cpu())
  154.         
  155.         # 后处理:调整增强位置掩码长度与生成结果一致
  156.         new_generated_len = current_input_ids.size(1) - prompt_len
  157.         augmentation_pos = augmentation_pos[:, :new_generated_len]
  158.          
  159.         # 根据配置返回结果:仅生成序列 或 序列+增强位置掩码
  160.         if not return_augmentation_mask:
  161.             return current_input_ids
  162.         else:
  163.             return current_input_ids, augmentation_pos
复制代码
2.2 Trigger

2.2.1. 核心作用

该模块定义了 MemGen 框架中记忆触发器的核心接口与两种具体实现,核心作用是动态决策记忆增强的时机—— 即在推理过程中判断何时插入编织器生成的潜在记忆,实现记忆与推理的动态耦合,突破传统静态记忆注入的局限。
2.2.2. 核心特色


  • 抽象接口统一规范:Trigger抽象基类定义了触发器的核心接口,确保后续扩展新触发器时遵循统一标准,提升代码可扩展性。
  • 双实现适配不同场景:

    • NanoTrigger:极简实现,始终触发记忆增强,无需训练,适用于快速测试、基线对比或无需动态控制的简单场景。
    • MemGenTrigger:基于预训练 LLM 的智能触发器,通过二分类头适配决策任务,支持 PEFT 参数高效微调,能根据输入序列动态判断是否触发,适配复杂真实场景。

  • 高效适配与灵活扩展:

    • 采用 bfloat16 精度和 Flash Attention 2 优化计算效率;
    • 支持 PEFT 微调,在不冻结基础模型的前提下实现参数高效学习;
    • 替换 LLM 原始输出头为二分类头,精准适配 "是否插入记忆" 的决策需求。

  • 模块解耦设计:触发器决策独立于编织器模块,仅基于输入序列和数据分布做出判断,保证了模块间的低耦合和高内聚。
2.2.3 网络架构

网络架构图如下。
说明如下:

  • 模型支持PEFT参数高效微调(如LoRA),适配于Transformer Blocks层
  • 整体精度采用bfloat16,平衡计算效率与数值稳定性
  • 注意力计算通过Flash Attention 2优化,提升长序列处理速度
3.png

2.2.4 代码
  1. class Trigger(torch.nn.Module, ABC):
  2.     """
  3.     记忆触发器的抽象基类(Trigger)。
  4.     定义了触发器的核心接口,用于决定在推理过程中何时触发记忆增强(插入潜在记忆)。
  5.     所有具体触发器实现都需继承此类并实现forward方法。
  6.     """
  7.     def __init__(self):
  8.         super().__init__()  # 调用父类Module的初始化方法
  9.    
  10.     @abstractmethod
  11.     def forward(self, **kwargs) -> bool:
  12.         """
  13.         抽象前向传播方法:接收输入数据,返回是否触发记忆增强的决策。
  14.         子类必须实现此方法,定义具体的触发逻辑。
  15.         
  16.         Args:
  17.             **kwargs: 可变关键字参数,包含输入序列、注意力掩码等模型所需数据
  18.             
  19.         Returns:
  20.             bool: 触发决策(True表示触发记忆增强,False表示不触发)
  21.         """
  22.         ...
  23. class NanoTrigger(torch.nn.Module):
  24.     """
  25.     极简触发器(NanoTrigger):始终触发记忆增强的基础实现。
  26.     无需复杂逻辑,固定返回触发决策,适用于基础测试或无需动态控制的场景。
  27.     """
  28.     def __init__(self):
  29.         super().__init__()  
  30.         # 注册一个缓冲区张量,用于获取模型所在设备(无实际计算意义)
  31.         self.register_buffer("_device", torch.tensor(0.0))
  32.    
  33.     @property
  34.     def device(self):
  35.         """获取模型所在设备(CPU/GPU)"""
  36.         return self._device.device
  37.    
  38.     def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> bool:
  39.         # 该"极简触发器"始终预测需要插入记忆
  40.         # 输出logits张量,其中插入决策(索引=1)的概率被设为1.0
  41.         # 适用于批次中的每个token位置
  42.         batch_size, seq_len = input_ids.shape
  43.         # 初始化logits张量:形状为[batch_size, seq_len, 2],2表示"不插入"(0)和"插入"(1)两类
  44.         logits = torch.zeros(batch_size, seq_len, 2, device=input_ids.device)
  45.         logits[..., 1] = 1.0  # 将所有位置的"插入"决策概率设为1.0
  46.         return logits
  47. class MemGenTrigger(torch.nn.Module):
  48.     """
  49.     MemGen框架的专用触发器模块(MemGenTrigger)。
  50.     - 输入:接收推理器模型当前解码序列的`inputs_embeds`(或input_ids)
  51.     - 输出:生成形状为[batch_size, seq_len, 2]的logits张量,
  52.       表示每个位置"不插入"(0)和"插入"(1)记忆的概率,用于动态决策记忆增强时机。
  53.     """
  54.     def __init__(
  55.         self,
  56.         pretrained_model_name_or_path: str,  # 预训练模型名称或路径(用于初始化触发器LLM)
  57.         peft_config: Optional[PeftConfig] = None  # PEFT配置(可选,用于参数高效微调)
  58.     ):
  59.         super().__init__()
  60.         
  61.         # 构建基础LLM模型(作为触发器的核心推理组件)
  62.         self.model = AutoModelForCausalLM.from_pretrained(
  63.             pretrained_model_name_or_path,
  64.             torch_dtype=torch.bfloat16,  # 使用bfloat16精度提升效率
  65.             attn_implementation="flash_attention_2"  # 启用Flash Attention 2优化注意力计算
  66.         )
  67.         self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)  # 对应的Tokenizer
  68.         
  69.         # 对基础模型进行后处理(设置可训练、替换输出头)
  70.         self.model = self._postprocess(self.model)
  71.         # 若提供PEFT配置,应用参数高效微调
  72.         if peft_config is not None:
  73.             self.model = get_peft_model(self.model, peft_config)
  74.         
  75.         self.config = self.model.config  # 保存模型配置
  76.     @property
  77.     def device(self):
  78.         """获取模型所在设备(CPU/GPU)"""
  79.         return self.model.device
  80.    
  81.     def _postprocess(self, model: PreTrainedModel):
  82.         """
  83.         对基础模型进行后处理,适配触发器的二分类任务需求。
  84.         
  85.         Args:
  86.             model: 原始预训练LLM模型
  87.             
  88.         Returns:
  89.             处理后的模型(可训练、替换为二分类输出头)
  90.         """
  91.         # 设置所有模型参数为可训练
  92.         for parameter in model.parameters():
  93.             parameter.requires_grad = True
  94.         
  95.         # 将原始语言模型的输出头(lm_head)替换为二分类头
  96.         hidden_size = model.config.hidden_size  # 模型隐藏层维度
  97.         classification_head = nn.Linear(hidden_size, 2)  # 输出维度为2(不插入/插入)
  98.         model.lm_head = classification_head
  99.         
  100.         # 确保新的二分类头参数可训练
  101.         for param in model.lm_head.parameters():
  102.             param.requires_grad = True
  103.         return model
  104.     def forward(
  105.         self,
  106.         input_ids: Optional[torch.LongTensor] = None,  # 生成序列的token ID,形状[batch_size, seq_len]
  107.         attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,避免关注填充token
  108.         **kwargs: Unpack[TransformersKwargs],  # 传递给底层模型的额外参数
  109.     ) -> torch.Tensor:
  110.         """
  111.         序列生成的触发决策机制。
  112.         触发器基于已生成的`input_ids`做出决策,受数据分布影响,但独立于编织器模块。
  113.         Args:
  114.             input_ids (Optional[torch.LongTensor]): 生成序列的token ID张量
  115.             attention_mask (Optional[torch.Tensor]): 注意力掩码,默认None
  116.             **kwargs: 传递给底层模型的额外关键字参数
  117.         Returns:
  118.             torch.Tensor: Logits张量,形状为`(batch_size, seq_len, num_classes)`
  119.                         num_classes=2,分别对应"不插入"(索引0)和"插入"(索引1)的概率
  120.         """   
  121.         # 调用基础模型前向传播,返回二分类logits
  122.         return self.model(
  123.             input_ids=input_ids,
  124.             attention_mask=attention_mask,
  125.             **kwargs
  126.         ).logits
复制代码
2.3  MemGenWeaver

2.3.1 核心作用

MemGenWeaver 是 MemGen 框架的核心组件之一,负责生成动态潜在记忆并将其与推理器的输入序列融合,从而实现记忆与推理过程的无缝交织。它通过可学习的潜在记忆查询向量,在提示词阶段和推理阶段分别生成针对性的记忆表示,引导推理器调整解码路径,提升智能体的动态决策能力。
2.3.2 核心特色


  • 双阶段记忆生成:区分提示词阶段(augment_prompt)和推理阶段(augment_inference),使用各自独立的可学习潜在记忆查询向量,适配不同阶段的记忆需求,增强记忆生成的针对性。
  • 灵活的潜在记忆融合:通过_augment方法统一实现潜在记忆与输入序列的融合,包括嵌入拼接、注意力掩码扩展和位置 ID 计算,确保记忆与原始输入在语义空间和时序上的一致性。
  • 高效的模型设计:

    • 基于预训练 LLM 构建,支持 PEFT 参数高效微调,在保留基础能力的同时降低训练成本;
    • 采用 bfloat16 精度和 Flash Attention 2 优化,提升计算效率和内存利用率。

  • 动态记忆编织机制:生成的潜在记忆并非静态检索结果,而是基于当前输入序列动态生成的隐藏状态,能够捕捉实时上下文信息,实现 “生成式记忆” 的核心特性。
  • 模块化与可扩展性:与推理器、触发器解耦,通过标准化接口交互;潜在记忆的数量可通过参数灵活配置,适配不同任务对记忆容量的需求。
2.3.3 网络架构

网络架构图如下。
说明如下:

  • 核心组件:

    • 可学习潜在记忆向量:分阶段设计(P=提示词阶段数量,I=推理阶段数量),支持动态生成记忆
    • 预训练LLM:作为记忆生成核心,默认启用bfloat16精度和Flash Attention 2优化
    • 序列融合层:确保输入与记忆在语义、掩码、时序上的一致性

  • 核心流程:

    • 输入 → 选择对应阶段的潜在记忆 → 融合序列 → LLM生成隐藏状态 → 提取潜在记忆输出
    • 支持PEFT参数高效微调(如LoRA),适配于Transformer Blocks层

  • 输出用途:

    • 生成的潜在记忆将通过投影层映射到推理器的嵌入空间,与原始输入融合以引导解码

4.png

2.3.4 代码

两个关键变量如下:

  • prompt_query_latents。

    • 作用:增强模型在处理prompt时候的表现。 在模型处理完原始提示之后会被注入到序列中,为模型提供额外的上下文信息。
    • 使用场景:在 augment_prompt 方法中使用,在生成阶段的开始阶段使用一次。

  • inference_query_latents。

    • 作用:在生成过程中动态增强模型的推理能力。可以在生成过程中的多个点被注入,以提供实时上下文增强。
    • 使用场景:在 augment_inference 方法中使用,在生成阶段中多次被使用。通常在遇到特定分隔符(逗号,句号等)后触发插入。

这两个变量都通过_augment 方法获得(获取学习到的潜在向量,并将其附加到输入嵌入中)。其流程如下:

  • 将潜在变量附加到当前输入嵌入序列的末尾。
  • 更新注意力掩码和位置ID,以考虑新增的潜在向量。
  • 将增强后的序列通过Weaver模型处理。
  • 提供于潜在向量位置对应的隐状态作为增强表示。
判断是否插入是通过函数 _should_augment 完成的。

  • 检查当前生成的文本是否是特殊字符(逗号等)
  • 使用触发模型(trigger model)进一步判断是否应该增强。
  • 考虑最大增强次数限制。
  1. class MemGenWeaver(torch.nn.Module):
  2.     """
  3.     MemGen模型的编织器模块(MemGenWeaver)。
  4.     - 输入:接收接收来自推理器模型当前当前解码序列的`inputs_embeds`(输入嵌入入)
  5.     - 输出:生成长度为K的隐藏状态序列,
  6. 这些状态将与原始`inputs_embeds`拼接,以改变推理器的解码路径
  7.     """
  8.     def __init__(
  9.         self,
  10.         pretrained_model_name_or_path: str,  # 预训练模型的名称或路径
  11.         prompt_latents_num: int,    # 提示词阶段生成的潜在记忆数量
  12.         inference_latents_num: int, # 推理阶段生成的潜在记忆数量
  13.         peft_config: Optional[PeftConfig] = None  # PEFT配置(可选)
  14.     ):
  15.         super().__init__()
  16.         
  17.         # 基础模型初始化
  18.         self.model = AutoModelForCausalLM.from_pretrained(
  19.             pretrained_model_name_or_path,
  20.             torch_dtype=torch.bfloat16,  # 使用bfloat16精度以提高效率
  21.             attn_implementation="flash_attention_2"  # 启用Flash Attentionention 2优化
  22.         )
  23.         self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)  # 对应的分词器
  24.         # 若提供PEFT配置,则应用参数高效微调
  25.         if peft_config is not None:
  26.             self.model = get_peft_model(self.model, peft_config)
  27.         
  28.         self.config = self.model.config  # 保存模型配置
  29.         
  30.         # 提示词阶段的潜在记忆查询向量(可学习参数)
  31.         self.prompt_query_latents = nn.Parameter(
  32.             torch.randn(prompt_latents_num, self.config.hidden_size),  # 形状:[prompt_latents_num, hidden_size]
  33.             requires_grad=True  # 允许反向传播更新
  34.         )
  35.         # 推理阶段的潜在记忆查询向量(可学习参数)
  36.         self.inference_query_latents = nn.Parameter(
  37.             torch.randn(inference_latents_num, self.config.hidden_size),  # 形状:[inference_latents_num, hidden_size]
  38.             requires_grad=True  # 允许反向传播更新
  39.         )
  40.    
  41.     @property
  42.     def prompt_latents_num(self) -> int:
  43.         """返回提示词阶段的潜在记忆数量"""
  44.         return self.prompt_query_latents.size(0)
  45.     @property
  46.     def inference_latents_num(self) -> int:
  47.         """返回推理阶段的潜在记忆数量"""
  48.         return self.inference_query_latents.size(0)
  49.     @property
  50.     def device(self):
  51.         """返回模型所在的设备(CPU/GPU)"""
  52.         return self.model.device
  53.     def _augment(
  54.         self,
  55.         latents: torch.Tensor,                # 潜在记忆查询向量,形状:[latents_num, hidden_size]
  56.         inputs_embeds: torch.Tensor,          # 输入嵌入,形状:[batch_size, seq_len, hidden_size]
  57.         attention_mask: torch.Tensor,         # 注意力掩码,形状:[batch_size, seq_len]
  58.         position_ids: torch.Tensor            # 位置ID,形状:[batch_size, seq_len]
  59.     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  60.         """
  61.         通用的潜在记忆增强方法:将潜在记忆与输入序列融合,生成增强后的隐藏状态。
  62.         
  63.         参数:
  64.             latents: 潜在记忆查询向量
  65.             inputs_embeds: 输入序列的嵌入表示
  66.             attention_mask: 输入序列的注意力掩码
  67.             position_ids: 输入序列的位置ID
  68.         
  69.         返回:
  70.             三元组 (latents_hidden_states, latents_mask, latents_position_ids)
  71.             - latents_hidden_states: 生成的潜在记忆隐藏状态,形状:[batch_size, latents_num, hidden_size]
  72.             - latents_mask: 潜在记忆的注意力掩码,形状:[batch_size, latents_num]
  73.             - latents_position_ids: 潜在记忆的位置ID,形状:[batch_size, latents_num]
  74.         """
  75.         batch_size = attention_mask.shape[0]  # 获取批次大小
  76.         latents_num = latents.size(0)         # 获取潜在记忆数量
  77.         
  78.         # 扩展潜在记忆维度以匹配批次大小:[1, latents_num, hidden_size] → [batch_size, latents_num, hidden_size]
  79.         latents = latents.unsqueeze(0).repeat(batch_size, 1, 1)
  80.         
  81.         # 将潜在记忆嵌入与输入嵌入拼接:[batch_size, seq_len + latents_num, hidden_size]
  82.         inputs_embeds = torch.cat([inputs_embeds, latents], dim=1)
  83.         # 构建潜在记忆的注意力掩码(全为1,表示有效)并与输入掩码拼接
  84.         latents_mask = torch.ones(latents.shape[:-1], dtype=attention_mask.dtype, device=attention_mask.device)
  85.         attention_mask = torch.cat([attention_mask, latents_mask], dim=1)  # 形状:[batch_size, seq_len + latents_num]
  86.         
  87.         # 生成潜在记忆的位置ID(在输入序列最后位置的基础上递增)
  88.         last_position_ids = position_ids.max(dim=1)[0]  # 获取输入序列的最大位置ID
  89.         latents_relative_positions = torch.arange(latents_num, device=attention_mask.device)  # 潜在记忆的相对位置
  90.         # 计算绝对位置:输入序列最大位置 + 相对位置 + 1(避免重叠)
  91.         latents_position_ids = last_position_ids.unsqueeze(1) + latents_relative_positions + 1
  92.         # 拼接位置ID:[batch_size, seq_len + latents_num]
  93.         position_ids = torch.cat([position_ids.long(), latents_position_ids.long()], dim=1)
  94.         # 验证拼接后的维度是否一致
  95.         assert inputs_embeds.shape[:2] == attention_mask.shape == position_ids.shape
  96.         # 模型前向传播,获取隐藏状态
  97.         outputs = self.model(
  98.             inputs_embeds=inputs_embeds,
  99.             attention_mask=attention_mask,
  100.             position_ids=position_ids,  
  101.             output_hidden_states=True,  # 输出所有层的隐藏状态
  102.         )
  103.         # 取最后一层的隐藏状态,并提取潜在记忆部分(序列末尾的latents_num个位置)
  104.         hidden_states = outputs.hidden_states[-1]
  105.         latents_hidden_states = hidden_states[:, -latents_num:, :]
  106.         return latents_hidden_states, latents_mask, latents_position_ids
  107.     def augment_prompt(
  108.         self,
  109.         inputs_embeds: torch.Tensor,
  110.         attention_mask: torch.Tensor,
  111.         position_ids: torch.Tensor
  112.     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  113.         """
  114.         提示词阶段的潜在记忆增强:使用提示词专用的潜在记忆查询向量。
  115.         
  116.         参数与返回值同_augment方法
  117.         """
  118.         return self._augment(
  119.             latents=self.prompt_query_latents,
  120.             inputs_embeds=inputs_embeds,
  121.             attention_mask=attention_mask,
  122.             position_ids=position_ids
  123.         )
  124.     def augment_inference(
  125.         self,
  126.         inputs_embeds: torch.Tensor,
  127.         attention_mask: torch.Tensor,
  128.         position_ids: torch.Tensor
  129.     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  130.         """
  131.         推理阶段的潜在记忆增强:使用推理专用的潜在记忆查询向量。
  132.         
  133.         参数与返回值同_augment方法
  134.         """
  135.         return self._augment(
  136.             latents=self.inference_query_latents,
  137.             inputs_embeds=inputs_embeds,
  138.             attention_mask=attention_mask,
  139.             position_ids=position_ids
  140.         )
复制代码
0xFF 参考

最新成果!Agent记忆的第三种可能:生成式隐式记忆

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册