File size: 1,238 Bytes
8d8e32a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
from transformers import AutoModel, AutoTokenizer
class ModelHandler:
def __init__(self):
self.model = None
self.tokenizer = None
def initialize(self, model_path):
"""Load model and tokenizer."""
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModel.from_pretrained(model_path)
def preprocess(self, data):
"""Preprocess input data."""
text = data.get("text", "")
inputs = self.tokenizer(text, return_tensors="pt")
return inputs
def inference(self, inputs):
"""Run inference on the model."""
outputs = self.model(**inputs)
return outputs
def postprocess(self, outputs):
"""Postprocess model output."""
return {"output": outputs.logits.tolist()}
_handler = ModelHandler()
def handle(data, context):
if not _handler.model:
model_path = context.system_properties.get("model_dir")
_handler.initialize(model_path)
if data is None:
return {"error": "No input data"}
inputs = _handler.preprocess(data[0])
outputs = _handler.inference(inputs)
return _handler.postprocess(outputs)
|