jeffreykthomas commited on
Commit
96f6ccb
·
1 Parent(s): 2ad1387

Refactored handler

Browse files
Files changed (1) hide show
  1. handler.py +10 -18
handler.py CHANGED
@@ -1,35 +1,27 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
- from transformers import GenerationConfig
3
  import torch
4
  from typing import Any, Dict
5
 
 
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
- self.tokenizer = AutoTokenizer.from_pretrained(path)
10
- self.model = AutoModelForCausalLM.from_pretrained(path,
11
- torch_dtype=torch.float16,
12
- device_map="auto")
13
 
14
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  def __call__(self, data: Dict[str, Any]) -> [str]:
17
- input_text = data.pop("inputs", data)
18
  generation_config = GenerationConfig(
 
19
  max_new_tokens=250, do_sample=True, top_k=50,
20
- eos_token_id=self.model.config.eos_token_id,
21
  temperature=0.8, pad_token_id=2, num_return_sequences=1,
22
  min_new_tokens=30, repetition_penalty=1.2,
23
  )
24
-
25
- self.model.generation_config = generation_config
26
- inputs = self.tokenizer(input_text, return_tensors="pt")
27
- inputs = {key: val.to(self.device) for key, val in inputs.items()}
28
- outputs = self.model.generate(**inputs)
29
-
30
- decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
31
 
32
- # remove the inputs from outputs
33
- decoded_output = decoded_output.replace(input_text + ' Expert: ', '')
34
 
35
- return [{'generated_text': decoded_output}]
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from transformers import GenerationConfig, pipeline
3
  import torch
4
  from typing import Any, Dict
5
 
6
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
7
+
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ tokenizer = AutoTokenizer.from_pretrained(path)
12
+ model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=dtype)
 
 
13
 
14
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
15
 
16
  def __call__(self, data: Dict[str, Any]) -> [str]:
17
+ inputs = data.pop("inputs", data)
18
  generation_config = GenerationConfig(
19
+ max_length=1024,
20
  max_new_tokens=250, do_sample=True, top_k=50,
 
21
  temperature=0.8, pad_token_id=2, num_return_sequences=1,
22
  min_new_tokens=30, repetition_penalty=1.2,
23
  )
 
 
 
 
 
 
 
24
 
25
+ output = self.pipeline(inputs, **generation_config.to_dict())
 
26
 
27
+ return output