File size: 2,679 Bytes
d0dbf63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from sklearn.linear_model import LogisticRegression
import joblib
from huggingface_hub import hf_hub_download
from transformers import pipeline
import pandas as pd
class LogisticRegressionBuzzer:
def __init__(self) -> None:
self.model = self.load_from_hf_pkl()
self.features = BuzzerFeatures()
def load_from_hf_pkl(self) -> LogisticRegression:
REPO_ID = "nes470/pipeline-as-repo"
FILENAME = "logreg_buzzer_model.pkl"
model = joblib.load(
hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
)
return model
def predict_buzz(self, question, guess):
X = self.features.get_features(question, guess)
X_formatted = pd.DataFrame(X, index=[0])
pred = self.model.predict(X_formatted)
print(pred)
#use predict_proba to get confidence probabilities
prob_pred = self.model.predict_proba(X_formatted)
print(prob_pred)
return (pred, float(pred[0]))
class BuzzerFeatures:
def __init__(self) -> None:
self.ner = pipeline("ner")
#returns dict with all the features
def get_features(self, question, guess):
sent_count = self.sentence_count(question)
guess_word_count = self.guess_word_count(guess)
guess_has_paren = self.guess_has_paren(guess)
guess_length = self.guess_length(guess)
guess_entity = self.guess_entity(guess)
feats = {'sentence_count':sent_count, 'guess_word_count':guess_word_count,
'guess_has_paren':guess_has_paren, 'guess_length':guess_length}
X = feats | guess_entity
return X
def sentence_count(self, str):
return len(str.split("."))
def guess_word_count(self, str):
return len(str.split("_"))
def guess_has_paren(self, str):
return int("(" in str or ")" in str)
def guess_length(self, str):
return len(str)
def guess_entity(self, text):
entities = self.ner(text)
if len(entities) == 0:
type = "" # <-- use "None" instead TODO
else:
type = entities[0]["entity"]
if type == "":
return {'':1, 'I-LOC':0, 'I-MISC':0, 'I-ORG':0, 'I-PER':0}
if type == "I-LOC":
return {'':0, 'I-LOC':1, 'I-MISC':0, 'I-ORG':0, 'I-PER':0}
if type == "I-MISC":
return {'':0, 'I-LOC':0, 'I-MISC':1, 'I-ORG':0, 'I-PER':0}
if type == "I-ORG":
return {'':0, 'I-LOC':0, 'I-MISC':0, 'I-ORG':1, 'I-PER':0}
if type == "I-PER":
return {'':0, 'I-LOC':0, 'I-MISC':0, 'I-ORG':0, 'I-PER':1}
|