Spaces:
Runtime error
Runtime error
File size: 6,302 Bytes
0c8d55e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from typing import List, Dict
from transformers import PreTrainedTokenizer
import torch
import torch.nn.functional as F
def pad_list_of_tensors(tensor_list, padding_value=0):
# tensor_list: list of tensors, each of shape (b, c, h, w)
# if all empty list, which means all data are t2i within this batch
if all(not isinstance(tensor, torch.Tensor) for tensor in tensor_list):
return []
else:
for tmp_tensor in tensor_list:
if isinstance(tmp_tensor, torch.Tensor):
# find a tensor
break
# this line pad zero_tensor when batch mixed between t2i and others.
# t2i can be considered a uncondition (no-reference image) editing
tensor_list = [
torch.zeros_like(tmp_tensor) if isinstance(tensor, list) else tensor for tensor in tensor_list
]
assert all(tensor.shape[1] == tensor_list[0].shape[1] for tensor in tensor_list)
# 找到最大的 b, h, w
max_b = max(tensor.shape[0] for tensor in tensor_list)
max_c = tensor_list[0].shape[1] # 假设c都是一样的
max_h = max(tensor.shape[2] for tensor in tensor_list)
max_w = max(tensor.shape[3] for tensor in tensor_list)
padded_tensors = []
for tensor in tensor_list:
b, c, h, w = tensor.shape
pad_b = max_b - b
pad_h = max_h - h
pad_w = max_w - w
# 先 pad h, w (最后两维)
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value=padding_value)
# 再 pad b 维(最前面),要扩成 (max_b, c, h, w)
if pad_b > 0:
padding_shape = (pad_b, c, max_h, max_w)
pad_tensor = torch.full(padding_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device)
tensor = torch.cat([tensor, pad_tensor], dim=0)
padded_tensors.append(tensor)
# 最后 stack 成 (B, b_max, c, h_max, w_max)
return torch.stack(padded_tensors)
def resize_list_of_tensors(weights):
# suppose weights is your list of [1, H, W] tensors
# 1) find the max height and width
heights = [w.shape[-2] for w in weights]
widths = [w.shape[-1] for w in weights]
max_h, max_w = max(heights), max(widths)
# 2) interpolate each mask to (max_h, max_w)
resized = []
for w in weights:
# F.interpolate expects a 4D tensor: (N, C, H, W)
w_4d = w.unsqueeze(0) # -> [1, 1, H, W]
w_4d = w_4d.unsqueeze(0) if w_4d.ndim == 3 else w_4d
# but since w is already [1,H,W], unsqueeze once is enough:
# w_4d = w.unsqueeze(0) # [1, 1, H, W]
w_resized = F.interpolate(
w_4d, size=(max_h, max_w), mode='nearest'
)
# back to [1, H', W']
w_resized = w_resized.squeeze(0)
resized.append(w_resized)
# 3) stack into a single tensor [N, 1, max_h, max_w]
weights = torch.stack(resized) # -> [N, 1, max_h, max_w]
return weights
class DataCollator:
def __init__(self, tokenizer: PreTrainedTokenizer, padding_side='right'):
self.tokenizer = tokenizer
self.padding_side = padding_side
def __call__(self, instances: List[Dict]) -> Dict:
input_ids = [instance["input_ids"][0] for instance in instances]
labels = [instance["labels"][0] for instance in instances]
image_position = [instance["image_position"] for instance in instances]
pixel_values = [
instance["pixel_values"] for instance in instances if len(instance["pixel_values"]) > 0
]
pixel_values = torch.cat(pixel_values) if len(pixel_values) > 0 else None
image_grid_thw = [
instance["image_grid_thw"] for instance in instances if len(instance["image_grid_thw"]) > 0
]
image_grid_thw = torch.cat(image_grid_thw) if len(image_grid_thw) > 0 else None
pil_pixel_values = [
instance["pil_pixel_values"] for instance in instances
]
prompts = [instance["prompt"] for instance in instances]
ref_pixel_values = [
instance["ref_pixel_values"] for instance in instances
]
ref_pixel_values = pad_list_of_tensors(ref_pixel_values, padding_value=0)
siglip_pixel_values = [
instance["siglip_pixel_values"] for instance in instances if len(instance["siglip_pixel_values"]) > 0
]
siglip_pixel_values = torch.cat(siglip_pixel_values, dim=0) if len(siglip_pixel_values) > 0 else []
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id,
padding_side=self.padding_side,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=-100,
padding_side=self.padding_side,
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
weights = [
instance["weights"] for instance in instances if len(instance["weights"]) > 0
]
if len(weights) > 0:
if all([i.shape == weights[0].shape for i in weights]):
weights = torch.stack(weights)
else:
weights = [i.unsqueeze(0) for i in weights]
else:
weights = None
generated_image = [
instance["generated_image"] for instance in instances if len(instance["generated_image"]) > 0
]
if len(generated_image) > 0:
if all([i.shape == generated_image[0].shape for i in generated_image]):
generated_image = torch.stack(generated_image)
else:
generated_image = [i.unsqueeze(0) for i in generated_image]
else:
generated_image = []
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"labels": labels,
"attention_mask": attention_mask,
"image_position": image_position,
"image_grid_thw": image_grid_thw,
"prompts": prompts,
"ref_pixel_values": ref_pixel_values,
"pil_pixel_values": pil_pixel_values,
"siglip_pixel_values": siglip_pixel_values,
"weights": weights,
"generated_image": generated_image,
}
|