Spaces:
Running
Running
Last commit not found
import gc | |
import time | |
import uuid | |
from threading import Thread | |
from types import MethodType | |
from typing import Iterable, Dict, Any | |
import torch | |
from transformers import ( | |
TextIteratorStreamer, | |
PreTrainedModel, | |
PreTrainedTokenizer, | |
) | |
from api.generation.qwen import check_is_qwen | |
from api.generation.utils import ( | |
prepare_logits_processor, | |
is_partial_stop, | |
apply_stopping_strings, | |
) | |
def generate_stream( | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer, | |
params: Dict[str, Any], | |
): | |
# Read parameters | |
input_ids = params.get("inputs") | |
prompt = params.get("prompt") | |
model_name = params.get("model", "llm") | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
top_k = int(params.get("top_k", -1)) # -1 means disable | |
max_new_tokens = int(params.get("max_tokens", 256)) | |
logprobs = params.get("logprobs") | |
echo = bool(params.get("echo", True)) | |
stop_str = params.get("stop") | |
stop_token_ids = params.get("stop_token_ids") or [] | |
if tokenizer.eos_token_id not in stop_token_ids: | |
stop_token_ids.append(tokenizer.eos_token_id) | |
logits_processor = prepare_logits_processor( | |
temperature, repetition_penalty, top_p, top_k | |
) | |
output_ids = list(input_ids) | |
input_echo_len = len(input_ids) | |
device = model.device | |
if model.config.is_encoder_decoder: | |
encoder_output = model.encoder( | |
input_ids=torch.as_tensor([input_ids], device=device) | |
)[0] | |
start_ids = torch.as_tensor( | |
[[model.generation_config.decoder_start_token_id]], | |
dtype=torch.int64, | |
device=device, | |
) | |
else: | |
start_ids = torch.as_tensor([input_ids], device=device) | |
past_key_values, sent_interrupt = None, False | |
token_logprobs = [None] # The first token has no logprobs. | |
completion_id: str = f"cmpl-{str(uuid.uuid4())}" | |
created: int = int(time.time()) | |
previous_text = "" | |
for i in range(max_new_tokens): | |
if i == 0: # prefill | |
if model.config.is_encoder_decoder: | |
out = model.decoder( | |
input_ids=start_ids, | |
encoder_hidden_states=encoder_output, | |
use_cache=True, | |
) | |
logits = model.lm_head(out[0]) | |
else: | |
out = model(torch.as_tensor([input_ids], device=device), use_cache=True) | |
logits = out.logits | |
past_key_values = out.past_key_values | |
if logprobs is not None: | |
# Prefull logprobs for the prompt. | |
shift_input_ids = start_ids[..., 1:].contiguous() | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() | |
for label_id, logit in zip( | |
shift_input_ids[0].tolist(), shift_logits[0] | |
): | |
token_logprobs.append(logit[label_id]) | |
else: # decoding | |
if model.config.is_encoder_decoder: | |
out = model.decoder( | |
input_ids=torch.as_tensor( | |
[output_ids if sent_interrupt else [token]], device=device | |
), | |
encoder_hidden_states=encoder_output, | |
use_cache=True, | |
past_key_values=None if sent_interrupt else past_key_values, | |
) | |
sent_interrupt = False | |
logits = model.lm_head(out[0]) | |
else: | |
out = model( | |
input_ids=torch.as_tensor( | |
[output_ids if sent_interrupt else [token]], device=device | |
), | |
use_cache=True, | |
past_key_values=None if sent_interrupt else past_key_values, | |
) | |
sent_interrupt = False | |
logits = out.logits | |
past_key_values = out.past_key_values | |
if logits_processor: | |
if repetition_penalty > 1.0: | |
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) | |
else: | |
tmp_output_ids = None | |
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] | |
else: | |
last_token_logits = logits[0, -1, :] | |
if device == "mps": | |
# Switch to CPU by avoiding some bugs in mps backend. | |
last_token_logits = last_token_logits.float().to("cpu") | |
if temperature < 1e-5 or top_p < 1e-8: # greedy | |
_, indices = torch.topk(last_token_logits, 2) | |
tokens = [int(index) for index in indices.tolist()] | |
else: | |
probs = torch.softmax(last_token_logits, dim=-1) | |
indices = torch.multinomial(probs, num_samples=2) | |
tokens = [int(token) for token in indices.tolist()] | |
token = tokens[0] | |
output_ids.append(token) | |
if logprobs is not None: | |
# Cannot use last_token_logits because logprobs is based on raw logits. | |
token_logprobs.append( | |
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() | |
) | |
if token in stop_token_ids: | |
stopped = True | |
else: | |
stopped = False | |
# Yield the output tokens | |
if i % 2 == 0 or i == max_new_tokens - 1 or stopped: | |
if echo: | |
tmp_output_ids = output_ids | |
rfind_start = len(prompt) | |
else: | |
tmp_output_ids = output_ids[input_echo_len:] | |
rfind_start = 0 | |
output = tokenizer.decode( | |
tmp_output_ids, | |
skip_special_tokens=False if check_is_qwen(model) else True, # fix for qwen react | |
spaces_between_special_tokens=False, | |
clean_up_tokenization_spaces=True, | |
) | |
ret_logprobs = None | |
if logprobs is not None: | |
ret_logprobs = { | |
"text_offset": [], | |
"tokens": [ | |
tokenizer.decode(token) | |
for token in ( | |
output_ids if echo else output_ids[input_echo_len:] | |
) | |
], | |
"token_logprobs": token_logprobs if echo else token_logprobs[input_echo_len:], | |
"top_logprobs": [{}] * len(token_logprobs if echo else token_logprobs[input_echo_len:]), | |
} | |
# Compute text_offset | |
curr_pos = 0 | |
for text in ret_logprobs["tokens"]: | |
ret_logprobs["text_offset"].append(curr_pos) | |
curr_pos += len(text) | |
partially_stopped, finish_reason = False, None | |
if stop_str: | |
if isinstance(stop_str, str): | |
pos = output.rfind(stop_str, rfind_start) | |
if pos != -1: | |
output = output[:pos] | |
stopped = True | |
else: | |
partially_stopped = is_partial_stop(output, stop_str) | |
elif isinstance(stop_str, Iterable): | |
for each_stop in stop_str: | |
pos = output.rfind(each_stop, rfind_start) | |
if pos != -1: | |
output = output[:pos] | |
stopped = True | |
if each_stop == "Observation:": | |
finish_reason = "function_call" | |
break | |
else: | |
partially_stopped = is_partial_stop(output, each_stop) | |
if partially_stopped: | |
break | |
else: | |
raise ValueError("Invalid stop field type.") | |
# Prevent yielding partial stop sequence | |
if (not partially_stopped) and output and output[-1] != "�": | |
delta_text = output[len(previous_text):] | |
previous_text = output | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"delta": delta_text, | |
"text": output, | |
"logprobs": ret_logprobs, | |
"finish_reason": finish_reason, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
} | |
if stopped: | |
break | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"delta": "", | |
"text": output, | |
"logprobs": ret_logprobs, | |
"finish_reason": "stop", | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
} | |
# Clean | |
del past_key_values, out | |
gc.collect() | |
torch.cuda.empty_cache() | |
def generate_stream_v2( | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer, | |
params: Dict[str, Any], | |
): | |
input_ids = params.get("inputs") | |
functions = params.get("functions") | |
model_name = params.get("model", "llm") | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
top_k = int(params.get("top_k", 40)) | |
max_new_tokens = int(params.get("max_tokens", 256)) | |
stop_token_ids = params.get("stop_token_ids") or [] | |
if tokenizer.eos_token_id not in stop_token_ids: | |
stop_token_ids.append(tokenizer.eos_token_id) | |
stop_strings = params.get("stop", []) | |
input_echo_len = len(input_ids) | |
device = model.device | |
generation_kwargs = dict( | |
input_ids=torch.tensor([input_ids], device=device), | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.pad_token_id, | |
) | |
if temperature <= 1e-5: | |
generation_kwargs["do_sample"] = False | |
generation_kwargs.pop("top_k") | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generation_kwargs["streamer"] = streamer | |
if "GenerationMixin" not in str(model.generate.__func__): | |
model.generate = MethodType(PreTrainedModel.generate, model) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
generated_text, func_call_found = "", False | |
completion_id: str = f"cmpl-{str(uuid.uuid4())}" | |
created: int = int(time.time()) | |
previous_text = "" | |
for i, new_text in enumerate(streamer): | |
generated_text += new_text | |
if functions: | |
_, func_call_found = apply_stopping_strings(generated_text, ["Observation:"]) | |
generated_text, stop_found = apply_stopping_strings(generated_text, stop_strings) | |
if generated_text and generated_text[-1] != "�": | |
delta_text = generated_text[len(previous_text):] | |
previous_text = generated_text | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"delta": delta_text, | |
"text": generated_text, | |
"logprobs": None, | |
"finish_reason": "function_call" if func_call_found else None, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
} | |
if stop_found: | |
break | |
yield { | |
"id": completion_id, | |
"object": "text_completion", | |
"created": created, | |
"model": model_name, | |
"delta": "", | |
"text": generated_text, | |
"logprobs": None, | |
"finish_reason": "stop", | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
} | |