|
import torch |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
def save_model(model: torch.nn.Module, target_dir: str, model_name: str): |
|
""" |
|
Save pytorch model to a traget dir |
|
Args: |
|
model: A traget pytorch model |
|
target_dir: Directory to save the model to |
|
model_name: File name to save model. should include ".pth" or ".pt" at the end of the file extention |
|
|
|
Example usage: |
|
save_model(model = model_0, target_dir = "models", model_name="model.pth") |
|
|
|
""" |
|
|
|
target_dir_path = Path(target_dir) |
|
target_dir_path.mkdir(parents = True, exist_ok = True) |
|
|
|
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model name should be end with .pth or .pt" |
|
model_save_path = target_dir_path / model_name |
|
|
|
print(f"\nSaving Model At: {model_save_path}") |
|
torch.save(obj = model.state_dict(), f = model_save_path) |
|
|
|
|
|
|
|
def load_model(model: torch.nn.Module, model_path: str): |
|
""" |
|
Load pytorch model from source dir |
|
Args: |
|
model: A model which need to load |
|
source_dir: path where trained model is saved. should be full path including model name |
|
|
|
Example usage: |
|
load_model(model = model_0, source_path = "models/model.pth") |
|
""" |
|
model.load_state_dict(torch.load(f = model_path, map_location=torch.device('cpu'))) |
|
print("\nModel Loaded.") |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
def plot_graph(train_losses: list, test_losses: list, train_accs: list, test_accs: list, fig_name: str): |
|
""" |
|
Plot the grapoh of loss abd accuray of the model |
|
Args: |
|
train_losses: list of train loss |
|
test_losses: list of test loss |
|
train_accs: list of train accuracy |
|
test_accs: list of test accuracy |
|
fig_name: name of image file which with you want to save plot image and must include .jpg |
|
|
|
Example usage: |
|
plot_graph(train_losses = train_loss, test_losses = test_loss, train_accs = train_acc, |
|
test_accs = test_acc, fig_name = "plot.jpg") |
|
""" |
|
plt.figure(figsize = (20, 8)) |
|
plt.subplot(1, 2, 1) |
|
plt.plot(range(len(train_losses)), train_losses, label = "Train Loss") |
|
plt.plot(range(len(test_losses)), test_losses, label = "Test Loss") |
|
plt.legend() |
|
plt.xlabel("Epoches") |
|
plt.ylabel("Loss") |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
plt.plot(range(len(train_accs)), train_accs, label = "Train Accuracy") |
|
plt.plot(range(len(test_accs)), test_accs, label = "Test Accuracy") |
|
plt.legend() |
|
plt.xlabel("Epoches") |
|
plt.ylabel("Accuracy") |
|
|
|
plt.savefig(fig_name) |
|
|
|
|
|
|
|
def onnx_inference(model: torch.nn.Module, path: str, device: str): |
|
torch.onnx.export(model, torch.randn(1, 3, 224, 224).to(device), path, verbose=False, input_names=['input'], output_names=['output'], export_params=True) |