import torch import os import asyncio import requests from io import BytesIO from PIL import Image from urllib.parse import urlparse import numpy as np def split_image_ur(img, max_slice_num, image_size, vit_image_size, force_min_size=False): if force_min_size: img = resize_by_patch_size_ur(img, min_size= image_size, max_size= image_size * max_slice_num, patch_size=14) slice_config = { "max_slice_nums": max_slice_num, "scale_resolution": image_size, "patch_size": 14 } source_image, sub_images, _ = do_slice_by_minicpmv_strategy_ur( img, max_slice_nums=slice_config["max_slice_nums"], scale_resolution=slice_config["scale_resolution"], patch_size=slice_config["patch_size"], vit_image_size=vit_image_size) splits = [] splits.append(source_image) for i in range(len(sub_images)): for j in range(len(sub_images[0])): splits.append(sub_images[i][j]) sliced_images, sliced_shapes = [], [] for slice_image in splits: sliced_images.append(slice_image) sliced_shapes.append(np.array((slice_image.size[0] // slice_config["patch_size"], slice_image.size[1] // slice_config["patch_size"]))) return sliced_images, sliced_shapes import math from PIL import Image import torch import torchvision.transforms.functional as F from torchvision.transforms import InterpolationMode # Strategy: MiniCPM-V def do_slice_by_minicpmv_strategy_ur(image, max_slice_nums=9, scale_resolution=1120, patch_size=14, vit_image_size=448, never_split=False): original_size = image.size original_width, original_height = original_size log_ratio = math.log(original_width / original_height) ratio = original_width * original_height / (scale_resolution * scale_resolution) multiple = min(math.ceil(ratio), max_slice_nums) source_image = None best_grid = None patches = [] if multiple <= 1 or never_split: # dont need to slice, upsample # best_size = find_best_resize( # original_size, scale_resolution, patch_size, allow_upscale=True # ) best_size = (scale_resolution, scale_resolution) source_image = image.resize(best_size, Image.BICUBIC) border_size = (vit_image_size-scale_resolution)/2 from PIL import ImageOps source_image = ImageOps.expand(source_image, border=int(border_size), fill=(0,0,0)) else: candidate_split_grids_nums = [] for i in [multiple - 1, multiple, multiple + 1]: if i == 1 or i > max_slice_nums: continue candidate_split_grids_nums.append(i) # source image, down-sampling and ensure divided by patch_size # best_resize = find_best_resize(original_size, scale_resolution, patch_size) # source_image = image.copy().resize(best_resize, Image.BICUBIC) source_image = image.copy().resize((scale_resolution,scale_resolution), Image.BICUBIC) candidate_grids = [] # find best grid for split_grids_nums in candidate_split_grids_nums: m = 1 while m <= split_grids_nums: if split_grids_nums % m == 0: candidate_grids.append([m, split_grids_nums // m]) m += 1 # print("candidate_grids: ", candidate_grids) best_grid = [1, 1] min_error = float("inf") for grid in candidate_grids: error = abs(log_ratio - math.log(grid[0] / grid[1])) if error < min_error: best_grid = grid min_error = error refine_size = get_refine_size( original_size, best_grid, scale_resolution, patch_size, allow_upscale=True ) refine_image = image.resize(refine_size, Image.BICUBIC) patches = split_to_patches(refine_image, best_grid, scale_resolution, vit_image_size) return source_image, patches, best_grid def ensure_divide(length, patch_size): return max(round(length / patch_size) * patch_size, patch_size) def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): width, height = original_size if (width * height > scale_resolution * scale_resolution) or allow_upscale: r = width / height height = int(scale_resolution / math.sqrt(r)) width = int(height * r) best_width = ensure_divide(width, patch_size) best_height = ensure_divide(height, patch_size) # print(best_width, best_height, scale_resolution) while best_width * best_height > scale_resolution ** 2: # print(best_width) best_width -= patch_size return (best_width, best_height) def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False): width, height = original_size grid_x, grid_y = grid # refine_width = ensure_divide(width, grid_x) # refine_height = ensure_divide(height, grid_y) # grid_width = refine_width / grid_x # grid_height = refine_height / grid_y # best_grid_size = find_best_resize( # (grid_width, grid_height), # scale_resolution, # patch_size, # allow_upscale=allow_upscale, # ) refine_size = (scale_resolution * grid_x, scale_resolution * grid_y) return refine_size def split_to_patches(image, grid, scale_resolution, vit_image_size): patches = [] width, height = image.size grid_x = int(width / grid[0]) grid_y = int(height / grid[1]) from PIL import ImageOps border_size = (vit_image_size - scale_resolution)/2 padded_img = ImageOps.expand(image, border=int(border_size), fill=(0,0,0)) padded_width, padded_height = padded_img.size for i in range(0, padded_height-vit_image_size+1, scale_resolution): images = [] for j in range(0, padded_width-vit_image_size+1, scale_resolution): box = (j, i, j + vit_image_size, i + vit_image_size) patch = padded_img.crop(box) images.append(patch) patches.append(images) return patches def resize_by_patch_size_ur(img, min_size=1152, max_size=2240, patch_size=14): interpolation=InterpolationMode.BICUBIC # min_size=756, max_size=756 * 4, patch_size=14 if isinstance(img, torch.Tensor): height, width = img.shape[:2] else: width, height = img.size # Check if the shorter side is less than min_size if min(height, width) < min_size: # print('less than min_size') scale_factor = min_size / min(height, width) new_height = max(min_size, round(height * scale_factor)) new_width = max(min_size, round(width * scale_factor)) # print(self.max_size) # Check if the longer side after resizing is greater than max_size if max(new_height, new_width) > max_size: scale_factor = max_size / max(new_height, new_width) new_height = min(max_size, round(new_height * scale_factor)) new_width = min(max_size, round(new_width * scale_factor)) else: scale_factor = min(max_size / max(height, width), 1) new_height = round(height * scale_factor) new_width = round(width * scale_factor) # # Make sure the new height and width are divisible by patch_size # new_height = (new_height // patch_size) * patch_size # new_width = (new_width // patch_size) * patch_size # Resize the image # img = F.resize(img, (new_height, new_width), interpolation) img = img.resize((new_width, new_height), Image.BICUBIC) return img