Spaces:
Runtime error
Runtime error
from typing import Any, Callable, Optional, List | |
import torch | |
from transformers import PreTrainedTokenizer | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
import json | |
import os | |
from PIL import Image | |
from univa.utils.prompter import Prompter | |
import numpy as np | |
from einops import rearrange | |
import random | |
# from qwen_vl_utils.vision_process import fetch_image, fetch_video | |
from qwen_vl_utils.vision_process import to_rgb, smart_resize, fetch_video | |
from univa.utils.constant import SPACIAL_TOKEN, GENERATE_TOKEN | |
from univa.utils.get_mask import get_weight_mask | |
from univa.utils.get_ocr import get_ocr_result | |
from fractions import Fraction | |
from torchvision.transforms import functional | |
from torchvision import transforms | |
from io import BytesIO | |
import base64 | |
import requests | |
import torch | |
from PIL import Image | |
from torchvision import io, transforms | |
from typing import Optional | |
def get_aspect_ratio(img): | |
width, height = img.size | |
return Fraction(width, height).limit_denominator() | |
def has_same_aspect_ratio(img1, img2): | |
if not isinstance(img1, Image.Image): | |
img1 = Image.open(img1).convert('RGB') | |
if not isinstance(img2, Image.Image): | |
img2 = Image.open(img2).convert('RGB') | |
ratio1 = get_aspect_ratio(img1) | |
ratio2 = get_aspect_ratio(img2) | |
return ratio1 == ratio2 | |
def has_same_resolution(img1, img2): | |
if not isinstance(img1, Image.Image): | |
img1 = Image.open(img1).convert('RGB') | |
if not isinstance(img2, Image.Image): | |
img2 = Image.open(img2).convert('RGB') | |
return img1.size == img2.size | |
class Qwen2VLDataset(Dataset): | |
def __init__( | |
self, | |
dataset_type: str, | |
data_txt: str, | |
transform: Callable, | |
tokenizer: PreTrainedTokenizer, | |
prompter: Prompter, | |
image_processor: Callable, | |
processor: Callable = None, | |
min_pixels: int = 384*384, | |
max_pixels: int = 384*384, | |
image_token_length: int = 729, | |
only_generated_task: bool = False, | |
drop_prompt_rate: float = 0.0, | |
joint_ref_feature: bool = False, | |
anyres: bool = False, | |
mask_weight_type: str = 'log', | |
siglip_processor: Callable = None, | |
ocr_enhancer: bool = False, | |
random_data: bool = False, | |
maxnum_per_data: int = -1, | |
notry: bool = False, | |
): | |
assert dataset_type == 'qwen2vl' or dataset_type == 'qwen2p5vl', "dataset_type == 'qwen2vl' or dataset_type == 'qwen2p5vl'" | |
with open(data_txt, "r") as f: | |
self.datasets = [line.strip() for line in f.readlines()] | |
self.data = [] | |
self._load_data(maxnum_per_data) | |
self.transform = transform | |
self.processor = processor | |
self.tokenizer = processor.tokenizer | |
self.prompter = prompter | |
self.min_pixels = min_pixels | |
self.max_pixels = max_pixels | |
self.image_token = SPACIAL_TOKEN[dataset_type]['image_token'] | |
self.image_begin_token = SPACIAL_TOKEN[dataset_type]['image_begin_token'] | |
self.image_end_token = SPACIAL_TOKEN[dataset_type]['image_end_token'] | |
self.generated_image_token = GENERATE_TOKEN | |
self.image_processor = processor.image_processor | |
# self.factor = 4 if joint_ref_feature else 1 | |
self.factor = 2 | |
self.only_generated_task = only_generated_task # For denoiser training | |
self.drop_prompt_rate = drop_prompt_rate | |
if self.drop_prompt_rate > 0: | |
assert self.only_generated_task, ( | |
"Only generated task is supported when drop_prompt_rate > 0" | |
) | |
self.mask_weight_type = mask_weight_type | |
self.siglip_processor = siglip_processor | |
self.ocr_enhancer = ocr_enhancer | |
self.random_data = random_data | |
self.notry = notry | |
# Add image token if not exists. | |
assert self.image_token in self.tokenizer.get_vocab() | |
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) | |
self.image_begin_token_id = self.tokenizer.convert_tokens_to_ids( | |
self.image_begin_token | |
) | |
assert isinstance(self.image_begin_token_id, int), ( | |
f"tokenizer miss image begin token `{self.image_begin_token}`" | |
) | |
self.image_end_token_id = self.tokenizer.convert_tokens_to_ids( | |
self.image_end_token | |
) | |
assert isinstance(self.image_end_token_id, int), ( | |
f"tokenizer miss image end token `{self.image_end_token}`" | |
) | |
def _load_data(self, maxnum_per_data=-1): | |
for dataset in self.datasets: | |
image_root, json_file, need_weight = dataset.split(",") | |
# Load json file | |
with open(json_file, "r") as f: | |
data = json.load(f) | |
if maxnum_per_data > 0 and maxnum_per_data < len(data): | |
print(f'original data: {len(data)}, sample: {maxnum_per_data}') | |
data = random.sample(data, maxnum_per_data) | |
dataset_data = [] | |
for line in tqdm(data): | |
if "image" not in line: | |
line["image"] = [] | |
# Ensure `image` is a list | |
if isinstance(line["image"], str): | |
line["image"] = [line["image"]] | |
assert isinstance(line["image"], list), ( | |
"`image` must be a str or a list." | |
) | |
# Convert image path to absolute path | |
line["need_weight"] = need_weight | |
line["image"] = [ | |
os.path.join(image_root, image_path) for image_path in line["image"] | |
] | |
dataset_data.append(line) | |
print(f"Load {len(dataset_data)} data from {json_file}.") | |
self.data.extend(dataset_data) | |
def __len__(self): | |
return len(self.data) | |
def _get_random_data(self, ): | |
prompt = self.prompter( | |
[ | |
{"from": "system", "value": "You are a helpful assistant."}, | |
{ | |
"from": "user", | |
"value": f"test an image {self.image_token}", | |
}, | |
] | |
) | |
input_ids = self.tokenizer.batch_encode_plus( | |
[prompt], return_tensors="pt", truncation=False, | |
).input_ids | |
labels = input_ids | |
width, height = 448, 448 | |
random_data = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) | |
image = Image.fromarray(random_data, 'RGB') | |
image_slice = [image] | |
image_dict = self._load_image( | |
image_slice, self.max_pixels, self.min_pixels, | |
processor=self.processor, image_token=self.image_token, | |
factor=self.factor, | |
last_image=image, | |
vae_image_transform=self.transform, | |
drop_prompt=False, | |
prompt=prompt, | |
mask_weight_type=self.mask_weight_type, | |
siglip_processor=self.siglip_processor, | |
) | |
image_token_lengths = image_dict['image_token_lengths'] | |
pixel_values = image_dict['pixel_values'] | |
image_grid_thw = image_dict['image_grid_thw'] | |
ref_pixel_values = image_dict['ref_pixel_values'] | |
pil_pixel_values = image_dict['pil_pixel_values'] | |
siglip_pixel_values = image_dict['siglip_pixel_values'] | |
weights = image_dict['weights'] | |
input_ids, labels, image_position = self._process_image_token( | |
input_ids, | |
labels=labels, | |
image_token_id=self.image_token_id, | |
image_begin_token_id=self.image_begin_token_id, | |
image_end_token_id=self.image_end_token_id, | |
image_token_lengths=image_token_lengths, | |
) | |
generated_image = torch.randn(3, 512, 512) | |
return_data = { | |
"input_ids": input_ids, | |
"labels": labels, | |
"pixel_values": pixel_values, | |
"image_position": image_position, | |
"image_grid_thw": image_grid_thw, | |
"prompt": prompt, | |
"ref_pixel_values": ref_pixel_values, | |
"pil_pixel_values": pil_pixel_values, | |
"siglip_pixel_values": siglip_pixel_values, | |
"weights": weights, | |
"generated_image": generated_image, | |
} | |
return return_data | |
def getitem(self, data): | |
# Reformat the conversation to the format of prompter | |
conversations = [] | |
prompt = "" | |
for item in data["conversations"]: | |
if item["from"] == "human": | |
role = self.prompter.user_role | |
prompt = item["value"] | |
elif item["from"] == "gpt": | |
role = self.prompter.assistant_role | |
else: | |
raise ValueError(f"Unknown role: {item['from']}") | |
conversations.append({"from": role, "value": item["value"]}) | |
assert prompt != "", "prompt != ''" | |
# The last turn instruction will be used for t5_embed | |
prompt = prompt.replace('<image>', '').replace('\n', '') | |
# Make prompt | |
drop_prompt = False | |
if self.only_generated_task: | |
if self.drop_prompt_rate < random.random(): # Randomly drop the prompt | |
prompt_list = self.prompter.get_train_prompt(conversations) | |
else: | |
drop_prompt = True | |
num_images = (''.join([i['value'] for i in conversations])).count('<image>') | |
# Drop the prompt | |
prompt_list = [ | |
{ | |
"from": self.prompter.system_role, | |
"value": "You are a helpful assistant.", | |
}, | |
{ | |
"from": self.prompter.user_role, | |
# "value": f"{num_images * '<image>'} Generate an image.", | |
"value": "Generate an image.", | |
}, | |
{ | |
"from": self.prompter.assistant_role, | |
"value": self.generated_image_token, | |
}, | |
] | |
prompt_list = self.prompter.get_train_prompt(prompt_list) | |
else: | |
prompt_list = self.prompter.get_train_prompt(conversations) | |
input_ids = [] | |
labels = [] | |
has_generated_image = False | |
cur_i = 0 | |
for item in prompt_list: | |
item["prompt"] = item["prompt"].replace('<image>', self.image_token) | |
if self.generated_image_token in item["prompt"]: # Check if self.generated_image_token in prompt | |
assert item["from"] == self.prompter.assistant_role, ( | |
"Generated image token must be in assistant role" | |
) | |
assert ( | |
f"{self.generated_image_token}{self.prompter.eos_token}" | |
in item["prompt"] | |
), "Generated image token must in end of prompt" | |
# Replace the generated image token with image begin token and without eos token | |
item["prompt"] = item["prompt"].replace( | |
f"{self.generated_image_token}{self.prompter.eos_token}", | |
self.image_begin_token, | |
) | |
has_generated_image = True | |
if self.ocr_enhancer and (self.image_token in item["prompt"]): | |
# print('item["prompt"]', item["prompt"]) | |
if not has_generated_image: | |
num_img = item["prompt"].count(self.image_token) | |
ocr_sentences = [] | |
for i in range(num_img): | |
ocr_sentences.append(get_ocr_result(data["image"][cur_i], cur_i)) | |
cur_i += 1 | |
ocr_sentences = '\n'.join(ocr_sentences) | |
if len(ocr_sentences.split()) > 256: | |
print(f'ocr_sentences too long, total len {len(ocr_sentences.split())} trunk first 256') | |
ocr_sentences = ' '.join(ocr_sentences.split()[:256]) | |
# ocr_sentences = '' | |
assert item['prompt'][-len(self.prompter.eos_token):] == self.prompter.eos_token, \ | |
"item['prompt'][-len(self.prompter.eos_token):] == self.prompter.eos_token" | |
assert item['prompt'].count(self.prompter.eos_token) == 1, \ | |
"item['prompt'].count(self.prompter.eos_token) == 1" | |
item["prompt"] = item["prompt"].replace(self.prompter.eos_token, f'{ocr_sentences} {self.prompter.eos_token}') | |
tokenized_item = self.tokenizer( | |
item["prompt"], | |
return_tensors="pt", | |
truncation=True, | |
max_length=1024, | |
) | |
if item["is_labels"]: # If this prompt is labels | |
labels.append(tokenized_item.input_ids) | |
else: | |
labels.append(torch.full_like(tokenized_item.input_ids, -100)) | |
input_ids.append(tokenized_item.input_ids) | |
if ( | |
self.only_generated_task and not has_generated_image | |
): # For denoiser training | |
raise ValueError( | |
f"Only generated task is not supported. But this prompt not contains generated image token: {prompt_list[0]['prompt']}" | |
) | |
input_ids = torch.cat(input_ids, dim=1) | |
labels = torch.cat(labels, dim=1) | |
# Load images | |
if has_generated_image: | |
# generate task | |
# process images but exclude the last image, which need to generate | |
image_slice = data["image"][:-1] | |
else: | |
# understanding task | |
image_slice = data["image"] | |
image_dict = self._load_image( | |
image_slice, self.max_pixels, self.min_pixels, | |
processor=self.processor, image_token=self.image_token, | |
factor=self.factor, | |
last_image=data["image"][-1] if has_generated_image else None, | |
vae_image_transform=self.transform, | |
drop_prompt=drop_prompt, | |
prompt=prompt, | |
mask_weight_type=self.mask_weight_type, | |
siglip_processor=self.siglip_processor, | |
need_weight=data['need_weight'], | |
) | |
image_token_lengths = image_dict['image_token_lengths'] | |
pixel_values = image_dict['pixel_values'] | |
image_grid_thw = image_dict['image_grid_thw'] | |
ref_pixel_values = image_dict['ref_pixel_values'] | |
pil_pixel_values = image_dict['pil_pixel_values'] | |
siglip_pixel_values = image_dict['siglip_pixel_values'] | |
weights = image_dict['weights'] | |
input_ids, labels, image_position = self._process_image_token( | |
input_ids, | |
labels=labels, | |
image_token_id=self.image_token_id, | |
image_begin_token_id=self.image_begin_token_id, | |
image_end_token_id=self.image_end_token_id, | |
image_token_lengths=image_token_lengths, | |
) | |
return_data = { | |
"input_ids": input_ids, | |
"labels": labels, | |
"pixel_values": pixel_values, | |
"image_position": image_position, | |
"image_grid_thw": image_grid_thw, | |
"prompt": prompt, | |
"ref_pixel_values": ref_pixel_values, | |
"pil_pixel_values": pil_pixel_values, | |
"siglip_pixel_values": siglip_pixel_values, | |
"weights": weights, | |
} | |
if has_generated_image: # If this item is a generation task | |
image = Image.open(data["image"][-1]).convert("RGB") | |
# if self.anyres: | |
# image = image.resize(pil_pixel_values[-1].size) | |
image_tensor = torch.tensor(np.array(image)) / 255.0 # scale to 0-1 | |
image_tensor = rearrange(image_tensor, "h w c -> c h w") | |
return_data["generated_image"] = self.transform(image_tensor) | |
else: | |
return_data["generated_image"] = [] | |
return return_data | |
def __getitem__(self, idx): | |
if self.random_data: | |
return self._get_random_data() | |
data: Any = self.data[idx] | |
if self.notry: | |
return self.getitem(data) | |
try: | |
return self.getitem(data) | |
except Exception as e: | |
print(f'Error with {e}') | |
return self.__getitem__(random.randint(0, self.__len__()-1)) | |
def _load_image( | |
image_slice: List[str], | |
max_pixels: int = 448*448, | |
min_pixels: int = 448*448, | |
processor: Callable = None, | |
image_processor: Callable = None, | |
image_token_lengths: int = 729, | |
image_token: str = '<|image_pad|>', | |
factor: int = 1, | |
last_image: Optional[str] = None, | |
vae_image_transform: Callable = None, | |
drop_prompt: bool = False, | |
prompt: str = '', | |
mask_weight_type: str = None, | |
siglip_processor: Callable = None, | |
need_weight: str = 'true', | |
): | |
resize_ref_image = False | |
pil_pixel_values_last = [] | |
if last_image is not None: | |
last_vision_infos = dict( | |
image=last_image, min_pixels=min_pixels, max_pixels=max_pixels | |
) | |
# last_image will be resize by qwenvl-processor automatically | |
# generated variable resolution | |
last_image_inputs, last_video_inputs = process_vision_info([last_vision_infos], factor=factor) | |
# logging what size will be process when use qwenvl-processor | |
pil_pixel_values_last.append(last_image_inputs[0]) | |
# not all reference images are same resolution | |
# if multiple reference images and they have different resolution, resize it depend on last_image (generated_image) | |
if not all([has_same_resolution(image_path, last_image) for image_path in image_slice]): | |
resize_ref_image = True | |
resize_w, resize_h = last_image_inputs[0].size | |
image_token_lengths = [] | |
pixel_values = [] | |
image_grid_thw = [] | |
ref_pixel_values = [] | |
pil_pixel_values = [] | |
siglip_pixel_values = [] | |
# Ignore the last image (generated image) | |
for image_path in image_slice: | |
vision_infos = dict(image=image_path, min_pixels=min_pixels, max_pixels=max_pixels) | |
# if multiple reference images and they have different aspect ratio, resize it depend on generated_image (last_image) | |
if resize_ref_image: | |
vision_infos.update( | |
dict(resized_height=resize_h, resized_width=resize_w) | |
) | |
image_inputs, video_inputs = process_vision_info([vision_infos], factor=factor) | |
inputs = processor(text=[f'dummy {image_token}'], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") | |
if not drop_prompt: | |
pixel_values.append(inputs.pixel_values) # inputs.pixel_values shape is (token, dim) | |
image_grid_thw.append(inputs.image_grid_thw) # image_grid_thw List[int, int, int] | |
image_token_length = (inputs.input_ids[0] == processor.tokenizer.convert_tokens_to_ids(image_token)).sum() | |
image_token_lengths.append(image_token_length) | |
image_tensor = torch.tensor(np.array(image_inputs[0])) / 255.0 # scale to 0-1 | |
image_tensor = rearrange(image_tensor, "h w c -> 1 c h w") | |
if vae_image_transform is not None: | |
# image_tensor has been resized by qwenvl-processor | |
image_tensor = (image_tensor - 0.5) / 0.5 # shift [0, 1] to [-1, 1] | |
pil_pixel_values.append(image_inputs[0]) | |
if siglip_processor is not None: | |
siglip_pixel_value = siglip_processor.preprocess( | |
images=Image.open(image_path).convert('RGB') if isinstance(image_path, str) else image_path, | |
do_resize=True, return_tensors="pt", do_convert_rgb=True | |
).pixel_values # 1 c h w | |
if drop_prompt: | |
siglip_pixel_values.append(torch.zeros_like(siglip_pixel_value)) | |
else: | |
siglip_pixel_values.append(siglip_pixel_value) | |
# use zero_image as uncondition reference image | |
if drop_prompt: | |
ref_pixel_values.append(torch.zeros_like(image_tensor)) | |
else: | |
ref_pixel_values.append(image_tensor) | |
# if multi-image in a sample, concat them | |
# assume pixel_values[0] (n1, 1176), pixel_values[1] (n2, 1176), pixel_values will be (n1+n2, 1176) | |
if len(pixel_values) > 0: | |
pixel_values = torch.concat(pixel_values) | |
image_grid_thw = torch.concat(image_grid_thw) # (b, 3), 3 mean the grid of t, h, w | |
# if len(ref_pixel_values) > 0: | |
# ref_pixel_values = torch.concat(ref_pixel_values) # b c h w | |
ref_pixel_values = [] | |
if len(siglip_pixel_values) > 0: | |
siglip_pixel_values = torch.concat(siglip_pixel_values) # b c h w | |
pil_pixel_values = pil_pixel_values + pil_pixel_values_last | |
if mask_weight_type is not None: | |
_, weights = get_weight_mask(pil_pixel_values, prompt, mask_weight_type, need_weight) | |
if need_weight.lower() == 'false': | |
assert torch.all(weights == 1) | |
else: | |
weights = [] | |
return { | |
'pixel_values': pixel_values, | |
'image_grid_thw': image_grid_thw, | |
'image_token_lengths': image_token_lengths, | |
'ref_pixel_values': ref_pixel_values, | |
'pil_pixel_values': pil_pixel_values, | |
'siglip_pixel_values': siglip_pixel_values, | |
'weights': weights, | |
} | |
def _process_image_token( | |
input_ids: torch.Tensor, | |
image_token_id: int, | |
image_begin_token_id: int, | |
image_end_token_id: int, | |
image_token_lengths: List[int], | |
labels: Optional[torch.Tensor] = None, | |
): | |
# Find the indices of the image token | |
image_token_indices = (input_ids == image_token_id).nonzero(as_tuple=True) | |
# assert len(image_token_lengths) == image_token_indices[1].numel() | |
image_position = [] | |
offset = 0 | |
cur_i = 0 | |
if isinstance(image_token_lengths, int): | |
image_token_lengths = [image_token_lengths] * len(image_token_indices[1]) | |
for idx in image_token_indices[1]: | |
image_token_length = image_token_lengths[cur_i] | |
adjusted_idx = idx + offset | |
assert input_ids[0, adjusted_idx] == image_token_id, "assert input_ids[0, adjusted_idx] == image_token_id" | |
# Add image begin and end token | |
input_ids = torch.cat( | |
[ | |
input_ids[:, :adjusted_idx], | |
input_ids.new_full( | |
(1, 1), image_begin_token_id | |
), # image begin token | |
input_ids.new_full( | |
(1, image_token_length), image_token_id | |
), # Repeat the image token to the length of image_token_length | |
input_ids.new_full((1, 1), image_end_token_id), # image end token | |
input_ids[:, adjusted_idx + 1 :], | |
], | |
dim=1, | |
) | |
if labels is not None: | |
labels = torch.cat( | |
[ | |
labels[:, :adjusted_idx], | |
labels.new_full( | |
(1, 1), image_begin_token_id | |
), # Make begin token as label | |
labels.new_full((1, image_token_length), -100), | |
labels.new_full((1, 1), -100), | |
labels[:, adjusted_idx + 1 :], | |
], | |
dim=1, | |
) | |
adjusted_idx += 1 # skip the image begin token | |
image_position.append(adjusted_idx.item()) | |
offset += image_token_length - 1 | |
offset += 2 # begin and end token | |
cur_i += 1 | |
return input_ids, labels, image_position | |
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = 28) -> Image.Image: | |
if "image" in ele: | |
image = ele["image"] | |
else: | |
image = ele["image_url"] | |
image_obj = None | |
if isinstance(image, Image.Image): | |
image_obj = image | |
elif image.startswith("http://") or image.startswith("https://"): | |
response = requests.get(image, stream=True) | |
image_obj = Image.open(BytesIO(response.content)) | |
elif image.startswith("file://"): | |
image_obj = Image.open(image[7:]) | |
elif image.startswith("data:image"): | |
if "base64," in image: | |
_, base64_data = image.split("base64,", 1) | |
data = base64.b64decode(base64_data) | |
image_obj = Image.open(BytesIO(data)) | |
else: | |
image_obj = Image.open(image) | |
if image_obj is None: | |
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") | |
image = to_rgb(image_obj) | |
## resize | |
if "resized_height" in ele and "resized_width" in ele: | |
resized_height, resized_width = smart_resize( | |
ele["resized_height"], | |
ele["resized_width"], | |
factor=size_factor, | |
) | |
else: | |
width, height = image.size | |
min_pixels = ele.get("min_pixels") | |
max_pixels = ele.get("max_pixels") | |
resized_height, resized_width = smart_resize( | |
height, | |
width, | |
factor=size_factor, | |
min_pixels=min_pixels, | |
max_pixels=max_pixels, | |
) | |
image = image.resize((resized_width, resized_height), resample=Image.Resampling.BICUBIC) | |
return image | |
def process_vision_info( | |
vision_infos: list, | |
return_video_kwargs: bool = False, | |
factor: int = 1, | |
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]: | |
## Read images or videos | |
image_inputs = [] | |
video_inputs = [] | |
video_sample_fps_list = [] | |
for vision_info in vision_infos: | |
if "image" in vision_info or "image_url" in vision_info: | |
image_inputs.append(fetch_image(vision_info, size_factor=28*factor)) | |
elif "video" in vision_info: | |
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True) | |
video_sample_fps_list.append(video_sample_fps) | |
video_inputs.append(video_input) | |
else: | |
raise ValueError("image, image_url or video should in content.") | |
if len(image_inputs) == 0: | |
image_inputs = None | |
if len(video_inputs) == 0: | |
video_inputs = None | |
if return_video_kwargs: | |
return image_inputs, video_inputs, {'fps': video_sample_fps_list} | |
return image_inputs, video_inputs |