File size: 6,302 Bytes
0c8d55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from typing import List, Dict
from transformers import PreTrainedTokenizer
import torch
import torch.nn.functional as F

def pad_list_of_tensors(tensor_list, padding_value=0):
    # tensor_list: list of tensors, each of shape (b, c, h, w)

    # if all empty list, which means all data are t2i within this batch
    if all(not isinstance(tensor, torch.Tensor) for tensor in tensor_list):
        return []
    else:
        for tmp_tensor in tensor_list:
            if isinstance(tmp_tensor, torch.Tensor):
                # find a tensor
                break
        # this line pad zero_tensor when batch mixed between t2i and others.
        # t2i can be considered a uncondition (no-reference image) editing
        tensor_list = [
            torch.zeros_like(tmp_tensor) if isinstance(tensor, list) else tensor for tensor in tensor_list
            ]
    assert all(tensor.shape[1] == tensor_list[0].shape[1] for tensor in tensor_list)
    # 找到最大的 b, h, w
    max_b = max(tensor.shape[0] for tensor in tensor_list)
    max_c = tensor_list[0].shape[1]  # 假设c都是一样的
    max_h = max(tensor.shape[2] for tensor in tensor_list)
    max_w = max(tensor.shape[3] for tensor in tensor_list)

    padded_tensors = []
    for tensor in tensor_list:
        b, c, h, w = tensor.shape
        pad_b = max_b - b
        pad_h = max_h - h
        pad_w = max_w - w

        # 先 pad h, w (最后两维)
        tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value=padding_value)
        # 再 pad b 维(最前面),要扩成 (max_b, c, h, w)
        if pad_b > 0:
            padding_shape = (pad_b, c, max_h, max_w)
            pad_tensor = torch.full(padding_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device)
            tensor = torch.cat([tensor, pad_tensor], dim=0)

        padded_tensors.append(tensor)

    # 最后 stack 成 (B, b_max, c, h_max, w_max)
    return torch.stack(padded_tensors)

def resize_list_of_tensors(weights):
    # suppose weights is your list of [1, H, W] tensors
    # 1) find the max height and width
    heights = [w.shape[-2] for w in weights]
    widths  = [w.shape[-1] for w in weights]
    max_h, max_w = max(heights), max(widths)

    # 2) interpolate each mask to (max_h, max_w)
    resized = []
    for w in weights:
        # F.interpolate expects a 4D tensor: (N, C, H, W)
        w_4d = w.unsqueeze(0)             # -> [1, 1, H, W]
        w_4d = w_4d.unsqueeze(0) if w_4d.ndim == 3 else w_4d
        # but since w is already [1,H,W], unsqueeze once is enough:
        # w_4d = w.unsqueeze(0) # [1, 1, H, W]

        w_resized = F.interpolate(
            w_4d, size=(max_h, max_w), mode='nearest'
        )
        # back to [1, H', W']
        w_resized = w_resized.squeeze(0)
        resized.append(w_resized)

    # 3) stack into a single tensor [N, 1, max_h, max_w]
    weights = torch.stack(resized)  # -> [N, 1, max_h, max_w]
    return weights

class DataCollator:
    def __init__(self, tokenizer: PreTrainedTokenizer, padding_side='right'):
        self.tokenizer = tokenizer
        self.padding_side = padding_side

    def __call__(self, instances: List[Dict]) -> Dict:
        input_ids = [instance["input_ids"][0] for instance in instances]
        labels = [instance["labels"][0] for instance in instances]
        image_position = [instance["image_position"] for instance in instances]

        pixel_values = [
            instance["pixel_values"] for instance in instances if len(instance["pixel_values"]) > 0
        ]
        pixel_values = torch.cat(pixel_values) if len(pixel_values) > 0 else None

        image_grid_thw = [
            instance["image_grid_thw"] for instance in instances if len(instance["image_grid_thw"]) > 0
        ]
        image_grid_thw = torch.cat(image_grid_thw) if len(image_grid_thw) > 0 else None

        pil_pixel_values = [
            instance["pil_pixel_values"] for instance in instances
        ]

        prompts = [instance["prompt"] for instance in instances]

        ref_pixel_values = [
            instance["ref_pixel_values"] for instance in instances
        ]
        ref_pixel_values = pad_list_of_tensors(ref_pixel_values, padding_value=0)

        siglip_pixel_values = [
            instance["siglip_pixel_values"] for instance in instances if len(instance["siglip_pixel_values"]) > 0
        ]
        siglip_pixel_values = torch.cat(siglip_pixel_values, dim=0) if len(siglip_pixel_values) > 0 else []

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, 
            padding_side=self.padding_side, 
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=-100, 
            padding_side=self.padding_side, 
        )
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

        weights = [
            instance["weights"] for instance in instances if len(instance["weights"]) > 0
        ]
        if len(weights) > 0:
            if all([i.shape == weights[0].shape for i in weights]):
                weights = torch.stack(weights)  
            else:
                weights = [i.unsqueeze(0) for i in weights]
        else:
            weights = None

        generated_image = [
            instance["generated_image"] for instance in instances if len(instance["generated_image"]) > 0
            ]
        if len(generated_image) > 0:
            if all([i.shape == generated_image[0].shape for i in generated_image]):
                generated_image = torch.stack(generated_image)  
            else:
                generated_image = [i.unsqueeze(0) for i in generated_image]
        else:
            generated_image = []
        return {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
            "labels": labels,
            "attention_mask": attention_mask,
            "image_position": image_position,
            "image_grid_thw": image_grid_thw, 
            "prompts": prompts, 
            "ref_pixel_values": ref_pixel_values, 
            "pil_pixel_values": pil_pixel_values, 
            "siglip_pixel_values": siglip_pixel_values, 
            "weights": weights, 
            "generated_image": generated_image, 
        }