Update handler.py
Browse files- handler.py +4 -4
handler.py
CHANGED
@@ -59,15 +59,15 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
59 |
# return response
|
60 |
|
61 |
class EndpointHandler:
|
62 |
-
def __init__(self):
|
63 |
# Load processor and model
|
64 |
self.PROCESSOR = AutoProcessor.from_pretrained(
|
65 |
-
|
66 |
trust_remote_code=True,
|
67 |
# token=API_TOKEN,
|
68 |
)
|
69 |
self.MODEL = AutoModelForCausalLM.from_pretrained(
|
70 |
-
|
71 |
# token=API_TOKEN,
|
72 |
trust_remote_code=True,
|
73 |
torch_dtype=torch.bfloat16,
|
@@ -99,7 +99,7 @@ class EndpointHandler:
|
|
99 |
# inputs = preprocess(model_inputs)
|
100 |
generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096)
|
101 |
generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
102 |
-
return {"
|
103 |
# return {"text":prediction[0]}
|
104 |
|
105 |
# @classmethod
|
|
|
59 |
# return response
|
60 |
|
61 |
class EndpointHandler:
|
62 |
+
def __init__(self,model_path:str):
|
63 |
# Load processor and model
|
64 |
self.PROCESSOR = AutoProcessor.from_pretrained(
|
65 |
+
model_path,
|
66 |
trust_remote_code=True,
|
67 |
# token=API_TOKEN,
|
68 |
)
|
69 |
self.MODEL = AutoModelForCausalLM.from_pretrained(
|
70 |
+
model_path,
|
71 |
# token=API_TOKEN,
|
72 |
trust_remote_code=True,
|
73 |
torch_dtype=torch.bfloat16,
|
|
|
99 |
# inputs = preprocess(model_inputs)
|
100 |
generated_ids = self.MODEL.generate(**inputs, bad_words_ids=self.BAD_WORDS_IDS, max_length=4096)
|
101 |
generated_text = self.PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
102 |
+
return {"text": generated_text}
|
103 |
# return {"text":prediction[0]}
|
104 |
|
105 |
# @classmethod
|