from transformers import pipeline | |
import json | |
# Load the model | |
def model_fn(model_dir): | |
classifier = pipeline("text-classification", model=model_dir) | |
return classifier | |
# Handle requests | |
def predict_fn(input_data, model): | |
# Parse input data | |
data = json.loads(input_data) | |
texts = data.get("inputs", []) | |
predictions = model(texts) | |
return predictions | |