|
from torch import nn, tensor, float32 |
|
import os |
|
import torch |
|
from torch.utils.data.dataloader import DataLoader |
|
from sklearn.metrics import r2_score |
|
import numpy as np |
|
from models.viscosity_models import CNN3D |
|
from pytorchtools import EarlyStopping |
|
from typing import List, Optional, Callable, Tuple |
|
from utils.datastruct import history, metrics |
|
from tqdm import tqdm |
|
|
|
def train(model : CNN3D, |
|
data_loader : DataLoader, |
|
optimizer : torch.optim.Optimizer, |
|
criterion : torch.nn.modules.loss._Loss, |
|
device : torch.device) -> float: |
|
|
|
train_loss = [] |
|
model.train() |
|
|
|
for (X,y) in data_loader: |
|
X = X.to(device) |
|
y = y.to(device) |
|
|
|
y = y.to(float32) |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
out = model(X) |
|
|
|
loss = criterion(out,y) |
|
loss.backward() |
|
optimizer.step() |
|
train_loss.append(loss.item()) |
|
|
|
return np.mean(train_loss) |
|
|
|
|
|
def test(model : CNN3D, data_loader : DataLoader, criterion : torch.nn.modules.loss._Loss, device : torch.device) -> Tuple[float, float]: |
|
model.eval() |
|
|
|
y_h_all = [] |
|
y_all =[] |
|
test_loss = [] |
|
|
|
with torch.no_grad(): |
|
for (X,y) in data_loader: |
|
X = X.to(device) |
|
|
|
y = y.to(float32) |
|
|
|
y_h = model(X) |
|
|
|
loss = criterion(y_h.detach().cpu(),y) |
|
test_loss.append(loss) |
|
|
|
y_h_all.extend(y_h.detach().cpu().numpy()) |
|
y_all.extend(y.numpy()) |
|
|
|
return (np.mean(test_loss), r2_score(np.array(y_all),np.array(y_h_all))) |
|
|
|
|
|
|
|
def train_epochs(model : CNN3D , |
|
dataloaders : List[DataLoader], |
|
optimizer : torch.optim.Optimizer, |
|
criterion : torch.nn.modules.loss._Loss, |
|
epochs : int, |
|
early_stop : Optional[int], |
|
device : torch.device, |
|
path : str, |
|
save_weights_frequency : int) -> Tuple[CNN3D, history]: |
|
|
|
|
|
''' |
|
if len(data_loader)>2 : |
|
(train_loader, val_loader, test_loader) = dataloaders |
|
else : |
|
(train_loader, val_loader) = dataloaders |
|
''' |
|
(train_loader, val_loader, test_loader) = dataloaders |
|
|
|
if early_stop : early_stopping = EarlyStopping(patience=early_stop, verbose=True) |
|
|
|
train_loss_list=[] |
|
val_loss_list=[] |
|
test_loss_list=[] |
|
|
|
train_r2_list=[] |
|
val_r2_list=[] |
|
test_r2_list=[] |
|
|
|
for epoch in tqdm(range(epochs)): |
|
|
|
loss = train(model, |
|
data_loader=train_loader, |
|
optimizer=optimizer, |
|
criterion=criterion, |
|
device = device) |
|
|
|
|
|
(_,r2_train) = test(model = model,data_loader = train_loader, criterion = criterion, device = device) |
|
(val_loss, r2_val) = test(model = model, data_loader=val_loader, criterion = criterion, device = device) |
|
(test_loss, r2_test) = test(model = model, data_loader=test_loader, criterion = criterion, device = device) |
|
(test_loss, r2_test) = test(model = model, data_loader=test_loader, criterion = criterion, device = device) |
|
|
|
|
|
if early_stop : |
|
early_stopping(val_loss, model) |
|
if early_stopping.early_stop: |
|
print("Early stopping") |
|
break |
|
|
|
train_loss_list.append(loss) |
|
val_loss_list.append(val_loss) |
|
test_loss_list.append(test_loss) |
|
|
|
|
|
train_r2_list.append(r2_train) |
|
val_r2_list.append(r2_val) |
|
test_r2_list.append(r2_test) |
|
|
|
|
|
|
|
if (epoch+1) % save_weights_frequency == 0: |
|
torch.save(model.state_dict(), os.path.join(path,'cnn3d_epoch_'+str(epoch+1)+'.pt')) |
|
|
|
print(f'Epoch: {epoch:03d}, train loss: {loss : .4f}, val loss: {val_loss:.4f}, test loss : {test_loss:.4f}') |
|
|
|
print(f'Epoch: {epoch:03d}, train r2: {r2_train : .4f}, val r2: {r2_val:.4f}, test r2: {r2_test:.4f}') |
|
|
|
|
|
return model, history(metrics(r2_train, train_loss_list), metrics(r2_val, val_loss_list), metrics(r2_test, test_loss_list)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|