NiCEtmtm commited on
Commit
b166186
·
verified ·
1 Parent(s): 2621401

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -0
handler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+ import os
4
+ import subprocess
5
+
6
+
7
+ # Manually install bitsandbytes
8
+ try:
9
+ import bitsandbytes
10
+ except ImportError:
11
+ subprocess.check_call([os.sys.executable, "-m", "pip", "install", "bitsandbytes==0.39.1"])
12
+ subprocess.check_call([os.sys.executable, "-m", "pip", "install", "accelerate==0.20.0"])
13
+
14
+ class ModelHandler:
15
+ def __init__(self):
16
+ self.model = None
17
+ self.tokenizer = None
18
+
19
+ def load_model(self):
20
+ # Load token as env var
21
+ model_id = "NiCETmtm/Llama3_kw_gen_new"
22
+ token = os.getenv("HF_API_TOKEN")
23
+ # Load model & tokenizer
24
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token, from_tf=True)
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
26
+
27
+ def predict(self, inputs):
28
+ tokens = self.tokenizer(inputs, return_tensors="pt")
29
+ with torch.no_grad():
30
+ outputs = self.model.generate(**tokens)
31
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
32
+
33
+
34
+ model_handler = ModelHandler()
35
+ model_handler.load_model()
36
+
37
+ def inference(event, context):
38
+ inputs = event["data"]
39
+ outputs = model_handler.predict(inputs)
40
+ return {"predictions": outputs}