waidhoferj's picture
first commit
aadb779
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np
from typing import List, Tuple
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import json
import os
class MajorMlpClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, device="cpu", seed=42, epochs=200, patience:int=None):
super().__init__()
self.device = device
self.seed = seed
self.model = None
self.epochs = epochs
self.patience = patience if patience is not None else epochs
self.class_labels = None
def _preprocess_features(self, X: np.ndarray) -> np.ndarray:
return torch.from_numpy(X).to(self.device)
def _preprocess_labels(self, y: List[str]) -> np.ndarray:
unique_labels = np.array(self._get_classes(y))
one_hot = np.array([
unique_labels == label
for label in y
], dtype="float32")
return torch.from_numpy(one_hot).to(self.device)
def _get_classes(self, y: List[str]) -> List[str]:
return sorted(set(y))
def fit(self, X:np.ndarray, y:List[str]):
"""
Args:
X: embeddings of shape (n_sentences, embedding_size)
y: program labels that match with each sentence
"""
self.class_labels = np.array(self._get_classes(y))
class_weights = compute_class_weight("balanced", classes=self.class_labels, y=y).astype("float32")
class_weights = torch.from_numpy(class_weights).to(self.device)
X, y = self._preprocess_features(X), self._preprocess_labels(y)
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=self.seed, shuffle=True)
should_stop = EarlyStopping(self.patience)
val_loss = np.inf
model = ProgramClassifierNetwork(x_train.shape[1], y_train.shape[1])
model = model.to(self.device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights)
epoch = 0
while not should_stop.step(val_loss) and epoch < self.epochs:
preds = model(x_train)
loss = criterion(preds, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
val_preds = model(x_val)
val_loss = criterion(val_preds, y_val).item()
epoch += 1
model.eval()
self.model = model
def predict_proba(self, X:np.ndarray) -> np.ndarray:
X = self._preprocess_features(X)
if self.model is None:
raise Exception("Train model with fit() before predicting.")
with torch.no_grad():
logits = self.model(X)
return F.softmax(logits, dim=-1).cpu().numpy()
def predict(self, X:np.ndarray) -> List[str]:
"""
Args:
X: embeddings of shape (n_sentences, embedding_size)
Returns:
predicted classes for each embedding
"""
pred_i = self.predict_proba(X).argmax(-1)
return self.class_labels[pred_i]
def save_weights(self,path:str):
os.makedirs(path, exist_ok=True)
weights_path = os.path.join(path, "weights.pt")
config_path = os.path.join(path,"config.json")
torch.save(self.model.state_dict(), weights_path)
state = {
"device": self.device,
"seed": self.seed,
"epochs": self.epochs,
"patience": self.patience,
"class_labels": list(self.class_labels)
}
with open(config_path, "w") as f:
json.dump(state, f)
def load_weights(self, path:str):
weights_path = os.path.join(path, "weights.pt")
config_path = os.path.join(path,"config.json")
state_dict = torch.load(weights_path)
input_size = int(state_dict["input_size"].item())
n_classes = int(state_dict["n_classes"].item())
model = ProgramClassifierNetwork(input_size,n_classes).to(self.device)
model.load_state_dict(state_dict)
model.eval()
self.model = model
with open(config_path, "r") as f:
config = json.load(f)
config["class_labels"] = np.array(config["class_labels"]) if config["class_labels"] is not None else None
self.__dict__.update(config)
class ProgramClassifierNetwork(nn.Module):
def __init__(self, input_size:int, n_classes:int) -> None:
super().__init__()
self.input_size = nn.Parameter(torch.Tensor([input_size]), requires_grad=False)
self.n_classes = nn.Parameter(torch.Tensor([n_classes]), requires_grad=False)
self.classifier = nn.Sequential(
nn.BatchNorm1d(input_size),
nn.Linear(input_size, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, n_classes),
)
def forward(self,x):
return self.classifier(x)
class EarlyStopping:
def __init__(self, patience=0):
self.patience = patience
self.last_measure = np.inf
self.consecutive_increase = 0
def step(self, val) -> bool:
if self.last_measure <= val:
self.consecutive_increase +=1
else:
self.consecutive_increase = 0
self.last_measure = val
return self.patience < self.consecutive_increase