jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import os
import PIL
import PIL.Image
import PIL.ImageOps
import torch
from typing import Union
import torch.nn.functional as F
import torchvision.transforms as TT
def tokenize_with_trigger_word(tokens, weights, num_images, num_tokens, img_token, start_token=49406, end_token=49407, pad_token=0, max_len=77, return_mask=False):
"""
Filters out the image token(s).
Repeats the preceding token if any.
Rebatches.
"""
count = 0
mask = (tokens != start_token) & (tokens != end_token) & (tokens != pad_token)
clean_tokens, clean_tokens_mask = tokens[mask], weights[mask]
img_token_indices = (clean_tokens == img_token).nonzero().view(-1)
split = torch.tensor_split(clean_tokens, img_token_indices + 1, dim=-1)
split_mask = torch.tensor_split(clean_tokens_mask, img_token_indices + 1, dim=-1)
tt = []
ww = []
for chunk, chunk_mask in zip(split, split_mask):
img_token_exists = chunk == img_token
img_token_not_exists = ~img_token_exists
pad_amount = img_token_exists.nonzero().view(-1).shape[0] * num_images * num_tokens
chunk_clean, chunk_mask_clean = chunk[img_token_not_exists], chunk_mask[img_token_not_exists]
if pad_amount > 0 and len(chunk_clean) > 0:
count += 1
tt.append(torch.nn.functional.pad(chunk_clean[:-1], (0, pad_amount), 'constant', chunk_clean[-1] if not return_mask else -1))
ww.append(torch.nn.functional.pad(chunk_mask_clean[:-1], (0, pad_amount), 'constant', chunk_mask_clean[-1] if not return_mask else -1))
if count == 0:
return (tokens, weights, count)
# rebatch and pad
out = []
outw = []
one = torch.tensor([1.0])
for tc, tcw in zip(torch.cat(tt).split(max_len - 2), torch.cat(ww).split(max_len - 2)):
out.append(torch.cat([torch.tensor([start_token]), tc, torch.tensor([end_token])]))
outw.append(torch.cat([one, tcw, one]))
out = torch.nn.utils.rnn.pad_sequence(out, batch_first=True, padding_value=pad_token)
outw = torch.nn.utils.rnn.pad_sequence(outw, batch_first=True, padding_value=1.0)
out = torch.nn.functional.pad(out, (0, max(0, max_len - out.shape[1])), 'constant', pad_token)
outw = torch.nn.functional.pad(outw, (0, max(0, max_len - outw.shape[1])), 'constant', 1.0)
return (out, outw, count)
def load_pil_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
import requests
img = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image_path = folder_paths.get_annotated_filepath(image)
img = Image.open(image_path)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
return img
# from diffusers.utils import load_image
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
Returns:
`PIL.Image.Image`:
A PIL Image.
"""
image = load_pil_image(image)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
from PIL import Image, ImageSequence, ImageOps
import numpy as np
import folder_paths
from nodes import LoadImage
class LoadImageCustom(LoadImage):
def load_image(self, image):
# image_path = folder_paths.get_annotated_filepath(image)
# img = Image.open(image_path)
img = load_pil_image(image)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
def crop_image_pil(image, crop_position):
"""
Crop a PIL image based on the specified crop_position.
Parameters:
- image: PIL Image object
- crop_position: One of "top", "bottom", "left", "right", "center", or "pad"
Returns:
- Cropped PIL Image object
"""
width, height = image.size
left, top, right, bottom = 0, 0, width, height
if "pad" in crop_position:
target_length = max(width, height)
pad_l = max((target_length - width) // 2, 0)
pad_t = max((target_length - height) // 2, 0)
return ImageOps.expand(image, border=(pad_l, pad_t, target_length - width - pad_l, target_length - height - pad_t), fill=0)
else:
crop_size = min(width, height)
x = (width - crop_size) // 2
y = (height - crop_size) // 2
if "top" in crop_position:
bottom = top + crop_size
elif "bottom" in crop_position:
top = height - crop_size
bottom = height
elif "left" in crop_position:
right = left + crop_size
elif "right" in crop_position:
left = width - crop_size
right = width
return image.crop((left, top, right, bottom))
def prepImages(images, *args, **kwargs):
to_tensor = TT.ToTensor()
images_ = []
for img in images:
image = to_tensor(img)
if len(image.shape) <= 3: image.unsqueeze_(0)
images_.append(prepImage(image.movedim(1,-1), *args, **kwargs))
return torch.cat(images_)
def prepImage(image, interpolation="LANCZOS", crop_position="center", size=(224,224), sharpening=0.0, padding=0):
_, oh, ow, _ = image.shape
output = image.permute([0,3,1,2])
if "pad" in crop_position:
target_length = max(oh, ow)
pad_l = (target_length - ow) // 2
pad_r = (target_length - ow) - pad_l
pad_t = (target_length - oh) // 2
pad_b = (target_length - oh) - pad_t
output = F.pad(output, (pad_l, pad_r, pad_t, pad_b), value=0, mode="constant")
else:
crop_size = min(oh, ow)
x = (ow-crop_size) // 2
y = (oh-crop_size) // 2
if "top" in crop_position:
y = 0
elif "bottom" in crop_position:
y = oh-crop_size
elif "left" in crop_position:
x = 0
elif "right" in crop_position:
x = ow-crop_size
x2 = x+crop_size
y2 = y+crop_size
# crop
output = output[:, :, y:y2, x:x2]
# resize (apparently PIL resize is better than torchvision interpolate)
imgs = []
to_PIL_image = TT.ToPILImage()
to_tensor = TT.ToTensor()
for i in range(output.shape[0]):
img = to_PIL_image(output[i])
img = img.resize(size, resample=PIL.Image.Resampling[interpolation])
imgs.append(to_tensor(img))
output = torch.stack(imgs, dim=0)
imgs = None # zelous GC
if padding > 0:
output = F.pad(output, (padding, padding, padding, padding), value=255, mode="constant")
output = output.permute([0,2,3,1])
return output