File size: 1,053 Bytes
aadb779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from classifiers.bert import BertClassifier
from classifiers.mlp import MajorMlpClassifier
from embeddings.bert import BertSentenceEmbedder
import pickle
from helper import load_data


def train_bert_classifier(
    device="cpu",
    n_classes=40,
    include_majors=[],
    epochs=25
):
    sentences, labels = load_data(num_majors=n_classes, include_majors=include_majors)
    bert_classifier = BertClassifier(device=device, epochs=epochs)
    bert_classifier.fit(sentences, labels)


def train_major_classifier(
    device="cpu",
    n_classes=40,
    include_majors=[],
    epochs=200
):
    sentences, labels = load_data(num_majors=n_classes, include_majors=include_majors)
    embedder = BertSentenceEmbedder(device, padding_length=1000)
    embeddings = embedder.transform(sentences)
    mlp = MajorMlpClassifier(device, epochs=epochs)
    mlp.fit(embeddings,labels)
    mlp.save_weights("weights/major_classifier")


if __name__ == "__main__":
    train_major_classifier(device="mps", include_majors=["Computer Science", "Computer Engineering"])