当前位置: 首页 > news >正文

GenerationMixin:_sample方法(GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH)

_sample

_sample是一个用于文本生成的函数。该函数使用 多项式采样(multinomial sampling) 来生成序列,可用于文本解码器、文本到文本、语音到文本以及视觉到文本的模型。

函数定义

def _sample(self,input_ids: torch.LongTensor,logits_processor: LogitsProcessorList,stopping_criteria: StoppingCriteriaList,generation_config: GenerationConfig,synced_gpus: bool,streamer: Optional["BaseStreamer"],**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:# 函数主体

参数详解

  1. input_ids (torch.LongTensor)

    • 形状为 (batch_size, sequence_length) 的输入张量。
    • 用作生成的提示序列。
  2. logits_processor (LogitsProcessorList)

    • LogitsProcessorList 的实例。
    • 包含了多个 LogitsProcessor,用于在每个生成步骤中修改模型的输出 logits,以实现诸如温度缩放、重复惩罚等功能。
  3. stopping_criteria (StoppingCriteriaList)

    • StoppingCriteriaList 的实例。
    • 包含了多个 StoppingCriteria,用于判断生成过程是否应该停止,例如达到最大长度或生成了结束符号等。
  4. generation_config (GenerationConfig)

    • 生成配置,包含了生成过程中的各种参数,如最大长度、是否进行采样等。
  5. synced_gpus (bool)

    • 是否在多 GPU 环境下同步生成循环,避免在某些 GPU 生成完成后出现死锁。
  6. streamer (BaseStreamer, 可选)

    • 流处理器对象,用于在生成过程中实时处理生成的序列。
  7. model_kwargs

    • 其他模型特定的关键字参数,将传递给模型的 forward 函数。
    • 如果模型是编码器-解码器模型,model_kwargs 应该包括 encoder_outputs

返回值

  • GenerateDecoderOnlyOutput、GenerateEncoderDecoderOutput 或 torch.LongTensor
    • 根据配置,返回包含生成的序列和其他信息的对象,或者仅返回生成的 token 序列。

函数流程详解

1. 初始化变量
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
  • pad_token_id:用于填充生成完成后的序列。
  • 输出控制参数
    • output_attentions:是否返回注意力权重。
    • output_hidden_states:是否返回隐藏状态。
    • output_scores:是否返回处理后的得分。
    • output_logits:是否返回原始 logits。
  • return_dict_in_generate:是否在生成中返回字典形式的结果。
  • max_length:生成的最大长度。
  • has_eos_stopping_criteria:是否存在基于结束符的停止条件。
  • do_sample:是否进行采样,True 表示采样,False 表示贪心搜索。
2. 初始化保存生成过程中的信息的元组
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  • 根据配置,初始化用于保存生成过程中各项信息的元组。
3. 处理编码器-解码器模型的编码器输出
if return_dict_in_generate and self.config.is_encoder_decoder:encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else Noneencoder_hidden_states = (model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None)
  • 如果模型是编码器-解码器模型,并且需要返回注意力权重或隐藏状态,则从 encoder_outputs 中获取。
4. 初始化生成过程的变量
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
  • batch_size:批次大小。
  • cur_len:当前序列长度。
  • this_peer_finished:在多 GPU 环境下,当前设备是否已完成生成。
  • unfinished_sequences:用于跟踪哪些序列尚未生成完成。
5. 获取初始缓存位置
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
  • 准备缓存机制,用于加速生成过程。
  • _get_initial_cache_position
6. 准备模型的前向调用函数
model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache):is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cacheis_compileable = is_compileable and not self.generation_config.disable_compileif is_compileable and (self.device.type == "cuda" or generation_config.compile_config._compile_all_devices):os.environ["TOKENIZERS_PARALLELISM"] = "0"model_forward = self.get_compiled_call(generation_config.compile_config)
  • 如果缓存可编译且设备支持,则使用编译后的模型前向函数,以提高性能。
7. 进入生成循环
is_prefill = True
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):# 循环体
  • 检查是否存在未完成的序列且未达到最大长度,进入生成循环。
8. 准备模型输入
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
  • 调用 prepare_inputs_for_generation 准备模型输入。
  • 根据配置,添加控制输出的参数。
  • prepare_inputs_for_generation
9. 执行模型前向传播
if is_prefill:outputs = self(**model_inputs, return_dict=True)is_prefill = False
else:outputs = model_forward(**model_inputs, return_dict=True)
  • 第一次迭代使用默认的模型调用,以正确初始化缓存。
  • 后续迭代可能使用编译后的模型前向函数。
10. 更新模型参数
model_kwargs = self._update_model_kwargs_for_generation(outputs,model_kwargs,is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:continue
  • 更新 model_kwargs,包含新的缓存 past_key_values 等。
  • 如果在多 GPU 环境中当前设备已完成生成,则跳过后续步骤。
  • _update_model_kwargs_for_generation
11. 处理模型输出的 logits
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)

这行代码的目的是从 outputs.logits 中提取最后一个时间步(即最后一列)的所有令牌的预测得分,并进行复制和转换为浮点类型,存储在 next_token_logits 变量中。

代码解释

  1. outputs.logits[:, -1, :]:

    • outputs.logits 是一个三维张量,通常代表模型输出的预测得分(logits),形状一般为 [batch_size, sequence_length, vocab_size]
    • [:, -1, :] 是一个切片操作,表示选择所有批次的最后一个时间步的所有令牌得分。具体来说:
      • : 切片表示选择所有批次(第一维度)。
      • -1 表示选择最后一个时间步(第二维度)。
      • : 切片表示选择所有令牌的得分(第三维度)。
  2. .clone():

    • 创建一个新张量,它是 outputs.logits[:, -1, :] 的副本。
    • 使用 clone() 是为了避免原始张量的意外改变影响到 next_token_logits
  3. .float():

    • 将克隆后的张量转换为浮点类型,以确保后续计算的精度。

举例说明

假设 outputs.logits 是如下的一个三维张量,形状为 [2, 4, 5](批次大小为2,序列长度为4,词汇大小为5):

outputs.logits = [[[0.1, 0.2, 0.3, 0.4, 0.5],[0.6, 0.7, 0.8, 0.9, 1.0],[1.1, 1.2, 1.3, 1.4, 1.5],[1.6, 1.7, 1.8, 1.9, 2.0],],[[0.2, 0.3, 0.4, 0.5, 0.6],[0.7, 0.8, 0.9, 1.0, 1.1],[1.2, 1.3, 1.4, 1.5, 1.6],[1.7, 1.8, 1.9, 2.0, 2.1],],
]

根据代码,我们执行以下步骤:

  1. 提取最后一个时间步的所有令牌得分
    next_token_logits = [[1.6, 1.7, 1.8, 1.9, 2.0],[1.7, 1.8, 1.9, 2.0, 2.1]
    ]
    

描述到切片操作后,我们从 outputs.logits 张量中获取每个批次的最后一个时间步(第四个时间步)的令牌得分。

  1. 克隆并转换为浮点类型
    • clone() 操作创建新张量,且 .float() 将张量元素转为浮点型(假设原始元素非浮点类型)。

最终 next_token_logits 将包含如下张量:

next_token_logits (float):
[[1.6, 1.7, 1.8, 1.9, 2.0],[1.7, 1.8, 1.9, 2.0, 2.1]
]

这用于下一步的处理,例如为模型生成下一个令牌选择提供概率分布(logits)。

12. 对 logits 进行预处理
next_token_scores = logits_processor(input_ids, next_token_logits)

这行代码的目的是通过 logits_processornext_token_logits 进行处理,以生成 next_token_scoreslogits_processor 是一个函数或包含一系列函数的对象,它对模型输出的 logits 进行某种自定义变换(如归一化、过滤或修改等)。我们来详细解释,并举例说明。

