File size: 5,642 Bytes
aadb779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
from typing import List, Tuple
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import json
import os

class MajorMlpClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, device="cpu", seed=42, epochs=200, patience:int=None):
        super().__init__()
        self.device = device
        self.seed = seed
        self.model = None
        self.epochs = epochs
        self.patience = patience if patience is not None else epochs
        self.class_labels = None


    def _preprocess_features(self, X: np.ndarray) -> np.ndarray:
        return torch.from_numpy(X).to(self.device)

    def _preprocess_labels(self, y: List[str]) -> np.ndarray:
        unique_labels = np.array(self._get_classes(y))
        one_hot = np.array([
            unique_labels == label
            for label in y
        ], dtype="float32")
        
        return torch.from_numpy(one_hot).to(self.device)

    def _get_classes(self, y: List[str]) -> List[str]:
        return sorted(set(y))
    
    def fit(self, X:np.ndarray, y:List[str]):
        """
        Args:
            X: embeddings of shape (n_sentences, embedding_size)
            y: program labels that match with each sentence
        """
        self.class_labels = np.array(self._get_classes(y))
        class_weights = compute_class_weight("balanced", classes=self.class_labels, y=y).astype("float32")
        class_weights = torch.from_numpy(class_weights).to(self.device)
        X, y = self._preprocess_features(X), self._preprocess_labels(y)
        x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=self.seed, shuffle=True)
        should_stop = EarlyStopping(self.patience)
        val_loss = np.inf
        model = ProgramClassifierNetwork(x_train.shape[1], y_train.shape[1])
        model = model.to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        epoch = 0
        while not should_stop.step(val_loss) and epoch < self.epochs:
            preds = model(x_train)
            loss = criterion(preds, y_train)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                val_preds = model(x_val)
                val_loss = criterion(val_preds, y_val).item()
            epoch += 1
        model.eval()
        self.model = model

    def predict_proba(self, X:np.ndarray) -> np.ndarray:
        X = self._preprocess_features(X)
        if self.model is None:
            raise Exception("Train model with fit() before predicting.")
        with torch.no_grad():
            logits = self.model(X)
            return F.softmax(logits, dim=-1).cpu().numpy()
    
    def predict(self, X:np.ndarray) -> List[str]:
        """
        Args:
            X: embeddings of shape (n_sentences, embedding_size)
        Returns:
            predicted classes for each embedding
        """
        pred_i = self.predict_proba(X).argmax(-1)
        return self.class_labels[pred_i]

    def save_weights(self,path:str):
        os.makedirs(path, exist_ok=True)
        weights_path = os.path.join(path, "weights.pt")
        config_path = os.path.join(path,"config.json")
        torch.save(self.model.state_dict(), weights_path)
        state = {
            "device": self.device,
            "seed": self.seed,
            "epochs": self.epochs,
            "patience": self.patience,
            "class_labels": list(self.class_labels)
        }
        with open(config_path, "w") as f:
            json.dump(state, f)


    def load_weights(self, path:str):
        weights_path = os.path.join(path, "weights.pt")
        config_path = os.path.join(path,"config.json")
        state_dict = torch.load(weights_path)
        input_size = int(state_dict["input_size"].item())
        n_classes = int(state_dict["n_classes"].item())
        model = ProgramClassifierNetwork(input_size,n_classes).to(self.device)
        model.load_state_dict(state_dict)
        model.eval()
        self.model = model
        with open(config_path, "r") as f:
            config = json.load(f)
        config["class_labels"] = np.array(config["class_labels"]) if config["class_labels"] is not None else None
        self.__dict__.update(config)


        


class ProgramClassifierNetwork(nn.Module):
    def __init__(self, input_size:int, n_classes:int) -> None:
        super().__init__()
        self.input_size = nn.Parameter(torch.Tensor([input_size]), requires_grad=False)
        self.n_classes = nn.Parameter(torch.Tensor([n_classes]), requires_grad=False)
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(input_size),
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, n_classes),
        )


    def forward(self,x):
        return self.classifier(x)

class EarlyStopping:
    def __init__(self, patience=0):
        self.patience = patience
        self.last_measure = np.inf
        self.consecutive_increase = 0
    
    def step(self, val) -> bool:
        if self.last_measure <= val:
            self.consecutive_increase +=1
        else:
            self.consecutive_increase = 0
        self.last_measure = val

        return self.patience < self.consecutive_increase