File size: 13,446 Bytes
9814d6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
import re
import torch
import numpy as np
from queue import Queue
from typing import Tuple, List, Union, Iterable
from transformers.utils import logging, add_start_docstrings
from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING, LogitsProcessorList
def make_context(model, tokenizer,
messages: List[dict],
system: str = "You are a helpful assistant.",
max_new_tokens: int=0,
):
# 确定新生成的token数量,优先使用传入参数,否则使用模型配置中的默认值
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
# 计算模型允许的最大输入长度(模型最大长度减去新生成的token数)
max_input_length = model.config.max_position_embeddings - max_new_tokens
nl_tokens = tokenizer.encode("\n", add_special_tokens=False)
def _parse_messages(messages):
""" 解析消息列表,分离系统消息、查询和对话历史
"""
system, query, history = "", "", []
## system
if messages[0]["role"] == "system":
system = messages[0]["content"]
messages = messages[1:]
## query
### 确保最后一项是用户消息
assert messages[-1]["role"] == "user"
query = messages[-1]["content"]
messages = messages[:-1]
## history
assert len(messages) % 2 == 0
for i in range(0, len(messages), 2):
assert messages[i]["role"] == "user" and messages[i+1]["role"] == "assistant"
history.append([messages[i]["content"], messages[i+1]["content"]])
return system, query, history
# 调用_parse_messages解析消息
_system, query, history = _parse_messages(messages)
## system
system_text = _system if _system != "" else system
system_tokens = []
if system_text:
# system_tokens = tokenizer.build_single_message("system", "", system_text.strip())
system_tokens = tokenizer.encode(text=("<|system|>\n"+system_text.strip()), add_special_tokens=True, truncation=True) + nl_tokens
## query
# query_tokens = tokenizer.build_single_message("user", "", query.strip())
query_tokens = tokenizer.encode(text=("<|user|>\n"+query.strip()), add_special_tokens=False, truncation=True) + nl_tokens
## final assistant
# final_tokens = tokenizer.build_single_message("assistant", "", "")
final_tokens = tokenizer.encode("<|assistant|>", add_special_tokens=False, truncation=True) + nl_tokens
## max_history_tokens
max_history_length = max_input_length - len(system_tokens) - len(query_tokens) - len(final_tokens)
## history
## 逆序遍历对话历史,构建token序列
context_tokens = []
for turn_query, turn_response in reversed(history):
## query tokens
history_query_tokens = tokenizer.encode("<|user|>\n"+turn_query.strip(), add_special_tokens=False, truncation=True) + nl_tokens
## answer tokens
histroy_response_tokens = tokenizer.encode("<|assistant|>\n"+turn_response.strip(), add_special_tokens=False, truncation=True) + nl_tokens
## this round tokens
next_context_tokens = history_query_tokens + histroy_response_tokens
## concat
## 确保加入这些token后总长度不超过允许的最大历史长度
current_context_size = len(next_context_tokens) + len(context_tokens)
if current_context_size < max_history_length:
context_tokens = next_context_tokens + context_tokens
else:
break
input_tokens = system_tokens + context_tokens + query_tokens + final_tokens
return torch.LongTensor([input_tokens]).to(model.device)
def parse_pot_no_stream(inputs):
""" 解析并处理输入字符串中特定格式(形如 <<...>>)的代码片段。
这些代码片段可以是简单的数学表达式赋值,也可以是定义和调用函数。
1. 对于包含 "func" 的代码片段,它会识别函数定义,执行该函数,
并将函数返回的结果替换到原始字符串中的相应位置。
如果函数涉及到 sympy(一个符号计算库),
则还会做一些特定的字符串替换处理。
2. 对于不包含 "func" 的代码片段,它会直接计算等号右边的表达式,
并将计算结果替换到原始字符串中,同时也会进行一些类型转换
(如将浮点数转为整数)。
"""
try:
# 尝试从输入字符串中找到形如 "<<...>>" 的模式
s = re.findall(r'<<(.*?)>>', inputs, re.DOTALL)
# 如果没有找到匹配项,则直接返回原始输入
if not s:
#print("err inputs: ", origin_inputs, flush=True)
return inputs
index = 0
# 遍历所有匹配到的模式
for k in s:
try:
# 检查模式内是否包含 "func"
if "func" in k:
# 分割并处理函数定义
var = k.split("=", 1)
try:
# 去除空白字符并执行函数定义
var[1] = var[1].strip(" ")
exec(var[1], globals())
# 调用函数获取结果
ans = func()
except:
# 特殊处理包含 'sympy' 的情况
if 'sympy' in var[1]:
var[1] = var[1].replace('res[x]', 'res[0][0]').replace('res[y]', 'res[0][1]')
exec(var[1], globals())
ans = func()
pass
var_list = [c.strip(" ") for c in var[0].split(",")]
# 如果只有一个变量名,则将结果放入列表
if len(var_list) == 1:
ans = [ans]
# 将结果转换为浮点数或整数形式,并替换到输入字符串中
for i in range(len(ans)):
try:
ans[i] = float(ans[i])
if abs(ans[i] - int(ans[i])) < 1e-10:
ans[i] = str(int(ans[i]))
except:
pass
# 替换原字符串中的模式和变量名
inputs = inputs.replace("<<"+k+">>", "")
for i in range(len(var_list)):
inputs = inputs.replace(var_list[i], str(ans[i]))
index += 1
# 更新后续模式中的变量值
for c in range(index, len(s)):
for i in range(len(var_list)):
s[c] = s[c].replace(var_list[i], str(ans[i]))
else:
# 处理非函数的情况,直接计算并替换
var = k.replace(" ", "").split("=")
var[1] = var[1].replace("eval", "")
ans = round(eval(var[1]), 10)
ans = float(ans)
if abs(ans - int(ans)) < 1e-10:
ans = str(int(ans))
# 替换原字符串中的模式和变量名
inputs = inputs.replace("<<"+k+">>", "").replace(var[0], str(ans))
index += 1
# 更新后续模式中的变量值
for c in range(index, len(s)):
s[c] = s[c].replace(var[0], str(ans))
except:
return inputs
except Exception as e:
return inputs
return inputs
class TextIterStreamer:
""" 实现文本的流式处理
能够逐个或逐段生成和输出文本,而不是一次性输出全部内容
"""
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False, use_pot=True):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
# 使用队列来缓存生成的文本片段,以便于逐块输出
self.text_queue = Queue()
self.next_tokens_are_prompt = True
# 是否使用特定的后处理技术(例如翻译或优化),默认为True
self.use_pot = use_pot
def put(self, value):
# 接收并处理生成的token值
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
tokens_str = self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens, errors='ignore')
if self.use_pot:
tokens_str = parse_pot_no_stream(tokens_str)
self.text_queue.put(tokens_str)
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
# 实现迭代器的下一步方法,从队列中获取并返回文本,
# 或在无更多内容时抛出StopIteration异常
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value
class OutputRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`OutputLogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
"""
def __init__(self, input_length: int,
presence_penalties: float = 1.0,
frequency_penalties: float = 0,
repetition_penalties: float = 0):
if not (repetition_penalties > 0):
raise ValueError(f"`repetition_penalties` has to be a strictly positive float, but is {repetition_penalties}")
if not ( (frequency_penalties >= -2) and (frequency_penalties <= 2) ):
raise ValueError(f"`frequency_penalties` has to be [-2, 2], but is {frequency_penalties}")
if not ( (presence_penalties >= -2) and (presence_penalties <= 2) ):
raise ValueError(f"`presence_penalties` has to be [-2, 2], but is {presence_penalties}")
self.repetition_penalties = repetition_penalties
self.frequency_penalties = frequency_penalties
self.presence_penalties = presence_penalties
self.input_length = input_length
def _get_bin_counts_and_mask(
self,
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=tokens.device)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
prompt_tokens_tensor = input_ids[:, :self.input_length+1]
output_tokens_tensor = input_ids[:, self.input_length+1:]
num_seqs, vocab_size = logits.shape
_, prompt_mask = self._get_bin_counts_and_mask(
prompt_tokens_tensor, vocab_size, num_seqs)
output_bin_counts, output_mask = self._get_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = torch.Tensor([self.repetition_penalties]).to(logits.device)
frequency_penalties = torch.Tensor([self.frequency_penalties]).to(logits.device)
presence_penalties = torch.Tensor([self.presence_penalties]).to(logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
return logits |