代码解释

  1. logits_processor:

    • logits_processor 是一个函数或 callable 对象,用于对模型输出的 logits 进行处理。处理逻辑可以是许多策略中的一种,比如:
      • 降低或惩罚某些特定令牌的分数。
      • 应用特定的归一化或过滤。
      • 结合一些上下文信息对 logits 进行调整。
    • 它接受两个参数:input_idsnext_token_logits
  2. input_ids:

    • input_ids 是模型当前接收到的输入序列,通常用于提供上下文或约束以在 logits_processor 中进行处理。
  3. next_token_logits:

    • 这些是从模型输出中提取并准备处理的原始 logit 值。
  4. next_token_scores:

    • logits_processor 返回处理后的结果,通常表示对每个可能的下一个令牌的更新得分。这个输出用于影响对下一个令牌的选择过程。

举例说明

假设有如下的 input_idsnext_token_logits

  • input_ids:表示当前序列,例如 [101, 1045, 2572, 1037, 102],其中每个数字都是代表某个词的唯一标识。
  • next_token_logits:例如为 [2.0, 1.5, 3.0, 0.5, -1.0],其中每个值是对应给定词汇表中一个可能的下一个词的预测得分(logit)。

假设 logits_processor 的简单处理逻辑是将所有负数值设置为 -inf,来避免选择这些令牌,如此进行如下调整:

  • 原始 next_token_logits[2.0, 1.5, 3.0, 0.5, -1.0]
  • 处理后的 next_token_scores[2.0, 1.5, 3.0, 0.5, -inf]

在这个例子中,logits_processor 的机制是简单的限制,导致原来的负分数被抑制,使得生成更可能选择正得分的令牌。

处理后的 next_token_scores 可以用于进一步的决策,例如通过 softmax 转换为概率分布,以便从词汇表中选出下一个最有可能的令牌。

13. 根据需要保存生成过程中的信息
if return_dict_in_generate:if output_scores:scores += (next_token_scores,)if output_logits:raw_logits += (next_token_logits,)if output_attentions:decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,))if self.config.is_encoder_decoder:cross_attentions += (outputs.cross_attentions,)if output_hidden_states:decoder_hidden_states += ((outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,))
  • 根据配置,保存处理过的得分、原始 logits、注意力权重和隐藏状态。
14. 选择下一个 token
if do_sample:probs = nn.functional.softmax(next_token_scores, dim=-1)next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:next_tokens = torch.argmax(next_token_scores, dim=-1)

这段代码实现了从给定的 next_token_scores 中选择下一个令牌(token)。根据标志 do_sample 的值不同,它提供两种选择策略:采样和贪心选择。让我们逐步解释这些代码,并举例说明。

代码解释

  1. do_sample:

    • 这是一个布尔值标志,用于决定是否进行采样选择。
    • 如果 do_sampleTrue,则使用概率采样来选择令牌。
    • 如果 do_sampleFalse,则使用贪心策略,选择具有最高得分的令牌。
  2. probs = nn.functional.softmax(next_token_scores, dim=-1):

    • do_sampleTrue时,这行代码将 next_token_scores 转换为概率分布。
    • softmax 函数确保所有得分转换为非负值且总和为1,使其成为概率分布。
  3. torch.multinomial(probs, num_samples=1).squeeze(1):

    • multinomial 用于从概率分布 probs 中进行采样,返回一个样本。
    • num_samples=1 说明每次只选择一个令牌。
    • squeeze(1) 用于去掉维度为1的多余维度,使返回结果变为基本的1D张量。
  4. torch.argmax(next_token_scores, dim=-1):

    • do_sampleFalse时,选择得分最高的令牌。
    • argmax 返回具有最大得分的索引,即最有可能的下一个令牌。

举例说明

假设 next_token_scores 是如下的一维张量,表示模型生成下一个令牌时的logits:

next_token_scores = [2.0, 1.5, 3.0, 0.5, -inf]
  1. do_sample=True 的情况下

    • 应用 softmax
      probs = softmax([2.0, 1.5, 3.0, 0.5, -inf]) = [0.2138, 0.1421, 0.5820, 0.0621, 0.0]
      
    • multinomialprobs 中采样可能得到 next_tokens
      next_tokens = [2]  # 假设采样到索引2(3.0的概率最大,最有可能)
      
  2. do_sample=False 的情况下

    • 使用 argmax
      next_tokens = [2]  # 因为索引2对应的得分3.0最大
      

在两种策略中,next_tokens 存储的都是选定的下一个令牌的索引,依据其选择策略可以应用于生成或预测序列的新部分。采样策略尤其适用于生成更有多样性的序列,而贪心策略则保证每一步都选择当前最优决策。

15. 处理已完成的序列
if has_eos_stopping_criteria:next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

这段代码处理了在生成文本时,当遇到结束标记(End of Sentence, EOS)后,确保接下来的令牌是填充标记(padding token)。它通过使用布尔掩码 unfinished_sequences 来实现这一点。让我们分步解释一下,并举例说明。

代码解释

  1. next_tokens * unfinished_sequences

    • next_tokens 是当前选择的下一个令牌。
    • unfinished_sequences 是一个布尔张量,指示哪些序列尚未完成。
    • 乘法操作会保持未完成的序列的 next_tokens 值。
  2. pad_token_id * (1 - unfinished_sequences)

    • pad_token_id 是填充标记的ID。
    • (1 - unfinished_sequences) 会将布尔值反转,这样完成的序列会变成1,未完成的序列变成0。
    • 乘法操作会将已完成的序列的下一令牌设为 pad_token_id
  3. 结合运算

    • next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
      这条语句最终会确保,对于未完成的序列,保留其原始的 next_tokens ;对于已完成的序列,将 next_tokens 设置为 pad_token_id

举例说明

假设有以下变量:

  • next_tokens 是当前选择的令牌,假设其值为 [5, 6, 7]
  • unfinished_sequences 是一个布尔张量,表示哪些序列未完成。例如 [1, 0, 1] (1 表示未完成,0 表示完成)。
  • pad_token_id 是填充标记的ID,例如为 0

示例数据:

next_tokens = [5, 6, 7]
unfinished_sequences = [1, 0, 1]
pad_token_id = 0

计算步骤:

  1. 部分1:next_tokens * unfinished_sequences

    • [5, 6, 7] * [1, 0, 1]
    = [5 * 1, 6 * 0, 7 * 1]
    = [5, 0, 7]
    
  2. 部分2:pad_token_id * (1 - unfinished_sequences)

    • pad_token_id = 0
    • 1 - unfinished_sequences = [0, 1, 0]
    • 0 * [0, 1, 0]
    = [0 * 0, 0 * 1, 0 * 0]
    = [0, 0, 0]
    
  3. 结合结果

    • [5, 0, 7] + [0, 0, 0]
    = [5 + 0, 0 + 0, 7 + 0]
    = [5, 0, 7]
    

最终得到的 next_tokens[5, 0, 7]。在这个例子中,未完成的序列(第1和第3项)保留了原来的令牌值,而已完成的序列(第2项)的下一令牌被设置为填充标记 0

这个逻辑确保在生成序列过程中,一旦某个序列到达结束标记(EOS),其后生成的所有令牌都会变成填充标记,以防止再度继续生成。

16. 更新输入序列
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

这段代码用于将新生成的令牌 next_tokens 添加到当前的 input_ids,更新为下一步模型输入。这是文本生成过程中重要的一步,确保每次生成的输出都成为下一次生成的输入。让我们详细解释这段代码,并举例说明。

代码解释

  1. next_tokens[:, None]:

    • next_tokens 是当前生成的令牌的张量。
    • [:, None] 用作添加一个新的维度,达到扩展张量维度的效果,使其从一维变为二维。具体来说:
      • 如果 next_tokens 是形状 (batch_size,),那么 [:, None] 将其转变为 (batch_size, 1)
    • 这样处理让 next_tokens 可以按列而不是按行进行连接。
  2. torch.cat([...], dim=-1):

    • torch.cat 函数用于沿着指定维度进行张量连接。
    • dim=-1 指定连接操作沿着最后一个维度进行,也就是在这里,可以理解为列连接。
    • next_tokens 添加到 input_ids 的每一批次的最后。

