flashsloth / mm_utils.py
Tongbo's picture
Upload folder using huggingface_hub
04f8e39 verified
from PIL import Image
from io import BytesIO
import base64
import torch
from transformers import StoppingCriteria
from flashsloth.constants import IMAGE_TOKEN_INDEX
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def extract_patches(image, patch_size, overlap_ratio):
assert isinstance(image, Image.Image), "Input should be a Pillow Image"
assert patch_size > 0, "Patch size should be greater than 0"
assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
W, H = image.size
patches = []
stride = int(patch_size * (1 - overlap_ratio))
num_patches_y = (H - patch_size) // stride + 1
num_patches_x = (W - patch_size) // stride + 1
y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
for y in range(y_start, y_start + num_patches_y * stride, stride):
for x in range(x_start, x_start + num_patches_x * stride, stride):
patch = image.crop((x, y, x + patch_size, y + patch_size))
patches.append(patch)
return patches
def process_images_hd(image, processor, model_cfg):
select_size = 768
image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
image_padded = image_padded.resize((select_size, select_size))
image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
image_patches = [image_original_resize] + image_patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def process_images_hd_inference(image_list, processor, model_cfg):
select_size = 768
processed_images = []
for image in image_list:
image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
image_padded = image_padded.resize((select_size, select_size))
image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
image_patches = [image_original_resize] + image_patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
processed_images.append(torch.stack(image_patches, dim=0))
return torch.stack(processed_images, dim=0)
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
else:
return image_processor(images, return_tensors='pt')['pixel_values']
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)