food-app / utils.py
Jash-RA's picture
fully app first commit
97daae4
import torch
from pathlib import Path
import matplotlib.pyplot as plt
# SAVE MODEL
def save_model(model: torch.nn.Module, target_dir: str, model_name: str):
"""
Save pytorch model to a traget dir
Args:
model: A traget pytorch model
target_dir: Directory to save the model to
model_name: File name to save model. should include ".pth" or ".pt" at the end of the file extention
Example usage:
save_model(model = model_0, target_dir = "models", model_name="model.pth")
"""
target_dir_path = Path(target_dir)
target_dir_path.mkdir(parents = True, exist_ok = True)
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model name should be end with .pth or .pt"
model_save_path = target_dir_path / model_name
print(f"\nSaving Model At: {model_save_path}")
torch.save(obj = model.state_dict(), f = model_save_path)
# LOAD MODEL
def load_model(model: torch.nn.Module, model_path: str):
"""
Load pytorch model from source dir
Args:
model: A model which need to load
source_dir: path where trained model is saved. should be full path including model name
Example usage:
load_model(model = model_0, source_path = "models/model.pth")
"""
model.load_state_dict(torch.load(f = model_path, map_location=torch.device('cpu')))
print("\nModel Loaded.")
return model
# PLOT Function
def plot_graph(train_losses: list, test_losses: list, train_accs: list, test_accs: list, fig_name: str):
"""
Plot the grapoh of loss abd accuray of the model
Args:
train_losses: list of train loss
test_losses: list of test loss
train_accs: list of train accuracy
test_accs: list of test accuracy
fig_name: name of image file which with you want to save plot image and must include .jpg
Example usage:
plot_graph(train_losses = train_loss, test_losses = test_loss, train_accs = train_acc,
test_accs = test_acc, fig_name = "plot.jpg")
"""
plt.figure(figsize = (20, 8))
plt.subplot(1, 2, 1)
plt.plot(range(len(train_losses)), train_losses, label = "Train Loss")
plt.plot(range(len(test_losses)), test_losses, label = "Test Loss")
plt.legend()
plt.xlabel("Epoches")
plt.ylabel("Loss")
# plt.show()
plt.subplot(1, 2, 2)
plt.plot(range(len(train_accs)), train_accs, label = "Train Accuracy")
plt.plot(range(len(test_accs)), test_accs, label = "Test Accuracy")
plt.legend()
plt.xlabel("Epoches")
plt.ylabel("Accuracy")
# plt.show()
plt.savefig(fig_name)
# Optimize Model with ONNX
def onnx_inference(model: torch.nn.Module, path: str, device: str):
torch.onnx.export(model, torch.randn(1, 3, 224, 224).to(device), path, verbose=False, input_names=['input'], output_names=['output'], export_params=True)