from pathlib import Path from numpy.core.shape_base import block import torch import matplotlib.pyplot as plt from torchvision.transforms import functional as TF from typing import Optional, Union import matplotlib.pyplot as plt import numpy as np from PIL.Image import Image def show_tensor_image(tensor: torch.Tensor, range_zero_one: bool = False): """Show a tensor of an image Args: tensor (torch.Tensor): Tensor of shape [N, 3, H, W] in range [-1, 1] or in range [0, 1] """ if not range_zero_one: tensor = (tensor + 1) / 2 tensor.clamp(0, 1) batch_size = tensor.shape[0] for i in range(batch_size): plt.title(f"Fig_{i}") pil_image = TF.to_pil_image(tensor[i]) plt.imshow(pil_image) plt.show(block=True) def show_editied_masked_image( title: str, source_image: Image, edited_image: Image, mask: Optional[Image] = None, path: Optional[Union[str, Path]] = None, distance: Optional[str] = None, ): fig_idx = 1 rows = 1 cols = 3 if mask is not None else 2 fig = plt.figure(figsize=(12, 5)) figure_title = f'Prompt: "{title}"' if distance is not None: figure_title += f" ({distance})" plt.title(figure_title) plt.axis("off") fig.add_subplot(rows, cols, fig_idx) fig_idx += 1 _set_image_plot_name("Source Image") plt.imshow(source_image) if mask is not None: fig.add_subplot(rows, cols, fig_idx) _set_image_plot_name("Mask") plt.imshow(mask) plt.gray() fig_idx += 1 fig.add_subplot(rows, cols, fig_idx) _set_image_plot_name("Edited Image") plt.imshow(edited_image) if path is not None: plt.savefig(path, bbox_inches="tight") else: plt.show(block=True) plt.close() def _set_image_plot_name(name): plt.title(name) plt.xticks([]) plt.yticks([])