|
import torch |
|
import ast |
|
import math |
|
from PIL import Image |
|
|
|
|
|
def has_fn(model, fn_name): |
|
"""Check if model has a function fn_name""" |
|
return callable(getattr(model, fn_name, None)) |
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def num_params(module, filter_to_trainable=False): |
|
"""Returns the number of parameters in the module, or optionally only the trainable parameters""" |
|
if filter_to_trainable: |
|
return sum(p.numel() for p in module.parameters() if p.requires_grad) |
|
else: |
|
return sum(p.numel() for p in module.parameters()) |
|
|
|
def hasattr_recursive(obj, att): |
|
""" |
|
Check if obj has nested attribute |
|
Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c') |
|
""" |
|
if att == "": |
|
return True |
|
i = att.find(".") |
|
if i < 0: |
|
return hasattr(obj, att) |
|
else: |
|
try: |
|
return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
|
except: |
|
return False |
|
|
|
def getattr_recursive(obj, att): |
|
""" |
|
Return nested attribute of obj |
|
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c |
|
""" |
|
if att == "": |
|
return obj |
|
i = att.find(".") |
|
if i < 0: |
|
return getattr(obj, att) |
|
else: |
|
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
|
|
|
|
|
def setattr_recursive(obj, att, val): |
|
""" |
|
Set nested attribute of obj |
|
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val |
|
""" |
|
if "." in att: |
|
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) |
|
setattr(obj, att.split(".")[-1], val) |
|
|
|
|
|
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"): |
|
""" |
|
Stack a list of tensors with padding on one side |
|
Args: |
|
list_of_tensors (list[torch.Tensor]): List of tensors to stack |
|
padding_value (int, optional): Value to pad with. Defaults to 0. |
|
padding_side (str, optional): Side to pad on. Defaults to "right". |
|
Returns: |
|
torch.Tensor: Stacked tensors |
|
""" |
|
max_tokens = max(tensor.size(0) for tensor in list_of_tensors) |
|
padded_tensors = [] |
|
for tensor in list_of_tensors: |
|
num_tokens = tensor.size(0) |
|
if len(tensor.size()) == 1: |
|
padding = torch.full( |
|
(max_tokens - num_tokens,), |
|
padding_value, |
|
dtype=tensor.dtype, |
|
device=tensor.device, |
|
) |
|
else: |
|
padding = torch.full( |
|
(max_tokens - num_tokens, tensor.size(1)), |
|
padding_value, |
|
dtype=tensor.dtype, |
|
device=tensor.device, |
|
) |
|
padded_tensor = ( |
|
torch.cat((tensor, padding), dim=0) |
|
if padding_side == "right" |
|
else torch.cat((padding, tensor), dim=0) |
|
) |
|
padded_tensors.append(padded_tensor) |
|
return torch.stack(padded_tensors) |
|
|
|
|
|
def check_embedding_fns(lang_model): |
|
"""Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model""" |
|
if not has_fn(lang_model, "get_input_embeddings"): |
|
if hasattr_recursive(lang_model, "transformer.wte"): |
|
lang_model.get_input_embeddings = lambda: lang_model.transformer.wte |
|
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): |
|
lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens |
|
else: |
|
raise ValueError( |
|
"We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py." |
|
) |
|
|
|
if not has_fn(lang_model, "set_input_embeddings"): |
|
if hasattr_recursive(lang_model, "transformer.wte"): |
|
lang_model.set_input_embeddings = lambda x: setattr_recursive( |
|
lang_model, "transformer.wte", x |
|
) |
|
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): |
|
lang_model.set_input_embeddings = lambda x: setattr_recursive( |
|
lang_model, "model.decoder.embed_tokens", x |
|
) |
|
else: |
|
raise ValueError( |
|
"We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py." |
|
) |
|
|
|
if not has_fn(lang_model, "get_output_embeddings"): |
|
if hasattr_recursive(lang_model, "lm_head"): |
|
lang_model.get_output_embeddings = lambda: lang_model.lm_head |
|
else: |
|
raise ValueError( |
|
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." |
|
) |
|
|
|
if not has_fn(lang_model, "set_output_embeddings"): |
|
if hasattr_recursive(lang_model, "lm_head"): |
|
lang_model.set_output_embeddings = lambda x: setattr_recursive( |
|
lang_model, "lm_head", x |
|
) |
|
else: |
|
raise ValueError( |
|
"We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." |
|
) |
|
|
|
|
|
def has_fn(model, fn_name): |
|
"""Check if model has a function fn_name""" |
|
return callable(getattr(model, fn_name, None)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unpad_image(tensor, original_size, keep_original_shape=False): |
|
""" |
|
Unpads a PyTorch tensor of a padded and resized image. |
|
|
|
Args: |
|
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. |
|
original_size (tuple): The original size of the image (height, width). |
|
|
|
Returns: |
|
torch.Tensor: The unpadded image tensor. |
|
""" |
|
original_width, original_height = original_size |
|
current_height, current_width = tensor.shape[1:] |
|
|
|
original_aspect_ratio = original_width / original_height |
|
current_aspect_ratio = current_width / current_height |
|
|
|
if original_aspect_ratio > current_aspect_ratio: |
|
scale_factor = current_width / original_width |
|
new_height = int(original_height * scale_factor) |
|
padding = (current_height - new_height) // 2 |
|
if keep_original_shape: |
|
attention_mask = torch.ones((current_height, current_width), device=tensor.device) |
|
attention_mask[:padding, :] = 0 |
|
attention_mask[current_height - padding:, :] = 0 |
|
return tensor, attention_mask |
|
else: |
|
unpadded_tensor = tensor[:, padding:current_height - padding, :] |
|
return unpadded_tensor, None |
|
else: |
|
scale_factor = current_height / original_height |
|
new_width = int(original_width * scale_factor) |
|
padding = (current_width - new_width) // 2 |
|
if keep_original_shape: |
|
attention_mask = torch.ones((current_height, current_width), device=tensor.device) |
|
attention_mask[:, :padding] = 0 |
|
attention_mask[:, current_width - padding:] = 0 |
|
return tensor, attention_mask |
|
else: |
|
unpadded_tensor = tensor[:, :, padding:current_width - padding] |
|
return unpadded_tensor, None |
|
|
|
|
|
def select_best_resolution(original_size, possible_resolutions): |
|
""" |
|
Selects the best resolution from a list of possible resolutions based on the original size. |
|
|
|
Args: |
|
original_size (tuple): The original size of the image in the format (width, height). |
|
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. |
|
|
|
Returns: |
|
tuple: The best fit resolution in the format (width, height). |
|
""" |
|
original_width, original_height = original_size |
|
best_fit = None |
|
max_effective_resolution = 0 |
|
min_wasted_resolution = float('inf') |
|
|
|
for width, height in possible_resolutions: |
|
scale = min(width / original_width, height / original_height) |
|
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
|
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
|
wasted_resolution = (width * height) - effective_resolution |
|
|
|
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
|
max_effective_resolution = effective_resolution |
|
min_wasted_resolution = wasted_resolution |
|
best_fit = (width, height) |
|
|
|
return best_fit |
|
|
|
|
|
def resize_and_pad_image(image, target_resolution): |
|
""" |
|
Resize and pad an image to a target resolution while maintaining aspect ratio. |
|
|
|
Args: |
|
image (PIL.Image.Image): The input image. |
|
target_resolution (tuple): The target resolution (width, height) of the image. |
|
|
|
Returns: |
|
PIL.Image.Image: The resized and padded image. |
|
""" |
|
original_width, original_height = image.size |
|
target_width, target_height = target_resolution |
|
|
|
scale_w = target_width / original_width |
|
scale_h = target_height / original_height |
|
|
|
if scale_w < scale_h: |
|
new_width = target_width |
|
new_height = min(math.ceil(original_height * scale_w), target_height) |
|
else: |
|
new_height = target_height |
|
new_width = min(math.ceil(original_width * scale_h), target_width) |
|
|
|
|
|
resized_image = image.resize((new_width, new_height)) |
|
|
|
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) |
|
paste_x = (target_width - new_width) // 2 |
|
paste_y = (target_height - new_height) // 2 |
|
new_image.paste(resized_image, (paste_x, paste_y)) |
|
|
|
return new_image |
|
|
|
|
|
def divide_to_patches(image, patch_size): |
|
""" |
|
Divides an image into patches of a specified size. |
|
|
|
Args: |
|
image (PIL.Image.Image): The input image. |
|
patch_size (int): The size of each patch. |
|
|
|
Returns: |
|
list: A list of PIL.Image.Image objects representing the patches. |
|
""" |
|
patches = [] |
|
width, height = image.size |
|
for i in range(0, height, patch_size): |
|
for j in range(0, width, patch_size): |
|
box = (j, i, j + patch_size, i + patch_size) |
|
patch = image.crop(box) |
|
patches.append(patch) |
|
|
|
return patches |
|
|
|
|
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): |
|
""" |
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. |
|
|
|
Args: |
|
image_size (tuple): The size of the input image in the format (width, height). |
|
grid_pinpoints (str): A string representation of a list of possible resolutions. |
|
patch_size (int): The size of each image patch. |
|
|
|
Returns: |
|
tuple: The shape of the image patch grid in the format (width, height). |
|
""" |
|
if type(grid_pinpoints) is list: |
|
possible_resolutions = grid_pinpoints |
|
else: |
|
possible_resolutions = ast.literal_eval(grid_pinpoints) |
|
width, height = select_best_resolution(image_size, possible_resolutions) |
|
return width // patch_size, height // patch_size |
|
|
|
|
|
def process_anyres_image(image, processor, grid_pinpoints): |
|
""" |
|
Process an image with variable resolutions. |
|
|
|
Args: |
|
image (PIL.Image.Image): The input image to be processed. |
|
processor: The image processor object. |
|
grid_pinpoints (str): A string representation of a list of possible resolutions. |
|
|
|
Returns: |
|
torch.Tensor: A tensor containing the processed image patches. |
|
""" |
|
|
|
if type(grid_pinpoints) is list: |
|
possible_resolutions = grid_pinpoints |
|
else: |
|
possible_resolutions = ast.literal_eval(grid_pinpoints) |
|
best_resolution = select_best_resolution(image.size, possible_resolutions) |
|
image_padded = resize_and_pad_image(image, best_resolution) |
|
|
|
processor_size = processor.transforms[0].size |
|
patches = divide_to_patches(image_padded, processor_size[0]) |
|
|
|
image_original_resize = image.resize((processor_size[0], processor_size[0])) |
|
|
|
image_patches = [image_original_resize] + patches |
|
image_patches = [processor(image_patch) |
|
for image_patch in image_patches] |
|
return torch.stack(image_patches, dim=0) |
|
|
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
|
|
def process_images(images, image_processor, model_cfg): |
|
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) |
|
new_images = [] |
|
if image_aspect_ratio == 'pad': |
|
for image in images: |
|
image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean)) |
|
image = image_processor(image) |
|
new_images.append(image) |
|
elif image_aspect_ratio in ["anyres", "anyres-legacy"]: |
|
base_img_size = image_processor.transforms[0].size[0] |
|
for image in images: |
|
image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2], |
|
[base_img_size*2,base_img_size], |
|
[base_img_size*2,base_img_size*2], |
|
[base_img_size*3,base_img_size], |
|
[base_img_size,base_img_size*3]]) |
|
|
|
|
|
|
|
new_images.append(image) |
|
else: |
|
return image_processor(images) |
|
if all(x.shape == new_images[0].shape for x in new_images): |
|
new_images = torch.stack(new_images, dim=0) |
|
return new_images |
|
|
|
|
|
|