from PIL import Image from io import BytesIO import base64 import torch import math import ast import copy import numpy as np import random from transformers import StoppingCriteria, CLIPImageProcessor, SiglipImageProcessor from llava.constants import MM_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN def select_best_resolution(original_size, possible_resolutions): """ Selects the best resolution from a list of possible resolutions based on the original size. Args: original_size (tuple): The original size of the image in the format (width, height). possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. Returns: tuple: The best fit resolution in the format (width, height). """ original_width, original_height = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float('inf') for width, height in possible_resolutions: scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit def resize_and_pad_image(image, target_resolution): """ Resize and pad an image to a target resolution while maintaining aspect ratio. Args: image (PIL.Image.Image): The input image. target_resolution (tuple): The target resolution (width, height) of the image. Returns: PIL.Image.Image: The resized and padded image. """ original_width, original_height = image.size target_width, target_height = target_resolution scale_w = target_width / original_width scale_h = target_height / original_height if scale_w < scale_h: new_width = target_width new_height = min(math.ceil(original_height * scale_w), target_height) else: new_height = target_height new_width = min(math.ceil(original_width * scale_h), target_width) # Resize the image resized_image = image.resize((new_width, new_height)) new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) paste_x = (target_width - new_width) // 2 paste_y = (target_height - new_height) // 2 new_image.paste(resized_image, (paste_x, paste_y)) return new_image def divide_to_patches(image, patch_size): """ Divides an image into patches of a specified size. Args: image (PIL.Image.Image): The input image. patch_size (int): The size of each patch. Returns: list: A list of PIL.Image.Image objects representing the patches. """ patches = [] width, height = image.size for i in range(0, height, patch_size): for j in range(0, width, patch_size): box = (j, i, j + patch_size, i + patch_size) patch = image.crop(box) patches.append(patch) return patches def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (tuple): The size of the input image in the format (width, height). grid_pinpoints (str): A string representation of a list of possible resolutions. patch_size (int): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) width, height = select_best_resolution(image_size, possible_resolutions) return width // patch_size, height // patch_size def process_anyres_image(image, processor, grid_pinpoints): """ Process an image with variable resolutions. Args: image (PIL.Image.Image): The input image to be processed. processor: The image processor object. grid_pinpoints (str): A string representation of a list of possible resolutions. Returns: torch.Tensor: A tensor containing the processed image patches. """ if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) patches = divide_to_patches(image_padded, processor.crop_size['height'] if hasattr(processor, 'crop_size') else processor.size['height']) if isinstance(processor, CLIPImageProcessor) or isinstance(processor, SiglipImageProcessor): image_original_resize = image.resize((processor.size['height'], processor.size['width'])) image_patches = [image_original_resize] + patches image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] for image_patch in image_patches] else: image_original_resize = image.resize((processor.img_size, processor.img_size)) image_patches = [image_original_resize] + patches image_patches = [processor.preprocess(image_patch) for image_patch in image_patches] return torch.stack(image_patches, dim=0) def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def process_images(images, image_processor, model_cfg): image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) new_images = [] if image_aspect_ratio == 'pad': for image in images: image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] new_images.append(image) elif image_aspect_ratio == "anyres": for image in images: image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) new_images.append(image) else: return image_processor(images, return_tensors='pt')['pixel_values'] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images def process_images_v2(images, image_processor, model_cfg): image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) new_images = [] if image_aspect_ratio == 'pad': for image in images: image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) if isinstance(image_processor, CLIPImageProcessor) or isinstance(image_processor, SiglipImageProcessor): image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = image_processor.preprocess(image) new_images.append(image) elif image_aspect_ratio == "anyres": for image in images: image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) new_images.append(image) else: for image in images: if isinstance(image_processor, CLIPImageProcessor) or isinstance(image_processor, SiglipImageProcessor): image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = image_processor.preprocess(image) new_images.append(image) if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images def tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX=MM_TOKEN_INDEX, return_tensors=None): mm_token = DEFAULT_VIDEO_TOKEN if DEFAULT_VIDEO_TOKEN in prompt else DEFAULT_IMAGE_TOKEN prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(mm_token)] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [MM_TOKEN_INDEX] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] if len(cur_keyword_ids) > self.max_keyword_len: self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] if torch.equal(truncated_output_ids, keyword_id): return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: outputs = [] for i in range(output_ids.shape[0]): outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) return all(outputs) def get_frame_indices(num_segments, vlen, sample='rand', fix_start=None, input_fps=1, pad_last=False): if sample in ['rand', 'middle']: # uniform sampling num_segments = min(num_segments, vlen) intervals = np.linspace(start=0, stop=vlen, num=num_segments + 1).astype(int) ranges = [] for idx, interv in enumerate(intervals[:-1]): ranges.append((interv, intervals[idx + 1] - 1)) if sample == 'rand': try: frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] except: frame_indices = np.random.permutation(vlen)[:num_segments] frame_indices.sort() frame_indices = list(frame_indices) elif fix_start is not None: frame_indices = [x[0] + fix_start for x in ranges] elif sample == 'middle': frame_indices = [(x[0] + x[1]) // 2 for x in ranges] if pad_last: if len(frame_indices) < num_segments: padded_frame_indices = [frame_indices[-1]] * num_segments padded_frame_indices[:len(frame_indices)] = frame_indices frame_indices = padded_frame_indices elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps output_fps = float(sample[3:]) duration = float(vlen) / input_fps delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) frame_indices = np.around(frame_seconds * input_fps).astype(int) frame_indices = [e for e in frame_indices if e < vlen] if num_segments > 0 and len(frame_indices) > num_segments: cand_indices = copy.deepcopy(frame_indices) intervals = np.linspace(start=0, stop=len(cand_indices), num=num_segments + 1).astype(int) ranges = [] for idx, interv in enumerate(intervals[:-1]): ranges.append((interv, intervals[idx + 1] - 1)) try: frame_indices = [cand_indices[random.choice(range(x[0], x[1]))] for x in ranges] except: frame_indices = [cand_indices[x[0]] for x in ranges] else: raise NotImplementedError if len(frame_indices) == 0: frame_indices = [0] return frame_indices