Spaces:
Runtime error
Runtime error
File size: 4,786 Bytes
7fab858 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import ntpath
import time
from . import util
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch
import numpy as np
class Visualizer:
def __init__(self, opt):
self.opt = opt
self.tf_log = opt.isTrain and opt.tf_log
self.tensorboard_log = opt.tensorboard_log
self.win_size = opt.display_winsize
self.name = opt.name
if self.tensorboard_log:
if self.opt.isTrain:
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.writer = SummaryWriter(log_dir=self.log_dir)
else:
print("hi :)")
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
if opt.isTrain:
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write("================ Training Loss (%s) ================\n" % now)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, step):
all_tensor = []
if self.tensorboard_log:
for key, tensor in visuals.items():
all_tensor.append((tensor.data.cpu() + 1) / 2)
output = torch.cat(all_tensor, 0)
img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)
if self.opt.isTrain:
self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
else:
vutils.save_image(
output,
os.path.join(self.log_dir, str(step) + ".png"),
nrow=self.opt.batchSize,
padding=0,
normalize=False,
)
# errors: dictionary of error labels and values
def plot_current_errors(self, errors, step):
if self.tf_log:
for tag, value in errors.items():
value = value.mean().float()
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
if self.tensorboard_log:
self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
self.writer.add_scalars(
"Loss/GAN",
{
"G": errors["GAN"].mean().float(),
"D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
},
step,
)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t):
message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
for k, v in errors.items():
v = v.mean().float()
message += "%s: %.3f " % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write("%s\n" % message)
def convert_visuals_to_numpy(self, visuals):
for key, t in visuals.items():
tile = self.opt.batchSize > 8
if "input_label" == key:
t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy
else:
t = util.tensor2im(t, tile=tile)
visuals[key] = t
return visuals
# save image to the disk
def save_images(self, webpage, visuals, image_path):
visuals = self.convert_visuals_to_numpy(visuals)
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = os.path.join(label, "%s.png" % (name))
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path, create_dir=True)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)
|