举例说明

假设 input_idsnext_tokens 如下:

  • input_ids 是之前生成的文本的标识,假如形状是 [batch_size, sequence_length]
input_ids = [[101, 102],[201, 202]]
  • next_tokens 是本次生成的新令牌:
next_tokens = [103, 203]

操作步骤:

  1. 应用 [:, None]
    • next_tokens[:, None] 使 next_tokens 从一维变为二维:
next_tokens = [[103],[203]]
  1. 连接操作
    • torch.cat([input_ids, next_tokens[:, None]], dim=-1)next_tokens 的每批次令牌添加到 input_ids 的最后:
input_ids = [[101, 102, 103],[201, 202, 203]]
16. 更新状态
if streamer is not None:streamer.put(next_tokens.cpu())
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1

这段代码用于更新未完成序列的状态,并通过 this_peer_finished 标识当前批次内的序列是否全部完成。下面就逐步解释这段代码,并提供一个例子进行说明。

代码解释

  1. unfinished_sequences:

    • 这是一个布尔张量,表示哪些序列仍然未完成。通常其元素为1表示未完成,0表示已完成。
  2. stopping_criteria(input_ids, scores):

    • stopping_criteria 是一个函数,用于判断是否有序列已经达到停止条件。
    • 接受当前的 input_idsscores(通常是得分或概率)作为参数。
    • 返回一个布尔张量,标识哪些序列达到了停止条件。
  3. ~stopping_criteria(input_ids, scores):

    • ~ 是按位取反操作,将 stopping_criteria 的结果反转。
    • 想达到停止条件的序列标记为 0,未达到停止条件的标记为 1
  4. unfinished_sequences & ~stopping_criteria(input_ids, scores):

    • 按位与操作更新未完成的序列数。
    • 如果一个序列已经达到停止条件,则对应的 unfinished_sequences 将设置为0。
    • 如果序列仍未完成,则保持为1。
  5. this_peer_finished = unfinished_sequences.max() == 0:

    • 检查 unfinished_sequences 中的所有值是否为0。
    • 如果unfinished_sequences.max()为0,表示所有序列都完成了。
    • this_peer_finished 为真时,所有序列在当前批次内都完成了。

举例说明

假设有以下情境:

  • unfinished_sequences 是初始的未完成序列状态:
unfinished_sequences = [1, 1, 1]
  • stopping_criteria(input_ids, scores) 返回一个布尔张量:
stopping_criteria_result = [0, 1, 0]

这是stopping_criteria的输出,表示第二个序列达到停止条件。

操作步骤:

  1. 取反操作
~stopping_criteria_result = [1, 0, 1]
  1. 按位与更新
unfinished_sequences & ~stopping_criteria_result = [1 & 1, 1 & 0, 1 & 1] = [1, 0, 1]

表示第一个和最后一个序列仍未完成,第二个序列已完成。

  1. 检查是否完成
  • unfinished_sequences.max() == 0 返回False,因为unfinished_sequences中有非零元素。
17. 清理变量防止内存泄漏
del outputs
  • 删除 outputs,防止在下一次迭代中占用不必要的内存。
18. 结束循环
  • 当所有序列都完成生成,或达到最大长度时,退出循环。
19. 结束流处理
if streamer is not None:streamer.end()
  • 如果使用了 streamer,调用其 end() 方法,表示生成结束。
20. 返回生成结果
if return_dict_in_generate:if self.config.is_encoder_decoder:return GenerateEncoderDecoderOutput(sequences=input_ids,scores=scores,logits=raw_logits,encoder_attentions=encoder_attentions,encoder_hidden_states=encoder_hidden_states,decoder_attentions=decoder_attentions,cross_attentions=cross_attentions,decoder_hidden_states=decoder_hidden_states,past_key_values=model_kwargs.get("past_key_values"),)else:return GenerateDecoderOnlyOutput(sequences=input_ids,scores=scores,logits=raw_logits,attentions=decoder_attentions,hidden_states=decoder_hidden_states,past_key_values=model_kwargs.get("past_key_values"),)
else:return input_ids
  • 根据配置,返回包含生成序列和其他信息的对象,或仅返回生成的序列。

_sample内部的while循环:

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):# 循环体

这个循环的执行次数取决于以下因素:

  1. 当前序列长度 (cur_len):生成开始时输入序列的长度。
  2. 最大生成长度 (max_length):生成的序列允许的最大长度。
  3. 停止条件 (stopping_criteria):定义了何时停止生成,例如生成了结束符号(EOS token)或达到最大长度。
  4. 模型在每一步生成的输出:每个时间步生成的token是否满足停止条件。

因此,while循环的执行次数并不是固定的,它可能会从一次到多次,具体次数取决于上述因素的组合。

循环为何会执行多次?

生成文本的过程是逐步进行的,在每个时间步,模型基于当前的输入序列预测下一个token,并将其添加到序列中。while循环控制这个生成过程,直到满足停止条件为止。

主要原因包括:

  • 序列长度的限制:如果生成的序列尚未达到max_length,并且有序列尚未完成,则循环继续。
  • 停止条件的影响:如果定义了停止条件(例如生成了EOS token),一旦满足条件的序列就被标记为完成,但其他未完成的序列需要继续生成。
  • 批量生成的同步:在批量生成(batch generation)中,需要确保批次中的所有序列都完成生成,循环会一直执行,直到所有序列都完成或达到最大长度。
循环执行次数的详细解释
1. 循环初始化
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
  • batch_size:批次大小,即一次生成的序列数量。
  • cur_len:当前序列的长度。
  • this_peer_finished:用于多GPU情况下,指示当前GPU是否完成生成。
  • unfinished_sequences:一个形状为(batch_size,)的张量,用于跟踪哪些序列尚未完成。
2. 循环条件
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):# 循环体
  • self._has_unfinished_sequences():这是一个函数,用于判断是否还有未完成的序列需要继续生成。
  • 循环继续的条件
    • 当前未达到最大长度 (cur_len < max_length)。
    • 存在未完成的序列 (unfinished_sequences 中存在值为1的元素)。
3. 循环内部

在循环内部,模型会逐步生成新的token,并更新相关的状态。

  • 模型前向传播:根据当前的input_ids,模型预测下一个token的概率分布(logits)。

  • 处理logits:通过logits_processor对logits进行处理,例如应用温度、top-k或top-p采样等。

  • 选择下一个token

    • 如果是采样 (do_sample=True),则根据概率分布进行采样。
    • 如果是贪心搜索 (do_sample=False),则选择概率最高的token。
  • 更新序列:将新生成的token添加到input_ids

  • 检查停止条件

    • 使用stopping_criteria检查哪些序列已满足停止条件。
    • 更新unfinished_sequences,将已完成的序列标记为0。
  • 更新循环变量:增加cur_len的值,继续下一次迭代。

实例说明

