jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
def prepare_torch_img(img, size_H, size_W, device="cuda", keep_shape=False):
# [N, H, W, C] -> [N, C, H, W]
img_new = img.permute(0, 3, 1, 2).to(device)
img_new = F.interpolate(img_new, (size_H, size_W), mode="bilinear", align_corners=False).contiguous()
if keep_shape:
img_new = img_new.permute(0, 2, 3, 1)
return img_new
def torch_imgs_to_pils(images, masks=None, alpha_min=0.1):
"""
images (torch): [N, H, W, C] or [H, W, C]
masks (torch): [N, H, W] or [H, W]
"""
if len(images.shape) == 3:
images = images.unsqueeze(0)
if masks is not None:
masks = masks.to(dtype=images.dtype, device=images.device)
if len(masks.shape) == 2:
masks = masks.unsqueeze(0)
inv_mask_index = masks < alpha_min
images[inv_mask_index] = 0.
masks = masks.unsqueeze(3)
images = torch.cat((images, masks), dim=3)
mode="RGBA"
else:
mode="RGB"
pil_image_list = [Image.fromarray((images[i].detach().cpu().numpy() * 255).astype(np.uint8), mode=mode) for i in range(images.shape[0])]
return pil_image_list
def troch_image_dilate(img):
"""
Remove thin seams on generated texture
img (torch): [H, W, C]
"""
import cv2
img = np.asarray(img.cpu().numpy(), dtype=np.float32)
img = img * 255
img = img.clip(0, 255)
mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
mask = (mask <= 3.0).astype(np.float32)
kernel = np.ones((3, 3), 'uint8')
dilate_img = cv2.dilate(img, kernel, iterations=1)
img = img * (1 - mask) + dilate_img * mask
img = (img.clip(0, 255) / 255).astype(np.float32)
return torch.from_numpy(img)
def pils_to_torch_imgs(pils: Union[Image.Image, List[Image.Image]], dtype=torch.float16, device="cuda", force_rgb=True):
if isinstance(pils, Image.Image):
pils = [pils]
images = []
for pil in pils:
if pil.mode == "RGBA" and force_rgb:
pil = pil.convert('RGB')
images.append(TF.to_tensor(pil).permute(1, 2, 0))
images = torch.stack(images, dim=0).to(dtype=dtype, device=device)
return images
def pils_rgba_to_rgb(pils: Union[Image.Image, List[Image.Image]], bkgd="WHITE"):
if isinstance(pils, Image.Image):
pils = [pils]
rgbs = []
for pil in pils:
if pil.mode == 'RGBA':
new_image = Image.new("RGBA", pil.size, bkgd)
new_image.paste(pil, (0, 0), pil)
rgbs.append(new_image.convert('RGB'))
else:
rgbs.append(pil)
return rgbs
def pil_split_image(image, rows=None, cols=None):
"""
inverse function of make_image_grid
"""
# image is in square
if rows is None and cols is None:
# image.size [W, H]
rows = 1
cols = image.size[0] // image.size[1]
assert cols * image.size[1] == image.size[0]
subimg_size = image.size[1]
elif rows is None:
subimg_size = image.size[0] // cols
rows = image.size[1] // subimg_size
assert rows * subimg_size == image.size[1]
elif cols is None:
subimg_size = image.size[1] // rows
cols = image.size[0] // subimg_size
assert cols * subimg_size == image.size[0]
else:
subimg_size = image.size[1] // rows
assert cols * subimg_size == image.size[0]
subimgs = []
for i in range(rows):
for j in range(cols):
subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
subimgs.append(subimg)
return subimgs
def pil_make_image_grid(images, rows=None, cols=None):
if rows is None and cols is None:
rows = 1
cols = len(images)
if rows is None:
rows = len(images) // cols
if len(images) % cols != 0:
rows += 1
if cols is None:
cols = len(images) // rows
if len(images) % rows != 0:
cols += 1
total_imgs = rows * cols
if total_imgs > len(images):
images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
w, h = images[0].size
grid = Image.new(images[0].mode, size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def pils_erode_masks(mask_list):
out_mask_list = []
for idx, mask in enumerate(mask_list):
arr = np.array(mask)
alpha = (arr[:, :, 3] > 127).astype(np.uint8)
# erode 1px
import cv2
alpha = cv2.erode(alpha, np.ones((3, 3), np.uint8), iterations=1)
alpha = (alpha * 255).astype(np.uint8)
out_mask_list.append(Image.fromarray(alpha[:, :, None]))
return out_mask_list
def pils_resize_foreground(
pils: Union[Image.Image, List[Image.Image]],
ratio: float,
) -> List[Image.Image]:
if isinstance(pils, Image.Image):
pils = [pils]
new_pils = []
for image in pils:
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# crop the foreground
fg = image[y1:y2, x1:x2]
# pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
# compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# pad to size, double side
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((0, 0), (0, 0), (0, 0)),
)
new_image = Image.fromarray(new_image, mode="RGBA")
new_pils.append(new_image)
return new_pils