GenerationMixin:_sample方法(GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH)
_sample
_sample
是一个用于文本生成的函数。该函数使用 多项式采样(multinomial sampling) 来生成序列,可用于文本解码器、文本到文本、语音到文本以及视觉到文本的模型。
函数定义
def _sample(self,input_ids: torch.LongTensor,logits_processor: LogitsProcessorList,stopping_criteria: StoppingCriteriaList,generation_config: GenerationConfig,synced_gpus: bool,streamer: Optional["BaseStreamer"],**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:# 函数主体
参数详解
-
input_ids (
torch.LongTensor
):- 形状为
(batch_size, sequence_length)
的输入张量。 - 用作生成的提示序列。
- 形状为
-
logits_processor (
LogitsProcessorList
):LogitsProcessorList
的实例。- 包含了多个
LogitsProcessor
,用于在每个生成步骤中修改模型的输出 logits,以实现诸如温度缩放、重复惩罚等功能。
-
stopping_criteria (
StoppingCriteriaList
):StoppingCriteriaList
的实例。- 包含了多个
StoppingCriteria
,用于判断生成过程是否应该停止,例如达到最大长度或生成了结束符号等。
-
generation_config (
GenerationConfig
):- 生成配置,包含了生成过程中的各种参数,如最大长度、是否进行采样等。
-
synced_gpus (
bool
):- 是否在多 GPU 环境下同步生成循环,避免在某些 GPU 生成完成后出现死锁。
-
streamer (
BaseStreamer
, 可选):- 流处理器对象,用于在生成过程中实时处理生成的序列。
-
model_kwargs:
- 其他模型特定的关键字参数,将传递给模型的
forward
函数。 - 如果模型是编码器-解码器模型,
model_kwargs
应该包括encoder_outputs
。
- 其他模型特定的关键字参数,将传递给模型的
返回值
- GenerateDecoderOnlyOutput、GenerateEncoderDecoderOutput 或
torch.LongTensor
:- 根据配置,返回包含生成的序列和其他信息的对象,或者仅返回生成的 token 序列。
函数流程详解
1. 初始化变量
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
pad_token_id
:用于填充生成完成后的序列。- 输出控制参数:
output_attentions
:是否返回注意力权重。output_hidden_states
:是否返回隐藏状态。output_scores
:是否返回处理后的得分。output_logits
:是否返回原始 logits。
return_dict_in_generate
:是否在生成中返回字典形式的结果。max_length
:生成的最大长度。has_eos_stopping_criteria
:是否存在基于结束符的停止条件。do_sample
:是否进行采样,True
表示采样,False
表示贪心搜索。
2. 初始化保存生成过程中的信息的元组
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
- 根据配置,初始化用于保存生成过程中各项信息的元组。
3. 处理编码器-解码器模型的编码器输出
if return_dict_in_generate and self.config.is_encoder_decoder:encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else Noneencoder_hidden_states = (model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None)
- 如果模型是编码器-解码器模型,并且需要返回注意力权重或隐藏状态,则从
encoder_outputs
中获取。
4. 初始化生成过程的变量
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
batch_size
:批次大小。cur_len
:当前序列长度。this_peer_finished
:在多 GPU 环境下,当前设备是否已完成生成。unfinished_sequences
:用于跟踪哪些序列尚未生成完成。
5. 获取初始缓存位置
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
- 准备缓存机制,用于加速生成过程。
- _get_initial_cache_position
6. 准备模型的前向调用函数
model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache):is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cacheis_compileable = is_compileable and not self.generation_config.disable_compileif is_compileable and (self.device.type == "cuda" or generation_config.compile_config._compile_all_devices):os.environ["TOKENIZERS_PARALLELISM"] = "0"model_forward = self.get_compiled_call(generation_config.compile_config)
- 如果缓存可编译且设备支持,则使用编译后的模型前向函数,以提高性能。
7. 进入生成循环
is_prefill = True
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):# 循环体
- 检查是否存在未完成的序列且未达到最大长度,进入生成循环。
8. 准备模型输入
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
- 调用
prepare_inputs_for_generation
准备模型输入。 - 根据配置,添加控制输出的参数。
- prepare_inputs_for_generation
9. 执行模型前向传播
if is_prefill:outputs = self(**model_inputs, return_dict=True)is_prefill = False
else:outputs = model_forward(**model_inputs, return_dict=True)
- 第一次迭代使用默认的模型调用,以正确初始化缓存。
- 后续迭代可能使用编译后的模型前向函数。
10. 更新模型参数
model_kwargs = self._update_model_kwargs_for_generation(outputs,model_kwargs,is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:continue
- 更新
model_kwargs
,包含新的缓存past_key_values
等。 - 如果在多 GPU 环境中当前设备已完成生成,则跳过后续步骤。
- _update_model_kwargs_for_generation
11. 处理模型输出的 logits
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
这行代码的目的是从 outputs.logits
中提取最后一个时间步(即最后一列)的所有令牌的预测得分,并进行复制和转换为浮点类型,存储在 next_token_logits
变量中。
代码解释
-
outputs.logits[:, -1, :]
:outputs.logits
是一个三维张量,通常代表模型输出的预测得分(logits),形状一般为[batch_size, sequence_length, vocab_size]
。[:, -1, :]
是一个切片操作,表示选择所有批次的最后一个时间步的所有令牌得分。具体来说::
切片表示选择所有批次(第一维度)。-1
表示选择最后一个时间步(第二维度)。:
切片表示选择所有令牌的得分(第三维度)。
-
.clone()
:- 创建一个新张量,它是
outputs.logits[:, -1, :]
的副本。 - 使用
clone()
是为了避免原始张量的意外改变影响到next_token_logits
。
- 创建一个新张量,它是
-
.float()
:- 将克隆后的张量转换为浮点类型,以确保后续计算的精度。
举例说明
假设 outputs.logits
是如下的一个三维张量,形状为 [2, 4, 5]
(批次大小为2,序列长度为4,词汇大小为5):
outputs.logits = [[[0.1, 0.2, 0.3, 0.4, 0.5],[0.6, 0.7, 0.8, 0.9, 1.0],[1.1, 1.2, 1.3, 1.4, 1.5],[1.6, 1.7, 1.8, 1.9, 2.0],],[[0.2, 0.3, 0.4, 0.5, 0.6],[0.7, 0.8, 0.9, 1.0, 1.1],[1.2, 1.3, 1.4, 1.5, 1.6],[1.7, 1.8, 1.9, 2.0, 2.1],],
]
根据代码,我们执行以下步骤:
- 提取最后一个时间步的所有令牌得分:
next_token_logits = [[1.6, 1.7, 1.8, 1.9, 2.0],[1.7, 1.8, 1.9, 2.0, 2.1] ]
描述到切片操作后,我们从 outputs.logits
张量中获取每个批次的最后一个时间步(第四个时间步)的令牌得分。
- 克隆并转换为浮点类型:
clone()
操作创建新张量,且.float()
将张量元素转为浮点型(假设原始元素非浮点类型)。
最终 next_token_logits
将包含如下张量:
next_token_logits (float):
[[1.6, 1.7, 1.8, 1.9, 2.0],[1.7, 1.8, 1.9, 2.0, 2.1]
]
这用于下一步的处理,例如为模型生成下一个令牌选择提供概率分布(logits)。
12. 对 logits 进行预处理
next_token_scores = logits_processor(input_ids, next_token_logits)
这行代码的目的是通过 logits_processor
对 next_token_logits
进行处理,以生成 next_token_scores
。 logits_processor
是一个函数或包含一系列函数的对象,它对模型输出的 logits 进行某种自定义变换(如归一化、过滤或修改等)。我们来详细解释,并举例说明。
代码解释
-
logits_processor
:logits_processor
是一个函数或 callable 对象,用于对模型输出的 logits 进行处理。处理逻辑可以是许多策略中的一种,比如:- 降低或惩罚某些特定令牌的分数。
- 应用特定的归一化或过滤。
- 结合一些上下文信息对 logits 进行调整。
- 它接受两个参数:
input_ids
和next_token_logits
。
-
input_ids
:input_ids
是模型当前接收到的输入序列,通常用于提供上下文或约束以在logits_processor
中进行处理。
-
next_token_logits
:- 这些是从模型输出中提取并准备处理的原始 logit 值。
-
next_token_scores
:logits_processor
返回处理后的结果,通常表示对每个可能的下一个令牌的更新得分。这个输出用于影响对下一个令牌的选择过程。
举例说明
假设有如下的 input_ids
和 next_token_logits
:
input_ids
:表示当前序列,例如[101, 1045, 2572, 1037, 102]
,其中每个数字都是代表某个词的唯一标识。next_token_logits
:例如为[2.0, 1.5, 3.0, 0.5, -1.0]
,其中每个值是对应给定词汇表中一个可能的下一个词的预测得分(logit)。
假设 logits_processor
的简单处理逻辑是将所有负数值设置为 -inf
,来避免选择这些令牌,如此进行如下调整:
- 原始
next_token_logits
为[2.0, 1.5, 3.0, 0.5, -1.0]
- 处理后的
next_token_scores
为[2.0, 1.5, 3.0, 0.5, -inf]
在这个例子中,logits_processor
的机制是简单的限制,导致原来的负分数被抑制,使得生成更可能选择正得分的令牌。
处理后的 next_token_scores
可以用于进一步的决策,例如通过 softmax 转换为概率分布,以便从词汇表中选出下一个最有可能的令牌。
13. 根据需要保存生成过程中的信息
if return_dict_in_generate:if output_scores:scores += (next_token_scores,)if output_logits:raw_logits += (next_token_logits,)if output_attentions:decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,))if self.config.is_encoder_decoder:cross_attentions += (outputs.cross_attentions,)if output_hidden_states:decoder_hidden_states += ((outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,))
- 根据配置,保存处理过的得分、原始 logits、注意力权重和隐藏状态。
14. 选择下一个 token
if do_sample:probs = nn.functional.softmax(next_token_scores, dim=-1)next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:next_tokens = torch.argmax(next_token_scores, dim=-1)
这段代码实现了从给定的 next_token_scores
中选择下一个令牌(token)。根据标志 do_sample
的值不同,它提供两种选择策略:采样和贪心选择。让我们逐步解释这些代码,并举例说明。
代码解释
-
do_sample
:- 这是一个布尔值标志,用于决定是否进行采样选择。
- 如果
do_sample
为True
,则使用概率采样来选择令牌。 - 如果
do_sample
为False
,则使用贪心策略,选择具有最高得分的令牌。
-
probs = nn.functional.softmax(next_token_scores, dim=-1)
:- 当
do_sample
为True
时,这行代码将next_token_scores
转换为概率分布。 softmax
函数确保所有得分转换为非负值且总和为1,使其成为概率分布。
- 当
-
torch.multinomial(probs, num_samples=1).squeeze(1)
:multinomial
用于从概率分布probs
中进行采样,返回一个样本。num_samples=1
说明每次只选择一个令牌。squeeze(1)
用于去掉维度为1的多余维度,使返回结果变为基本的1D张量。
-
torch.argmax(next_token_scores, dim=-1)
:- 当
do_sample
为False
时,选择得分最高的令牌。 argmax
返回具有最大得分的索引,即最有可能的下一个令牌。
- 当
举例说明
假设 next_token_scores
是如下的一维张量,表示模型生成下一个令牌时的logits:
next_token_scores = [2.0, 1.5, 3.0, 0.5, -inf]
-
在
do_sample=True
的情况下:- 应用
softmax
:probs = softmax([2.0, 1.5, 3.0, 0.5, -inf]) = [0.2138, 0.1421, 0.5820, 0.0621, 0.0]
- 用
multinomial
从probs
中采样可能得到next_tokens
:next_tokens = [2] # 假设采样到索引2(3.0的概率最大,最有可能)
- 应用
-
在
do_sample=False
的情况下:- 使用
argmax
:next_tokens = [2] # 因为索引2对应的得分3.0最大
- 使用
在两种策略中,next_tokens
存储的都是选定的下一个令牌的索引,依据其选择策略可以应用于生成或预测序列的新部分。采样策略尤其适用于生成更有多样性的序列,而贪心策略则保证每一步都选择当前最优决策。
15. 处理已完成的序列
if has_eos_stopping_criteria:next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
这段代码处理了在生成文本时,当遇到结束标记(End of Sentence, EOS)后,确保接下来的令牌是填充标记(padding token)。它通过使用布尔掩码 unfinished_sequences
来实现这一点。让我们分步解释一下,并举例说明。
代码解释
-
next_tokens * unfinished_sequences
:next_tokens
是当前选择的下一个令牌。unfinished_sequences
是一个布尔张量,指示哪些序列尚未完成。- 乘法操作会保持未完成的序列的
next_tokens
值。
-
pad_token_id * (1 - unfinished_sequences)
:pad_token_id
是填充标记的ID。(1 - unfinished_sequences)
会将布尔值反转,这样完成的序列会变成1,未完成的序列变成0。- 乘法操作会将已完成的序列的下一令牌设为
pad_token_id
。
-
结合运算:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
:
这条语句最终会确保,对于未完成的序列,保留其原始的next_tokens
;对于已完成的序列,将next_tokens
设置为pad_token_id
。
举例说明
假设有以下变量:
next_tokens
是当前选择的令牌,假设其值为[5, 6, 7]
。unfinished_sequences
是一个布尔张量,表示哪些序列未完成。例如[1, 0, 1]
(1 表示未完成,0 表示完成)。pad_token_id
是填充标记的ID,例如为0
。
示例数据:
next_tokens = [5, 6, 7]
unfinished_sequences = [1, 0, 1]
pad_token_id = 0
计算步骤:
-
部分1:
next_tokens * unfinished_sequences
:[5, 6, 7] * [1, 0, 1]
:
= [5 * 1, 6 * 0, 7 * 1] = [5, 0, 7]
-
部分2:
pad_token_id * (1 - unfinished_sequences)
:pad_token_id = 0
:1 - unfinished_sequences = [0, 1, 0]
:0 * [0, 1, 0]
:
= [0 * 0, 0 * 1, 0 * 0] = [0, 0, 0]
-
结合结果:
[5, 0, 7] + [0, 0, 0]
:
= [5 + 0, 0 + 0, 7 + 0] = [5, 0, 7]
最终得到的 next_tokens
为 [5, 0, 7]
。在这个例子中,未完成的序列(第1和第3项)保留了原来的令牌值,而已完成的序列(第2项)的下一令牌被设置为填充标记 0
。
这个逻辑确保在生成序列过程中,一旦某个序列到达结束标记(EOS),其后生成的所有令牌都会变成填充标记,以防止再度继续生成。
16. 更新输入序列
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
这段代码用于将新生成的令牌 next_tokens
添加到当前的 input_ids
,更新为下一步模型输入。这是文本生成过程中重要的一步,确保每次生成的输出都成为下一次生成的输入。让我们详细解释这段代码,并举例说明。
代码解释
-
next_tokens[:, None]
:next_tokens
是当前生成的令牌的张量。[:, None]
用作添加一个新的维度,达到扩展张量维度的效果,使其从一维变为二维。具体来说:- 如果
next_tokens
是形状(batch_size,)
,那么[:, None]
将其转变为(batch_size, 1)
。
- 如果
- 这样处理让
next_tokens
可以按列而不是按行进行连接。
-
torch.cat([...], dim=-1)
:torch.cat
函数用于沿着指定维度进行张量连接。dim=-1
指定连接操作沿着最后一个维度进行,也就是在这里,可以理解为列连接。- 将
next_tokens
添加到input_ids
的每一批次的最后。
举例说明
假设 input_ids
和 next_tokens
如下:
input_ids
是之前生成的文本的标识,假如形状是[batch_size, sequence_length]
:
input_ids = [[101, 102],[201, 202]]
next_tokens
是本次生成的新令牌:
next_tokens = [103, 203]
操作步骤:
- 应用
[:, None]
:next_tokens[:, None]
使next_tokens
从一维变为二维:
next_tokens = [[103],[203]]
- 连接操作:
torch.cat([input_ids, next_tokens[:, None]], dim=-1)
将next_tokens
的每批次令牌添加到input_ids
的最后:
input_ids = [[101, 102, 103],[201, 202, 203]]
16. 更新状态
if streamer is not None:streamer.put(next_tokens.cpu())
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
这段代码用于更新未完成序列的状态,并通过 this_peer_finished
标识当前批次内的序列是否全部完成。下面就逐步解释这段代码,并提供一个例子进行说明。
代码解释
-
unfinished_sequences
:- 这是一个布尔张量,表示哪些序列仍然未完成。通常其元素为1表示未完成,0表示已完成。
-
stopping_criteria(input_ids, scores)
:stopping_criteria
是一个函数,用于判断是否有序列已经达到停止条件。- 接受当前的
input_ids
和scores
(通常是得分或概率)作为参数。 - 返回一个布尔张量,标识哪些序列达到了停止条件。
-
~stopping_criteria(input_ids, scores)
:~
是按位取反操作,将stopping_criteria
的结果反转。- 想达到停止条件的序列标记为
0
,未达到停止条件的标记为1
。
-
unfinished_sequences & ~stopping_criteria(input_ids, scores)
:- 按位与操作更新未完成的序列数。
- 如果一个序列已经达到停止条件,则对应的
unfinished_sequences
将设置为0。 - 如果序列仍未完成,则保持为1。
-
this_peer_finished = unfinished_sequences.max() == 0
:- 检查
unfinished_sequences
中的所有值是否为0。 - 如果
unfinished_sequences.max()
为0,表示所有序列都完成了。 this_peer_finished
为真时,所有序列在当前批次内都完成了。
- 检查
举例说明
假设有以下情境:
unfinished_sequences
是初始的未完成序列状态:
unfinished_sequences = [1, 1, 1]
stopping_criteria(input_ids, scores)
返回一个布尔张量:
stopping_criteria_result = [0, 1, 0]
这是stopping_criteria
的输出,表示第二个序列达到停止条件。
操作步骤:
- 取反操作:
~stopping_criteria_result = [1, 0, 1]
- 按位与更新:
unfinished_sequences & ~stopping_criteria_result = [1 & 1, 1 & 0, 1 & 1] = [1, 0, 1]
表示第一个和最后一个序列仍未完成,第二个序列已完成。
- 检查是否完成:
unfinished_sequences.max() == 0
返回False
,因为unfinished_sequences
中有非零元素。
17. 清理变量防止内存泄漏
del outputs
- 删除
outputs
,防止在下一次迭代中占用不必要的内存。
18. 结束循环
- 当所有序列都完成生成,或达到最大长度时,退出循环。
19. 结束流处理
if streamer is not None:streamer.end()
- 如果使用了
streamer
,调用其end()
方法,表示生成结束。
20. 返回生成结果
if return_dict_in_generate:if self.config.is_encoder_decoder:return GenerateEncoderDecoderOutput(sequences=input_ids,scores=scores,logits=raw_logits,encoder_attentions=encoder_attentions,encoder_hidden_states=encoder_hidden_states,decoder_attentions=decoder_attentions,cross_attentions=cross_attentions,decoder_hidden_states=decoder_hidden_states,past_key_values=model_kwargs.get("past_key_values"),)else:return GenerateDecoderOnlyOutput(sequences=input_ids,scores=scores,logits=raw_logits,attentions=decoder_attentions,hidden_states=decoder_hidden_states,past_key_values=model_kwargs.get("past_key_values"),)
else:return input_ids
- 根据配置,返回包含生成序列和其他信息的对象,或仅返回生成的序列。
_sample
内部的while
循环:
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):# 循环体
这个循环的执行次数取决于以下因素:
- 当前序列长度 (
cur_len
):生成开始时输入序列的长度。 - 最大生成长度 (
max_length
):生成的序列允许的最大长度。 - 停止条件 (
stopping_criteria
):定义了何时停止生成,例如生成了结束符号(EOS token)或达到最大长度。 - 模型在每一步生成的输出:每个时间步生成的token是否满足停止条件。
因此,while
循环的执行次数并不是固定的,它可能会从一次到多次,具体次数取决于上述因素的组合。
循环为何会执行多次?
生成文本的过程是逐步进行的,在每个时间步,模型基于当前的输入序列预测下一个token,并将其添加到序列中。while
循环控制这个生成过程,直到满足停止条件为止。
主要原因包括:
- 序列长度的限制:如果生成的序列尚未达到
max_length
,并且有序列尚未完成,则循环继续。 - 停止条件的影响:如果定义了停止条件(例如生成了EOS token),一旦满足条件的序列就被标记为完成,但其他未完成的序列需要继续生成。
- 批量生成的同步:在批量生成(batch generation)中,需要确保批次中的所有序列都完成生成,循环会一直执行,直到所有序列都完成或达到最大长度。
循环执行次数的详细解释
1. 循环初始化
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
batch_size
:批次大小,即一次生成的序列数量。cur_len
:当前序列的长度。this_peer_finished
:用于多GPU情况下,指示当前GPU是否完成生成。unfinished_sequences
:一个形状为(batch_size,)
的张量,用于跟踪哪些序列尚未完成。
2. 循环条件
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):# 循环体
self._has_unfinished_sequences()
:这是一个函数,用于判断是否还有未完成的序列需要继续生成。- 循环继续的条件:
- 当前未达到最大长度 (
cur_len < max_length
)。 - 存在未完成的序列 (
unfinished_sequences
中存在值为1的元素)。
- 当前未达到最大长度 (
3. 循环内部
在循环内部,模型会逐步生成新的token,并更新相关的状态。
-
模型前向传播:根据当前的
input_ids
,模型预测下一个token的概率分布(logits)。 -
处理logits:通过
logits_processor
对logits进行处理,例如应用温度、top-k或top-p采样等。 -
选择下一个token:
- 如果是采样 (
do_sample=True
),则根据概率分布进行采样。 - 如果是贪心搜索 (
do_sample=False
),则选择概率最高的token。
- 如果是采样 (
-
更新序列:将新生成的token添加到
input_ids
。 -
检查停止条件:
- 使用
stopping_criteria
检查哪些序列已满足停止条件。 - 更新
unfinished_sequences
,将已完成的序列标记为0。
- 使用
-
更新循环变量:增加
cur_len
的值,继续下一次迭代。
实例说明
假设以下设定:
- 批次大小 (
batch_size
):2 - 初始序列长度 (
cur_len
):5 - 最大长度 (
max_length
):10 - 停止条件:生成了EOS token(假设其ID为
eos_token_id = 2
)
初始状态:
input_ids = [[101, 10, 23, 45, 67], # 序列1[101, 11, 34, 56, 78] # 序列2
]
unfinished_sequences = [1, 1] # 两个序列都未完成
cur_len = 5
第一次循环
- 模型生成下一个token:
- 序列1生成token 89
- 序列2生成token 90
- 更新
input_ids
:
input_ids = [[101, 10, 23, 45, 67, 89],[101, 11, 34, 56, 78, 90]
]
cur_len = 6
- 检查停止条件:
- 没有序列生成EOS token。
unfinished_sequences
保持为[1, 1]
第二次循环
- 模型生成下一个token:
- 序列1生成token 2(EOS token)
- 序列2生成token 91
- 更新
input_ids
:
input_ids = [[101, 10, 23, 45, 67, 89, 2],[101, 11, 34, 56, 78, 90, 91]
]
cur_len = 7
- 检查停止条件:
- 序列1生成了EOS token,将其标记为完成。
- 更新
unfinished_sequences
为[0, 1]
第三次循环
- 模型生成下一个token:
- 序列1已完成,不再生成,填充
pad_token_id
- 序列2生成token 2(EOS token)
- 序列1已完成,不再生成,填充
- 更新
input_ids
:
input_ids = [[101, 10, 23, 45, 67, 89, 2, pad_token_id],[101, 11, 34, 56, 78, 90, 91, 2]
]
cur_len = 8
- 检查停止条件:
- 序列2生成了EOS token,将其标记为完成。
- 更新
unfinished_sequences
为[0, 0]
循环结束
- 所有序列都完成了生成,
unfinished_sequences
全为0。 - 循环条件不再满足,退出循环。
总结循环次数
- 循环次数:3次
- 最终序列长度:8(因为
cur_len
从5增加到8)
为什么会循环这么多次?
- 序列长度增长:每次循环,
cur_len
增加1。 - 停止条件未满足:只要有序列未生成EOS token,
unfinished_sequences
中就有1,循环继续。 - 需要等待所有序列完成:在批量生成中,即使某些序列提前完成,循环也会继续,直到所有序列都完成或达到最大长度。
另一种情况的实例
假设模型在默认情况下不太可能生成EOS token,或者没有设置停止条件。
设定:
- 批次大小 (
batch_size
):2 - 初始序列长度 (
cur_len
):5 - 最大长度 (
max_length
):10 - 未设置停止条件(或模型未生成EOS token)
循环过程:
- 从
cur_len
=5开始,每次循环cur_len
增加1。 - 没有序列会被标记为完成,
unfinished_sequences
始终为[1, 1]
。 - 循环持续到
cur_len
达到max_length
。 - 循环次数:5次(从
cur_len
=5到cur_len
=10)
总结
- 循环次数依赖于序列是否完成:只要有未完成的序列,且未达到最大长度,循环就会继续。
- 最大循环次数:
max_length - cur_len_initial
,即可能的最大生成步数。 - 最小循环次数:如果初始序列已经满足停止条件,循环可能不会执行。
_get_initial_cache_position
函数功能概述
_get_initial_cache_position
是一个用于计算 cache_position
的函数。cache_position
用于跟踪生成过程中每个位置的 token 编号,特别是在使用缓存(past_key_values
)的情况下,它确保生成的新 token 的位置编号与缓存中的位置正确对应。
在生成文本的过程中,Transformer 模型可以使用缓存的 key 和 value(past_key_values
),以加速生成。这些缓存通常来自于模型之前的计算,即已经生成的序列。为了正确处理缓存,需要调整新生成的 token 的位置编号,这就是 cache_position
的作用。
代码详解
def _get_initial_cache_position(self, input_ids, model_kwargs):"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:cache_position = (torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1)else:cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1past_length = 0if model_kwargs.get("past_key_values") is not None:cache = model_kwargs["past_key_values"]past_length = 0if not isinstance(cache, Cache):past_length = cache[0][0].shape[2]elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:past_length = cache.get_seq_length()# TODO: this is not torch.compile-friendly, find a work-around. If the cache is not empty,# end-to-end compilation will yield bad results because `cache_position` will be incorrect.if not is_torchdynamo_compiling():cache_position = cache_position[past_length:]model_kwargs["cache_position"] = cache_positionreturn model_kwargs
步骤1:计算 cache_position
首先,根据是否存在嵌入向量,以及模型的类型(解码器或编码器-解码器),计算 cache_position
。
-
如果存在
inputs_embeds
且模型是解码器模型(不是编码器-解码器):if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
model_kwargs["inputs_embeds"]
的形状通常为(batch_size, sequence_length, embed_dim)
。model_kwargs["inputs_embeds"][0, :, 0]
取出第一个样本的所有时间步的第一个嵌入向量分量,形状为(sequence_length,)
。torch.ones_like(..., dtype=torch.int64)
创建一个与形状相同的全1张量,类型为int64
。.cumsum(0)
计算累积和,得到[1, 2, 3, ..., sequence_length]
。- 减去1,得到
[0, 1, 2, ..., sequence_length - 1]
。
-
如果存在
decoder_inputs_embeds
且模型是编码器-解码器模型:elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:cache_position = (torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1)
- 类似地,获取解码器的嵌入,计算位置编号。
-
否则,使用
input_ids
:else:cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
input_ids[0, :]
取第一个样本的所有 token。- 通过创建全1张量,计算累积和并减1,得到位置编号。
总结:cache_position
是一个长度为 sequence_length
的张量,表示每个位置的编号,从 0
开始。
步骤2:处理 past_key_values
(缓存)
past_length = 0
if model_kwargs.get("past_key_values") is not None:cache = model_kwargs["past_key_values"]past_length = 0if not isinstance(cache, Cache):past_length = cache[0][0].shape[2]elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:past_length = cache.get_seq_length()# TODO: this is not torch.compile-friendly, find a work-around. If the cache is not empty,# end-to-end compilation will yield bad results because `cache_position` will be incorrect.if not is_torchdynamo_compiling():cache_position = cache_position[past_length:]
-
首先,初始化
past_length = 0
。 -
如果存在
past_key_values
(缓存):-
获取缓存的长度
past_length
:-
如果缓存不是
Cache
类的实例,那么:past_length = cache[0][0].shape[2]
cache
是一个列表,包含每一层的(key, value)
对。cache[0][0]
是第一层的key
,其形状通常为(batch_size, num_heads, past_length, head_dim)
。cache[0][0].shape[2]
即为past_length
。
-
如果缓存是
Cache
类的实例,且具有get_seq_length
方法:past_length = cache.get_seq_length()
- 直接获取缓存的序列长度。
-
-
调整
cache_position
:-
如果不存在编译(
is_torchdynamo_compiling()
为False
),则切片cache_position
,去掉已经缓存的部分:cache_position = cache_position[past_length:]
- 这样,
cache_position
只包含新生成的 token 的位置编号。
- 这样,
-
-
步骤3:更新 model_kwargs
并返回
model_kwargs["cache_position"] = cache_position
return model_kwargs
- 将计算得到的
cache_position
添加到model_kwargs
中,以供后续生成过程中使用。
示例说明 cache_position
的变化过程
示例1:无缓存的情况
假设:
-
input_ids
:input_ids = torch.tensor([[101, 102, 103], # 样本1[201, 202, 203] # 样本2 ]) # 形状为 (batch_size=2, sequence_length=3)
-
model_kwargs = {}
-
模型为解码器模型(非编码器-解码器),且未提供
inputs_embeds
。
步骤1:计算 cache_position
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
input_ids[0, :]
为[101, 102, 103]
。torch.ones_like(...)
得到[1, 1, 1]
。.cumsum(0)
计算累积和,得到[1, 2, 3]
。- 减去1,得到
[0, 1, 2]
。
得到:
cache_position = torch.tensor([0, 1, 2]) # 位置编号为 0, 1, 2
步骤2:不存在 past_key_values
(缓存),无需调整 cache_position
步骤3:更新 model_kwargs
并返回
model_kwargs["cache_position"] = cache_position
-
model_kwargs
现在包含:model_kwargs = {"cache_position": torch.tensor([0, 1, 2]) }
示例2:存在缓存的情况
假设:
input_ids
与之前相同。model_kwargs
包含past_key_values
,其中缓存长度past_length = 2
。
假设 past_key_values
:
缓存中保存了前两个位置的 key
和 value
,因此 past_length = 2
。
步骤1:计算 cache_position
与之前相同,得到:
cache_position = torch.tensor([0, 1, 2])
步骤2:调整 cache_position
-
由于
past_length = 2
,需要从位置编号中移除前两个位置:cache_position = cache_position[past_length:]
cache_position = cache_position[2:] # 取索引为 2 及之后的元素
-
得到:
cache_position = torch.tensor([2]) # 仅剩下位置编号 2
步骤3:更新 model_kwargs
并返回
model_kwargs["cache_position"] = cache_position
-
model_kwargs
现在包含:model_kwargs = {"past_key_values": ..., # 缓存内容(省略)"cache_position": torch.tensor([2]) }
说明:
- 缓存中已包含了位置 0 和 1 的
key
和value
,因此新的生成只需要处理位置 2 的 token。 cache_position
从[0, 1, 2]
变为[2]
。
cache_position
的作用
- 在 Transformer 模型中,位置编码对于捕获序列顺序信息非常重要。
- 使用缓存时,我们需要知道新生成的 token 应该对应哪些位置,以确保模型正确地生成下一个 token。
- 如果不调整
cache_position
,模型可能会错误地将新 token 视为之前的位置,从而导致生成错误。
总结
- 函数
_get_initial_cache_position
的作用是根据input_ids
和past_key_values
,计算生成过程中新 token 的 位置编号cache_position
。 - 当没有缓存时,
cache_position
从 0 开始,依次递增。 - 当存在缓存时,需要考虑缓存的长度,调整
cache_position
,以对应新添加的 token。 - 通过上述示例可以看到,
cache_position
根据是否存在缓存,以及缓存的长度,发生了相应的变化。
prepare_inputs_for_generation
函数概述
prepare_inputs_for_generation
函数的主要作用是为生成过程准备模型的输入。它根据当前的输入和模型的配置,处理和调整 input_ids
、past_key_values
(缓存)、attention_mask
以及其他相关输入,以确保模型在生成过程中正确运作。
这个函数涉及到了以下几个关键步骤:
- 处理缓存位置
cache_position
:根据是否存在缓存以及是否支持Cache
类,计算或调整cache_position
。 - 根据缓存调整输入序列:如果存在缓存,只需要处理未缓存的部分,因此对
input_ids
、inputs_embeds
等进行切片。 - 准备模型的基础输入:根据模型类型(解码器或编码器-解码器),设置正确的输入键,如
input_ids
或decoder_input_ids
。 - 创建缺失的
position_ids
:如果需要,为输入序列生成位置编码position_ids
。 - 调整其他与输入长度相关的输入:如
token_type_ids
等,确保它们与输入序列的长度一致。 - 创建 4D 注意力掩码(可选):如果使用了
StaticCache
,为提高性能,预先创建固定形状的因果掩码。 - 传递其他未初始化的关键字参数:如
use_cache
等。 - 移除生成过程中不需要的输入:如
labels
。
逐步详解
1. 初始化和处理 cache_position
model_inputs = {}
if self._supports_cache_class:model_inputs["cache_position"] = cache_position
elif cache_position is None:past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
-
model_inputs
:用于存储准备好的模型输入。 -
处理
cache_position
:- 如果模型支持
Cache
类(self._supports_cache_class
为True
),则将cache_position
添加到model_inputs
中。 - 如果模型不支持
Cache
类且cache_position
为None
,则根据past_key_values
来计算cache_position
:past_length
:缓存的长度,即past_key_values
中缓存的序列长度。cache_position
:从past_length
开始,到当前input_ids
的序列长度,创建一个递增的序列,用于表示未处理的 token 的位置。
- 如果模型支持
2. 根据缓存调整输入序列
if past_key_values is not None:model_inputs["past_key_values"] = past_key_valuesif inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]elif (inputs_embeds is not None # Exception 1or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3):input_ids = input_ids[:, -cache_position.shape[0] :]elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)input_ids = input_ids[:, cache_position]
-
目的:当存在缓存时,只需要处理未缓存的部分,所以需要根据
cache_position
对input_ids
或inputs_embeds
进行切片。 -
异常情况:
- 异常1:当传入
inputs_embeds
时,input_ids
可能缺少元素,需要处理inputs_embeds
。 - 异常2:某些生成方法对
input_ids
进行了特殊的切片,这里不需要再处理。 - 异常3:在同步 GPU 时,
cache_position
可能超出input_ids
的范围,此时需要特殊处理。 - 异常4:如果传入了
inputs_embeds
,并且input_ids
长度为 0,则需要对inputs_embeds
进行切片。
- 异常1:当传入
3. 准备模型的基础输入
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
if not self.config.is_encoder_decoder:if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:model_inputs[input_ids_key] = Nonemodel_inputs["inputs_embeds"] = inputs_embedselse:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)model_inputs["inputs_embeds"] = None
else:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
-
根据模型类型,确定输入的键:
- 解码器模型:
input_ids_key
为"input_ids"
。 - 编码器-解码器模型:
input_ids_key
为"decoder_input_ids"
。
- 解码器模型:
-
准备输入:
- 如果传入了
inputs_embeds
且长度匹配:- 仅在初始生成步骤使用
inputs_embeds
。 - 将
inputs_embeds
添加到model_inputs
,并将对应的input_ids
设为None
。
- 仅在初始生成步骤使用
- 否则:
- 将
input_ids
克隆后添加到model_inputs
。 - 设置
inputs_embeds
为None
。
- 将
- 如果传入了
4. 创建缺失的 position_ids
encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
attention_mask = (kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
)
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
if (attention_mask is not Noneand kwargs.get(position_ids_key) is Noneand position_ids_key in set(inspect.signature(self.forward).parameters.keys())
):position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)
- 创建
position_ids
:- 如果没有提供
position_ids
,则根据attention_mask
计算。 - 计算方式:对
attention_mask
进行累加 (cumsum
),得到位置索引。 - 对于被填充的位置(
attention_mask == 0
),将position_ids
设置为1
。
- 如果没有提供
5. 调整其他与输入长度相关的输入
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:model_input = kwargs.get(model_input_name)if model_input is not None:if past_key_values is not None:current_input_length = (model_inputs["inputs_embeds"].shape[1]if model_inputs.get("inputs_embeds") is not Noneelse model_inputs[input_ids_key].shape[1])model_input = model_input[:, -current_input_length:]model_input = model_input.clone(memory_format=torch.contiguous_format)model_inputs[model_input_name] = model_input
-
目的:确保其他需要与
input_ids
长度相同的输入(如position_ids
、token_type_ids
)的长度匹配。 -
处理方式:
- 如果存在缓存,且对应的输入存在,则对其进行切片,仅保留当前需要处理的部分。
6. 创建 4D 注意力掩码(可选)
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:if model_inputs["inputs_embeds"] is not None:batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shapedevice = model_inputs["inputs_embeds"].deviceelse:batch_size, sequence_length = model_inputs[input_ids_key].shapedevice = model_inputs[input_ids_key].device# 获取创建 4D 掩码的函数base_model = getattr(self, self.base_model_prefix, None)if base_model is None:causal_mask_creation_function = getattr(self, "_prepare_4d_causal_attention_mask_with_cache_position", None)else:causal_mask_creation_function = getattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None)# 如果存在该函数,则创建 4D 掩码if causal_mask_creation_function is not None:attention_mask = causal_mask_creation_function(attention_mask,sequence_length=sequence_length,target_length=past_key_values.get_max_cache_shape(),dtype=self.dtype,device=device,cache_position=cache_position,batch_size=batch_size,config=self.config,past_key_values=past_key_values,)
-
目的:当使用了
StaticCache
时,为了提高性能,可以预先创建固定形状的 4D 注意力掩码。 -
处理方式:
- 获取创建因果掩码的函数
_prepare_4d_causal_attention_mask_with_cache_position
。 - 如果存在,则调用该函数创建 4D 注意力掩码。
- 获取创建因果掩码的函数
7. 传递其他未初始化的关键字参数
for key, value in kwargs.items():if key not in model_inputs:model_inputs[key] = value
- 目的:将所有未处理的关键字参数(如
use_cache
)传递给模型输入。
8. 移除生成过程中不需要的输入
model_inputs.pop("labels", None)
- 目的:在生成过程中,不需要
labels
,因此将其从model_inputs
中移除。
返回处理好的模型输入
return model_inputs
示例
场景设定
-
模型类型:解码器模型(非编码器-解码器)。
-
批次大小(batch_size):2
-
初始输入
input_ids
:input_ids = torch.tensor([[101, 102, 103, 104, 105], # 样本 1[201, 202, 203, 204, 205] # 样本 2 ]) # 形状为 (2, 5)
-
注意力掩码
attention_mask
:attention_mask = torch.tensor([[1, 1, 1, 1, 1], # 样本 1[1, 1, 1, 1, 1] # 样本 2 ]) # 形状为 (2, 5)
-
缓存
past_key_values
:由DynamicCache()
创建的空Cache
对象,无已缓存的序列。 -
模型支持
Cache
类:_supports_cache_class = True
-
未传入
inputs_embeds
第一次调用 prepare_inputs_for_generation
调用参数
model_inputs = model.prepare_inputs_for_generation(input_ids=input_ids,past_key_values=past_key_values,attention_mask=attention_mask,cache_position=torch.arange(0, input_ids.shape[1], dtype=torch.long).to(input_ids.device),**model_kwargs
)
-
cache_position
:由于模型支持Cache
类,我们需要提供cache_position
。这里,我们直接使用torch.arange
从 0 到input_ids
的序列长度生成位置索引:cache_position = torch.arange(0, input_ids.shape[1], dtype=torch.long).to(input_ids.device) # 对于本例,cache_position = [0, 1, 2, 3, 4]
函数执行步骤
-
初始化
model_inputs
并处理cache_position
model_inputs = {} if self._supports_cache_class:model_inputs["cache_position"] = cache_position
- 将
cache_position
添加到model_inputs
。
- 将
-
处理缓存相关的输入
if past_key_values is not None:model_inputs["past_key_values"] = past_key_values# 由于 inputs_embeds 为 None,且不存在特殊情况,直接进入默认情况if input_ids.shape[1] != cache_position.shape[0]:input_ids = input_ids[:, cache_position]
-
检查
input_ids
的长度是否与cache_position
的长度一致。- 由于
input_ids.shape[1] = 5
,cache_position.shape[0] = 5
,长度一致,无需切片。
- 由于
-
将
past_key_values
添加到model_inputs
。
-
-
准备基础模型输入
input_ids_key = "input_ids" # 因为是解码器模型if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:model_inputs[input_ids_key] = Nonemodel_inputs["inputs_embeds"] = inputs_embeds else:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)model_inputs["inputs_embeds"] = None
- 将
input_ids
克隆后添加到model_inputs
,并确保内存连续。 - 设置
inputs_embeds
为None
。
- 将
-
创建缺失的
position_ids
encoder_attention_mask = None # 因为是解码器模型 attention_mask_key = "attention_mask" position_ids_key = "position_ids"if (attention_mask is not Noneand kwargs.get(position_ids_key) is Noneand position_ids_key in inspect.signature(self.forward).parameters ):position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)kwargs[position_ids_key] = position_ids
-
计算
position_ids
:position_ids = torch.tensor([[0, 1, 2, 3, 4], # 样本 1[0, 1, 2, 3, 4] # 样本 2 ], dtype=torch.long)
-
将
position_ids
添加到kwargs
。
-
-
调整与输入长度相关的其他输入
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:model_input = kwargs.get(model_input_name)if model_input is not None:if past_key_values is not None:current_input_length = model_inputs[input_ids_key].shape[1] # 5model_input = model_input[:, -current_input_length:]model_input = model_input.clone(memory_format=torch.contiguous_format)model_inputs[model_input_name] = model_input
- 对于
position_ids
,长度与input_ids
一致,无需截断。
- 对于
-
添加
attention_mask
if attention_mask is not None:model_inputs[attention_mask_key] = attention_mask
- 将
attention_mask
添加到model_inputs
。
- 将
-
传递未初始化的其他关键字参数
- 如果有其他参数(如
use_cache
),也会被添加到model_inputs
。
- 如果有其他参数(如
-
移除
labels
model_inputs.pop("labels", None)
- 移除不需要的参数。
第一次调用后的 model_inputs
model_inputs = {"cache_position": torch.tensor([0, 1, 2, 3, 4], dtype=torch.long),"past_key_values": past_key_values, # 空的 DynamicCache()"input_ids": input_ids.clone(memory_format=torch.contiguous_format),"inputs_embeds": None,"position_ids": torch.tensor([[0, 1, 2, 3, 4],[0, 1, 2, 3, 4]], dtype=torch.long),"attention_mask": attention_mask
}
模型生成新 token 后的第二次调用
假设模型在第一次生成后,生成了一个新 token,past_key_values
被更新。同时,cache_position
需要更新,input_ids
也需要更新。
更新的输入
-
新的
input_ids
(新生成的 token):new_tokens = torch.tensor([[106], # 样本 1[206] # 样本 2 ]) # 形状为 (2, 1)
-
更新后的
input_ids
:只保留新生成的 token,因为过去的序列已在past_key_values
中缓存input_ids = new_tokens # 形状为 (2, 1)
-
更新后的
attention_mask
:针对新 tokenattention_mask = torch.ones_like(input_ids, dtype=torch.long) # 形状为 (2, 1)
-
更新的
past_key_values
:包含了先前的缓存 -
更新的
cache_position
:增加新的位置索引cache_position = torch.tensor([5], dtype=torch.long)
第二次调用 prepare_inputs_for_generation
model_inputs = model.prepare_inputs_for_generation(input_ids=input_ids,past_key_values=past_key_values,attention_mask=attention_mask,cache_position=cache_position,**model_kwargs
)
函数执行步骤
-
初始化
model_inputs
并处理cache_position
model_inputs = {} if self._supports_cache_class:model_inputs["cache_position"] = cache_position
-
处理缓存相关的输入
if past_key_values is not None:model_inputs["past_key_values"] = past_key_values# 检查输入长度与 cache_position 长度if input_ids.shape[1] != cache_position.shape[0]:input_ids = input_ids[:, -cache_position.shape[0]:]
input_ids.shape[1] = 1
,cache_position.shape[0] = 1
,长度一致,无需切片。
-
准备基础模型输入
input_ids_key = "input_ids"if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:model_inputs[input_ids_key] = Nonemodel_inputs["inputs_embeds"] = inputs_embeds else:model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)model_inputs["inputs_embeds"] = None
- 将
input_ids
添加到model_inputs
。
- 将
-
创建缺失的
position_ids
encoder_attention_mask = None attention_mask_key = "attention_mask" position_ids_key = "position_ids"if (attention_mask is not Noneand kwargs.get(position_ids_key) is Noneand position_ids_key in inspect.signature(self.forward).parameters ):position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)kwargs[position_ids_key] = position_ids
-
计算新的
position_ids
:position_ids = torch.tensor([[0], # 样本 1[0] ], dtype=torch.long)
-
由于这是新生成的 token,在全新的序列中位置索引为 0。
-
-
调整与输入长度相关的其他输入
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:model_input = kwargs.get(model_input_name)if model_input is not None:if past_key_values is not None:current_input_length = model_inputs[input_ids_key].shape[1] # 1model_input = model_input[:, -current_input_length:]model_input = model_input.clone(memory_format=torch.contiguous_format)model_inputs[model_input_name] = model_input
- 对于
position_ids
,已是正确的长度。
- 对于
-
添加
attention_mask
if attention_mask is not None:model_inputs[attention_mask_key] = attention_mask
- 将
attention_mask
添加到model_inputs
。
- 将
-
传递未初始化的其他关键字参数
- 如果有其他参数,也会被添加到
model_inputs
。
- 如果有其他参数,也会被添加到
-
移除
labels
- 移除不需要的参数。
第二次调用后的 model_inputs
model_inputs = {"cache_position": torch.tensor([5], dtype=torch.long), # 新的位置索引"past_key_values": past_key_values, # 更新后的缓存"input_ids": input_ids.clone(memory_format=torch.contiguous_format), # 形状为 (2, 1)"inputs_embeds": None,"position_ids": torch.tensor([[0], # 样本 1[0]], dtype=torch.long),"attention_mask": attention_mask
}
总结
在第二次调用 prepare_inputs_for_generation
时,函数根据新的 cache_position
和 past_key_values
,对 input_ids
、position_ids
、attention_mask
等进行了相应的处理,以确保模型只处理新生成的 token,并正确地应用位置编码和注意力掩码。
cache_position
的变化
-
第一次调用:
cache_position
为[0, 1, 2, 3, 4]
,对应初始输入序列的位置。input_ids
使用完整的输入序列。
-
第二次调用:
cache_position
更新为[5]
,表示新生成的 token 在整个序列中的位置。input_ids
只包含新生成的 token。
_update_model_kwargs_for_generation
函数概述
def _update_model_kwargs_for_generation(self,outputs: ModelOutput,model_kwargs: Dict[str, Any],is_encoder_decoder: bool = False,num_new_tokens: int = 1,
) -> Dict[str, Any]:# 函数内容
这个函数的主要作用是在生成过程中,更新模型的关键字参数 model_kwargs
,以便在下一步生成时使用。更新的内容包括:
- 缓存
past_key_values
:用于加速后续生成。 - 注意力掩码
attention_mask
或decoder_attention_mask
:告诉模型在生成时应关注的序列长度。 - 缓存位置
cache_position
:用于跟踪当前缓存中的位置,特别是对于需要位置编码的模型。
这个函数通常在生成循环中每次生成一个新 token 之后被调用,以确保模型在下一次调用时具有正确的上下文和输入参数。
代码详解
1. 更新 past_key_values
for possible_cache_name in ALL_CACHE_NAMES:if possible_cache_name in outputs:# TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecatedif possible_cache_name in ("past_buckets_states", "mems"):cache_name = "past_key_values"else:cache_name = possible_cache_namemodel_kwargs[cache_name] = getattr(outputs, possible_cache_name)break
-
作用:从模型的输出
outputs
中提取缓存past_key_values
,并更新model_kwargs
中对应的键。 -
解释:
ALL_CACHE_NAMES
是一个包含可能的缓存名称的列表,例如["past_key_values", "mems", "past_buckets_states"]
,因为不同的模型可能使用不同的缓存名称。- 该循环检查
outputs
中是否存在这些缓存名称。 - 如果找到:
- 如果缓存名称是
"past_buckets_states"
或"mems"
(针对旧模型如 XLNet、Reformer),将其统一为"past_key_values"
,以保持一致性。 - 将缓存添加到
model_kwargs
,以便在下一步生成时使用。
- 如果缓存名称是
2. 更新 token_type_ids
if "token_type_ids" in model_kwargs:token_type_ids = model_kwargs["token_type_ids"]model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
-
作用:如果模型使用了
token_type_ids
(用于区分不同句子或段落的类型 ID,例如在 BERT 中),则需要在生成过程中更新它。 -
解释:
- 获取当前的
token_type_ids
。 - 将最后一个
token_type_id
复制一份,添加到序列末尾。这样,新生成的 token 将具有与前一个 token 相同的类型 ID。
- 获取当前的
3. 更新注意力掩码
-
对于非编码器-解码器模型:
if not is_encoder_decoder:if "attention_mask" in model_kwargs:attention_mask = model_kwargs["attention_mask"]model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
-
对于编码器-解码器模型:
else:if "decoder_attention_mask" in model_kwargs:decoder_attention_mask = model_kwargs["decoder_attention_mask"]model_kwargs["decoder_attention_mask"] = torch.cat([decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],dim=-1,)
-
作用:在生成新 token 后,需要更新注意力掩码,让模型知道现在的序列长度增加了,应该在计算注意力时关注到新生成的 token。
-
解释:
- 非编码器-解码器模型:
- 获取当前的
attention_mask
,形状为(batch_size, sequence_length)
。 - 为每个样本添加一个值为
1
的新列,表示新生成的 token 是有效的(不是填充或被掩码的)。 - 更新
model_kwargs["attention_mask"]
。
- 获取当前的
- 编码器-解码器模型:
- 类似地,更新
decoder_attention_mask
。
- 类似地,更新
- 非编码器-解码器模型:
4. 更新 cache_position
if model_kwargs.get("use_cache", True):model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
else:past_positions = model_kwargs.pop("cache_position")new_positions = torch.arange(past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype).to(past_positions.device)model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
-
作用:更新缓存位置
cache_position
,用于跟踪当前缓存中的位置,在生成过程中正确地处理位置编码。 -
解释:
use_cache
:从model_kwargs
中获取,默认为True
,表示模型在生成过程中使用缓存。- 如果
use_cache
为True
:- 取当前的
cache_position
的最后一个值,增加num_new_tokens
(默认为1
)。 - 这样,新的
cache_position
将是[last_position + num_new_tokens]
。
- 取当前的
- 如果
use_cache
为False
:- 弹出当前的
cache_position
。 - 生成一个新的位置序列
new_positions
,从past_positions
的最后一个值加1
开始,长度为num_new_tokens
。 - 将
past_positions
和new_positions
拼接,更新cache_position
。
- 弹出当前的
5. 返回更新后的 model_kwargs
return model_kwargs
- 返回更新后的
model_kwargs
,以供下一步生成使用。
示例说明
假设我们在文本生成过程中,使用一个非编码器-解码器模型,批次大小为 2
。
初始状态
-
模型输出
outputs
:包含past_key_values
,这是模型生成过程中用于缓存的注意力 key 和 value。 -
model_kwargs
:model_kwargs = {"attention_mask": torch.tensor([[1, 1, 1, 1, 1], # 样本 1 的注意力掩码[1, 1, 1, 1, 1] # 样本 2 的注意力掩码]), # 形状为 (2, 5)"cache_position": torch.tensor([0, 1, 2, 3, 4], dtype=torch.long), # 初始位置# 可能还有其他参数 }
-
is_encoder_decoder
:False
,表示这是一个解码器-only(GPT 类)模型。 -
num_new_tokens
:1
,每次生成一个新 token。
第一次调用
model_kwargs = _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
步骤解析:
-
更新
past_key_values
:- 从
outputs
中提取past_key_values
,更新model_kwargs
。
- 从
-
更新
attention_mask
:-
原始
attention_mask
:tensor([[1, 1, 1, 1, 1], # 样本 1[1, 1, 1, 1, 1] # 样本 2 ]) # 形状为 (2, 5)
-
新增一个长度为
1
的列,值为1
,表示新生成的 token:new_column = torch.ones((2, 1), dtype=torch.long) model_kwargs["attention_mask"] = torch.cat([attention_mask, new_column], dim=-1) # 更新后的 attention_mask: # tensor([ # [1, 1, 1, 1, 1, 1], # 形状为 (2, 6) # [1, 1, 1, 1, 1, 1] # ])
-
-
更新
cache_position
:-
原始
cache_position
:cache_position = torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)
-
更新:
model_kwargs["cache_position"] = cache_position[-1:] + num_new_tokens # 取最后一个值 4,加 1,得到 5 # 更新后的 cache_position: # tensor([5])
-
-
返回更新后的
model_kwargs
。
更新后的 model_kwargs
:
model_kwargs = {"attention_mask": tensor([[1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1]]), # 形状为 (2, 6)"cache_position": tensor([5], dtype=torch.long),"past_key_values": outputs.past_key_values,# 其他参数保持不变
}
第二次调用
假设模型又生成了一个新 token,再次调用该函数:
model_kwargs = _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
步骤解析:
-
更新
past_key_values
:- 更新
past_key_values
。
- 更新
-
更新
attention_mask
:-
原始
attention_mask
:tensor([[1, 1, 1, 1, 1, 1], # 形状为 (2, 6)[1, 1, 1, 1, 1, 1] ])
-
新增一个长度为
1
的列:new_column = torch.ones((2, 1), dtype=torch.long) model_kwargs["attention_mask"] = torch.cat([attention_mask, new_column], dim=-1) # 更新后的 attention_mask: # tensor([ # [1, 1, 1, 1, 1, 1, 1], # 形状为 (2, 7) # [1, 1, 1, 1, 1, 1, 1] # ])
-
-
更新
cache_position
:-
原始
cache_position
:cache_position = torch.tensor([5], dtype=torch.long)
-
更新:
model_kwargs["cache_position"] = cache_position[-1:] + num_new_tokens # 取最后一个值 5,加 1,得到 6 # 更新后的 cache_position: # tensor([6])
-
-
返回更新后的
model_kwargs
。
更新后的 model_kwargs
:
model_kwargs = {"attention_mask": tensor([[1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1]]), # 形状为 (2, 7)"cache_position": tensor([6], dtype=torch.long),"past_key_values": outputs.past_key_values,# 其他参数保持不变
}
总结
通过上述两次调用,我们可以看到:
attention_mask
在每次生成新 token 后,都在序列末尾添加了一个值为1
的新列,表示序列长度增加,模型应关注新的 token。cache_position
跟踪当前的位置,每次生成一个新 token,位置索引增加1
。
结论
该函数在生成循环中起到关键作用,确保模型在每个生成步骤都具有正确的缓存和输入参数,从而生成连贯的序列。通过更新 past_key_values
、attention_mask
、cache_position
等关键参数,模型能够高效地利用先前的计算,加速生成过程。
理解这个函数有助于深入掌握文本生成模型的工作原理,尤其是在实现自回归生成和缓存机制时。
_has_unfinished_sequences
详细解释:
该函数 _has_unfinished_sequences
的主要作用是在生成序列的过程中,判断是否需要继续生成,即是否还有未完成的序列。这在循环生成模型中非常重要,可以避免不必要的计算,提高效率。
参数说明:
this_peer_finished
(bool
):当前进程(或节点)是否已经完成了生成序列。synced_gpus
(bool
):是否在使用多 GPU 同步模式。如果为True
,表示多个 GPU 需要同步生成过程。device
(torch.device
):当前计算所在的设备(CPU 或 GPU)。cur_len
(Optional[int]
):当前序列的长度。max_length
(Optional[int]
):生成序列的最大长度。
返回值:
bool
:返回True
表示仍有未完成的序列,需要继续生成;返回False
表示所有序列都已完成,可以停止生成。
代码逻辑详解:
-
处理
torch.compile
的特殊情况:if is_torchdynamo_compiling():return cur_len < max_length
-
背景:
torch.compile
(或torchdynamo
)目前不支持数据依赖的控制流,也就是说,无法根据运行时的数据(如模型的预测结果)来决定控制流(如循环、条件判断)。
-
逻辑:
- 如果正在使用
torch.compile
编译,则无法在生成过程中根据序列是否完成来提前停止,只能按照固定的最大长度max_length
来控制循环。 - 因此,这里直接判断当前长度
cur_len
是否小于最大长度max_length
,如果小于则返回True
,表示需要继续生成。
- 如果正在使用
-
影响:
- 这种方式会失去根据序列生成的内容(如是否生成了 EOS 标记)来提前终止循环的能力。
-
-
正常情况下的处理逻辑:
else:if synced_gpus:# 同步 GPU 的处理逻辑...elif this_peer_finished:return Falsereturn True
-
说明:
- 如果未使用
torch.compile
,则进入正常的处理逻辑。
- 如果未使用
-
-
处理同步 GPU 的情况:
if synced_gpus:# 在 synced_gpus 模式下,`forward` 调用必须继续,直到所有 GPU 完成它们的序列。# 以下逻辑允许在所有节点完成生成序列时提前结束this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)# 如果当前节点已完成,发送 0.0,否则发送 1.0dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)# 所有节点都完成了吗?如果归约的和为 0.0,则表示所有节点都已完成if this_peer_finished_flag.item() == 0.0:return False
-
背景:
- 在使用多 GPU 同步模式时,各个 GPU 需要同步进行生成过程,不能有的节点提前退出,否则会导致同步问题和潜在的死锁。
-
逻辑:
-
Step 1:创建一个标志变量
this_peer_finished_flag
,表示当前节点是否完成。this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
- 如果当前节点已完成,则值为
0.0
,未完成则为1.0
。 - 放在
device
上以确保在正确的设备上进行计算。
- 如果当前节点已完成,则值为
-
Step 2:使用分布式的
all_reduce
操作,将所有节点的标志变量求和。dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
all_reduce
会收集所有参与进程的this_peer_finished_flag
,并按照指定的操作(这里是求和SUM
)进行归约,然后将结果广播给所有节点。- 结果是所有节点的
this_peer_finished_flag
之和。
-
Step 3:判断是否所有节点都已完成。
if this_peer_finished_flag.item() == 0.0:return False
- 如果归约结果为
0.0
,说明所有节点的this_peer_finished_flag
都是0.0
,即所有节点都已完成生成。 - 返回
False
,表示不需要继续生成。
- 如果归约结果为
-
Step 4:否则,返回
True
,表示仍有未完成的序列,需要继续生成。
-
-
注意:
- 这样可以确保在多 GPU 同步模式下,所有节点同步地决定是否继续生成,防止有的节点提前退出导致同步问题。
-
-
处理未使用同步 GPU 且当前节点已完成的情况:
elif this_peer_finished:return False
- 如果未使用同步 GPU(
synced_gpus == False
),并且当前节点已完成生成,则可以直接返回False
,不需要继续。
- 如果未使用同步 GPU(
-
默认返回
True
:return True
- 如果上述条件都不满足,表示仍有未完成的序列,需要继续生成,返回
True
。
- 如果上述条件都不满足,表示仍有未完成的序列,需要继续生成,返回