pygmalion7b-20230517 / handler.py
karlbooster's picture
Update handler.py
572062d
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers import StoppingCriteria, StoppingCriteriaList
# get dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = [stop.to("cuda") for stop in stops]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
stop_len = len(stop)
if input_ids.shape[1] >= stop_len:
if torch.all(stop == input_ids[:, -stop_len:]).item():
return True
return False
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",torch_dtype=dtype)
print("model loaded")
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
if parameters is None:
parameters = {}
prompt = inputs
temperature = parameters.get("temperature", 0.8)
top_p = parameters.get("top_p", 0.9)
top_k = parameters.get("top_k", 0)
max_new_tokens = parameters.get("max_new_tokens", 100)
repetition_penalty=parameters.get("diversity_penalty",1.1)
max_length=parameters.get("max_length",2048)
stop_words = parameters.get("stop_words", [])
num_return_sequences=parameters.get("num_return_sequences",1)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
max_length=max_length,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
do_sample=True
)
# Tokenize inputs
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt",max_length=max_length-max_new_tokens, truncation=True).to(self.model.device)
# Decode truncated prompt
truncated_prompt = self.tokenizer.decode(input_tokens.squeeze(), skip_special_tokens=True)
stop_words_ids = [self.tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
# Create attention mask
attention_mask = torch.ones_like(input_tokens).to(self.model.device)
# Run the model
output = self.model.generate(input_tokens,
generation_config=generation_config,
stopping_criteria=stopping_criteria,
attention_mask=attention_mask,
)
#only return the part after the prompt
output_text = self.tokenizer.batch_decode(output, skip_special_tokens = True)[0][len(truncated_prompt):]
return [{"generated_text": output_text}]