Syko commited on
Commit
b56fdd4
·
verified ·
1 Parent(s): e0a3333

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -22
handler.py CHANGED
@@ -1,28 +1,29 @@
1
- import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
 
4
- # Load model and tokenizer
5
- model_name = "Syko/SykoNaught-v1"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
8
 
9
- def handle(inputs):
10
- """
11
- Handle incoming inference requests.
12
- """
13
- input_text = inputs.get("inputs", "")
14
- max_new_tokens = inputs.get("parameters", {}).get("max_new_tokens", 50)
15
- temperature = inputs.get("parameters", {}).get("temperature", 0.7)
16
 
17
- # Tokenize input
18
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids
19
 
20
- # Generate output
21
- output = model.generate(
22
- input_ids,
23
- max_new_tokens=max_new_tokens,
24
- temperature=temperature,
25
- )
26
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
27
 
28
- return {"generated_text": output_text}
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, path):
6
+ # Load tokenizer and model
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
8
+ self.model = AutoModelForCausalLM.from_pretrained(path)
9
 
10
+ def __call__(self, inputs):
11
+ # Parse input
12
+ input_text = inputs.get("inputs", "")
13
+ parameters = inputs.get("parameters", {})
14
+ max_new_tokens = parameters.get("max_new_tokens", 50)
15
+ temperature = parameters.get("temperature", 0.7)
 
16
 
17
+ # Tokenize input
18
+ input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
19
 
20
+ # Generate output
21
+ output = self.model.generate(
22
+ input_ids,
23
+ max_new_tokens=max_new_tokens,
24
+ temperature=temperature,
25
+ )
 
26
 
27
+ # Decode output
28
+ output_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
29
+ return {"generated_text": output_text}