|
import torch
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
class ModelHandler:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.tokenizer = None
|
|
|
|
def initialize(self, model_path):
|
|
"""Load model and tokenizer."""
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
self.model = AutoModel.from_pretrained(model_path)
|
|
|
|
def preprocess(self, data):
|
|
"""Preprocess input data."""
|
|
text = data.get("text", "")
|
|
inputs = self.tokenizer(text, return_tensors="pt")
|
|
return inputs
|
|
|
|
def inference(self, inputs):
|
|
"""Run inference on the model."""
|
|
outputs = self.model(**inputs)
|
|
return outputs
|
|
|
|
def postprocess(self, outputs):
|
|
"""Postprocess model output."""
|
|
return {"output": outputs.logits.tolist()}
|
|
|
|
_handler = ModelHandler()
|
|
|
|
def handle(data, context):
|
|
if not _handler.model:
|
|
model_path = context.system_properties.get("model_dir")
|
|
_handler.initialize(model_path)
|
|
|
|
if data is None:
|
|
return {"error": "No input data"}
|
|
|
|
inputs = _handler.preprocess(data[0])
|
|
outputs = _handler.inference(inputs)
|
|
return _handler.postprocess(outputs)
|
|
|