Spaces:
Runtime error
Runtime error
from classifiers.mlp import MajorMlpClassifier | |
from embeddings.bert import BertSentenceEmbedder | |
from sklearn.neighbors import KNeighborsClassifier | |
from sklearn.neural_network import MLPClassifier | |
from classifiers.bert import BertClassifier | |
import pandas as pd | |
import numpy as np | |
from typing import Tuple | |
from sklearn.model_selection import train_test_split, cross_val_score | |
from sklearn.metrics import classification_report | |
from helper import load_data, get_recommendations, plot_confusion_matrix | |
import matplotlib.pyplot as plt | |
import os | |
device = "mps" | |
def evaluate(load_weights=False): | |
""" | |
Performs basic train/test split evaluation. | |
""" | |
os.makedirs("figures", exist_ok=True) | |
sentences, labels = load_data(num_majors=40) | |
embedder = BertSentenceEmbedder(device, padding_length=1000) | |
seed = 2 | |
x_train, x_test, y_train, y_test = train_test_split( | |
sentences, labels, random_state=seed, shuffle=True, train_size=0.8 | |
) | |
train_embeddings = embedder.transform(x_train) | |
test_embeddings = embedder.transform(x_test) | |
knn = KNeighborsClassifier() | |
mlp = MajorMlpClassifier(device) | |
bert_classifier = BertClassifier( | |
device=device, | |
epochs=25, | |
) | |
if load_weights: | |
mlp.load_weights("weights/major_classifier") | |
bert_classifier.load_weights("weights/bert_classifier_deployment_weights") | |
else: | |
bert_classifier.fit(x_train, y_train) | |
mlp.fit(train_embeddings, y_train) | |
knn.fit(train_embeddings, y_train) | |
class_labels = np.array(bert_classifier.labels) | |
def report(name, classifier, x, y, n=3): | |
probs = classifier.predict_proba(x) | |
ordered_choices = class_labels[(-probs).argsort(-1)[:, :n]] | |
preds = ordered_choices[:, 0] | |
print(name) | |
print( | |
f"Top {n} accuracy", | |
np.mean([label in choices for label, choices in zip(y, ordered_choices)]), | |
) | |
print(classification_report(y, preds)) | |
plot_confusion_matrix(y, preds, class_labels) | |
plt.savefig(f"figures/{name}_cm.png") | |
plt.clf() | |
report("bert_classifier", bert_classifier, x_test, y_test) | |
report("KNN", knn, test_embeddings, y_test) | |
report("major_mlp", mlp, test_embeddings, y_test) | |
def demo(): | |
""" | |
Interact with a model on the command line. | |
""" | |
bert_classifier = BertClassifier(device="mps") | |
weights_path = os.path.join("weights", "bert_classifier_deployment_weights") | |
bert_classifier.load_weights(weights_path) | |
while True: | |
command = input("Describe your ideal major: ") | |
if command.lower() == "q" or command.lower() == "quit": | |
break | |
probs = bert_classifier.predict_proba(command) | |
labels = bert_classifier.labels | |
print(get_recommendations(probs, labels, n=3)[0]) | |
if __name__ == "__main__": | |
evaluate() | |