major-matcher / test.py
waidhoferj's picture
first commit
aadb779
from classifiers.mlp import MajorMlpClassifier
from embeddings.bert import BertSentenceEmbedder
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from classifiers.bert import BertClassifier
import pandas as pd
import numpy as np
from typing import Tuple
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report
from helper import load_data, get_recommendations, plot_confusion_matrix
import matplotlib.pyplot as plt
import os
device = "mps"
def evaluate(load_weights=False):
"""
Performs basic train/test split evaluation.
"""
os.makedirs("figures", exist_ok=True)
sentences, labels = load_data(num_majors=40)
embedder = BertSentenceEmbedder(device, padding_length=1000)
seed = 2
x_train, x_test, y_train, y_test = train_test_split(
sentences, labels, random_state=seed, shuffle=True, train_size=0.8
)
train_embeddings = embedder.transform(x_train)
test_embeddings = embedder.transform(x_test)
knn = KNeighborsClassifier()
mlp = MajorMlpClassifier(device)
bert_classifier = BertClassifier(
device=device,
epochs=25,
)
if load_weights:
mlp.load_weights("weights/major_classifier")
bert_classifier.load_weights("weights/bert_classifier_deployment_weights")
else:
bert_classifier.fit(x_train, y_train)
mlp.fit(train_embeddings, y_train)
knn.fit(train_embeddings, y_train)
class_labels = np.array(bert_classifier.labels)
def report(name, classifier, x, y, n=3):
probs = classifier.predict_proba(x)
ordered_choices = class_labels[(-probs).argsort(-1)[:, :n]]
preds = ordered_choices[:, 0]
print(name)
print(
f"Top {n} accuracy",
np.mean([label in choices for label, choices in zip(y, ordered_choices)]),
)
print(classification_report(y, preds))
plot_confusion_matrix(y, preds, class_labels)
plt.savefig(f"figures/{name}_cm.png")
plt.clf()
report("bert_classifier", bert_classifier, x_test, y_test)
report("KNN", knn, test_embeddings, y_test)
report("major_mlp", mlp, test_embeddings, y_test)
def demo():
"""
Interact with a model on the command line.
"""
bert_classifier = BertClassifier(device="mps")
weights_path = os.path.join("weights", "bert_classifier_deployment_weights")
bert_classifier.load_weights(weights_path)
while True:
command = input("Describe your ideal major: ")
if command.lower() == "q" or command.lower() == "quit":
break
probs = bert_classifier.predict_proba(command)
labels = bert_classifier.labels
print(get_recommendations(probs, labels, n=3)[0])
if __name__ == "__main__":
evaluate()