import os |
import sys |
import json |
import pickle |
import random |
import torch |
from tqdm import tqdm |
import matplotlib.pyplot as plt |
def read_split_data(root: str, val_rate: float = 0.2): |
random.seed(0) |
assert os.path.exists(root), "dataset root: {} does not exist.".format(root) |
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] |
flower_class.sort() |
class_indices = dict((k, v) for v, k in enumerate(flower_class)) |
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) |
with open('class_indices.json', 'w') as json_file: |
json_file.write(json_str) |
train_images_path = [] |
train_images_label = [] |
val_images_path = [] |
val_images_label = [] |
every_class_num = [] |
supported = [".jpg", ".JPG", ".png", ".PNG"] |
for cla in flower_class: |
cla_path = os.path.join(root, cla) |
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) |
if os.path.splitext(i)[-1] in supported] |
image_class = class_indices[cla] |
every_class_num.append(len(images)) |
val_path = random.sample(images, k=int(len(images) * val_rate)) |
for img_path in images: |
if img_path in val_path: |
val_images_path.append(img_path) |
val_images_label.append(image_class) |
else: |
train_images_path.append(img_path) |
train_images_label.append(image_class) |
print("{} images were found in the dataset.".format(sum(every_class_num))) |
print("{} images for training.".format(len(train_images_path))) |
print("{} images for validation.".format(len(val_images_path))) |
plot_image = False |
if plot_image: |
plt.bar(range(len(flower_class)), every_class_num, align='center') |
plt.xticks(range(len(flower_class)), flower_class) |
for i, v in enumerate(every_class_num): |
plt.text(x=i, y=v + 5, s=str(v), ha='center') |
plt.xlabel('image class') |
plt.ylabel('number of images') |
plt.title('flower class distribution') |
plt.show() |
return train_images_path, train_images_label, val_images_path, val_images_label |
def plot_data_loader_image(data_loader): |
batch_size = data_loader.batch_size |
plot_num = min(batch_size, 4) |
json_path = './class_indices.json' |
assert os.path.exists(json_path), json_path + " does not exist." |
json_file = open(json_path, 'r') |
class_indices = json.load(json_file) |
for data in data_loader: |
images, labels = data |
for i in range(plot_num): |
img = images[i].numpy().transpose(1, 2, 0) |
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 |
label = labels[i].item() |
plt.subplot(1, plot_num, i+1) |
plt.xlabel(class_indices[str(label)]) |
plt.xticks([]) |
plt.yticks([]) |
plt.imshow(img.astype('uint8')) |
plt.show() |
def write_pickle(list_info: list, file_name: str): |
with open(file_name, 'wb') as f: |
pickle.dump(list_info, f) |
def read_pickle(file_name: str) -> list: |
with open(file_name, 'rb') as f: |
info_list = pickle.load(f) |
return info_list |
def train_one_epoch(model, optimizer, data_loader, device, epoch): |
model.train() |
loss_function = torch.nn.CrossEntropyLoss() |
accu_loss = torch.zeros(1).to(device) |
accu_num = torch.zeros(1).to(device) |
optimizer.zero_grad() |
sample_num = 0 |
data_loader = tqdm(data_loader) |
for step, data in enumerate(data_loader): |
images, labels = data |
sample_num += images.shape[0] |
pred = model(images.to(device)) |
pred_classes = torch.max(pred, dim=1)[1] |
accu_num += torch.eq(pred_classes, labels.to(device)).sum() |
loss = loss_function(pred, labels.to(device)) |
loss.backward() |
accu_loss += loss.detach() |
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, |
accu_loss.item() / (step + 1), |
accu_num.item() / sample_num) |
if not torch.isfinite(loss): |
print('WARNING: non-finite loss, ending training ', loss) |
sys.exit(1) |
optimizer.step() |
optimizer.zero_grad() |
return accu_loss.item() / (step + 1), accu_num.item() / sample_num |
@torch.no_grad() |
def evaluate(model, data_loader, device, epoch): |
loss_function = torch.nn.CrossEntropyLoss() |
model.eval() |
accu_num = torch.zeros(1).to(device) |
accu_loss = torch.zeros(1).to(device) |
sample_num = 0 |
data_loader = tqdm(data_loader) |
for step, data in enumerate(data_loader): |
images, labels = data |
sample_num += images.shape[0] |
pred = model(images.to(device)) |
pred_classes = torch.max(pred, dim=1)[1] |
accu_num += torch.eq(pred_classes, labels.to(device)).sum() |
loss = loss_function(pred, labels.to(device)) |
accu_loss += loss |
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, |
accu_loss.item() / (step + 1), |
accu_num.item() / sample_num) |
return accu_loss.item() / (step + 1), accu_num.item() / sample_num |