File size: 3,112 Bytes
9c5415b
 
572062d
 
9c5415b
 
 
 
572062d
 
 
 
 
 
 
 
 
 
 
 
9c5415b
 
572062d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers import StoppingCriteria, StoppingCriteriaList

# get dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

class StoppingCriteriaSub(StoppingCriteria):
  def __init__(self, stops = [], encounters=1):
    super().__init__()
    self.stops = [stop.to("cuda") for stop in stops]

  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
    for stop in self.stops:
      stop_len = len(stop)
      if input_ids.shape[1] >= stop_len:
        if torch.all(stop == input_ids[:, -stop_len:]).item():
          return True
      return False

class EndpointHandler:
  def __init__(self, path=""):
    # load the model
    self.tokenizer = AutoTokenizer.from_pretrained(path)
    self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto",torch_dtype=dtype)
    print("model loaded")

  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
    inputs = data.pop("inputs", data)
    parameters = data.pop("parameters", None)
    if parameters is None:
        parameters = {}        

    prompt = inputs
    temperature = parameters.get("temperature", 0.8)
    top_p = parameters.get("top_p", 0.9)
    top_k = parameters.get("top_k", 0)
    max_new_tokens = parameters.get("max_new_tokens", 100)
    repetition_penalty=parameters.get("diversity_penalty",1.1)
    max_length=parameters.get("max_length",2048)
    stop_words = parameters.get("stop_words", [])
    num_return_sequences=parameters.get("num_return_sequences",1)

    generation_config = GenerationConfig(
      temperature=temperature,
      top_p=top_p,
      top_k=top_k,
      max_new_tokens=max_new_tokens,
      max_length=max_length,
      eos_token_id=self.tokenizer.eos_token_id,  
      pad_token_id=self.tokenizer.pad_token_id,          
      repetition_penalty=repetition_penalty,
      num_return_sequences=num_return_sequences,
      do_sample=True
      )        
    
    # Tokenize inputs
    input_tokens = self.tokenizer.encode(prompt, return_tensors="pt",max_length=max_length-max_new_tokens, truncation=True).to(self.model.device)

    # Decode truncated prompt
    truncated_prompt = self.tokenizer.decode(input_tokens.squeeze(), skip_special_tokens=True)

    
    stop_words_ids = [self.tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    
    # Create attention mask
    attention_mask = torch.ones_like(input_tokens).to(self.model.device)

    # Run the model
    output = self.model.generate(input_tokens, 
              generation_config=generation_config, 
              stopping_criteria=stopping_criteria,
              attention_mask=attention_mask,
          )
    #only return the part after the prompt
    output_text = self.tokenizer.batch_decode(output, skip_special_tokens = True)[0][len(truncated_prompt):]

    return [{"generated_text": output_text}]