hjc-owo
init repo
966ae59
raw
history blame
2.81 kB
# -*- coding: utf-8 -*-
# Author: ximing
# Copyright (c) 2023, XiMing Xing.
# License: MPL-2.0 License
from typing import AnyStr
import matplotlib.pyplot as plt
import torch
from torchvision.utils import make_grid
def plot_couple(input_1: torch.Tensor,
input_2: torch.Tensor,
step: int,
output_dir: str,
fname: str, # file name
prompt: str = '', # text prompt as image tile
dpi: int = 300):
if input_1.shape != input_2.shape:
raise ValueError("inputs and outputs must have the same dimensions")
plt.figure()
plt.subplot(1, 2, 1) # nrows=1, ncols=2, index=1
grid = make_grid(input_1, normalize=True, pad_value=2)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
plt.imshow(ndarr)
plt.axis("off")
plt.title("Input")
plt.subplot(1, 2, 2) # nrows=1, ncols=2, index=2
grid = make_grid(input_2, normalize=True, pad_value=2)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
plt.imshow(ndarr)
plt.axis("off")
plt.title(f"Rendering - {step} steps")
def insert_newline(string, point=9):
# split by blank
words = string.split()
if len(words) <= point:
return string
word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
return new_string
plt.suptitle(insert_newline(prompt), fontsize=10)
plt.tight_layout()
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi)
plt.close()
def plot_img(inputs: torch.Tensor,
output_dir: AnyStr,
fname: str, # file name
dpi: int = 100):
assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}"
grid = make_grid(inputs, normalize=True, pad_value=2)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
plt.imshow(ndarr)
plt.axis("off")
plt.tight_layout()
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi, bbox_inches='tight')
plt.close()
def plot_img_title(inputs: torch.Tensor,
title: str,
output_dir: AnyStr,
fname: str, # file name
dpi: int = 500):
assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}"
grid = make_grid(inputs, normalize=True, pad_value=2)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
plt.imshow(ndarr)
plt.axis("off")
plt.title(f"{title}")
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi)
plt.close()