usagent100 commited on
Commit
d0d714b
·
verified ·
1 Parent(s): 567a500

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +22 -17
handler.py CHANGED
@@ -1,12 +1,10 @@
1
- from typing import Any, Dict, List
2
-
3
  import torch
4
  import transformers
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
-
7
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
8
-
9
-
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
@@ -18,20 +16,27 @@ class EndpointHandler:
18
  torch_dtype=dtype,
19
  trust_remote_code=True,
20
  )
21
-
22
- generation_config = model.generation_config
23
- generation_config.max_new_tokens = 1000
24
- generation_config.temperature = 0
25
- generation_config.num_return_sequences = 1
26
- generation_config.pad_token_id = tokenizer.eos_token_id
27
- generation_config.eos_token_id = tokenizer.eos_token_id
28
- self.generation_config = generation_config
29
-
30
  self.pipeline = transformers.pipeline(
31
  "text-generation", model=model, tokenizer=tokenizer
32
  )
33
-
34
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
35
  prompt = data.pop("inputs", data)
36
- result = self.pipeline(prompt, generation_config=self.generation_config)
37
- return result
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
 
2
  import torch
3
  import transformers
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
7
+
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
 
16
  torch_dtype=dtype,
17
  trust_remote_code=True,
18
  )
19
+
20
+ self.generation_config = model.generation_config
21
+ self.generation_config.max_new_tokens = 1000
22
+ self.generation_config.temperature = 0.7 # Changed from 0 to 0.7
23
+ self.generation_config.num_return_sequences = 1
24
+ self.generation_config.pad_token_id = tokenizer.eos_token_id
25
+ self.generation_config.eos_token_id = tokenizer.eos_token_id
26
+
 
27
  self.pipeline = transformers.pipeline(
28
  "text-generation", model=model, tokenizer=tokenizer
29
  )
30
+
31
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
32
  prompt = data.pop("inputs", data)
33
+ result = self.pipeline(
34
+ prompt,
35
+ max_length=1000, # Added this line to set max_length
36
+ temperature=0.7, # Added this line to set temperature
37
+ top_p=0.9, # Added this line to set top_p
38
+ num_return_sequences=1, # Added this line to set num_return_sequences
39
+ pad_token_id=self.generation_config.pad_token_id,
40
+ eos_token_id=self.generation_config.eos_token_id
41
+ )
42
+ return {"generated_text": result}