lemur-7B / utils /inference.py
tianyang's picture
upload?
51e2020
raw
history blame
3.06 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from typing import Iterator
from variables import SYSTEM, HUMAN, AI
def load_tokenizer_and_model(base_model, load_8bit=True):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model, load_8bit=load_8bit)
return tokenizer, model, device
class State:
interrupted = False
def interrupt(self):
self.interrupted = True
def recover(self):
self.interrupted = False
shared_state = State()
def decode(
input_ids: torch.Tensor,
model: PeftModel,
tokenizer: AutoTokenizer,
stop_words: list,
max_length: int,
temperature: float = 1.0,
top_p: float = 1.0,
) -> Iterator[str]:
generated_tokens = []
past_key_values = None
for _ in range(max_length):
with torch.no_grad():
if past_key_values is None:
outputs = model(input_ids)
else:
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# apply temperature
logits /= temperature
probs = torch.softmax(logits, dim=-1)
# apply top_p
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
input_ids = torch.cat((input_ids, next_token), dim=-1)
generated_tokens.append(next_token[0].item())
text = tokenizer.decode(generated_tokens)
yield text
if any([x in text for x in stop_words]):
return
def get_prompt_with_history(text, history, tokenizer, max_length=2048):
prompt = SYSTEM
history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history]
history.append(f"\n{HUMAN} {text}\n{AI}")
history_text = ""
flag = False
for x in history[::-1]:
if (
tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(
-1
)
<= max_length
):
history_text = x + history_text
flag = True
else:
break
if flag:
return prompt + history_text, tokenizer(
prompt + history_text, return_tensors="pt"
)
else:
return None
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
for stop_word in stop_words:
if s.endswith(stop_word):
return True
for i in range(1, len(stop_word)):
if s.endswith(stop_word[:i]):
return True
return False