假设以下设定:

  • 批次大小 (batch_size):2
  • 初始序列长度 (cur_len):5
  • 最大长度 (max_length):10
  • 停止条件:生成了EOS token(假设其ID为eos_token_id = 2

初始状态:

input_ids = [[101, 10, 23, 45, 67],   # 序列1[101, 11, 34, 56, 78]    # 序列2
]
unfinished_sequences = [1, 1]  # 两个序列都未完成
cur_len = 5
第一次循环
  • 模型生成下一个token
    • 序列1生成token 89
    • 序列2生成token 90
  • 更新input_ids
input_ids = [[101, 10, 23, 45, 67, 89],[101, 11, 34, 56, 78, 90]
]
cur_len = 6
  • 检查停止条件
    • 没有序列生成EOS token。
    • unfinished_sequences保持为 [1, 1]
第二次循环
  • 模型生成下一个token
    • 序列1生成token 2(EOS token)
    • 序列2生成token 91
  • 更新input_ids
input_ids = [[101, 10, 23, 45, 67, 89, 2],[101, 11, 34, 56, 78, 90, 91]
]
cur_len = 7
  • 检查停止条件
    • 序列1生成了EOS token,将其标记为完成。
    • 更新unfinished_sequences[0, 1]
第三次循环
  • 模型生成下一个token
    • 序列1已完成,不再生成,填充pad_token_id
    • 序列2生成token 2(EOS token)
  • 更新input_ids
input_ids = [[101, 10, 23, 45, 67, 89, 2, pad_token_id],[101, 11, 34, 56, 78, 90, 91, 2]
]
cur_len = 8
  • 检查停止条件
    • 序列2生成了EOS token,将其标记为完成。
    • 更新unfinished_sequences[0, 0]
循环结束
  • 所有序列都完成了生成,unfinished_sequences全为0。
  • 循环条件不再满足,退出循环。
总结循环次数
  • 循环次数:3次
  • 最终序列长度:8(因为 cur_len 从5增加到8)
为什么会循环这么多次?
  • 序列长度增长:每次循环,cur_len增加1。
  • 停止条件未满足:只要有序列未生成EOS token,unfinished_sequences中就有1,循环继续。
  • 需要等待所有序列完成:在批量生成中,即使某些序列提前完成,循环也会继续,直到所有序列都完成或达到最大长度。
另一种情况的实例

假设模型在默认情况下不太可能生成EOS token,或者没有设置停止条件。

设定:

  • 批次大小 (batch_size):2
  • 初始序列长度 (cur_len):5
  • 最大长度 (max_length):10
  • 未设置停止条件(或模型未生成EOS token)

循环过程:

  • cur_len=5开始,每次循环cur_len增加1
  • 没有序列会被标记为完成unfinished_sequences始终为 [1, 1]
  • 循环持续到cur_len达到max_length
  • 循环次数:5次(从cur_len=5到cur_len=10)
总结
  • 循环次数依赖于序列是否完成:只要有未完成的序列,且未达到最大长度,循环就会继续。
  • 最大循环次数max_length - cur_len_initial,即可能的最大生成步数。
  • 最小循环次数:如果初始序列已经满足停止条件,循环可能不会执行。

_get_initial_cache_position

函数功能概述

_get_initial_cache_position 是一个用于计算 cache_position 的函数。cache_position 用于跟踪生成过程中每个位置的 token 编号,特别是在使用缓存(past_key_values)的情况下,它确保生成的新 token 的位置编号与缓存中的位置正确对应。

在生成文本的过程中,Transformer 模型可以使用缓存的 key 和 value(past_key_values),以加速生成。这些缓存通常来自于模型之前的计算,即已经生成的序列。为了正确处理缓存,需要调整新生成的 token 的位置编号,这就是 cache_position 的作用。

代码详解
def _get_initial_cache_position(self, input_ids, model_kwargs):"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:cache_position = (torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1)else:cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1past_length = 0if model_kwargs.get("past_key_values") is not None:cache = model_kwargs["past_key_values"]past_length = 0if not isinstance(cache, Cache):past_length = cache[0][0].shape[2]elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:past_length = cache.get_seq_length()# TODO: this is not torch.compile-friendly, find a work-around. If the cache is not empty,# end-to-end compilation will yield bad results because `cache_position` will be incorrect.if not is_torchdynamo_compiling():cache_position = cache_position[past_length:]model_kwargs["cache_position"] = cache_positionreturn model_kwargs
步骤1:计算 cache_position

首先,根据是否存在嵌入向量,以及模型的类型(解码器或编码器-解码器),计算 cache_position

  • 如果存在 inputs_embeds 且模型是解码器模型(不是编码器-解码器)

    if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
    
    • model_kwargs["inputs_embeds"] 的形状通常为 (batch_size, sequence_length, embed_dim)
    • model_kwargs["inputs_embeds"][0, :, 0] 取出第一个样本的所有时间步的第一个嵌入向量分量,形状为 (sequence_length,)
    • torch.ones_like(..., dtype=torch.int64) 创建一个与形状相同的全1张量,类型为 int64
    • .cumsum(0) 计算累积和,得到 [1, 2, 3, ..., sequence_length]
    • 减去1,得到 [0, 1, 2, ..., sequence_length - 1]
  • 如果存在 decoder_inputs_embeds 且模型是编码器-解码器模型

    elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:cache_position = (torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1)
    
    • 类似地,获取解码器的嵌入,计算位置编号。
  • 否则,使用 input_ids

    else:cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
    
    • input_ids[0, :] 取第一个样本的所有 token。
    • 通过创建全1张量,计算累积和并减1,得到位置编号。

总结cache_position 是一个长度为 sequence_length 的张量,表示每个位置的编号,从 0 开始。

步骤2:处理 past_key_values(缓存)
past_length = 0
if model_kwargs.get("past_key_values") is not None:cache = model_kwargs["past_key_values"]past_length = 0if not isinstance(cache, Cache):past_length = cache[0][0].shape[2]elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:past_length = cache.get_seq_length()# TODO: this is not torch.compile-friendly, find a work-around. If the cache is not empty,# end-to-end compilation will yield bad results because `cache_position` will be incorrect.if not is_torchdynamo_compiling():cache_position = cache_position[past_length:]
  • 首先,初始化 past_length = 0

  • 如果存在 past_key_values(缓存)

    • 获取缓存的长度 past_length

      • 如果缓存不是 Cache 类的实例,那么:

        past_length = cache[0][0].shape[2]
        
        • cache 是一个列表,包含每一层的 (key, value) 对。
        • cache[0][0] 是第一层的 key,其形状通常为 (batch_size, num_heads, past_length, head_dim)
        • cache[0][0].shape[2] 即为 past_length
      • 如果缓存是 Cache 类的实例,且具有 get_seq_length 方法:

        past_length = cache.get_seq_length()
        
        • 直接获取缓存的序列长度。
    • 调整 cache_position

      • 如果不存在编译(is_torchdynamo_compiling()False),则切片 cache_position,去掉已经缓存的部分:

        cache_position = cache_position[past_length:]
        
        • 这样,cache_position 只包含新生成的 token 的位置编号。
步骤3:更新 model_kwargs 并返回
model_kwargs["cache_position"] = cache_position
return model_kwargs
  • 将计算得到的 cache_position 添加到 model_kwargs 中,以供后续生成过程中使用。

示例说明 cache_position 的变化过程

示例1:无缓存的情况

假设:

  • input_ids

    input_ids = torch.tensor([[101, 102, 103],  # 样本1[201, 202, 203]   # 样本2
    ])  # 形状为 (batch_size=2, sequence_length=3)
    
  • model_kwargs = {}

  • 模型为解码器模型(非编码器-解码器),且未提供 inputs_embeds

步骤1:计算 cache_position

cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
  • input_ids[0, :][101, 102, 103]
  • torch.ones_like(...) 得到 [1, 1, 1]
  • .cumsum(0) 计算累积和,得到 [1, 2, 3]
  • 减去1,得到 [0, 1, 2]

得到:

cache_position = torch.tensor([0, 1, 2])  # 位置编号为 0, 1, 2

步骤2:不存在 past_key_values(缓存),无需调整 cache_position

步骤3:更新 model_kwargs 并返回

model_kwargs["cache_position"] = cache_position
  • model_kwargs 现在包含:

    model_kwargs = {"cache_position": torch.tensor([0, 1, 2])
    }
    
示例2:存在缓存的情况

假设:

  • input_ids 与之前相同。
  • model_kwargs 包含 past_key_values,其中缓存长度 past_length = 2

假设 past_key_values

缓存中保存了前两个位置的 keyvalue,因此 past_length = 2

步骤1:计算 cache_position

与之前相同,得到:

cache_position = torch.tensor([0, 1, 2])

步骤2:调整 cache_position

  • 由于 past_length = 2,需要从位置编号中移除前两个位置:

    cache_position = cache_position[past_length:]
    
    cache_position = cache_position[2:]  # 取索引为 2 及之后的元素
    
  • 得到:

    cache_position = torch.tensor([2])  # 仅剩下位置编号 2
    

步骤3:更新 model_kwargs 并返回

model_kwargs["cache_position"] = cache_position
  • model_kwargs 现在包含:

    model_kwargs = {"past_key_values": ...,    # 缓存内容(省略)"cache_position": torch.tensor([2])
    }
    

说明:

  • 缓存中已包含了位置 0 和 1 的 keyvalue,因此新的生成只需要处理位置 2 的 token。
  • cache_position[0, 1, 2] 变为 [2]
cache_position 的作用
  • 在 Transformer 模型中,位置编码对于捕获序列顺序信息非常重要
  • 使用缓存时,我们需要知道新生成的 token 应该对应哪些位置,以确保模型正确地生成下一个 token。
  • 如果不调整 cache_position,模型可能会错误地将新 token 视为之前的位置,从而导致生成错误。
总结
  • 函数 _get_initial_cache_position 的作用是根据 input_idspast_key_values,计算生成过程中新 token 的 位置编号 cache_position
  • 当没有缓存时,cache_position 从 0 开始,依次递增。
  • 当存在缓存时,需要考虑缓存的长度,调整 cache_position,以对应新添加的 token。
  • 通过上述示例可以看到,cache_position 根据是否存在缓存,以及缓存的长度,发生了相应的变化。

prepare_inputs_for_generation

函数概述

prepare_inputs_for_generation 函数的主要作用是为生成过程准备模型的输入。它根据当前的输入和模型的配置,处理和调整 input_idspast_key_values(缓存)、attention_mask 以及其他相关输入,以确保模型在生成过程中正确运作。

这个函数涉及到了以下几个关键步骤:

  1. 处理缓存位置 cache_position:根据是否存在缓存以及是否支持 Cache 类,计算或调整 cache_position
  2. 根据缓存调整输入序列:如果存在缓存,只需要处理未缓存的部分,因此对 input_idsinputs_embeds 等进行切片。
  3. 准备模型的基础输入:根据模型类型(解码器或编码器-解码器),设置正确的输入键,如 input_idsdecoder_input_ids
  4. 创建缺失的 position_ids:如果需要,为输入序列生成位置编码 position_ids
  5. 调整其他与输入长度相关的输入:如 token_type_ids 等,确保它们与输入序列的长度一致。
  6. 创建 4D 注意力掩码(可选):如果使用了 StaticCache,为提高性能,预先创建固定形状的因果掩码。
  7. 传递其他未初始化的关键字参数:如 use_cache 等。
  8. 移除生成过程中不需要的输入:如 labels

逐步详解

1. 初始化和处理 cache_position
model_inputs = {}
if self._supports_cache_class:model_inputs["cache_position"] = cache_position
elif cache_position is None:past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
  • model_inputs:用于存储准备好的模型输入。

  • 处理 cache_position

    • 如果模型支持 Cacheself._supports_cache_classTrue),则将 cache_position 添加到 model_inputs 中。
    • 如果模型不支持 Cache 类且 cache_positionNone,则根据 past_key_values 来计算 cache_position
      • past_length:缓存的长度,即 past_key_values 中缓存的序列长度。
      • cache_position:从 past_length 开始,到当前 input_ids 的序列长度,创建一个递增的序列,用于表示未处理的 token 的位置。
