Spaces:
Runtime error
Runtime error
from classifiers.bert import BertClassifier | |
from classifiers.mlp import MajorMlpClassifier | |
from embeddings.bert import BertSentenceEmbedder | |
import pickle | |
from helper import load_data | |
def train_bert_classifier( | |
device="cpu", | |
n_classes=40, | |
include_majors=[], | |
epochs=25 | |
): | |
sentences, labels = load_data(num_majors=n_classes, include_majors=include_majors) | |
bert_classifier = BertClassifier(device=device, epochs=epochs) | |
bert_classifier.fit(sentences, labels) | |
def train_major_classifier( | |
device="cpu", | |
n_classes=40, | |
include_majors=[], | |
epochs=200 | |
): | |
sentences, labels = load_data(num_majors=n_classes, include_majors=include_majors) | |
embedder = BertSentenceEmbedder(device, padding_length=1000) | |
embeddings = embedder.transform(sentences) | |
mlp = MajorMlpClassifier(device, epochs=epochs) | |
mlp.fit(embeddings,labels) | |
mlp.save_weights("weights/major_classifier") | |
if __name__ == "__main__": | |
train_major_classifier(device="mps", include_majors=["Computer Science", "Computer Engineering"]) |