Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from sklearn.base import BaseEstimator, ClassifierMixin | |
import numpy as np | |
from typing import List, Tuple | |
import pandas as pd | |
from sklearn.model_selection import train_test_split | |
from sklearn.utils.class_weight import compute_class_weight | |
import json | |
import os | |
class MajorMlpClassifier(BaseEstimator, ClassifierMixin): | |
def __init__(self, device="cpu", seed=42, epochs=200, patience:int=None): | |
super().__init__() | |
self.device = device | |
self.seed = seed | |
self.model = None | |
self.epochs = epochs | |
self.patience = patience if patience is not None else epochs | |
self.class_labels = None | |
def _preprocess_features(self, X: np.ndarray) -> np.ndarray: | |
return torch.from_numpy(X).to(self.device) | |
def _preprocess_labels(self, y: List[str]) -> np.ndarray: | |
unique_labels = np.array(self._get_classes(y)) | |
one_hot = np.array([ | |
unique_labels == label | |
for label in y | |
], dtype="float32") | |
return torch.from_numpy(one_hot).to(self.device) | |
def _get_classes(self, y: List[str]) -> List[str]: | |
return sorted(set(y)) | |
def fit(self, X:np.ndarray, y:List[str]): | |
""" | |
Args: | |
X: embeddings of shape (n_sentences, embedding_size) | |
y: program labels that match with each sentence | |
""" | |
self.class_labels = np.array(self._get_classes(y)) | |
class_weights = compute_class_weight("balanced", classes=self.class_labels, y=y).astype("float32") | |
class_weights = torch.from_numpy(class_weights).to(self.device) | |
X, y = self._preprocess_features(X), self._preprocess_labels(y) | |
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=self.seed, shuffle=True) | |
should_stop = EarlyStopping(self.patience) | |
val_loss = np.inf | |
model = ProgramClassifierNetwork(x_train.shape[1], y_train.shape[1]) | |
model = model.to(self.device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
criterion = nn.CrossEntropyLoss(weight=class_weights) | |
epoch = 0 | |
while not should_stop.step(val_loss) and epoch < self.epochs: | |
preds = model(x_train) | |
loss = criterion(preds, y_train) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
with torch.no_grad(): | |
val_preds = model(x_val) | |
val_loss = criterion(val_preds, y_val).item() | |
epoch += 1 | |
model.eval() | |
self.model = model | |
def predict_proba(self, X:np.ndarray) -> np.ndarray: | |
X = self._preprocess_features(X) | |
if self.model is None: | |
raise Exception("Train model with fit() before predicting.") | |
with torch.no_grad(): | |
logits = self.model(X) | |
return F.softmax(logits, dim=-1).cpu().numpy() | |
def predict(self, X:np.ndarray) -> List[str]: | |
""" | |
Args: | |
X: embeddings of shape (n_sentences, embedding_size) | |
Returns: | |
predicted classes for each embedding | |
""" | |
pred_i = self.predict_proba(X).argmax(-1) | |
return self.class_labels[pred_i] | |
def save_weights(self,path:str): | |
os.makedirs(path, exist_ok=True) | |
weights_path = os.path.join(path, "weights.pt") | |
config_path = os.path.join(path,"config.json") | |
torch.save(self.model.state_dict(), weights_path) | |
state = { | |
"device": self.device, | |
"seed": self.seed, | |
"epochs": self.epochs, | |
"patience": self.patience, | |
"class_labels": list(self.class_labels) | |
} | |
with open(config_path, "w") as f: | |
json.dump(state, f) | |
def load_weights(self, path:str): | |
weights_path = os.path.join(path, "weights.pt") | |
config_path = os.path.join(path,"config.json") | |
state_dict = torch.load(weights_path) | |
input_size = int(state_dict["input_size"].item()) | |
n_classes = int(state_dict["n_classes"].item()) | |
model = ProgramClassifierNetwork(input_size,n_classes).to(self.device) | |
model.load_state_dict(state_dict) | |
model.eval() | |
self.model = model | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
config["class_labels"] = np.array(config["class_labels"]) if config["class_labels"] is not None else None | |
self.__dict__.update(config) | |
class ProgramClassifierNetwork(nn.Module): | |
def __init__(self, input_size:int, n_classes:int) -> None: | |
super().__init__() | |
self.input_size = nn.Parameter(torch.Tensor([input_size]), requires_grad=False) | |
self.n_classes = nn.Parameter(torch.Tensor([n_classes]), requires_grad=False) | |
self.classifier = nn.Sequential( | |
nn.BatchNorm1d(input_size), | |
nn.Linear(input_size, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Linear(128, n_classes), | |
) | |
def forward(self,x): | |
return self.classifier(x) | |
class EarlyStopping: | |
def __init__(self, patience=0): | |
self.patience = patience | |
self.last_measure = np.inf | |
self.consecutive_increase = 0 | |
def step(self, val) -> bool: | |
if self.last_measure <= val: | |
self.consecutive_increase +=1 | |
else: | |
self.consecutive_increase = 0 | |
self.last_measure = val | |
return self.patience < self.consecutive_increase | |