2. 根据缓存调整输入序列
if past_key_values is not None:model_inputs["past_key_values"] = past_key_valuesif inputs_embeds is not None and input_ids.shape[1] == 0:  # Exception 4inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]elif (inputs_embeds is not None  # Exception 1or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1])  # Exception 3):input_ids = input_ids[:, -cache_position.shape[0] :]elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)input_ids = input_ids[:, cache_position]
  • 目的:当存在缓存时,只需要处理未缓存的部分,所以需要根据 cache_positioninput_idsinputs_embeds 进行切片。

  • 异常情况

    • 异常1:当传入 inputs_embeds 时,input_ids 可能缺少元素,需要处理 inputs_embeds
    • 异常2:某些生成方法对 input_ids 进行了特殊的切片,这里不需要再处理。
    • 异常3:在同步 GPU 时,cache_position 可能超出 input_ids 的范围,此时需要特殊处理。
    • 异常4:如果传入了 inputs_embeds,并且 input_ids 长度为 0,则需要对 inputs_embeds 进行切片。
3. 准备模型的基础输入
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
if not self.config.is_encoder_decoder:if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:model_inputs[input_ids_key] = Nonemodel_inputs["inputs_embeds"] = inputs_embedselse:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)model_inputs["inputs_embeds"] = None
else:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
  • 根据模型类型,确定输入的键:

    • 解码器模型input_ids_key"input_ids"
    • 编码器-解码器模型input_ids_key"decoder_input_ids"
  • 准备输入

    • 如果传入了 inputs_embeds 且长度匹配
      • 仅在初始生成步骤使用 inputs_embeds
      • inputs_embeds 添加到 model_inputs,并将对应的 input_ids 设为 None
    • 否则
      • input_ids 克隆后添加到 model_inputs
      • 设置 inputs_embedsNone
4. 创建缺失的 position_ids
encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
attention_mask = (kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
)
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
if (attention_mask is not Noneand kwargs.get(position_ids_key) is Noneand position_ids_key in set(inspect.signature(self.forward).parameters.keys())
):position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)kwargs[position_ids_key] = position_ids  # placed in kwargs for further processing (see below)
  • 创建 position_ids
    • 如果没有提供 position_ids,则根据 attention_mask 计算。
    • 计算方式:对 attention_mask 进行累加 (cumsum),得到位置索引。
    • 对于被填充的位置(attention_mask == 0),将 position_ids 设置为 1
5. 调整其他与输入长度相关的输入
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:model_input = kwargs.get(model_input_name)if model_input is not None:if past_key_values is not None:current_input_length = (model_inputs["inputs_embeds"].shape[1]if model_inputs.get("inputs_embeds") is not Noneelse model_inputs[input_ids_key].shape[1])model_input = model_input[:, -current_input_length:]model_input = model_input.clone(memory_format=torch.contiguous_format)model_inputs[model_input_name] = model_input
  • 目的:确保其他需要与 input_ids 长度相同的输入(如 position_idstoken_type_ids)的长度匹配。

  • 处理方式

    • 如果存在缓存,且对应的输入存在,则对其进行切片,仅保留当前需要处理的部分。
6. 创建 4D 注意力掩码(可选)
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:if model_inputs["inputs_embeds"] is not None:batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shapedevice = model_inputs["inputs_embeds"].deviceelse:batch_size, sequence_length = model_inputs[input_ids_key].shapedevice = model_inputs[input_ids_key].device# 获取创建 4D 掩码的函数base_model = getattr(self, self.base_model_prefix, None)if base_model is None:causal_mask_creation_function = getattr(self, "_prepare_4d_causal_attention_mask_with_cache_position", None)else:causal_mask_creation_function = getattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None)# 如果存在该函数,则创建 4D 掩码if causal_mask_creation_function is not None:attention_mask = causal_mask_creation_function(attention_mask,sequence_length=sequence_length,target_length=past_key_values.get_max_cache_shape(),dtype=self.dtype,device=device,cache_position=cache_position,batch_size=batch_size,config=self.config,past_key_values=past_key_values,)
  • 目的:当使用了 StaticCache 时,为了提高性能,可以预先创建固定形状的 4D 注意力掩码。

  • 处理方式

    • 获取创建因果掩码的函数 _prepare_4d_causal_attention_mask_with_cache_position
    • 如果存在,则调用该函数创建 4D 注意力掩码。
