File size: 3,057 Bytes
51e2020 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from typing import Iterator
from variables import SYSTEM, HUMAN, AI
def load_tokenizer_and_model(base_model, load_8bit=True):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model, load_8bit=load_8bit)
return tokenizer, model, device
class State:
interrupted = False
def interrupt(self):
self.interrupted = True
def recover(self):
self.interrupted = False
shared_state = State()
def decode(
input_ids: torch.Tensor,
model: PeftModel,
tokenizer: AutoTokenizer,
stop_words: list,
max_length: int,
temperature: float = 1.0,
top_p: float = 1.0,
) -> Iterator[str]:
generated_tokens = []
past_key_values = None
for _ in range(max_length):
with torch.no_grad():
if past_key_values is None:
outputs = model(input_ids)
else:
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# apply temperature
logits /= temperature
probs = torch.softmax(logits, dim=-1)
# apply top_p
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
input_ids = torch.cat((input_ids, next_token), dim=-1)
generated_tokens.append(next_token[0].item())
text = tokenizer.decode(generated_tokens)
yield text
if any([x in text for x in stop_words]):
return
def get_prompt_with_history(text, history, tokenizer, max_length=2048):
prompt = SYSTEM
history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history]
history.append(f"\n{HUMAN} {text}\n{AI}")
history_text = ""
flag = False
for x in history[::-1]:
if (
tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(
-1
)
<= max_length
):
history_text = x + history_text
flag = True
else:
break
if flag:
return prompt + history_text, tokenizer(
prompt + history_text, return_tensors="pt"
)
else:
return None
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
for stop_word in stop_words:
if s.endswith(stop_word):
return True
for i in range(1, len(stop_word)):
if s.endswith(stop_word[:i]):
return True
return False |