Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Author: ximing | |
# Copyright (c) 2023, XiMing Xing. | |
# License: MPL-2.0 License | |
from typing import AnyStr | |
import matplotlib.pyplot as plt | |
import torch | |
from torchvision.utils import make_grid | |
def plot_couple(input_1: torch.Tensor, | |
input_2: torch.Tensor, | |
step: int, | |
output_dir: str, | |
fname: str, # file name | |
prompt: str = '', # text prompt as image tile | |
dpi: int = 300): | |
if input_1.shape != input_2.shape: | |
raise ValueError("inputs and outputs must have the same dimensions") | |
plt.figure() | |
plt.subplot(1, 2, 1) # nrows=1, ncols=2, index=1 | |
grid = make_grid(input_1, normalize=True, pad_value=2) | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
plt.imshow(ndarr) | |
plt.axis("off") | |
plt.title("Input") | |
plt.subplot(1, 2, 2) # nrows=1, ncols=2, index=2 | |
grid = make_grid(input_2, normalize=True, pad_value=2) | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
plt.imshow(ndarr) | |
plt.axis("off") | |
plt.title(f"Rendering - {step} steps") | |
def insert_newline(string, point=9): | |
# split by blank | |
words = string.split() | |
if len(words) <= point: | |
return string | |
word_chunks = [words[i:i + point] for i in range(0, len(words), point)] | |
new_string = "\n".join(" ".join(chunk) for chunk in word_chunks) | |
return new_string | |
plt.suptitle(insert_newline(prompt), fontsize=10) | |
plt.tight_layout() | |
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi) | |
plt.close() | |
def plot_img(inputs: torch.Tensor, | |
output_dir: AnyStr, | |
fname: str, # file name | |
dpi: int = 100): | |
assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}" | |
grid = make_grid(inputs, normalize=True, pad_value=2) | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
plt.imshow(ndarr) | |
plt.axis("off") | |
plt.tight_layout() | |
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi, bbox_inches='tight') | |
plt.close() | |
def plot_img_title(inputs: torch.Tensor, | |
title: str, | |
output_dir: AnyStr, | |
fname: str, # file name | |
dpi: int = 500): | |
assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}" | |
grid = make_grid(inputs, normalize=True, pad_value=2) | |
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
plt.imshow(ndarr) | |
plt.axis("off") | |
plt.title(f"{title}") | |
plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi) | |
plt.close() | |