UniWorld-V1 / univa /dataset /data_collator.py
LinB203
init
0c8d55e
from typing import List, Dict
from transformers import PreTrainedTokenizer
import torch
import torch.nn.functional as F
def pad_list_of_tensors(tensor_list, padding_value=0):
# tensor_list: list of tensors, each of shape (b, c, h, w)
# if all empty list, which means all data are t2i within this batch
if all(not isinstance(tensor, torch.Tensor) for tensor in tensor_list):
return []
else:
for tmp_tensor in tensor_list:
if isinstance(tmp_tensor, torch.Tensor):
# find a tensor
break
# this line pad zero_tensor when batch mixed between t2i and others.
# t2i can be considered a uncondition (no-reference image) editing
tensor_list = [
torch.zeros_like(tmp_tensor) if isinstance(tensor, list) else tensor for tensor in tensor_list
]
assert all(tensor.shape[1] == tensor_list[0].shape[1] for tensor in tensor_list)
# 找到最大的 b, h, w
max_b = max(tensor.shape[0] for tensor in tensor_list)
max_c = tensor_list[0].shape[1] # 假设c都是一样的
max_h = max(tensor.shape[2] for tensor in tensor_list)
max_w = max(tensor.shape[3] for tensor in tensor_list)
padded_tensors = []
for tensor in tensor_list:
b, c, h, w = tensor.shape
pad_b = max_b - b
pad_h = max_h - h
pad_w = max_w - w
# 先 pad h, w (最后两维)
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value=padding_value)
# 再 pad b 维(最前面),要扩成 (max_b, c, h, w)
if pad_b > 0:
padding_shape = (pad_b, c, max_h, max_w)
pad_tensor = torch.full(padding_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat([tensor, pad_tensor], dim=0)
padded_tensors.append(tensor)
# 最后 stack 成 (B, b_max, c, h_max, w_max)
return torch.stack(padded_tensors)
def resize_list_of_tensors(weights):
# suppose weights is your list of [1, H, W] tensors
# 1) find the max height and width
heights = [w.shape[-2] for w in weights]
widths = [w.shape[-1] for w in weights]
max_h, max_w = max(heights), max(widths)
# 2) interpolate each mask to (max_h, max_w)
resized = []
for w in weights:
# F.interpolate expects a 4D tensor: (N, C, H, W)
w_4d = w.unsqueeze(0) # -> [1, 1, H, W]
w_4d = w_4d.unsqueeze(0) if w_4d.ndim == 3 else w_4d
# but since w is already [1,H,W], unsqueeze once is enough:
# w_4d = w.unsqueeze(0) # [1, 1, H, W]
w_resized = F.interpolate(
w_4d, size=(max_h, max_w), mode='nearest'
)
# back to [1, H', W']
w_resized = w_resized.squeeze(0)
resized.append(w_resized)
# 3) stack into a single tensor [N, 1, max_h, max_w]
weights = torch.stack(resized) # -> [N, 1, max_h, max_w]
return weights
class DataCollator:
def __init__(self, tokenizer: PreTrainedTokenizer, padding_side='right'):
self.tokenizer = tokenizer
self.padding_side = padding_side
def __call__(self, instances: List[Dict]) -> Dict:
input_ids = [instance["input_ids"][0] for instance in instances]
labels = [instance["labels"][0] for instance in instances]
image_position = [instance["image_position"] for instance in instances]
pixel_values = [
instance["pixel_values"] for instance in instances if len(instance["pixel_values"]) > 0
]
pixel_values = torch.cat(pixel_values) if len(pixel_values) > 0 else None
image_grid_thw = [
instance["image_grid_thw"] for instance in instances if len(instance["image_grid_thw"]) > 0
]
image_grid_thw = torch.cat(image_grid_thw) if len(image_grid_thw) > 0 else None
pil_pixel_values = [
instance["pil_pixel_values"] for instance in instances
]
prompts = [instance["prompt"] for instance in instances]
ref_pixel_values = [
instance["ref_pixel_values"] for instance in instances
]
ref_pixel_values = pad_list_of_tensors(ref_pixel_values, padding_value=0)
siglip_pixel_values = [
instance["siglip_pixel_values"] for instance in instances if len(instance["siglip_pixel_values"]) > 0
]
siglip_pixel_values = torch.cat(siglip_pixel_values, dim=0) if len(siglip_pixel_values) > 0 else []
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id,
padding_side=self.padding_side,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=-100,
padding_side=self.padding_side,
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
weights = [
instance["weights"] for instance in instances if len(instance["weights"]) > 0
]
if len(weights) > 0:
if all([i.shape == weights[0].shape for i in weights]):
weights = torch.stack(weights)
else:
weights = [i.unsqueeze(0) for i in weights]
else:
weights = None
generated_image = [
instance["generated_image"] for instance in instances if len(instance["generated_image"]) > 0
]
if len(generated_image) > 0:
if all([i.shape == generated_image[0].shape for i in generated_image]):
generated_image = torch.stack(generated_image)
else:
generated_image = [i.unsqueeze(0) for i in generated_image]
else:
generated_image = []
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"labels": labels,
"attention_mask": attention_mask,
"image_position": image_position,
"image_grid_thw": image_grid_thw,
"prompts": prompts,
"ref_pixel_values": ref_pixel_values,
"pil_pixel_values": pil_pixel_values,
"siglip_pixel_values": siglip_pixel_values,
"weights": weights,
"generated_image": generated_image,
}