File size: 5,472 Bytes
52e4f53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import itertools
import json
import logging
import math
import os
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence

import torch
import transformers
from torch.utils.data import default_collate
from transformers.trainer_pt_utils import LabelSmoother

IGNORE_TOKEN_ID = LabelSmoother.ignore_index


@dataclass
class DataCollatorForSupervisedDataset(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids = [instance["input_ids"] for instance in instances]
        labels = [instance["labels"] for instance in instances]

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id,
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_TOKEN_ID,
        )

        input_ids = input_ids[:, : self.tokenizer.model_max_length]
        labels = labels[:, : self.tokenizer.model_max_length]

        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        if "images" in instances[0]:
            images = [instance["images"] for instance in instances]
            batch["images"] = torch.cat(images, dim=0)

        if "doclm_images" in instances[0]:
            images = [instance["doclm_images"] for instance in instances]
            batch["doclm_images"] = torch.cat(images, dim=0)

        if "image_paths" in instances[0]:
            image_paths = [instance["image_paths"] for instance in instances]
            batch["image_paths"] = image_paths

        if "pixel_values" in instances[0]:
            pixel_values = torch.cat([instance["pixel_values"] for instance in instances])
            batch["pixel_values"] = pixel_values

        if "image_flags" in instances[0]:
            image_flags = torch.cat([instance["image_flags"] for instance in instances])
            batch["image_flags"] = image_flags

        return batch


def collate_fn_deepspeed_old(batch):

    keys = list(set().union(*[set(x.keys()) for x in batch]))
    tmp_batch = [{} for _ in range(len(batch))]
    if "actual_seq_len" in batch[0]:
        actual_seq_len = [x["actual_seq_len"] for x in batch]
    else:
        actual_seq_len = None

    for k in keys:
        if "images" in k or k == "image_indices":
            for x, y in zip(tmp_batch, batch):
                if k in y:
                    x[k] = y.pop(k)
                    # print("x[image_indices]", x["image_indices"].size())

    new_batch = default_collate(batch)

    for k in keys:
        if "images" in k or k == "image_indices":
            cat_dim = 0 if k != "image_indices" else 1
            if k == "image_indices":
                cnt = 0
                for sample in tmp_batch:
                    if k in sample:
                        sample[k][0] = cnt
                    cnt += 1
            new_batch[k] = torch.cat([x[k] for x in tmp_batch if k in x], dim=cat_dim)
    # print("new_batch[image_indices]", new_batch["image_indices"].size())

    if actual_seq_len is not None:
        seq_len = actual_seq_len[0][-1]
        actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)]
        new_batch["actual_seq_len"] = torch.cat(actual_seq_len)

    return new_batch


def collate_fn_deepspeed(batch):

    keys = list(set().union(*[set(x.keys()) for x in batch]))
    # print(f"{keys=}")

    tmp_batch = [{} for _ in range(len(batch))]
    if "actual_seq_len" in batch[0]:
        actual_seq_len = [x["actual_seq_len"] for x in batch]
    else:
        actual_seq_len = None

    if "images" in batch[0].keys():
        for new_x, x in zip(tmp_batch, batch):
            new_x["images"] = x.pop("images")
            new_x["image_indices"] = x.pop("image_indices")

    if "audios" in batch[0].keys():
        for new_x, x in zip(tmp_batch, batch):
            new_x["audios"] = x.pop("audios")
            new_x["audio_indices"] = x.pop("audio_indices")

    new_batch = default_collate(batch)

    if "images" in tmp_batch[0].keys():
        new_batch["images"] = torch.cat([x["images"] for x in tmp_batch], dim=0)

        for sample_idx, sample in enumerate(tmp_batch):
            for j in range(len(sample["image_indices"])):
                sample["image_indices"][j][0, :, :] = sample_idx

        new_batch["image_indices"] = torch.cat([x["image_indices"] for x in tmp_batch], dim=1)

    if "audios" in tmp_batch[0].keys():
        new_batch["audios"] = list(itertools.chain.from_iterable([x["audios"] for x in tmp_batch]))
        # print(f"{[x.size() for x in sample['audios']]}")

        for sample_idx, sample in enumerate(tmp_batch):
            for j in range(len(sample["audio_indices"])):
                sample["audio_indices"][j][0, :, :] = sample_idx

        new_batch["audio_indices"] = list(
            itertools.chain.from_iterable([x["audio_indices"] for x in tmp_batch])
        )
        # print(f"{[x.size() for x in sample['audio_indices']]}")

    if actual_seq_len is not None:
        seq_len = actual_seq_len[0][-1]
        actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)]
        new_batch["actual_seq_len"] = torch.cat(actual_seq_len)

    return new_batch