7. 传递其他未初始化的关键字参数
for key, value in kwargs.items():if key not in model_inputs:model_inputs[key] = value
  • 目的:将所有未处理的关键字参数(如 use_cache)传递给模型输入。
8. 移除生成过程中不需要的输入
model_inputs.pop("labels", None)
  • 目的:在生成过程中,不需要 labels,因此将其从 model_inputs 中移除。
返回处理好的模型输入
return model_inputs

示例

场景设定
  • 模型类型:解码器模型(非编码器-解码器)。

  • 批次大小(batch_size):2

  • 初始输入 input_ids

    input_ids = torch.tensor([[101, 102, 103, 104, 105],  # 样本 1[201, 202, 203, 204, 205]   # 样本 2
    ])  # 形状为 (2, 5)
    
  • 注意力掩码 attention_mask

    attention_mask = torch.tensor([[1, 1, 1, 1, 1],  # 样本 1[1, 1, 1, 1, 1]   # 样本 2
    ])  # 形状为 (2, 5)
    
  • 缓存 past_key_values:由 DynamicCache() 创建的空 Cache 对象,无已缓存的序列。

  • 模型支持 Cache_supports_cache_class = True

  • 未传入 inputs_embeds

第一次调用 prepare_inputs_for_generation
调用参数
model_inputs = model.prepare_inputs_for_generation(input_ids=input_ids,past_key_values=past_key_values,attention_mask=attention_mask,cache_position=torch.arange(0, input_ids.shape[1], dtype=torch.long).to(input_ids.device),**model_kwargs
)
  • cache_position:由于模型支持 Cache 类,我们需要提供 cache_position。这里,我们直接使用 torch.arange 从 0 到 input_ids 的序列长度生成位置索引:

    cache_position = torch.arange(0, input_ids.shape[1], dtype=torch.long).to(input_ids.device)
    # 对于本例,cache_position = [0, 1, 2, 3, 4]
    
函数执行步骤
  1. 初始化 model_inputs 并处理 cache_position

    model_inputs = {}
    if self._supports_cache_class:model_inputs["cache_position"] = cache_position
    
    • cache_position 添加到 model_inputs
  2. 处理缓存相关的输入

    if past_key_values is not None:model_inputs["past_key_values"] = past_key_values# 由于 inputs_embeds 为 None,且不存在特殊情况,直接进入默认情况if input_ids.shape[1] != cache_position.shape[0]:input_ids = input_ids[:, cache_position]
    
    • 检查 input_ids 的长度是否与 cache_position 的长度一致。

      • 由于 input_ids.shape[1] = 5cache_position.shape[0] = 5,长度一致,无需切片。
    • past_key_values 添加到 model_inputs

  3. 准备基础模型输入

    input_ids_key = "input_ids"  # 因为是解码器模型if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:model_inputs[input_ids_key] = Nonemodel_inputs["inputs_embeds"] = inputs_embeds
    else:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)model_inputs["inputs_embeds"] = None
    
    • input_ids 克隆后添加到 model_inputs,并确保内存连续。
    • 设置 inputs_embedsNone
  4. 创建缺失的 position_ids

    encoder_attention_mask = None  # 因为是解码器模型
    attention_mask_key = "attention_mask"
    position_ids_key = "position_ids"if (attention_mask is not Noneand kwargs.get(position_ids_key) is Noneand position_ids_key in inspect.signature(self.forward).parameters
    ):position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)kwargs[position_ids_key] = position_ids
    
    • 计算 position_ids

      position_ids = torch.tensor([[0, 1, 2, 3, 4],  # 样本 1[0, 1, 2, 3, 4]   # 样本 2
      ], dtype=torch.long)
      
    • position_ids 添加到 kwargs

  5. 调整与输入长度相关的其他输入

    for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:model_input = kwargs.get(model_input_name)if model_input is not None:if past_key_values is not None:current_input_length = model_inputs[input_ids_key].shape[1]  # 5model_input = model_input[:, -current_input_length:]model_input = model_input.clone(memory_format=torch.contiguous_format)model_inputs[model_input_name] = model_input
    
    • 对于 position_ids,长度与 input_ids 一致,无需截断。
  6. 添加 attention_mask

    if attention_mask is not None:model_inputs[attention_mask_key] = attention_mask
    
    • attention_mask 添加到 model_inputs
  7. 传递未初始化的其他关键字参数

    • 如果有其他参数(如 use_cache),也会被添加到 model_inputs
  8. 移除 labels

    model_inputs.pop("labels", None)
    
    • 移除不需要的参数。
第一次调用后的 model_inputs
model_inputs = {"cache_position": torch.tensor([0, 1, 2, 3, 4], dtype=torch.long),"past_key_values": past_key_values,  # 空的 DynamicCache()"input_ids": input_ids.clone(memory_format=torch.contiguous_format),"inputs_embeds": None,"position_ids": torch.tensor([[0, 1, 2, 3, 4],[0, 1, 2, 3, 4]], dtype=torch.long),"attention_mask": attention_mask
}
模型生成新 token 后的第二次调用

假设模型在第一次生成后,生成了一个新 token,past_key_values 被更新。同时,cache_position 需要更新,input_ids 也需要更新。

更新的输入
  • 新的 input_ids(新生成的 token):

    new_tokens = torch.tensor([[106],  # 样本 1[206]   # 样本 2
    ])  # 形状为 (2, 1)
    
  • 更新后的 input_ids:只保留新生成的 token,因为过去的序列已在 past_key_values 中缓存

    input_ids = new_tokens  # 形状为 (2, 1)
    
  • 更新后的 attention_mask:针对新 token

    attention_mask = torch.ones_like(input_ids, dtype=torch.long)  # 形状为 (2, 1)
    
  • 更新的 past_key_values:包含了先前的缓存

  • 更新的 cache_position:增加新的位置索引

    cache_position = torch.tensor([5], dtype=torch.long)
    
第二次调用 prepare_inputs_for_generation
model_inputs = model.prepare_inputs_for_generation(input_ids=input_ids,past_key_values=past_key_values,attention_mask=attention_mask,cache_position=cache_position,**model_kwargs
)
函数执行步骤
  1. 初始化 model_inputs 并处理 cache_position

    model_inputs = {}
    if self._supports_cache_class:model_inputs["cache_position"] = cache_position
    
  2. 处理缓存相关的输入

    if past_key_values is not None:model_inputs["past_key_values"] = past_key_values# 检查输入长度与 cache_position 长度if input_ids.shape[1] != cache_position.shape[0]:input_ids = input_ids[:, -cache_position.shape[0]:]
    
    • input_ids.shape[1] = 1cache_position.shape[0] = 1,长度一致,无需切片。
  3. 准备基础模型输入

    input_ids_key = "input_ids"if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:model_inputs[input_ids_key] = Nonemodel_inputs["inputs_embeds"] = inputs_embeds
    else:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)model_inputs["inputs_embeds"] = None
    
    • input_ids 添加到 model_inputs
  4. 创建缺失的 position_ids

    encoder_attention_mask = None
    attention_mask_key = "attention_mask"
    position_ids_key = "position_ids"if (attention_mask is not Noneand kwargs.get(position_ids_key) is Noneand position_ids_key in inspect.signature(self.forward).parameters
    ):position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)kwargs[position_ids_key] = position_ids
    
    • 计算新的 position_ids

      position_ids = torch.tensor([[0],  # 样本 1[0]
      ], dtype=torch.long)
      
    • 由于这是新生成的 token,在全新的序列中位置索引为 0。

  5. 调整与输入长度相关的其他输入

    for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:model_input = kwargs.get(model_input_name)if model_input is not None:if past_key_values is not None:current_input_length = model_inputs[input_ids_key].shape[1]  # 1model_input = model_input[:, -current_input_length:]model_input = model_input.clone(memory_format=torch.contiguous_format)model_inputs[model_input_name] = model_input
    
    • 对于 position_ids,已是正确的长度。
  6. 添加 attention_mask

    if attention_mask is not None:model_inputs[attention_mask_key] = attention_mask
    
    • attention_mask 添加到 model_inputs
  7. 传递未初始化的其他关键字参数

    • 如果有其他参数,也会被添加到 model_inputs
  8. 移除 labels

    • 移除不需要的参数。
