|
import os |
|
import sys |
|
import json |
|
import pickle |
|
import random |
|
|
|
import torch |
|
from tqdm import tqdm |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def read_split_data(root: str, val_rate: float = 0.2): |
|
random.seed(0) |
|
assert os.path.exists(root), "dataset root: {} does not exist.".format(root) |
|
|
|
|
|
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] |
|
|
|
flower_class.sort() |
|
|
|
class_indices = dict((k, v) for v, k in enumerate(flower_class)) |
|
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) |
|
with open('class_indices.json', 'w') as json_file: |
|
json_file.write(json_str) |
|
|
|
train_images_path = [] |
|
train_images_label = [] |
|
val_images_path = [] |
|
val_images_label = [] |
|
every_class_num = [] |
|
supported = [".jpg", ".JPG", ".png", ".PNG"] |
|
|
|
for cla in flower_class: |
|
cla_path = os.path.join(root, cla) |
|
|
|
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) |
|
if os.path.splitext(i)[-1] in supported] |
|
|
|
image_class = class_indices[cla] |
|
|
|
every_class_num.append(len(images)) |
|
|
|
val_path = random.sample(images, k=int(len(images) * val_rate)) |
|
|
|
for img_path in images: |
|
if img_path in val_path: |
|
val_images_path.append(img_path) |
|
val_images_label.append(image_class) |
|
else: |
|
train_images_path.append(img_path) |
|
train_images_label.append(image_class) |
|
|
|
print("{} images were found in the dataset.".format(sum(every_class_num))) |
|
print("{} images for training.".format(len(train_images_path))) |
|
print("{} images for validation.".format(len(val_images_path))) |
|
|
|
plot_image = False |
|
if plot_image: |
|
|
|
plt.bar(range(len(flower_class)), every_class_num, align='center') |
|
|
|
plt.xticks(range(len(flower_class)), flower_class) |
|
|
|
for i, v in enumerate(every_class_num): |
|
plt.text(x=i, y=v + 5, s=str(v), ha='center') |
|
|
|
plt.xlabel('image class') |
|
|
|
plt.ylabel('number of images') |
|
|
|
plt.title('flower class distribution') |
|
plt.show() |
|
|
|
return train_images_path, train_images_label, val_images_path, val_images_label |
|
|
|
|
|
def plot_data_loader_image(data_loader): |
|
batch_size = data_loader.batch_size |
|
plot_num = min(batch_size, 4) |
|
|
|
json_path = './class_indices.json' |
|
assert os.path.exists(json_path), json_path + " does not exist." |
|
json_file = open(json_path, 'r') |
|
class_indices = json.load(json_file) |
|
|
|
for data in data_loader: |
|
images, labels = data |
|
for i in range(plot_num): |
|
|
|
img = images[i].numpy().transpose(1, 2, 0) |
|
|
|
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 |
|
label = labels[i].item() |
|
plt.subplot(1, plot_num, i+1) |
|
plt.xlabel(class_indices[str(label)]) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
plt.imshow(img.astype('uint8')) |
|
plt.show() |
|
|
|
|
|
def write_pickle(list_info: list, file_name: str): |
|
with open(file_name, 'wb') as f: |
|
pickle.dump(list_info, f) |
|
|
|
|
|
def read_pickle(file_name: str) -> list: |
|
with open(file_name, 'rb') as f: |
|
info_list = pickle.load(f) |
|
return info_list |
|
|
|
|
|
def train_one_epoch(model, optimizer, data_loader, device, epoch): |
|
model.train() |
|
loss_function = torch.nn.CrossEntropyLoss() |
|
accu_loss = torch.zeros(1).to(device) |
|
accu_num = torch.zeros(1).to(device) |
|
optimizer.zero_grad() |
|
|
|
sample_num = 0 |
|
data_loader = tqdm(data_loader) |
|
for step, data in enumerate(data_loader): |
|
images, labels = data |
|
sample_num += images.shape[0] |
|
|
|
pred = model(images.to(device)) |
|
pred_classes = torch.max(pred, dim=1)[1] |
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum() |
|
|
|
loss = loss_function(pred, labels.to(device)) |
|
loss.backward() |
|
accu_loss += loss.detach() |
|
|
|
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, |
|
accu_loss.item() / (step + 1), |
|
accu_num.item() / sample_num) |
|
|
|
if not torch.isfinite(loss): |
|
print('WARNING: non-finite loss, ending training ', loss) |
|
sys.exit(1) |
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(model, data_loader, device, epoch): |
|
loss_function = torch.nn.CrossEntropyLoss() |
|
|
|
model.eval() |
|
|
|
accu_num = torch.zeros(1).to(device) |
|
accu_loss = torch.zeros(1).to(device) |
|
|
|
sample_num = 0 |
|
data_loader = tqdm(data_loader) |
|
for step, data in enumerate(data_loader): |
|
images, labels = data |
|
sample_num += images.shape[0] |
|
|
|
pred = model(images.to(device)) |
|
pred_classes = torch.max(pred, dim=1)[1] |
|
accu_num += torch.eq(pred_classes, labels.to(device)).sum() |
|
|
|
loss = loss_function(pred, labels.to(device)) |
|
accu_loss += loss |
|
|
|
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, |
|
accu_loss.item() / (step + 1), |
|
accu_num.item() / sample_num) |
|
|
|
return accu_loss.item() / (step + 1), accu_num.item() / sample_num |
|
|