sklearn-transformers / pipeline.py
merve's picture
merve HF Staff
Update pipeline.py
d9266df
raw
history blame
610 Bytes
!pip install whatlies
import json
from typing import Any, Dict, List
import sklearn
import os
import joblib
import numpy as np
import whatlies
class PreTrainedPipeline():
def __init__(self, path: str):
# load the model
self.model = joblib.load(os.path.join(path, "pipeline.pkl"))
def __call__(self, inputs: str):
predictions = self.model.predict_proba([inputs])
labels = []
for cls in predictions[0]:
labels.append({
"label": f"LABEL_{cls}",
"score": predictions[0][cls],
})
return labels