第二次调用后的 model_inputs
model_inputs = {"cache_position": torch.tensor([5], dtype=torch.long),  # 新的位置索引"past_key_values": past_key_values,  # 更新后的缓存"input_ids": input_ids.clone(memory_format=torch.contiguous_format),  # 形状为 (2, 1)"inputs_embeds": None,"position_ids": torch.tensor([[0],  # 样本 1[0]], dtype=torch.long),"attention_mask": attention_mask
}
总结

在第二次调用 prepare_inputs_for_generation 时,函数根据新的 cache_positionpast_key_values,对 input_idsposition_idsattention_mask 等进行了相应的处理,以确保模型只处理新生成的 token,并正确地应用位置编码和注意力掩码。

cache_position 的变化
  • 第一次调用

    • cache_position[0, 1, 2, 3, 4],对应初始输入序列的位置。
    • input_ids 使用完整的输入序列。
  • 第二次调用

    • cache_position 更新为 [5],表示新生成的 token 在整个序列中的位置。
    • input_ids 只包含新生成的 token。

_update_model_kwargs_for_generation

函数概述

def _update_model_kwargs_for_generation(self,outputs: ModelOutput,model_kwargs: Dict[str, Any],is_encoder_decoder: bool = False,num_new_tokens: int = 1,
) -> Dict[str, Any]:# 函数内容

这个函数的主要作用是在生成过程中,更新模型的关键字参数 model_kwargs,以便在下一步生成时使用。更新的内容包括:

  • 缓存 past_key_values:用于加速后续生成。
  • 注意力掩码 attention_maskdecoder_attention_mask:告诉模型在生成时应关注的序列长度。
  • 缓存位置 cache_position:用于跟踪当前缓存中的位置,特别是对于需要位置编码的模型。

这个函数通常在生成循环中每次生成一个新 token 之后被调用,以确保模型在下一次调用时具有正确的上下文和输入参数。

代码详解

1. 更新 past_key_values
for possible_cache_name in ALL_CACHE_NAMES:if possible_cache_name in outputs:# TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecatedif possible_cache_name in ("past_buckets_states", "mems"):cache_name = "past_key_values"else:cache_name = possible_cache_namemodel_kwargs[cache_name] = getattr(outputs, possible_cache_name)break
  • 作用:从模型的输出 outputs 中提取缓存 past_key_values,并更新 model_kwargs 中对应的键。

  • 解释

    • ALL_CACHE_NAMES 是一个包含可能的缓存名称的列表,例如 ["past_key_values", "mems", "past_buckets_states"],因为不同的模型可能使用不同的缓存名称。
    • 该循环检查 outputs 中是否存在这些缓存名称。
    • 如果找到:
      • 如果缓存名称是 "past_buckets_states""mems"(针对旧模型如 XLNet、Reformer),将其统一为 "past_key_values",以保持一致性。
      • 将缓存添加到 model_kwargs,以便在下一步生成时使用。
2. 更新 token_type_ids
if "token_type_ids" in model_kwargs:token_type_ids = model_kwargs["token_type_ids"]model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
  • 作用:如果模型使用了 token_type_ids(用于区分不同句子或段落的类型 ID,例如在 BERT 中),则需要在生成过程中更新它。

  • 解释

    • 获取当前的 token_type_ids
    • 将最后一个 token_type_id 复制一份,添加到序列末尾。这样,新生成的 token 将具有与前一个 token 相同的类型 ID。
3. 更新注意力掩码
  • 对于非编码器-解码器模型

    if not is_encoder_decoder:if "attention_mask" in model_kwargs:attention_mask = model_kwargs["attention_mask"]model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
    
  • 对于编码器-解码器模型

    else:if "decoder_attention_mask" in model_kwargs:decoder_attention_mask = model_kwargs["decoder_attention_mask"]model_kwargs["decoder_attention_mask"] = torch.cat([decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],dim=-1,)
    
  • 作用:在生成新 token 后,需要更新注意力掩码,让模型知道现在的序列长度增加了,应该在计算注意力时关注到新生成的 token。

  • 解释

    • 非编码器-解码器模型
      • 获取当前的 attention_mask,形状为 (batch_size, sequence_length)
      • 为每个样本添加一个值为 1 的新列,表示新生成的 token 是有效的(不是填充或被掩码的)。
      • 更新 model_kwargs["attention_mask"]
    • 编码器-解码器模型
      • 类似地,更新 decoder_attention_mask
4. 更新 cache_position
if model_kwargs.get("use_cache", True):model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
else:past_positions = model_kwargs.pop("cache_position")new_positions = torch.arange(past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype).to(past_positions.device)model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
  • 作用:更新缓存位置 cache_position,用于跟踪当前缓存中的位置,在生成过程中正确地处理位置编码。

  • 解释

    • use_cache:从 model_kwargs 中获取,默认为 True,表示模型在生成过程中使用缓存。
    • 如果 use_cacheTrue
      • 取当前的 cache_position 的最后一个值,增加 num_new_tokens(默认为 1)。
      • 这样,新的 cache_position 将是 [last_position + num_new_tokens]
    • 如果 use_cacheFalse
      • 弹出当前的 cache_position
      • 生成一个新的位置序列 new_positions,从 past_positions 的最后一个值加 1 开始,长度为 num_new_tokens
      • past_positionsnew_positions 拼接,更新 cache_position
5. 返回更新后的 model_kwargs
return model_kwargs
  • 返回更新后的 model_kwargs,以供下一步生成使用。

示例说明

假设我们在文本生成过程中,使用一个非编码器-解码器模型,批次大小为 2

初始状态
  • 模型输出 outputs:包含 past_key_values,这是模型生成过程中用于缓存的注意力 key 和 value。

  • model_kwargs

    model_kwargs = {"attention_mask": torch.tensor([[1, 1, 1, 1, 1],  # 样本 1 的注意力掩码[1, 1, 1, 1, 1]   # 样本 2 的注意力掩码]),  # 形状为 (2, 5)"cache_position": torch.tensor([0, 1, 2, 3, 4], dtype=torch.long),  # 初始位置# 可能还有其他参数
    }
    
  • is_encoder_decoderFalse,表示这是一个解码器-only(GPT 类)模型。

  • num_new_tokens1,每次生成一个新 token。

第一次调用
model_kwargs = _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)

步骤解析

  1. 更新 past_key_values

    • outputs 中提取 past_key_values,更新 model_kwargs
  2. 更新 attention_mask

    • 原始 attention_mask

      tensor([[1, 1, 1, 1, 1],  # 样本 1[1, 1, 1, 1, 1]   # 样本 2
      ])  # 形状为 (2, 5)
      
    • 新增一个长度为 1 的列,值为 1,表示新生成的 token:

      new_column = torch.ones((2, 1), dtype=torch.long)
      model_kwargs["attention_mask"] = torch.cat([attention_mask, new_column], dim=-1)
      # 更新后的 attention_mask:
      # tensor([
      #     [1, 1, 1, 1, 1, 1],  # 形状为 (2, 6)
      #     [1, 1, 1, 1, 1, 1]
      # ])
      
  3. 更新 cache_position

    • 原始 cache_position

      cache_position = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
      
    • 更新:

      model_kwargs["cache_position"] = cache_position[-1:] + num_new_tokens
      # 取最后一个值 4,加 1,得到 5
      # 更新后的 cache_position:
      # tensor([5])
      
  4. 返回更新后的 model_kwargs

更新后的 model_kwargs

model_kwargs = {"attention_mask": tensor([[1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1]]),  # 形状为 (2, 6)"cache_position": tensor([5], dtype=torch.long),"past_key_values": outputs.past_key_values,# 其他参数保持不变
}
第二次调用

假设模型又生成了一个新 token,再次调用该函数:

model_kwargs = _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)

