File size: 4,786 Bytes
7fab858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import ntpath
import time
from . import util
import scipy.misc

try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO  # Python 3.x
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch
import numpy as np


class Visualizer:
    def __init__(self, opt):
        self.opt = opt
        self.tf_log = opt.isTrain and opt.tf_log

        self.tensorboard_log = opt.tensorboard_log

        self.win_size = opt.display_winsize
        self.name = opt.name
        if self.tensorboard_log:

            if self.opt.isTrain:
                self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs")
                if not os.path.exists(self.log_dir):
                    os.makedirs(self.log_dir)
                self.writer = SummaryWriter(log_dir=self.log_dir)
            else:
                print("hi :)")
                self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir)
                if not os.path.exists(self.log_dir):
                    os.makedirs(self.log_dir)

        if opt.isTrain:
            self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt")
            with open(self.log_name, "a") as log_file:
                now = time.strftime("%c")
                log_file.write("================ Training Loss (%s) ================\n" % now)

    # |visuals|: dictionary of images to display or save
    def display_current_results(self, visuals, epoch, step):

        all_tensor = []
        if self.tensorboard_log:

            for key, tensor in visuals.items():
                all_tensor.append((tensor.data.cpu() + 1) / 2)

            output = torch.cat(all_tensor, 0)
            img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False)

            if self.opt.isTrain:
                self.writer.add_image("Face_SPADE/training_samples", img_grid, step)
            else:
                vutils.save_image(
                    output,
                    os.path.join(self.log_dir, str(step) + ".png"),
                    nrow=self.opt.batchSize,
                    padding=0,
                    normalize=False,
                )

    # errors: dictionary of error labels and values
    def plot_current_errors(self, errors, step):
        if self.tf_log:
            for tag, value in errors.items():
                value = value.mean().float()
                summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
                self.writer.add_summary(summary, step)

        if self.tensorboard_log:

            self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step)
            self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step)
            self.writer.add_scalars(
                "Loss/GAN",
                {
                    "G": errors["GAN"].mean().float(),
                    "D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2,
                },
                step,
            )

    # errors: same format as |errors| of plotCurrentErrors
    def print_current_errors(self, epoch, i, errors, t):
        message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t)
        for k, v in errors.items():
            v = v.mean().float()
            message += "%s: %.3f " % (k, v)

        print(message)
        with open(self.log_name, "a") as log_file:
            log_file.write("%s\n" % message)

    def convert_visuals_to_numpy(self, visuals):
        for key, t in visuals.items():
            tile = self.opt.batchSize > 8
            if "input_label" == key:
                t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)  ## B*H*W*C 0-255 numpy
            else:
                t = util.tensor2im(t, tile=tile)
            visuals[key] = t
        return visuals

    # save image to the disk
    def save_images(self, webpage, visuals, image_path):
        visuals = self.convert_visuals_to_numpy(visuals)

        image_dir = webpage.get_image_dir()
        short_path = ntpath.basename(image_path[0])
        name = os.path.splitext(short_path)[0]

        webpage.add_header(name)
        ims = []
        txts = []
        links = []

        for label, image_numpy in visuals.items():
            image_name = os.path.join(label, "%s.png" % (name))
            save_path = os.path.join(image_dir, image_name)
            util.save_image(image_numpy, save_path, create_dir=True)

            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
        webpage.add_images(ims, txts, links, width=self.win_size)