Spaces:
Runtime error
Runtime error
from typing import List, Dict | |
from transformers import PreTrainedTokenizer | |
import torch | |
import torch.nn.functional as F | |
def pad_list_of_tensors(tensor_list, padding_value=0): | |
# tensor_list: list of tensors, each of shape (b, c, h, w) | |
# if all empty list, which means all data are t2i within this batch | |
if all(not isinstance(tensor, torch.Tensor) for tensor in tensor_list): | |
return [] | |
else: | |
for tmp_tensor in tensor_list: | |
if isinstance(tmp_tensor, torch.Tensor): | |
# find a tensor | |
break | |
# this line pad zero_tensor when batch mixed between t2i and others. | |
# t2i can be considered a uncondition (no-reference image) editing | |
tensor_list = [ | |
torch.zeros_like(tmp_tensor) if isinstance(tensor, list) else tensor for tensor in tensor_list | |
] | |
assert all(tensor.shape[1] == tensor_list[0].shape[1] for tensor in tensor_list) | |
# 找到最大的 b, h, w | |
max_b = max(tensor.shape[0] for tensor in tensor_list) | |
max_c = tensor_list[0].shape[1] # 假设c都是一样的 | |
max_h = max(tensor.shape[2] for tensor in tensor_list) | |
max_w = max(tensor.shape[3] for tensor in tensor_list) | |
padded_tensors = [] | |
for tensor in tensor_list: | |
b, c, h, w = tensor.shape | |
pad_b = max_b - b | |
pad_h = max_h - h | |
pad_w = max_w - w | |
# 先 pad h, w (最后两维) | |
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value=padding_value) | |
# 再 pad b 维(最前面),要扩成 (max_b, c, h, w) | |
if pad_b > 0: | |
padding_shape = (pad_b, c, max_h, max_w) | |
pad_tensor = torch.full(padding_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device) | |
tensor = torch.cat([tensor, pad_tensor], dim=0) | |
padded_tensors.append(tensor) | |
# 最后 stack 成 (B, b_max, c, h_max, w_max) | |
return torch.stack(padded_tensors) | |
def resize_list_of_tensors(weights): | |
# suppose weights is your list of [1, H, W] tensors | |
# 1) find the max height and width | |
heights = [w.shape[-2] for w in weights] | |
widths = [w.shape[-1] for w in weights] | |
max_h, max_w = max(heights), max(widths) | |
# 2) interpolate each mask to (max_h, max_w) | |
resized = [] | |
for w in weights: | |
# F.interpolate expects a 4D tensor: (N, C, H, W) | |
w_4d = w.unsqueeze(0) # -> [1, 1, H, W] | |
w_4d = w_4d.unsqueeze(0) if w_4d.ndim == 3 else w_4d | |
# but since w is already [1,H,W], unsqueeze once is enough: | |
# w_4d = w.unsqueeze(0) # [1, 1, H, W] | |
w_resized = F.interpolate( | |
w_4d, size=(max_h, max_w), mode='nearest' | |
) | |
# back to [1, H', W'] | |
w_resized = w_resized.squeeze(0) | |
resized.append(w_resized) | |
# 3) stack into a single tensor [N, 1, max_h, max_w] | |
weights = torch.stack(resized) # -> [N, 1, max_h, max_w] | |
return weights | |
class DataCollator: | |
def __init__(self, tokenizer: PreTrainedTokenizer, padding_side='right'): | |
self.tokenizer = tokenizer | |
self.padding_side = padding_side | |
def __call__(self, instances: List[Dict]) -> Dict: | |
input_ids = [instance["input_ids"][0] for instance in instances] | |
labels = [instance["labels"][0] for instance in instances] | |
image_position = [instance["image_position"] for instance in instances] | |
pixel_values = [ | |
instance["pixel_values"] for instance in instances if len(instance["pixel_values"]) > 0 | |
] | |
pixel_values = torch.cat(pixel_values) if len(pixel_values) > 0 else None | |
image_grid_thw = [ | |
instance["image_grid_thw"] for instance in instances if len(instance["image_grid_thw"]) > 0 | |
] | |
image_grid_thw = torch.cat(image_grid_thw) if len(image_grid_thw) > 0 else None | |
pil_pixel_values = [ | |
instance["pil_pixel_values"] for instance in instances | |
] | |
prompts = [instance["prompt"] for instance in instances] | |
ref_pixel_values = [ | |
instance["ref_pixel_values"] for instance in instances | |
] | |
ref_pixel_values = pad_list_of_tensors(ref_pixel_values, padding_value=0) | |
siglip_pixel_values = [ | |
instance["siglip_pixel_values"] for instance in instances if len(instance["siglip_pixel_values"]) > 0 | |
] | |
siglip_pixel_values = torch.cat(siglip_pixel_values, dim=0) if len(siglip_pixel_values) > 0 else [] | |
input_ids = torch.nn.utils.rnn.pad_sequence( | |
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, | |
padding_side=self.padding_side, | |
) | |
labels = torch.nn.utils.rnn.pad_sequence( | |
labels, batch_first=True, padding_value=-100, | |
padding_side=self.padding_side, | |
) | |
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) | |
weights = [ | |
instance["weights"] for instance in instances if len(instance["weights"]) > 0 | |
] | |
if len(weights) > 0: | |
if all([i.shape == weights[0].shape for i in weights]): | |
weights = torch.stack(weights) | |
else: | |
weights = [i.unsqueeze(0) for i in weights] | |
else: | |
weights = None | |
generated_image = [ | |
instance["generated_image"] for instance in instances if len(instance["generated_image"]) > 0 | |
] | |
if len(generated_image) > 0: | |
if all([i.shape == generated_image[0].shape for i in generated_image]): | |
generated_image = torch.stack(generated_image) | |
else: | |
generated_image = [i.unsqueeze(0) for i in generated_image] | |
else: | |
generated_image = [] | |
return { | |
"input_ids": input_ids, | |
"pixel_values": pixel_values, | |
"labels": labels, | |
"attention_mask": attention_mask, | |
"image_position": image_position, | |
"image_grid_thw": image_grid_thw, | |
"prompts": prompts, | |
"ref_pixel_values": ref_pixel_values, | |
"pil_pixel_values": pil_pixel_values, | |
"siglip_pixel_values": siglip_pixel_values, | |
"weights": weights, | |
"generated_image": generated_image, | |
} | |