步骤解析

  1. 更新 past_key_values

    • 更新 past_key_values
  2. 更新 attention_mask

    • 原始 attention_mask

      tensor([[1, 1, 1, 1, 1, 1],  # 形状为 (2, 6)[1, 1, 1, 1, 1, 1]
      ])
      
    • 新增一个长度为 1 的列:

      new_column = torch.ones((2, 1), dtype=torch.long)
      model_kwargs["attention_mask"] = torch.cat([attention_mask, new_column], dim=-1)
      # 更新后的 attention_mask:
      # tensor([
      #     [1, 1, 1, 1, 1, 1, 1],  # 形状为 (2, 7)
      #     [1, 1, 1, 1, 1, 1, 1]
      # ])
      
  3. 更新 cache_position

    • 原始 cache_position

      cache_position = torch.tensor([5], dtype=torch.long)
      
    • 更新:

      model_kwargs["cache_position"] = cache_position[-1:] + num_new_tokens
      # 取最后一个值 5,加 1,得到 6
      # 更新后的 cache_position:
      # tensor([6])
      
  4. 返回更新后的 model_kwargs

更新后的 model_kwargs

model_kwargs = {"attention_mask": tensor([[1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1]]),  # 形状为 (2, 7)"cache_position": tensor([6], dtype=torch.long),"past_key_values": outputs.past_key_values,# 其他参数保持不变
}
总结

通过上述两次调用,我们可以看到:

  • attention_mask 在每次生成新 token 后,都在序列末尾添加了一个值为 1 的新列,表示序列长度增加,模型应关注新的 token。
  • cache_position 跟踪当前的位置,每次生成一个新 token,位置索引增加 1

结论

该函数在生成循环中起到关键作用,确保模型在每个生成步骤都具有正确的缓存和输入参数,从而生成连贯的序列。通过更新 past_key_valuesattention_maskcache_position 等关键参数,模型能够高效地利用先前的计算,加速生成过程。

理解这个函数有助于深入掌握文本生成模型的工作原理,尤其是在实现自回归生成和缓存机制时。

_has_unfinished_sequences

详细解释:

该函数 _has_unfinished_sequences 的主要作用是在生成序列的过程中,判断是否需要继续生成,即是否还有未完成的序列。这在循环生成模型中非常重要,可以避免不必要的计算,提高效率。

参数说明:

  • this_peer_finishedbool):当前进程(或节点)是否已经完成了生成序列。
  • synced_gpusbool):是否在使用多 GPU 同步模式。如果为 True,表示多个 GPU 需要同步生成过程。
  • devicetorch.device):当前计算所在的设备(CPU 或 GPU)。
  • cur_lenOptional[int]):当前序列的长度。
  • max_lengthOptional[int]):生成序列的最大长度。

返回值:

  • bool:返回 True 表示仍有未完成的序列,需要继续生成;返回 False 表示所有序列都已完成,可以停止生成。

代码逻辑详解:

  1. 处理 torch.compile 的特殊情况:

    if is_torchdynamo_compiling():return cur_len < max_length
    
    • 背景

      • torch.compile(或 torchdynamo)目前不支持数据依赖的控制流,也就是说,无法根据运行时的数据(如模型的预测结果)来决定控制流(如循环、条件判断)。
    • 逻辑

      • 如果正在使用 torch.compile 编译,则无法在生成过程中根据序列是否完成来提前停止,只能按照固定的最大长度 max_length 来控制循环。
      • 因此,这里直接判断当前长度 cur_len 是否小于最大长度 max_length,如果小于则返回 True,表示需要继续生成。
    • 影响

      • 这种方式会失去根据序列生成的内容(如是否生成了 EOS 标记)来提前终止循环的能力。
  2. 正常情况下的处理逻辑:

    else:if synced_gpus:# 同步 GPU 的处理逻辑...elif this_peer_finished:return Falsereturn True
    
    • 说明

      • 如果未使用 torch.compile,则进入正常的处理逻辑。
  3. 处理同步 GPU 的情况:

    if synced_gpus:# 在 synced_gpus 模式下,`forward` 调用必须继续,直到所有 GPU 完成它们的序列。# 以下逻辑允许在所有节点完成生成序列时提前结束this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)# 如果当前节点已完成,发送 0.0,否则发送 1.0dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)# 所有节点都完成了吗?如果归约的和为 0.0,则表示所有节点都已完成if this_peer_finished_flag.item() == 0.0:return False
    
    • 背景

      • 在使用多 GPU 同步模式时,各个 GPU 需要同步进行生成过程,不能有的节点提前退出,否则会导致同步问题和潜在的死锁。
    • 逻辑

      • Step 1:创建一个标志变量 this_peer_finished_flag,表示当前节点是否完成。

        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
        
        • 如果当前节点已完成,则值为 0.0,未完成则为 1.0
        • 放在 device 上以确保在正确的设备上进行计算。
      • Step 2:使用分布式的 all_reduce 操作,将所有节点的标志变量求和。

        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        
        • all_reduce 会收集所有参与进程的 this_peer_finished_flag,并按照指定的操作(这里是求和 SUM)进行归约,然后将结果广播给所有节点。
        • 结果是所有节点的 this_peer_finished_flag 之和。
      • Step 3:判断是否所有节点都已完成。

        if this_peer_finished_flag.item() == 0.0:return False
        
        • 如果归约结果为 0.0,说明所有节点的 this_peer_finished_flag 都是 0.0,即所有节点都已完成生成。
        • 返回 False,表示不需要继续生成。
      • Step 4:否则,返回 True,表示仍有未完成的序列,需要继续生成。

    • 注意

      • 这样可以确保在多 GPU 同步模式下,所有节点同步地决定是否继续生成,防止有的节点提前退出导致同步问题。
  4. 处理未使用同步 GPU 且当前节点已完成的情况:

    elif this_peer_finished:return False
    
    • 如果未使用同步 GPU(synced_gpus == False),并且当前节点已完成生成,则可以直接返回 False,不需要继续。
  5. 默认返回 True

    return True
    
    • 如果上述条件都不满足,表示仍有未完成的序列,需要继续生成,返回 True


http://www.mrgr.cn/news/97008.html

相关文章:

  • [C语言入门] 结构体
  • 接口自动化学习二:session自动管理cookie
  • Maven+Spring实现后端开发
  • Pytorch中预置数据集的加载方式
  • 数据结构学习
  • 大模型学习二:DeepSeek R1+蒸馏模型组本地部署与调用
  • 【MyBatis】深入解析 MyBatis:关于注解和 XML 的 MyBatis 开发方案下字段名不一致的的查询映射解决方案
  • 操作系统(二):实时系统介绍与实例分析
  • python match case语法
  • Vue3 Pinia Store 新建store示例、使用store示例
  • 【大模型】SpringBoot整合LangChain4j实现RAG检索实战详解
  • ros2--gazebo--launch
  • 【gdutthesis模板】章节标题有英文解决方案
  • Tmux 核心操作速查指南
  • 【c++深入系列】:类与对象详解(中)
  • STM32单片机入门学习——第12节: [5-2]对射式红外传感器计次旋转编码器计次
  • 基于yolo11的BGA图像目标检测
  • 动、静态创建任务
  • MySQL - 事务隔离级别和锁的机制
  • WPF设计学习记录滴滴滴4