sklearn-transformers / pipeline.py
merve's picture
merve HF Staff
Update pipeline.py
78460d0
raw
history blame
572 Bytes
import json
from typing import Any, Dict, List
import sklearn
import os
import joblib
import numpy as np
class PreTrainedPipeline():
def __init__(self, path: str):
# load the model
self.model = joblib.load((os.path.join(path, "pipeline.pkl"))
def __call__(self, inputs: str):
predictions = self.model.predict_proba([inputs])
labels = []
for cls in predictions[0]:
labels.append({
"label": f"LABEL_{cls}",
"score": predictions[0][cls],
})
return labels