FIRE / src /model /model_yuan2.py
zhangbofei
fix: import
a22404f
raw
history blame
4.38 kB
import gc
from threading import Thread
from typing import Iterable
import torch
import transformers
from transformers import TextIteratorStreamer, GenerationConfig
from src.utils import is_partial_stop
@torch.inference_mode()
def generate_stream_yuan2(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
len_prompt = len(prompt)
temperature = float(params.get("temperature", 1))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 0))
top_k = int(params.get("top_k", 1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 512))
stop_str = params.get("stop", "<eod>")
echo = bool(params.get("echo", True))
stop_token_ids = params.get("stop_token_ids", None) or []
stop_token_ids.append(tokenizer("<eod>")["input_ids"][0])
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:] # truncate from the left
attention_mask = attention_mask[-max_src_len:] # truncate from the left
input_echo_len = len(input_ids)
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=temperature >= 1.2,
temperature=temperature,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=10,
top_p=top_p,
top_k=top_k,
)
generation_kwargs = dict(
inputs=input_ids,
attention_mask=attention_mask,
streamer=streamer,
generation_config=generation_config,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
if echo:
# means keep the prompt
output = prompt
else:
output = ""
for i, new_text in enumerate(streamer):
output += new_text
if i % stream_interval == 0:
if echo:
rfind_start = len_prompt
else:
rfind_start = 0
partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
output = output.strip()
# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif partially_stopped:
finish_reason = None
else:
finish_reason = "stop"
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()