Syko commited on
Commit
1ce4cfe
·
verified ·
1 Parent(s): b56fdd4

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +21 -23
handler.py CHANGED
@@ -1,29 +1,27 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
 
 
3
 
4
- class EndpointHandler:
5
- def __init__(self, path):
6
- # Load tokenizer and model
7
- self.tokenizer = AutoTokenizer.from_pretrained(path)
8
- self.model = AutoModelForCausalLM.from_pretrained(path)
9
 
10
- def __call__(self, inputs):
11
- # Parse input
12
- input_text = inputs.get("inputs", "")
13
- parameters = inputs.get("parameters", {})
14
- max_new_tokens = parameters.get("max_new_tokens", 50)
15
- temperature = parameters.get("temperature", 0.7)
16
 
17
- # Tokenize input
18
- input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
 
 
 
 
 
19
 
20
- # Generate output
21
- output = self.model.generate(
22
- input_ids,
23
- max_new_tokens=max_new_tokens,
24
- temperature=temperature,
25
- )
26
 
27
- # Decode output
28
- output_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
29
- return {"generated_text": output_text}
 
 
 
 
 
 
1
  import torch
2
+ from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
 
5
+ # check for GPU
6
+ device = 0 if torch.cuda.is_available() else -1
 
 
 
7
 
 
 
 
 
 
 
8
 
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # load the model
12
+ tokenizer = AutoTokenizer.from_pretrained(path)
13
+ model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
14
+ # create inference pipeline
15
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
16
 
17
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
+ inputs = data.pop("inputs", data)
19
+ parameters = data.pop("parameters", None)
 
 
 
20
 
21
+ # pass inputs with all kwargs in data
22
+ if parameters is not None:
23
+ prediction = self.pipeline(inputs, **parameters)
24
+ else:
25
+ prediction = self.pipeline(inputs)
26
+ # postprocess the prediction
27
+ return prediction