Spaces:
Runtime error
Runtime error
import datetime | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
import torch | |
def getDevice(): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print('There are %d GPU(s) available.' % torch.cuda.device_count()) | |
print('We will use the GPU:', torch.cuda.get_device_name(0)) | |
else: | |
print('No GPU available, using the CPU instead.') | |
device = torch.device("cpu") | |
return device | |
def flatAccuracy(preds, labels): | |
pred_flat = np.argmax(preds, axis=1).flatten() | |
labels_flat = labels.flatten() | |
return np.sum(pred_flat == labels_flat) / len(labels_flat) | |
def formatTime(elapsed): | |
elapsed_rounded = int(round((elapsed))) | |
# Format as hh:mm:ss | |
return str(datetime.timedelta(seconds=elapsed_rounded)) | |
def plotTrainingLoss(lossValues): | |
sns.set(style='darkgrid') | |
sns.set(font_scale=1.5) | |
plt.rcParams["figure.figsize"] = (12,6) | |
plt.plot(lossValues, 'b-o') | |
plt.title("Training loss") | |
plt.xlabel("Epoch") | |
plt.ylabel("Loss") | |
plt.show() |