Muennighoff commited on
Commit
18652d8
·
1 Parent(s): d13896f

Cp over files

Browse files
Files changed (17) hide show
  1. beam_search.py +1087 -0
  2. config_molmoe.py +9 -5
  3. constants.py +571 -0
  4. data_factory.py +222 -0
  5. data_utils.py +827 -0
  6. dataset_sizes.py +262 -0
  7. exceptions.py +50 -0
  8. iterable_dataset.py +266 -0
  9. modeling_molmoe.py +4 -4
  10. multimodal_preprocessor.py +1549 -0
  11. preprocesssors.py +2472 -0
  12. prompts.py +385 -0
  13. seqio_tokenizer.py +659 -0
  14. tasks.py +2548 -0
  15. torch_util.py +183 -0
  16. util.py +1 -1
  17. utils.py +195 -0
beam_search.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a self-contained and flexible beam search implementation adapted from
3
+ AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
+ """
5
+
6
+ import copy
7
+ import warnings
8
+ from abc import abstractmethod
9
+ from inspect import signature
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
+
12
+ import torch
13
+
14
+ __all__ = [
15
+ "Sampler",
16
+ "DeterministicSampler",
17
+ "MultinomialSampler",
18
+ "TopKSampler",
19
+ "TopPSampler",
20
+ "GumbelSampler",
21
+ "FinalSequenceScorer",
22
+ "SequenceLogProbabilityScorer",
23
+ "LengthNormalizedSequenceLogProbabilityScorer",
24
+ "Constraint",
25
+ "RepeatedNGramBlockingConstraint",
26
+ "BeamSearch",
27
+ ]
28
+
29
+ StateType = Dict[str, torch.Tensor]
30
+ StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
+ StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
+
33
+ StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
+ """
35
+ The type of step function that can be passed to [`BeamSearch.search`](#search).
36
+
37
+ This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
+ or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
+ """
40
+
41
+ ConstraintStateType = List[List[Dict[str, Any]]]
42
+
43
+
44
+ class Sampler:
45
+ """
46
+ An abstract class that can be used to sample candidates (either nodes or beams)
47
+ within `BeamSearch`.
48
+
49
+ A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
+
51
+ `init_state()` takes three arguments:
52
+
53
+ - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
+ - the batch size, an int,
55
+ - and the number of classes, also an int.
56
+
57
+ It returns a state dictionary with any state tensors needed for subsequent
58
+ calls to `sample_nodes()` and `sample_beams()`.
59
+
60
+ By default this method just returns an empty dictionary.
61
+
62
+ Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
+
64
+ - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
+ - an integer representing the number of samples to take for each example in the batch,
66
+ - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
+ track of state.
68
+
69
+ For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
+ `num_examples = beam_size * per_node_beam_size`.
71
+
72
+ The return value should be a tuple containing:
73
+
74
+ - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
+ - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
+ - and the updated state dictionary.
77
+
78
+ A default implementation of `sample_beams` is provided, which just deterministically
79
+ picks the `k` examples with highest log probability.
80
+ """
81
+
82
+ def init_state(
83
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
+ ) -> StateType:
85
+ del start_class_log_probabilities, batch_size, num_classes
86
+ return {}
87
+
88
+ @abstractmethod
89
+ def sample_nodes(
90
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
+ raise NotImplementedError
93
+
94
+ def sample_beams(
95
+ self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
+ del state
98
+ selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
+ return selected_log_probs, selected_indices, {}
100
+
101
+
102
+ class DeterministicSampler(Sampler):
103
+ """
104
+ A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
+ log probability.
106
+ """
107
+
108
+ def sample_nodes(
109
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
+ del state
112
+ selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
+ return selected_log_probs, selected_indices, {}
114
+
115
+
116
+ class MultinomialSampler(Sampler):
117
+ """
118
+ A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
+ in the default, non-deterministic way.
120
+
121
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
+ above 1.0 produces a flatter probability distribution.
123
+ :param with_replacement: Whether to sample with replacement.
124
+
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ temperature: float = 1.0,
130
+ with_replacement: bool = False,
131
+ ) -> None:
132
+ self.temperature = temperature
133
+ self.with_replacement = with_replacement
134
+
135
+ def sample_nodes(
136
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
+ if self.temperature != 1.0:
139
+ _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
+ else:
141
+ _probabilities = log_probs.exp()
142
+
143
+ selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
+
145
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
+
147
+
148
+ class TopKSampler(Sampler):
149
+ """
150
+ A `Sampler` which redistributes the probability mass function for nodes among the
151
+ top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
+
153
+ Beams are sampled in the default, deterministic way.
154
+
155
+ :param k: The number of top choices to be selected from.
156
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
+ above 1.0 produces a flatter probability distribution.
158
+ :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ k: int = 1,
164
+ temperature: float = 1.0,
165
+ with_replacement: bool = False,
166
+ ):
167
+ self.k = k
168
+ self.temperature = temperature or 1.0
169
+ self.with_replacement = with_replacement
170
+
171
+ def sample_nodes(
172
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
+ if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
+ raise ValueError(
176
+ "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
+ )
178
+
179
+ # shape (both): (batch_size, k)
180
+ top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
+
182
+ # Apply temperature if necessary.
183
+ # shape: (batch_size, k)
184
+ if self.temperature != 1.0:
185
+ top_k_log_probs = top_k_log_probs / self.temperature
186
+
187
+ # Re-normalize the subset.
188
+ # shape: (batch_size, k)
189
+ normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
+
191
+ # Sample from the re-normalized subset.
192
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
+ # shape: (batch_size, per_node_beam_size)
194
+ sampled_indices = torch.multinomial(
195
+ normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
+ )
197
+
198
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
+ # shape: (batch_size, per_node_beam_size)
200
+ indices = top_k_indices.gather(-1, sampled_indices)
201
+
202
+ return log_probs.gather(1, indices), indices, state
203
+
204
+
205
+ class TopPSampler(Sampler):
206
+ """
207
+ A `Sampler` which redistributes the probability mass function for nodes among
208
+ the top choices with a cumulative probability of at least `p`, then samples from that subset
209
+ after re-normalizing the probabilities.
210
+
211
+ Beams are sampled in the default, deterministic way.
212
+
213
+ :param p:
214
+ The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
+ examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
+ insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
+ `per_node_beam_size` examples will be chosen.
218
+ :param temperature:
219
+ A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
+ above 1.0 produces a flatter probability distribution.
221
+ :param with_replacement:
222
+ If set to `True`, samples will be selected with replacement from the top choices.
223
+
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ p: float = 0.9,
229
+ temperature: float = 1.0,
230
+ with_replacement: bool = False,
231
+ ):
232
+ if p < 0.0 or p > 1.0:
233
+ raise ValueError("p must be a positive float no greater than 1.0")
234
+ self.p = p
235
+ self.temperature = temperature or 1.0
236
+ self.with_replacement = with_replacement
237
+
238
+ def sample_nodes(
239
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
+ if not per_node_beam_size <= log_probs.size()[1]:
242
+ raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
+
244
+ # First apply temperature coefficient:
245
+ if self.temperature != 1.0:
246
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
+ else:
248
+ _log_probs = log_probs
249
+
250
+ # Sort the probabilities in descending order to then find cumulative sum
251
+ log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
+
253
+ # shape: (batch_size, num_classes)
254
+ probabilities_descending = log_probs_descending.exp()
255
+ probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
+
257
+ # Create a mask for filtering out probabilities that don't make the top `p`.
258
+ # shape: (batch_size, num_classes)
259
+ exclusion_mask = probabilities_summed >= self.p
260
+
261
+ # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
+ exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
+ exclusion_mask[..., 0] = False
264
+
265
+ # Make sure there's at least `per_node_beam_size` options to be selected.
266
+ if not self.with_replacement:
267
+ exclusion_mask[..., :per_node_beam_size] = False
268
+
269
+ log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
+
271
+ # Now re-normalized the included log probs.
272
+ # shape: (batch_size, num_classes)
273
+ filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
+
275
+ # Sample from the re-normalized subset.
276
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
+ # shape: (batch_size, per_node_beam_size)
278
+ sampled_indices = torch.multinomial(
279
+ filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
+ )
281
+
282
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
+ # shape: (batch_size, per_node_beam_size)
284
+ selected_indices = sorting_indices.gather(-1, sampled_indices)
285
+
286
+ # Return (selected log probabilities, selected classes)
287
+ # shape: (len(log_probs),1) , (len(log_probs), 1)
288
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
+
290
+
291
+ class GumbelSampler(Sampler):
292
+ """
293
+ A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
+ [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
+ Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
+ (https://api.semanticscholar.org/CorpusID:76662039).
297
+
298
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
+ above 1.0 produces a flatter probability distribution.
300
+ """
301
+
302
+ def __init__(self, temperature: float = 1.0):
303
+ self.temperature = temperature
304
+
305
+ def init_state(
306
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
+ ) -> StateType:
308
+ # shape: (batch_size, num_classes)
309
+ zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
+
311
+ # shape: (batch_size, num_classes)
312
+ G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
+
314
+ return {"G_phi_S": G_phi_S}
315
+
316
+ def sample_nodes(
317
+ self,
318
+ log_probs: torch.Tensor,
319
+ per_node_beam_size: int,
320
+ state: StateType,
321
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
+ # First apply temperature coefficient:
323
+ # shape: (batch_size * beam_size, num_classes)
324
+ if self.temperature != 1.0:
325
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
+ else:
327
+ _log_probs = log_probs
328
+
329
+ # shape: (group_size,)
330
+ phi_S = state["phi_S"]
331
+
332
+ # shape: (group_size, num_classes)
333
+ phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
+
335
+ # shape: (group_size, num_classes)
336
+ phi_S_new = phi_S + _log_probs
337
+
338
+ # shape: (group_size, 1)
339
+ G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
+
341
+ # shape: (group_size, num_classes)
342
+ G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
+
344
+ # Replace NaNs with very negative number.
345
+ # shape: (group_size, num_classes)
346
+ # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
+
348
+ # shape (both): (group_size, per_node_beam_size)
349
+ top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
+
351
+ # shape: (group_size, per_node_beam_size)
352
+ top_log_probs = log_probs.gather(1, top_indices)
353
+
354
+ return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
+
356
+ def sample_beams(
357
+ self,
358
+ log_probs: torch.Tensor,
359
+ beam_size: int,
360
+ state: StateType,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
+ """
363
+ Returns the beams with the highest perturbed log probabilities.
364
+ """
365
+ # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
+
367
+ batch_size = log_probs.size()[0]
368
+
369
+ # shape: (batch_size * beam_size, per_node_beam_size)
370
+ G_phi_S = state["G_phi_S"]
371
+
372
+ # shape: (batch_size, beam_size * per_node_beam_size)
373
+ G_phi_S = G_phi_S.reshape_as(log_probs)
374
+
375
+ # shape (both): (batch_size, beam_size)
376
+ G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
+
378
+ # shape: (batch_size, beam_size)
379
+ selected_log_probs = log_probs.gather(1, selected_indices)
380
+
381
+ # Now sort the selected beams by their true log prob.
382
+ # shape (all): (batch_size, beam_size)
383
+ selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
+ selected_indices = selected_indices.gather(1, sort_indices)
385
+ G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
+
387
+ # shape: (batch_size * beam_size,)
388
+ G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
+
390
+ # shape: (batch_size * beam_size,)
391
+ phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
+
393
+ return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
+
395
+ def gumbel(self, phi) -> torch.Tensor:
396
+ """
397
+ Sample `Gumbel(phi)`.
398
+
399
+ `phi` should have shape `(batch_size, num_classes)`.
400
+ """
401
+ return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
+
403
+ def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
+ """
405
+ Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
+
407
+ `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
+ shape `(batch_size, 1)`.
409
+ """
410
+ # Shape: (batch_size, num_classes)
411
+ G_phi = self.gumbel(phi)
412
+
413
+ # Now we find the maximum from these samples.
414
+ # Shape: (batch_size, )
415
+ Z, _ = G_phi.max(dim=-1)
416
+
417
+ # Shape: (batch_size, num_classes)
418
+ v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
+
420
+ # Shape: (batch_size, num_classes)
421
+ return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
+
423
+
424
+ class FinalSequenceScorer:
425
+ """
426
+ An abstract class that can be used to score the final generated sequences found
427
+ by beam search. Given the predicted sequences and the corresponding log probabilities of
428
+ those sequences, the class calculates and returns the final score of the sequences.
429
+
430
+ The default implementation scores the sequences using the sum of the log probabilities of
431
+ the sequence, which is passed as input.
432
+ """
433
+
434
+ @abstractmethod
435
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
+ """
437
+ Score the final predictions found by beam search.
438
+ Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
+
440
+ :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
+ :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
+ of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
+ :param end_index: The index of the end symbol.
444
+
445
+ """
446
+ raise NotImplementedError
447
+
448
+
449
+ class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
+ """
451
+ A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
+ across the sequence's tokens.
453
+ """
454
+
455
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
+ del predictions, end_index
457
+ # The sum of the sequence log probabilities is the input parameter, so just
458
+ # return it.
459
+ return log_probabilities
460
+
461
+
462
+ class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
+ """
464
+ A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
+ tokens in the sequence. It optionally includes a length penalty which promotes
466
+ or demotes sequences based on their lengths. The final score for a sequence will
467
+ be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
+ here includes the end token.
469
+
470
+ :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
+ A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
+ """
473
+
474
+ def __init__(self, length_penalty: float = 1.0):
475
+ super().__init__()
476
+ self.length_penalty = length_penalty
477
+
478
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
+ # shape: (batch_size, beam_size)
480
+ lengths = (predictions != end_index).long().sum(dim=2)
481
+
482
+ # If the sequence ended during beam search, the `log_probabilities` will include
483
+ # the transition to the end token. Therefore, in such situations, `lengths` is
484
+ # actually off by 1. This corrects for that.
485
+ # shape: (batch_size, beam_size)
486
+ is_end_token = predictions[:, :, -1] == end_index
487
+ lengths += is_end_token.long()
488
+
489
+ # shape: (batch_size, beam_size)
490
+ average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
+ return average_log_probs
492
+
493
+
494
+ class Constraint:
495
+ """
496
+ An abstract class that can be used to enforce constraints on the output predictions
497
+ by manipulating the class log probabilities during beam search.
498
+
499
+ A `Constraint` just has three methods that need to be implemented by subclasses:
500
+ `init_state()`, `apply()` and `_update_state()`.
501
+
502
+ `init_state()` takes one argument:
503
+
504
+ - the batch size, an int
505
+
506
+ It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
+ calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
+ Each inner list should be of length 1.
509
+
510
+ `apply()` takes two arguments:
511
+
512
+ - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
+ and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
+ - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
+ log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
+
517
+ The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
+ for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
+ the corresponding log probability to a negligible value such as `float("-inf")` or
520
+ `torch.finfo(class_log_probabilities.dtype).min`.
521
+
522
+ `_update_state()` takes two arguments:
523
+
524
+ - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
+ copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
+ directly edited in-place without affecting the others.
527
+ - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
+ step of beam search.
529
+
530
+ The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
+ length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
+
533
+ """
534
+
535
+ @abstractmethod
536
+ def init_state(
537
+ self,
538
+ batch_size: int,
539
+ ) -> ConstraintStateType:
540
+ raise NotImplementedError
541
+
542
+ @abstractmethod
543
+ def apply(
544
+ self,
545
+ state: ConstraintStateType,
546
+ class_log_probabilities: torch.Tensor,
547
+ ) -> torch.Tensor:
548
+ raise NotImplementedError
549
+
550
+ @staticmethod
551
+ def _copy_state(
552
+ state: ConstraintStateType,
553
+ batch_size: int,
554
+ beam_size: int,
555
+ last_backpointer: Optional[torch.Tensor] = None,
556
+ ) -> ConstraintStateType:
557
+ """
558
+ Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
+ is not appropriate for your constraint, you will need to implement the copying yourself.
560
+ """
561
+ new_state = []
562
+ for i in range(batch_size):
563
+ batch_state = []
564
+ for j in range(beam_size):
565
+ if last_backpointer is None:
566
+ # This is the first prediction, so the backpointer is 0
567
+ backpointer = 0
568
+ else:
569
+ backpointer = last_backpointer[i, j].item()
570
+ batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
+ new_state.append(batch_state)
572
+ return new_state
573
+
574
+ def update_state(
575
+ self,
576
+ state: ConstraintStateType,
577
+ last_prediction: torch.Tensor,
578
+ last_backpointer: Optional[torch.Tensor] = None,
579
+ ) -> ConstraintStateType:
580
+ batch_size, beam_size = last_prediction.size()
581
+ new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
+ return self._update_state(new_state, last_prediction)
583
+
584
+ @abstractmethod
585
+ def _update_state(
586
+ self,
587
+ state: ConstraintStateType,
588
+ last_prediction: torch.Tensor,
589
+ ) -> ConstraintStateType:
590
+ raise NotImplementedError
591
+
592
+
593
+ class RepeatedNGramBlockingConstraint(Constraint):
594
+ def __init__(self, ngram_size: int, **kwargs) -> None:
595
+ super().__init__(**kwargs)
596
+ self.ngram_size = ngram_size
597
+
598
+ def init_state(
599
+ self,
600
+ batch_size: int,
601
+ ) -> ConstraintStateType:
602
+ return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
+
604
+ def apply(
605
+ self,
606
+ state: ConstraintStateType,
607
+ class_log_probabilities: torch.Tensor,
608
+ ) -> torch.Tensor:
609
+ for i, batch in enumerate(state):
610
+ for j, beam in enumerate(batch):
611
+ current_prefix = tuple(beam["current_prefix"])
612
+ seen_ngrams = beam["seen_ngrams"]
613
+ try:
614
+ disallowed_indices = seen_ngrams[current_prefix]
615
+ class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
+ class_log_probabilities.dtype
617
+ ).min
618
+ except KeyError:
619
+ # We have not seen this prefix before, so there is no index
620
+ # that needs to be blocked
621
+ pass
622
+ return class_log_probabilities
623
+
624
+ def _update_state(
625
+ self,
626
+ state: ConstraintStateType,
627
+ last_prediction: torch.Tensor,
628
+ ) -> ConstraintStateType:
629
+ for i, batch in enumerate(state):
630
+ for j, beam in enumerate(batch):
631
+ prediction = last_prediction[i, j].item()
632
+ prefix = beam["current_prefix"]
633
+ seen_ngrams = beam["seen_ngrams"]
634
+
635
+ if len(prefix) == self.ngram_size - 1:
636
+ # This is a new ngram that we have to remember
637
+ if tuple(prefix) not in seen_ngrams:
638
+ seen_ngrams[tuple(prefix)] = []
639
+ seen_ngrams[tuple(prefix)].append(prediction)
640
+
641
+ # Create the new prefix, removing the oldest index if the prefix
642
+ # is too long
643
+ prefix.append(prediction)
644
+ if len(prefix) == self.ngram_size:
645
+ prefix.pop(0)
646
+ return state
647
+
648
+
649
+ class BeamSearch:
650
+ """
651
+ Implements the beam search algorithm for decoding the most likely sequences.
652
+
653
+ :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
+
655
+ :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
+ of the predicted sequences.
657
+
658
+ :param beam_size: The width of the beam used.
659
+
660
+ :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
+ If not given, this just defaults to `beam_size`. Setting this parameter
662
+ to a number smaller than `beam_size` may give better results, as it can introduce
663
+ more diversity into the search. See
664
+ [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
+ (https://api.semanticscholar.org/CorpusID:2229477).
666
+
667
+ :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
+ If not specified, `DeterministicSampler` will be used, which just takes the
669
+ `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
+
671
+ Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
+ [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
+
674
+ :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
+ the predicted sequences. This does not include the start or end tokens. If `None`,
676
+ no minimum is enforced.
677
+
678
+ :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
+ The output from this module is what is returned by the `search` method. If not
680
+ specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
+ by the sum of the token log probabilities.
682
+
683
+ :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
+ provided, no constraints will be enforced.
685
+
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ end_index: int,
691
+ *,
692
+ max_steps: int = 50,
693
+ beam_size: int = 10,
694
+ per_node_beam_size: Optional[int] = None,
695
+ sampler: Optional[Sampler] = None,
696
+ min_steps: Optional[int] = None,
697
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
+ constraints: Optional[List[Constraint]] = None,
699
+ distributed_model: bool = False
700
+ ) -> None:
701
+ if not max_steps > 0:
702
+ raise ValueError("max_steps must be positive")
703
+ if not beam_size > 0:
704
+ raise ValueError("beam_size must be positive")
705
+ if per_node_beam_size is not None and not per_node_beam_size > 0:
706
+ raise ValueError("per_node_beam_size must be positive")
707
+ if min_steps is not None:
708
+ if not min_steps >= 0:
709
+ raise ValueError("min_steps must be non-negative")
710
+ if not min_steps <= max_steps:
711
+ raise ValueError("min_steps must be less than or equal to max_steps")
712
+
713
+ self._end_index = end_index
714
+ self.max_steps = max_steps
715
+ self.beam_size = beam_size
716
+ self.per_node_beam_size = per_node_beam_size or beam_size
717
+ self.sampler = sampler or DeterministicSampler()
718
+ self.min_steps = min_steps or 0
719
+ self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
720
+ self.constraints = constraints or []
721
+ self.distributed_model = distributed_model
722
+
723
+ @staticmethod
724
+ def _reconstruct_sequences(predictions, backpointers):
725
+ # Reconstruct the sequences.
726
+ # shape: [(batch_size, beam_size, 1)]
727
+ reconstructed_predictions = [predictions[-1].unsqueeze(2)]
728
+
729
+ if not backpointers:
730
+ return reconstructed_predictions
731
+
732
+ # shape: (batch_size, beam_size)
733
+ cur_backpointers = backpointers[-1]
734
+
735
+ for timestep in range(len(predictions) - 2, 0, -1):
736
+ # shape: (batch_size, beam_size, 1)
737
+ cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
738
+
739
+ reconstructed_predictions.append(cur_preds)
740
+
741
+ # shape: (batch_size, beam_size)
742
+ cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
743
+
744
+ # shape: (batch_size, beam_size, 1)
745
+ final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
746
+
747
+ reconstructed_predictions.append(final_preds)
748
+
749
+ return reconstructed_predictions
750
+
751
+ def search(
752
+ self,
753
+ start_predictions: torch.Tensor,
754
+ start_state: StateType,
755
+ step: StepFunctionType,
756
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
757
+ """
758
+ Given a starting state and a step function, apply beam search to find the
759
+ most likely target sequences.
760
+
761
+ Returns a tuple of `(predictions, final_scores)`, where `predictions`
762
+ has shape `(batch_size, beam_size, max_steps)` and `final_scores`
763
+ has shape `(batch_size, beam_size)`.
764
+
765
+ .. note::
766
+ If your step function returns `-inf` for some log probabilities
767
+ (like if you're using a masked log-softmax) then some of the "best"
768
+ sequences returned may also have `-inf` log probability. Specifically
769
+ this happens when the beam size is smaller than the number of actions
770
+ with finite log probability (non-zero probability) returned by the step function.
771
+ Therefore if you're using a mask you may want to check the results from `search`
772
+ and potentially discard sequences with non-finite log probability.
773
+
774
+ :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
775
+ Usually the initial predictions are just the index of the "start" token
776
+ in the target vocabulary.
777
+
778
+ :param start_state: The initial state passed to the `step` function. Each value of the state dict
779
+ should be a tensor of shape `(batch_size, *)`, where `*` means any other
780
+ number of dimensions.
781
+
782
+ :param step: A function that is responsible for computing the next most likely tokens,
783
+ given the current state and the predictions from the last time step.
784
+ The function should accept two or three arguments:
785
+
786
+ - a tensor of shape `(group_size,)` or representing the index of the predicted
787
+ tokens from the last time step,
788
+ - the current state, a `StateType`, and
789
+ - optionally, the timestep, an `int`.
790
+
791
+ The `group_size` will be `batch_size * beam_size`, except in the initial
792
+ step, for which it will just be `batch_size`.
793
+
794
+ The function is expected to return a tuple, where the first element
795
+ is a tensor of shape `(group_size, vocab_size)` containing
796
+ the log probabilities of the tokens for the next step, and the second
797
+ element is the updated state. The tensor in the state should have shape
798
+ `(group_size, *)`, where `*` means any other number of dimensions.
799
+
800
+ """
801
+ step_signature = signature(step)
802
+ if len(step_signature.parameters) < 3:
803
+ # If the step function we're given does not take the time step argument, wrap it
804
+ # in one that does.
805
+ old_step = cast(StepFunctionTypeNoTimestep, step)
806
+
807
+ def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
808
+ del time_step
809
+ return old_step(last_predictions, state)
810
+
811
+ return self._search(start_predictions, start_state, new_step)
812
+ else:
813
+ return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
814
+
815
+ def _search(
816
+ self,
817
+ start_predictions: torch.Tensor,
818
+ start_state: StateType,
819
+ step: StepFunctionTypeWithTimestep,
820
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
821
+ batch_size = start_predictions.size()[0]
822
+
823
+ # List of (batch_size, beam_size) tensors. One for each time step. Does not
824
+ # include the start symbols, which are implicit.
825
+ predictions: List[torch.Tensor] = []
826
+
827
+ # List of (batch_size, beam_size) tensors. One for each time step. None for
828
+ # the first. Stores the index n for the parent prediction, i.e.
829
+ # predictions[t-1][i][n], that it came from.
830
+ backpointers: List[torch.Tensor] = []
831
+
832
+ constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
833
+
834
+ # Calculate the first timestep. This is done outside the main loop
835
+ # because we are going from a single decoder input (the output from the
836
+ # encoder) to the top `beam_size` decoder outputs. On the other hand,
837
+ # within the main loop we are going from the `beam_size` elements of the
838
+ # beam to `beam_size`^2 candidates from which we will select the top
839
+ # `beam_size` elements for the next iteration.
840
+ # shape: (batch_size, num_classes)
841
+ start_class_log_probabilities, state = step(start_predictions, start_state, 0)
842
+
843
+ num_classes = start_class_log_probabilities.size()[1]
844
+
845
+ # Make sure `per_node_beam_size` is not larger than `num_classes`.
846
+ if self.per_node_beam_size > num_classes:
847
+ raise ValueError(
848
+ f"Vocab size ({num_classes:d}) too small "
849
+ f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
850
+ f"Please decrease beam_size or per_node_beam_size."
851
+ )
852
+
853
+ sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
854
+
855
+ # Apply all constraints.
856
+ if self.constraints:
857
+ # shape: (batch_size, 1, num_classes)
858
+ expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
859
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
860
+ expanded_start_class_log_probabilities = constraint.apply(
861
+ constraint_state, expanded_start_class_log_probabilities
862
+ )
863
+ start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
864
+
865
+ # Prevent selecting the end symbol if there is any min_steps constraint
866
+ if self.min_steps >= 1:
867
+ start_class_log_probabilities[:, self._end_index] = torch.finfo(
868
+ start_class_log_probabilities.dtype
869
+ ).min
870
+
871
+ # Get the initial predicted classed and their log probabilities.
872
+ # shape: (batch_size, beam_size), (batch_size, beam_size)
873
+ (
874
+ start_top_log_probabilities,
875
+ start_predicted_classes,
876
+ sampler_state,
877
+ ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
878
+
879
+ if (
880
+ self.beam_size == 1 and
881
+ (start_predicted_classes == self._end_index).all() and
882
+ not self.distributed_model
883
+ ):
884
+ warnings.warn(
885
+ "Empty sequences predicted. You may want to increase the beam size or ensure "
886
+ "your step function is working properly.",
887
+ RuntimeWarning,
888
+ )
889
+ return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
890
+
891
+ # The log probabilities for the last time step.
892
+ # shape: (batch_size, beam_size)
893
+ last_log_probabilities = start_top_log_probabilities
894
+
895
+ # shape: [(batch_size, beam_size)]
896
+ predictions.append(start_predicted_classes)
897
+
898
+ # Log probability tensor that mandates that the end token is selected.
899
+ # shape: (batch_size * beam_size, num_classes)
900
+ log_probs_after_end = start_class_log_probabilities.new_full(
901
+ (batch_size * self.beam_size, num_classes),
902
+ torch.finfo(start_class_log_probabilities.dtype).min,
903
+ )
904
+ log_probs_after_end[:, self._end_index] = 0.0
905
+
906
+ # Set the same state for each element in the beam.
907
+ self._update_initial_state(state, batch_size)
908
+
909
+ for i, constraint in enumerate(self.constraints):
910
+ constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
911
+
912
+ for timestep in range(self.max_steps - 1):
913
+ # shape: (batch_size * beam_size,)
914
+ last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
915
+
916
+ # If every predicted token from the last step is `self._end_index`,
917
+ # then we can stop early.
918
+ # FIXME for distributed model we cannot stop early unless all devices are done,
919
+ # for now we just always run to the max limit, ideally we should check all devices
920
+ if not self.distributed_model and (last_predictions == self._end_index).all():
921
+ # finished
922
+ break
923
+ # Take a step. This get the predicted log probs of the next classes
924
+ # and updates the state.
925
+ # shape: (batch_size * beam_size, num_classes)
926
+ class_log_probabilities, state = step(last_predictions, state, timestep + 1)
927
+
928
+ # Apply all constraints.
929
+ if self.constraints:
930
+ # shape: (batch_size, beam_size, num_classes)
931
+ reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
932
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
933
+ reshaped_class_log_probabilities = constraint.apply(
934
+ constraint_state, reshaped_class_log_probabilities
935
+ )
936
+ # shape: (batch_size * beam_size, num_classes)
937
+ class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
938
+
939
+ # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
940
+ # of the sequence (because `timestep` is 0-indexed and we generated the first token
941
+ # before the for loop). Here we block the end index if the search is not allowed to
942
+ # terminate on this iteration.
943
+ if timestep + 2 <= self.min_steps:
944
+ class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
945
+
946
+ # shape: (batch_size * beam_size, num_classes)
947
+ last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
948
+ batch_size * self.beam_size, num_classes
949
+ )
950
+
951
+ # Here we are finding any beams where we predicted the end token in
952
+ # the previous timestep and replacing the distribution with a
953
+ # one-hot distribution, forcing the beam to predict the end token
954
+ # this timestep as well.
955
+ # shape: (batch_size * beam_size, num_classes)
956
+ cleaned_log_probabilities = torch.where(
957
+ last_predictions_expanded == self._end_index,
958
+ log_probs_after_end,
959
+ class_log_probabilities,
960
+ )
961
+
962
+ # shape (both): (batch_size * beam_size, per_node_beam_size)
963
+ top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
964
+ cleaned_log_probabilities, self.per_node_beam_size, sampler_state
965
+ )
966
+
967
+ # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
968
+ # so that we can add them to the current log probs for this timestep.
969
+ # This lets us maintain the log probability of each element on the beam.
970
+ # shape: (batch_size * beam_size, per_node_beam_size)
971
+ expanded_last_log_probabilities = (
972
+ last_log_probabilities.unsqueeze(2)
973
+ .expand(batch_size, self.beam_size, self.per_node_beam_size)
974
+ .reshape(batch_size * self.beam_size, self.per_node_beam_size)
975
+ )
976
+
977
+ # shape: (batch_size * beam_size, per_node_beam_size)
978
+ summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
979
+
980
+ # shape: (batch_size, beam_size * per_node_beam_size)
981
+ reshaped_summed = summed_top_log_probabilities.reshape(
982
+ batch_size, self.beam_size * self.per_node_beam_size
983
+ )
984
+
985
+ # shape: (batch_size, beam_size * per_node_beam_size)
986
+ reshaped_predicted_classes = predicted_classes.reshape(
987
+ batch_size, self.beam_size * self.per_node_beam_size
988
+ )
989
+
990
+ # Keep only the top `beam_size` beam indices.
991
+ # shape (both): (batch_size, beam_size)
992
+ (
993
+ restricted_beam_log_probs,
994
+ restricted_beam_indices,
995
+ sampler_state,
996
+ ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
997
+
998
+ # Use the beam indices to extract the corresponding classes.
999
+ # shape: (batch_size, beam_size)
1000
+ restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
1001
+
1002
+ predictions.append(restricted_predicted_classes)
1003
+
1004
+ # shape: (batch_size, beam_size)
1005
+ last_log_probabilities = restricted_beam_log_probs
1006
+
1007
+ # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
1008
+ # indices with a common ancestor are grouped together. Hence
1009
+ # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1010
+ # division as the tensor is a LongTensor.)
1011
+ # shape: (batch_size, beam_size)
1012
+ backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1013
+ backpointers.append(backpointer)
1014
+
1015
+ # Keep only the pieces of the state tensors corresponding to the
1016
+ # ancestors created this iteration.
1017
+ self._update_state(state, backpointer)
1018
+
1019
+ for i, constraint in enumerate(self.constraints):
1020
+ constraint_states[i] = constraint.update_state(
1021
+ constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1022
+ )
1023
+
1024
+ # Warn about "-inf" log probabilities if not using any constraints (negligible
1025
+ # log probabilities are expected when using constraints).
1026
+ if not self.constraints and (
1027
+ not torch.isfinite(last_log_probabilities).all()
1028
+ or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1029
+ ):
1030
+ warnings.warn(
1031
+ "Negligible log probabilities encountered ('-inf' or equivalent). "
1032
+ "Some final sequences may not make sense. "
1033
+ "This can happen when the beam size is larger than the number of valid (non-zero "
1034
+ "probability) transitions that the step function produces.",
1035
+ RuntimeWarning,
1036
+ )
1037
+
1038
+ reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1039
+
1040
+ # shape: (batch_size, beam_size, max_steps)
1041
+ all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1042
+
1043
+ # Calculate the final sequence scores
1044
+ # shape: (batch_size, beam_size)
1045
+ final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1046
+
1047
+ # Sort the sequences based on the final scores so the best scoring
1048
+ # sequence is at index 0
1049
+ sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1050
+ sorted_all_predictions = torch.gather(
1051
+ all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1052
+ )
1053
+
1054
+ return sorted_all_predictions, sorted_final_scores
1055
+
1056
+ def _update_initial_state(self, state: StateType, batch_size: int):
1057
+ """
1058
+ Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1059
+ """
1060
+ for key, state_tensor in state.items():
1061
+ if state_tensor is None:
1062
+ continue
1063
+ # shape: (batch_size * beam_size, *)
1064
+ _, *last_dims = state_tensor.size()
1065
+ state[key] = (
1066
+ state_tensor.unsqueeze(1)
1067
+ .expand(batch_size, self.beam_size, *last_dims)
1068
+ .reshape(batch_size * self.beam_size, *last_dims)
1069
+ )
1070
+
1071
+ def _update_state(self, state: StateType, backpointer: torch.Tensor):
1072
+ batch_size = backpointer.size()[0]
1073
+
1074
+ for key, state_tensor in state.items():
1075
+ if state_tensor is None:
1076
+ continue
1077
+ _, *last_dims = state_tensor.size()
1078
+ # shape: (batch_size, beam_size, *)
1079
+ expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1080
+ batch_size, self.beam_size, *last_dims
1081
+ )
1082
+ # shape: (batch_size * beam_size, *)
1083
+ state[key] = (
1084
+ state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1085
+ .gather(1, expanded_backpointer)
1086
+ .reshape(batch_size * self.beam_size, *last_dims)
1087
+ )
config_molmoe.py CHANGED
@@ -27,11 +27,15 @@ import gin
27
 
28
  #from olmo.aliases import PathOrStr
29
  from .aliases import PathOrStr
30
- from olmo.exceptions import OLMoConfigurationError
31
- from olmo.util import StrEnum, resource_path
32
-
33
- from olmo.mm_data.data_utils import build_tokenizer
34
- from olmo.multimodal_preprocessor import MultiModalPreprocessor
 
 
 
 
35
 
36
  __all__ = [
37
  "ActivationType",
 
27
 
28
  #from olmo.aliases import PathOrStr
29
  from .aliases import PathOrStr
30
+ #from olmo.exceptions import OLMoConfigurationError
31
+ from .exceptions import OLMoConfigurationError
32
+ #from olmo.util import StrEnum, resource_path
33
+ from .util import StrEnum, resource_path
34
+
35
+ #from olmo.mm_data.data_utils import build_tokenizer
36
+ from .data_utils import build_tokenizer
37
+ #from olmo.multimodal_preprocessor import MultiModalPreprocessor
38
+ from .multimodal_preprocessor import MultiModalPreprocessor
39
 
40
  __all__ = [
41
  "ActivationType",
constants.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
2
+ DEFAULT_IM_START_TOKEN = f"<im_start>"
3
+ DEFAULT_IM_END_TOKEN = f"<im_end>"
4
+ DEFAULT_IM_COL_TOKEN = f"<im_col>"
5
+ IMAGE_PROMPT = "<|image|>"
6
+
7
+ EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
8
+
9
+
10
+ VIT_STANDARD_CONFIGS = {
11
+ "dinov2-large": {
12
+ "image_emb_dim": 1024,
13
+ "image_mlp_dim": 4096,
14
+ 'image_patch_size': 14,
15
+ 'image_pos_patch_size': 14,
16
+ 'image_num_layers': 24,
17
+ 'image_num_heads': 16,
18
+ 'image_num_key_value_heads': 16,
19
+ 'image_head_dim': 64,
20
+ 'image_mlp_activations': 'gelu',
21
+ 'image_default_input_size': (224, 224),
22
+ 'image_num_pos': 257,
23
+ 'image_norm_eps': 1e-6,
24
+ "image_model_type": "dino"
25
+ },
26
+ "SigLIP-So400m-14-384": {
27
+ "image_emb_dim": 1152,
28
+ 'image_num_layers': 27,
29
+ "image_mlp_dim": 4304,
30
+ 'image_patch_size': 14,
31
+ 'image_pos_patch_size': 14,
32
+ 'image_num_heads': 16,
33
+ 'image_num_key_value_heads': 16,
34
+ 'image_head_dim': 72,
35
+ 'image_mlp_activations': 'gelu',
36
+ # Although it is called "384" that seems to be an error of the author's
37
+ # part, it actually only handles 378 inputs
38
+ 'image_default_input_size': (378, 378),
39
+ 'image_num_pos': 729, # note not CLS token
40
+ 'image_norm_eps': 1e-6,
41
+ "image_model_type": "siglip",
42
+ "resize_mode": "siglip"
43
+ },
44
+ "DFN5B-CLIP-ViT-H-14-378": {
45
+ "image_emb_dim": 1280,
46
+ 'image_patch_size': 14,
47
+ 'image_pos_patch_size': 14,
48
+ 'image_num_layers': 32,
49
+ 'image_num_heads': 16,
50
+ 'image_num_key_value_heads': 16,
51
+ 'image_head_dim': 80,
52
+ 'image_mlp_dim': 5120,
53
+ 'image_dropout_rate': 0.0,
54
+ 'image_mlp_activations': 'quick_gelu',
55
+ 'image_default_input_size': (378, 378),
56
+ 'image_num_pos': 730,
57
+ 'image_norm_eps': 1e-5,
58
+ "image_model_type": "openai",
59
+ "resize_mode": "no_aspect_ratio"
60
+ },
61
+ 'ViT-L/14-336': {
62
+ 'image_patch_size': 14,
63
+ 'image_pos_patch_size': 14,
64
+ 'image_emb_dim': 1024,
65
+ 'image_num_heads': 16,
66
+ 'image_num_layers': 23,
67
+ 'image_head_dim': 64,
68
+ 'image_mlp_dim': 4096,
69
+ 'image_mlp_activations': 'quick_gelu',
70
+ 'image_dropout_rate': 0.0,
71
+ 'image_num_pos': 577,
72
+ 'image_default_input_size': (336, 336),
73
+ 'image_norm_eps': 1e-5,
74
+ 'image_num_key_value_heads': 16,
75
+ "image_model_type": "openai"
76
+ },
77
+ 'EVA02-L-14-336': {
78
+ 'image_patch_size': 14,
79
+ 'image_pos_patch_size': 14,
80
+ 'image_emb_dim': 1024,
81
+ 'image_num_heads': 16,
82
+ 'image_num_layers': 24,
83
+ 'image_head_dim': 64,
84
+ 'image_mlp_dim': 2730,
85
+ 'image_mlp_activations': 'silu',
86
+ 'image_dropout_rate': 0.0,
87
+ 'image_num_pos': 577,
88
+ 'image_default_input_size': (336, 336),
89
+ 'image_norm_eps': 1e-6,
90
+ 'image_num_key_value_heads': 16,
91
+ "image_model_type": "eva"
92
+ },
93
+ 'ViT-L/14': {
94
+ 'image_patch_size': 14,
95
+ 'image_pos_patch_size': 14,
96
+ 'image_emb_dim': 1024,
97
+ 'image_num_heads': 16,
98
+ # Note the original model has 24 layers, but we don't use the last layer
99
+ 'image_num_layers': 23,
100
+ 'image_head_dim': 64,
101
+ 'image_mlp_dim': 4096,
102
+ 'image_mlp_activations': 'quick_gelu',
103
+ 'image_dropout_rate': 0.0,
104
+ 'image_num_pos': 257,
105
+ 'image_default_input_size': (224, 224),
106
+ 'image_norm_eps': 1e-5,
107
+ 'image_num_key_value_heads': 16,
108
+ "image_model_type": "openai"
109
+ },
110
+ 'debug': {
111
+ 'image_patch_size': 14,
112
+ 'image_pos_patch_size': 14,
113
+ 'image_emb_dim': 1024,
114
+ 'image_num_heads': 16,
115
+ 'image_num_layers': 1,
116
+ 'image_head_dim': 64,
117
+ 'image_mlp_dim': 1024,
118
+ 'image_mlp_activations': 'quick_gelu',
119
+ 'image_dropout_rate': 0.0,
120
+ 'image_num_pos': 577,
121
+ 'image_default_input_size': (336, 336),
122
+ 'image_norm_eps': 1e-5,
123
+ 'image_num_key_value_heads': 16,
124
+ "image_model_type": "openai"
125
+ }
126
+ }
127
+
128
+ OPEN_LLM_STANDARD_CONFIGS = {
129
+ "qwen1.5_7b": {
130
+ 'vocab_size': 151936,
131
+ 'hidden_size': 4096,
132
+ 'intermediate_size': 11008,
133
+ 'num_hidden_layers': 32,
134
+ 'num_attention_heads': 32,
135
+ 'num_key_value_heads': 32,
136
+ 'max_sequence_length': 2048,
137
+ 'max_position_embeddings': 32768,
138
+ 'rope_theta': 1000000.0,
139
+ 'initializer_range': 0.02,
140
+ 'rms_norm_eps': 1e-6,
141
+ "qkv_bias": True,
142
+ 'tie_word_embeddings': False,
143
+ 'hidden_act': 'silu',
144
+ 'norm_module': 'RMSNorm',
145
+ "tokenizer": "hf-Qwen/Qwen1.5-7B",
146
+ },
147
+ "qwen1.5_14b": {
148
+ 'vocab_size': 152064,
149
+ 'hidden_size': 5120,
150
+ 'intermediate_size': 13696,
151
+ 'num_hidden_layers': 40,
152
+ 'num_attention_heads': 40,
153
+ 'num_key_value_heads': 40,
154
+ 'max_sequence_length': 2048,
155
+ 'max_position_embeddings': 32768,
156
+ 'rope_theta': 1000000.0,
157
+ 'initializer_range': 0.02,
158
+ 'rms_norm_eps': 1e-6,
159
+ "qkv_bias": True,
160
+ 'tie_word_embeddings': False,
161
+ 'hidden_act': 'silu',
162
+ 'norm_module': 'RMSNorm',
163
+ "tokenizer": "hf-Qwen/Qwen1.5-14B",
164
+ },
165
+ "qwen1.5_32b": {
166
+ "vocab_size": 152064,
167
+ "hidden_size": 5120,
168
+ "intermediate_size": 27392,
169
+ "num_hidden_layers": 64,
170
+ "num_attention_heads": 40,
171
+ "num_key_value_heads": 8,
172
+ 'max_sequence_length': 2048,
173
+ 'max_position_embeddings': 32768,
174
+ "rope_theta": 1000000.0,
175
+ 'initializer_range': 0.02,
176
+ "rms_norm_eps": 1e-6,
177
+ "qkv_bias": True,
178
+ "tie_word_embeddings": False,
179
+ 'hidden_act': 'silu',
180
+ 'norm_module': 'RMSNorm',
181
+ "tokenizer": "hf-Qwen/Qwen1.5-32B",
182
+ },
183
+ 'llama_7b': {
184
+ 'vocab_size': 32000,
185
+ 'hidden_size': 4096,
186
+ 'intermediate_size': 11008,
187
+ 'num_hidden_layers': 32,
188
+ 'num_attention_heads': 32,
189
+ 'num_key_value_heads': 32,
190
+ 'max_sequence_length': 2048,
191
+ 'max_position_embeddings': 8192,
192
+ 'rope_theta': 10000.0,
193
+ 'initializer_range': 0.02,
194
+ 'rms_norm_eps': 1e-5,
195
+ 'tie_word_embeddings': False,
196
+ 'hidden_act': 'silu',
197
+ 'norm_module': 'RMSNorm',
198
+ "tokenizer": "llama"
199
+ },
200
+ 'yi_6b': {
201
+ 'vocab_size': 64000,
202
+ 'hidden_size': 4096,
203
+ 'intermediate_size': 11008,
204
+ 'num_hidden_layers': 32,
205
+ 'num_attention_heads': 32,
206
+ 'num_key_value_heads': 4,
207
+ 'max_sequence_length': 4096,
208
+ 'max_position_embeddings': 4096,
209
+ 'rope_theta': 5000000.0,
210
+ 'initializer_range': 0.02,
211
+ 'rms_norm_eps': 1e-5,
212
+ 'tie_word_embeddings': False,
213
+ 'hidden_act': 'silu',
214
+ 'norm_module': 'RMSNorm',
215
+ "tokenizer": "yi"
216
+ },
217
+ 'yi_9b': {
218
+ 'vocab_size': 64000,
219
+ 'hidden_size': 4096,
220
+ 'intermediate_size': 11008,
221
+ 'num_hidden_layers': 48,
222
+ 'num_attention_heads': 32,
223
+ 'num_key_value_heads': 4,
224
+ 'max_sequence_length': 4096,
225
+ 'max_position_embeddings': 4096,
226
+ 'rope_theta': 10000,
227
+ 'initializer_range': 0.02,
228
+ 'rms_norm_eps': 1e-06,
229
+ 'tie_word_embeddings': False,
230
+ 'hidden_act': 'silu',
231
+ 'norm_module': 'RMSNorm',
232
+ "tokenizer": "yi"
233
+ },
234
+ 'yi_34b': {
235
+ 'vocab_size': 64000,
236
+ 'hidden_size': 7168,
237
+ 'intermediate_size': 20480,
238
+ 'num_hidden_layers': 60,
239
+ 'num_attention_heads': 56,
240
+ 'num_key_value_heads': 8,
241
+ 'max_sequence_length': 4096,
242
+ 'max_position_embeddings': 4096,
243
+ 'rope_theta': 5000000.0,
244
+ 'initializer_range': 0.02,
245
+ 'rms_norm_eps': 1e-5,
246
+ 'tie_word_embeddings': False,
247
+ 'hidden_act': 'silu',
248
+ 'norm_module': 'RMSNorm',
249
+ "tokenizer": "yi"
250
+ },
251
+ "olmo_1b": {
252
+ 'vocab_size': 50304,
253
+ 'hidden_size': 2048,
254
+ 'intermediate_size': 8192,
255
+ 'num_hidden_layers': 16,
256
+ 'num_attention_heads': 16,
257
+ 'num_key_value_heads': 16,
258
+ 'max_sequence_length': 4096,
259
+ 'max_position_embeddings': 32768,
260
+ 'rope_theta': 10000.0,
261
+ 'initializer_range': 0.02,
262
+ 'rms_norm_eps': 1e-5,
263
+ 'tie_word_embeddings': True,
264
+ 'hidden_act': 'silu',
265
+ 'norm_module': 'OlmoLayerNorm',
266
+ "tokenizer": "hf-allenai/OLMo-1B"
267
+ },
268
+ "olmo_7b": {
269
+ 'vocab_size': 50304,
270
+ 'hidden_size': 4096,
271
+ 'intermediate_size': 22016//2,
272
+ 'num_hidden_layers': 32,
273
+ 'num_attention_heads': 32,
274
+ 'num_key_value_heads': 32,
275
+ 'max_sequence_length': 4096,
276
+ 'max_position_embeddings': 32768,
277
+ 'rope_theta': 10000.0,
278
+ 'initializer_range': 0.02,
279
+ 'rms_norm_eps': 1e-5,
280
+ 'tie_word_embeddings': False,
281
+ 'hidden_act': 'silu',
282
+ 'norm_module': 'OlmoLayerNorm',
283
+ "tokenizer": "hf-allenai/OLMo-7B",
284
+ },
285
+ "olmo_1.7_7b": {
286
+ 'vocab_size': 50304,
287
+ 'hidden_size': 4096,
288
+ 'intermediate_size': 22016//2,
289
+ 'num_hidden_layers': 32,
290
+ 'num_attention_heads': 32,
291
+ 'num_key_value_heads': 32,
292
+ 'max_sequence_length': 4096,
293
+ 'max_position_embeddings': 32768,
294
+ 'rope_theta': 10000.0,
295
+ 'initializer_range': 0.02,
296
+ 'rms_norm_eps': 1e-5,
297
+ 'tie_word_embeddings': False,
298
+ 'hidden_act': 'silu',
299
+ "qkv_clip": 8,
300
+ 'norm_module': 'OlmoLayerNorm',
301
+ "tokenizer": "hf-allenai/OLMo-1.7-7B",
302
+ },
303
+ 'mistral_7b': {
304
+ 'vocab_size': 32000,
305
+ 'hidden_size': 4096,
306
+ 'intermediate_size': 14336,
307
+ 'num_hidden_layers': 32,
308
+ 'num_attention_heads': 32,
309
+ 'num_key_value_heads': 8,
310
+ 'max_sequence_length': 4096,
311
+ 'max_position_embeddings': 32768,
312
+ 'rope_theta': 10000.0,
313
+ 'initializer_range': 0.02,
314
+ 'rms_norm_eps': 1e-5,
315
+ 'tie_word_embeddings': False,
316
+ 'hidden_act': 'silu',
317
+ 'norm_module': 'RMSNorm',
318
+ "tokenizer": "mistral"
319
+ },
320
+ 'mistral0.3_7b': {
321
+ 'vocab_size': 32768,
322
+ 'hidden_size': 4096,
323
+ 'intermediate_size': 14336,
324
+ 'num_hidden_layers': 32,
325
+ 'num_attention_heads': 32,
326
+ 'num_key_value_heads': 8,
327
+ 'max_sequence_length': 4096,
328
+ 'max_position_embeddings': 32768,
329
+ 'rope_theta': 1000000.0,
330
+ 'initializer_range': 0.02,
331
+ 'rms_norm_eps': 1e-5,
332
+ 'tie_word_embeddings': False,
333
+ 'hidden_act': 'silu',
334
+ 'norm_module': 'RMSNorm',
335
+ "tokenizer": "mistral0.3"
336
+ },
337
+ "mistral0.2_22b": {
338
+ 'vocab_size': 32000,
339
+ 'hidden_size': 6144,
340
+ 'intermediate_size': 16384,
341
+ 'num_hidden_layers': 56,
342
+ 'num_attention_heads': 48,
343
+ 'num_key_value_heads': 8,
344
+ 'max_sequence_length': 4096,
345
+ 'max_position_embeddings': 32768,
346
+ 'rope_theta': 1000000,
347
+ 'initializer_range': 0.02,
348
+ 'rms_norm_eps': 1e-5,
349
+ 'tie_word_embeddings': False,
350
+ 'hidden_act': 'silu',
351
+ 'norm_module': 'RMSNorm',
352
+ "tokenizer": "mistral"
353
+ },
354
+ 'llama_13b': {
355
+ 'vocab_size': 32000,
356
+ 'hidden_size': 5120,
357
+ 'intermediate_size': 13824,
358
+ 'num_hidden_layers': 40,
359
+ 'num_attention_heads': 40,
360
+ 'num_key_value_heads': 40,
361
+ 'max_sequence_length': 2048,
362
+ 'max_position_embeddings': 8192,
363
+ 'initializer_range': 0.02,
364
+ 'rms_norm_eps': 1e-5,
365
+ 'tie_word_embeddings': False,
366
+ 'hidden_act': 'silu',
367
+ "norm_module": 'RMSNorm',
368
+ 'rope_theta': 10000.0,
369
+ "tokenizer": "llama"
370
+ },
371
+ 'llama_70b': {
372
+ 'vocab_size': 32000,
373
+ 'hidden_size': 8192,
374
+ 'intermediate_size': 28672,
375
+ 'num_hidden_layers': 80,
376
+ 'num_attention_heads': 64,
377
+ 'num_key_value_heads': 8,
378
+ 'max_sequence_length': 8192,
379
+ 'max_position_embeddings': 8192,
380
+ 'rope_theta': 10000.0,
381
+ 'initializer_range': 0.02,
382
+ 'rms_norm_eps': 1e-5,
383
+ 'tie_word_embeddings': False,
384
+ 'hidden_act': 'silu',
385
+ "tokenizer": "llama"
386
+ },
387
+ 'llama_70bflash': {
388
+ 'vocab_size': 32000,
389
+ 'hidden_size': 8192,
390
+ 'intermediate_size': 28672,
391
+ 'num_hidden_layers': 80,
392
+ 'num_attention_heads': 64,
393
+ 'num_key_value_heads': 8,
394
+ 'max_sequence_length': 8192,
395
+ 'max_position_embeddings': 8192,
396
+ 'rope_theta': 10000.0,
397
+ 'initializer_range': 0.02,
398
+ 'rms_norm_eps': 1e-5,
399
+ 'tie_word_embeddings': False,
400
+ 'scan_attention': True,
401
+ 'scan_mlp': True,
402
+ 'hidden_act': 'silu',
403
+ "tokenizer": "llama"
404
+ },
405
+ 'llama3_8b': {
406
+ 'vocab_size': 128256,
407
+ 'hidden_size': 4096,
408
+ 'intermediate_size': 14336,
409
+ 'num_hidden_layers': 32,
410
+ 'num_attention_heads': 32,
411
+ 'num_key_value_heads': 8,
412
+ 'max_sequence_length': 8192,
413
+ 'max_position_embeddings': 8192,
414
+ 'rope_theta': 500000.0,
415
+ 'initializer_range': 0.02,
416
+ 'rms_norm_eps': 1e-5,
417
+ 'tie_word_embeddings': False,
418
+ 'hidden_act': 'silu',
419
+ 'norm_module': 'RMSNorm',
420
+ "tokenizer": "hf-meta-llama/Meta-Llama-3-8B",
421
+
422
+ },
423
+ 'llama3_70b': {
424
+ 'vocab_size': 128256,
425
+ 'hidden_size': 8192,
426
+ 'intermediate_size': 28672,
427
+ 'num_hidden_layers': 80,
428
+ 'num_attention_heads': 64,
429
+ 'num_key_value_heads': 8,
430
+ 'max_sequence_length': 8192,
431
+ 'max_position_embeddings': 8192,
432
+ 'rope_theta': 500000.0,
433
+ 'initializer_range': 0.02,
434
+ 'rms_norm_eps': 1e-5,
435
+ 'tie_word_embeddings': False,
436
+ 'hidden_act': 'silu',
437
+ 'norm_module': 'RMSNorm',
438
+ "tokenizer": "hf-meta-llama/Meta-Llama-3-70B",
439
+ },
440
+ 'open_llama_3b': {
441
+ 'vocab_size': 32000,
442
+ 'hidden_size': 3200,
443
+ 'intermediate_size': 8640,
444
+ 'num_hidden_layers': 26,
445
+ 'num_attention_heads': 32,
446
+ 'max_sequence_length': 2048,
447
+ 'initializer_range': 0.02,
448
+ 'rms_norm_eps': 1e-6,
449
+ 'max_position_embeddings': 2048,
450
+ 'num_key_value_heads': 32,
451
+ 'rope_theta': 10000.0,
452
+ 'tie_word_embeddings': False,
453
+ 'hidden_act': 'silu',
454
+ 'norm_module': 'RMSNorm',
455
+ "tokenizer": "llama"
456
+ },
457
+ 'gemma_2b': {
458
+ 'vocab_size': 256000,
459
+ 'hidden_size': 2048,
460
+ 'intermediate_size': 16384,
461
+ 'num_hidden_layers': 18,
462
+ 'num_attention_heads': 8,
463
+ 'max_sequence_length': 8192,
464
+ 'initializer_range': 0.02,
465
+ 'rms_norm_eps': 1e-6,
466
+ 'max_position_embeddings': 8192,
467
+ 'num_key_value_heads': 1,
468
+ 'rope_theta': 10000.0,
469
+ 'tie_word_embeddings': True,
470
+ 'normalize_input_embeds': True,
471
+ 'norm_module': 'GemmaRMSNorm',
472
+ 'hidden_act': 'gelu',
473
+ "tokenizer": "gemma"
474
+ },
475
+ 'gemma_7b': {
476
+ 'vocab_size': 256000,
477
+ 'hidden_size': 3072,
478
+ 'intermediate_size': 24576,
479
+ 'num_hidden_layers': 28,
480
+ 'num_attention_heads': 16,
481
+ 'max_sequence_length': 8192,
482
+ 'initializer_range': 0.02,
483
+ 'rms_norm_eps': 1e-6,
484
+ 'max_position_embeddings': 8192,
485
+ 'num_key_value_heads': 16,
486
+ 'rope_theta': 10000.0,
487
+ 'tie_word_embeddings': True,
488
+ 'normalize_input_embeds': True,
489
+ 'norm_module': 'GemmaRMSNorm',
490
+ 'hidden_act': 'gelu',
491
+ "tokenizer": "gemma"
492
+ },
493
+ 'tiny_llama_1b': {
494
+ 'vocab_size': 32000,
495
+ 'hidden_size': 2048,
496
+ 'intermediate_size': 5632,
497
+ 'num_hidden_layers': 22,
498
+ 'num_attention_heads': 32,
499
+ 'max_sequence_length': 2048,
500
+ 'initializer_range': 0.02,
501
+ 'rms_norm_eps': 1e-5,
502
+ 'max_position_embeddings': 2048,
503
+ 'num_key_value_heads': 4,
504
+ 'rope_theta': 10000.0,
505
+ 'tie_word_embeddings': False,
506
+ 'hidden_act': 'silu',
507
+ 'norm_module': 'RMSNorm',
508
+ "tokenizer": "llama"
509
+ },
510
+ 'debug': { # A small model for debugging
511
+ 'vocab_size': 32000,
512
+ 'hidden_size': 512,
513
+ 'intermediate_size': 512,
514
+ 'num_hidden_layers': 1,
515
+ 'num_attention_heads': 8,
516
+ 'max_sequence_length': 4096,
517
+ 'initializer_range': 0.02,
518
+ 'rms_norm_eps': 1e-5,
519
+ 'max_position_embeddings': 4096,
520
+ 'num_key_value_heads': 8,
521
+ 'rope_theta': 10000.0,
522
+ 'tie_word_embeddings': False,
523
+ 'hidden_act': 'silu',
524
+ 'norm_module': 'RMSNorm',
525
+ "tokenizer": "llama"
526
+ },
527
+ 'gemma2_9b': {
528
+ 'vocab_size': 256000,
529
+ 'hidden_size': 3584,
530
+ 'head_dim': 256,
531
+ 'intermediate_size': 14336,
532
+ 'num_hidden_layers': 42,
533
+ 'num_attention_heads': 16,
534
+ 'max_sequence_length': 8192,
535
+ "query_pre_attn_scalar": 224,
536
+ 'initializer_range': 0.02,
537
+ 'rms_norm_eps': 1e-6,
538
+ 'max_position_embeddings': 8192,
539
+ 'num_key_value_heads': 8,
540
+ 'rope_theta': 10000.0,
541
+ 'tie_word_embeddings': False,
542
+ 'normalize_input_embeds': True,
543
+ 'norm_module': 'GemmaRMSNorm',
544
+ 'hidden_act': 'gelu_tanh',
545
+ "tokenizer": "hf-google/gemma-2-9b",
546
+ "attn_logit_softcapping": 50.0,
547
+ "final_logit_softcapping": 30.0,
548
+ },
549
+ 'gemma2_27b': {
550
+ 'vocab_size': 256000,
551
+ 'hidden_size': 4608,
552
+ 'head_dim': 128,
553
+ 'intermediate_size': 36864,
554
+ 'num_hidden_layers': 46,
555
+ 'num_attention_heads': 32,
556
+ 'max_sequence_length': 8192,
557
+ "query_pre_attn_scalar": 144,
558
+ 'initializer_range': 0.02,
559
+ 'rms_norm_eps': 1e-6,
560
+ 'max_position_embeddings': 8192,
561
+ 'num_key_value_heads': 16,
562
+ 'rope_theta': 10000.0,
563
+ 'tie_word_embeddings': False,
564
+ 'normalize_input_embeds': True,
565
+ 'norm_module': 'GemmaRMSNorm',
566
+ 'hidden_act': 'gelu_tanh',
567
+ "tokenizer": "hf-google/gemma-2-27b",
568
+ "attn_logit_softcapping": 50.0,
569
+ "final_logit_softcapping": 30.0,
570
+ },
571
+ }
data_factory.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Dataset factory to load data from huggingface and others.
3
+ '''
4
+ import dataclasses
5
+ import logging
6
+ from typing import List, Optional
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+
11
+ from .data_utils import add_segment_ids
12
+ from .dataset_sizes import get_dataset_size
13
+ from .tasks import get_task
14
+ from .multimodal_preprocessor import MultiModalPreprocessor
15
+ import seqio
16
+
17
+ from .torch_util import get_global_rank
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class SeqioDataset:
24
+ mixture_or_task_name: str
25
+ seq_len: int
26
+ global_batch_size: int
27
+ max_crops: int = None
28
+ is_training: bool = False
29
+ for_inference: bool = False
30
+ split: str = 'train'
31
+ shuffle: bool = True
32
+ num_epochs: int = None
33
+ drop_remainder: bool = True
34
+ seed: int = None
35
+ pack: bool = False
36
+ use_custom_packing_ops: bool = False
37
+ use_memory_cache: bool = False
38
+ shuffle_buffer_size: Optional[int] = None
39
+ different_host_mixture_seeds: bool = True
40
+ disable_autotune: bool = True
41
+ trim_output_features: bool = True
42
+
43
+ @classmethod
44
+ def from_dict(cls, data):
45
+ return cls(**data)
46
+
47
+ def get_task_feature_lengths_dict(self, max_crops):
48
+ if self.max_crops is not None:
49
+ assert self.max_crops >= max_crops
50
+ max_crops = self.max_crops
51
+ return dict(
52
+ target_tokens=self.seq_len,
53
+ loss_masks=self.seq_len,
54
+ images=max_crops,
55
+ image_positions=max_crops,
56
+ image_input_idx=max_crops,
57
+ is_training=self.is_training
58
+ )
59
+
60
+ def build(self, preprocessor: MultiModalPreprocessor, shard_id, num_shards):
61
+ shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards)
62
+ task_feature_lengths_dict = self.get_task_feature_lengths_dict(
63
+ preprocessor.get_max_total_crops())
64
+
65
+ seed = self.seed
66
+ assert seed is not None
67
+
68
+ batch_size = self.global_batch_size // num_shards
69
+
70
+ if isinstance(self.mixture_or_task_name, (dict, list, tuple)):
71
+ if isinstance(self.mixture_or_task_name, dict):
72
+ items = self.mixture_or_task_name.items()
73
+ else:
74
+ items = self.mixture_or_task_name
75
+ task_list = []
76
+ for task, weight in items:
77
+ task = get_task(preprocessor, task, self.is_training, self.for_inference)
78
+ task_list.append((task, weight))
79
+ mixture_or_task = task_list
80
+ else:
81
+ mixture_or_task = get_task(
82
+ preprocessor, self.mixture_or_task_name, self.is_training, self.for_inference)
83
+
84
+ in_memory_shuffle = self.shuffle
85
+ if not self.drop_remainder:
86
+ # Used if we want to evaluate on an eval dataset without dropping any examples.
87
+ # To do this, we pad the dataset with dummy examples marked as invalid in their
88
+ # metadata so we can still get fixed-sized batches.
89
+ assert self.num_epochs is not None
90
+ assert not self.pack
91
+ assert not isinstance(mixture_or_task, list), "Inference datasets cannot be mixtures"
92
+ logging.info(
93
+ f"Initializing inf. dataset {mixture_or_task.name}: replica_batch_size={batch_size}"
94
+ f' seed={seed}, sharding={shard_info.index}/{shard_info.num_shards}'
95
+ )
96
+ ds = mixture_or_task.get_dataset(
97
+ sequence_length=task_feature_lengths_dict,
98
+ split=self.split,
99
+ shuffle=in_memory_shuffle,
100
+ num_epochs=self.num_epochs,
101
+ seed=seed,
102
+ try_in_mem_cache=self.use_memory_cache,
103
+ trim_output_features=self.trim_output_features
104
+ )
105
+
106
+ try:
107
+ n = len(ds)
108
+ except TypeError:
109
+ dataset_len = get_dataset_size(self.mixture_or_task_name, self.split)
110
+ logging.info(f"Setting dataset len to {dataset_len} based on DATASET_SIZES")
111
+ n = dataset_len
112
+ ds = tf.data.experimental.assert_cardinality(n)(ds)
113
+
114
+ remainder = n % self.global_batch_size
115
+ if remainder > 0:
116
+ n_to_pad = self.global_batch_size - remainder
117
+ else:
118
+ n_to_pad = 0
119
+ assert "metadata/valid" not in ds.element_spec
120
+ def add_valid(x):
121
+ x["metadata/valid"] = True
122
+ return x
123
+ def add_invalid(x):
124
+ x["metadata/valid"] = False
125
+ return x
126
+ ds = ds.map(add_valid)
127
+ if n_to_pad > 0:
128
+ to_pad = ds.take(1).map(add_invalid).cache().repeat(n_to_pad)
129
+ ds = ds.concatenate(to_pad)
130
+
131
+ # Shard after padding to ensure shards are the same length
132
+ ds = ds.shard(num_shards=num_shards, index=shard_id)
133
+
134
+ ds = preprocessor.get_post_mixing_preprocessor()(
135
+ ds, task_feature_lengths=task_feature_lengths_dict)
136
+ data_iter = ds.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.experimental.AUTOTUNE)
137
+ # Make it possible for client to get the size of the batched/sharded dataset with `len()`
138
+ new_len = (n + n_to_pad) // self.global_batch_size
139
+ data_iter = tf.data.experimental.assert_cardinality(new_len)(data_iter)
140
+ else:
141
+ if isinstance(mixture_or_task, list):
142
+ total_rate = sum(x[1] for x in mixture_or_task)
143
+ mixture_or_task = [(task, r/total_rate) for task, r in mixture_or_task]
144
+ sorted_tasks: List[seqio.Task] = sorted(mixture_or_task, key=lambda x: -x[1])
145
+
146
+ if self.different_host_mixture_seeds and shard_info:
147
+ # If each process has the same seed they will draw from the datasets in the same
148
+ # order, which can make the global batches very non-random if there are
149
+ # many processes each with a small batch size. To fix this, we give each host
150
+ # a different seed based on its rank to use when mixing
151
+ mix_seed = seed + shard_info.index*4397
152
+ else:
153
+ mix_seed = seed
154
+
155
+ logging.info(
156
+ f"Initializing mixture: replica_batch_size={batch_size} seed={seed}, "
157
+ f"mix_seed={mix_seed}, sharding={shard_info.index}/{shard_info.num_shards} rates:"
158
+ )
159
+ for task, rate in sorted_tasks:
160
+ logging.info(f"\t{task.name}: {rate:0.4f}")
161
+
162
+ datasets = []
163
+ rates = []
164
+ for task, rate in sorted_tasks:
165
+ assert rate > 0
166
+ datasets.append(task.get_dataset(
167
+ task_feature_lengths_dict,
168
+ split=self.split,
169
+ shuffle=self.shuffle,
170
+ seed=seed,
171
+ shard_info=shard_info,
172
+ num_epochs=self.num_epochs,
173
+ try_in_mem_cache=self.use_memory_cache,
174
+ trim_output_features=self.trim_output_features
175
+ ))
176
+ rates.append(rate)
177
+
178
+ # If any of the sub-tasks have subsegment_ids, we need to ensure all the tasks have
179
+ # a subsegment_ids field so they can be mixed
180
+ if any("subsegment_ids" in ds.element_spec for ds in datasets):
181
+ for ix, ds in enumerate(datasets):
182
+ if "subsegment_ids" not in ds.element_spec:
183
+ datasets[ix] = add_segment_ids(ds)
184
+
185
+ ds = tf.data.Dataset.sample_from_datasets(
186
+ datasets, rates, seed=mix_seed, stop_on_empty_dataset=False)
187
+ else:
188
+ logging.info(
189
+ f"Initializing dataset {mixture_or_task.name}: replica_batch_size={batch_size}"
190
+ f' seed={seed}, sharding={shard_info.index}/{shard_info.num_shards}'
191
+ )
192
+ ds = mixture_or_task.get_dataset(
193
+ task_feature_lengths_dict,
194
+ split=self.split,
195
+ shuffle=self.shuffle,
196
+ seed=seed,
197
+ shard_info=shard_info,
198
+ num_epochs=self.num_epochs,
199
+ try_in_mem_cache=self.use_memory_cache,
200
+ trim_output_features=self.trim_output_features
201
+ )
202
+ data_iter = preprocessor.get_post_mixing_preprocessor()(
203
+ ds, task_feature_lengths=task_feature_lengths_dict)
204
+ data_iter = data_iter.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.experimental.AUTOTUNE)
205
+ ds = ds.prefetch(2)
206
+
207
+ # Following https://github.com/google-research/big_vision/blob/b8dab6e4de3436849415f37c591399c93b1eaf39/big_vision/input_pipeline.py#L228
208
+ # These options try to stop tf datasets from eating all our RAM if we are using a
209
+ # large mixture
210
+ # This options are used by default in some google codebases
211
+ # For example: (https://github.com/google-research/big_vision/blob/b8dab6e4de3436849415f37c591399c93b1eaf39/big_vision/input_pipeline.py#L228)
212
+ # They don't seem to harm throughput and can save RAM so we use them as well
213
+ options = tf.data.Options()
214
+ options.experimental_optimization.inject_prefetch = False
215
+ options.threading.max_intra_op_parallelism = 1
216
+ if self.disable_autotune:
217
+ # Following https://www.tensorflow.org/datasets/performances
218
+ # This reduces RAM and checkpoint size by a lot
219
+ options.autotune.enabled = False
220
+ data_iter = data_iter.with_options(options)
221
+
222
+ return data_iter
data_utils.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import dataclasses
3
+ import functools
4
+ import os
5
+ from os import environ
6
+ from typing import Mapping, Optional, Sequence, List
7
+ from absl import logging
8
+ import clu
9
+ import gin
10
+ from pathlib import Path
11
+
12
+ import seqio
13
+ from seqio import utils
14
+ from seqio.feature_converters import _check_exact_match, _check_lengths
15
+
16
+ import tensorflow as tf
17
+ from tensorflow.python.ops import control_flow_ops
18
+ from tensorflow.python.ops.image_ops_impl import _ImageDimensions, _CheckAtLeast3DImage, _assert, _is_tensor
19
+
20
+ from tensorflow.python.framework import ops
21
+ from tensorflow.python.ops import array_ops
22
+ from transformers import PreTrainedTokenizerFast
23
+
24
+ from . import seqio_tokenizer as vocab
25
+ from .constants import *
26
+ from .utils import pop_metadata
27
+ from .util import is_url
28
+
29
+ DEFAULT_EXTRA_IDS = 0
30
+ OutputFeaturesType = Mapping[str, utils.Feature]
31
+
32
+
33
+ def build_tokenizer(
34
+ tokenizer_type, has_extra_token=True,
35
+ adds_space=False,
36
+ olmo_bos_token_id=1, olmo_eos_token_id=2,
37
+ tokenizer_dir="gs://mm-olmo/tokenizer",
38
+ pad_tokenizer_to=None, cache={},
39
+ ):
40
+ cache_key = (tokenizer_type, has_extra_token, adds_space, olmo_bos_token_id,
41
+ olmo_eos_token_id, pad_tokenizer_to)
42
+ if cache_key in cache:
43
+ return cache[cache_key]
44
+
45
+ if tokenizer_type == 'llama':
46
+ tok = vocab.SentencePieceVocabulary(
47
+ os.path.join(tokenizer_dir, "llama_tokenizer.model"),
48
+ extra_ids=DEFAULT_EXTRA_IDS,
49
+ reverse_extra_ids=True,
50
+ extra_tokens=EXTRA_TOKENS if has_extra_token else None,
51
+ )
52
+ elif tokenizer_type == 'yi':
53
+ tok = vocab.SentencePieceVocabulary(
54
+ os.path.join(tokenizer_dir, "yi_tokenizer.model"),
55
+ extra_ids=DEFAULT_EXTRA_IDS,
56
+ reverse_extra_ids=True,
57
+ extra_tokens=EXTRA_TOKENS if has_extra_token else None,
58
+ )
59
+ elif tokenizer_type == 'mistral':
60
+ tok = vocab.SentencePieceVocabulary(
61
+ os.path.join(tokenizer_dir, "mistral_tokenizer.model"),
62
+ extra_ids=DEFAULT_EXTRA_IDS,
63
+ reverse_extra_ids=True,
64
+ extra_tokens=EXTRA_TOKENS if has_extra_token else None,
65
+ )
66
+
67
+ elif tokenizer_type == "mistral0.3":
68
+ tok = vocab.SentencePieceVocabulary(
69
+ os.path.join(tokenizer_dir, "mistral0.3_tokenizer.model.v3"),
70
+ extra_ids=DEFAULT_EXTRA_IDS,
71
+ reverse_extra_ids=True,
72
+ extra_tokens=EXTRA_TOKENS if has_extra_token else None,
73
+ )
74
+ elif tokenizer_type == 'gemma':
75
+ tok = vocab.SentencePieceVocabulary(
76
+ os.path.join(tokenizer_dir, "gemma_tokenizer.model"),
77
+ extra_ids=DEFAULT_EXTRA_IDS,
78
+ reverse_extra_ids=True,
79
+ extra_tokens=EXTRA_TOKENS if has_extra_token else None,
80
+ )
81
+ elif tokenizer_type.startswith("hf-"):
82
+ # FIXME When using the beaker image "sanghol/mm-olmo" for hosting endpoints,
83
+ # we should set the cache_dir, otherwise FileNotFound errors will be raised
84
+ cache_dir = None if tokenizer_dir is None or is_url(tokenizer_dir) else tokenizer_dir
85
+ from transformers import AutoTokenizer
86
+
87
+ extra_tokens = list(EXTRA_TOKENS)
88
+ if pad_tokenizer_to is not None:
89
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_type[3:], token=environ.get("HF_ACCESS_TOKEN"), cache_dir=cache_dir)
90
+ n_extra_tokens = pad_tokenizer_to - len(tokenizer)
91
+ # This handles a case where the LLM embedding matrix is larger than the vocab size
92
+ # We need the extra tokens in `EXTRA_TOKENS` to be assigned id's higher than the embedding
93
+ # matrix size, not the vocab size, since we will concat the embedding and matrix with
94
+ # the special token embedding matrix, so we pad the vocab with additional special tokens
95
+ if n_extra_tokens > 0:
96
+ logging.info(f"Padding tokenizer with {n_extra_tokens} tokens")
97
+ extra_tokens = [f"|<EXTRA_TOKENS_{i}>|" for i in range(n_extra_tokens)] + extra_tokens
98
+
99
+ bos_token_id = None
100
+
101
+ tokenizer = AutoTokenizer.from_pretrained(
102
+ tokenizer_type[3:], additional_special_tokens=extra_tokens,
103
+ token=environ.get("HF_ACCESS_TOKEN"),
104
+ cache_dir=cache_dir,
105
+ )
106
+ if ("qwen2" in tokenizer_type.lower()) or ("olmo" in tokenizer_type.lower()):
107
+ # These tokenizers do not have a BOS, and instead use EOS as a generic seperator token.
108
+ # In this case we will use EOS as BOS
109
+ assert tokenizer.bos_token_id is None
110
+ bos_token_id = tokenizer.eos_token_id
111
+
112
+ if pad_tokenizer_to is not None:
113
+ for ix, tok in enumerate(EXTRA_TOKENS):
114
+ ids = tokenizer.encode(tok, add_special_tokens=False)
115
+ assert ids == [pad_tokenizer_to + ix]
116
+
117
+ tok = vocab.HfTokenizerWrapper(tokenizer, bos_token_id=bos_token_id, adds_space=adds_space)
118
+ elif tokenizer_type.startswith("olmo-"):
119
+ from olmo.tokenizer import Tokenizer
120
+ assert Path(tokenizer_type[5:]).is_file()
121
+ tokenizer = Tokenizer.from_file(
122
+ tokenizer_type[5:],
123
+ eos_token_id=olmo_eos_token_id,
124
+ pad_token_id=-1,
125
+ )
126
+ tok = vocab.OLMoTokenizerWrapper(tokenizer, bos_token_id=olmo_bos_token_id, adds_space=adds_space)
127
+ else:
128
+ raise NotImplementedError(tokenizer_type)
129
+ cache[cache_key] = tok
130
+ return tok
131
+
132
+
133
+ def get_special_token_ids(tokenizer):
134
+ if isinstance(tokenizer, (vocab.HfTokenizerWrapper, vocab.OLMoTokenizerWrapper)):
135
+ ids = tokenizer.encode("".join(EXTRA_TOKENS))
136
+ if len(ids) == len(EXTRA_TOKENS) + 1:
137
+ ids = ids[1:]
138
+ elif ("gemma_tokenizer" in tokenizer._sentencepiece_model_file or
139
+ "yi_tokenizer" in tokenizer._sentencepiece_model_file
140
+ ):
141
+ # Not sure why ATM, but the LLaMa tokenizer will add an extra space token
142
+ # if this string starts with a space, while the gemma one needs the leading space
143
+ ids = tokenizer.encode(" " + " ".join(EXTRA_TOKENS))
144
+ else:
145
+ ids = tokenizer.encode(" ".join(EXTRA_TOKENS))
146
+
147
+ assert len(ids) == len(EXTRA_TOKENS)
148
+ return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
149
+
150
+
151
+ def _append_to_innermost_axis(
152
+ tensor: tf.Tensor, scalar: tf.Tensor,
153
+ ) -> tf.Tensor:
154
+ """Appends `scalar` to each slice in the innermost axis of `tensor`.
155
+
156
+ >>> _append_to_innermost_axis([1, 2, 3], -1)
157
+ [1, 2, 3, -1]
158
+ >>> _append_to_innermost_axis([[1, 2], [3, 4]], -1)
159
+ [[1, 2, -1], [3, 4, -1]]
160
+ >>> _append_to_innermost_axis(tf.ragged.constant([[1, 2], [3]]), -1)
161
+ [[1, 2, -1], [3, -1]]
162
+
163
+ Args:
164
+ tensor: The tensor that should have a value appended.
165
+ scalar: The value to append.
166
+
167
+ Returns:
168
+ A copy of `tensor` with `scalar` appended to each slice along
169
+ the innermost axis.
170
+ """
171
+ if isinstance(tensor, tf.RaggedTensor):
172
+ if tensor.shape.rank > 2:
173
+ return tensor.with_values(
174
+ _append_to_innermost_axis(tensor.values, scalar)
175
+ )
176
+ else:
177
+ return tf.concat([tensor, tf.fill([tensor.nrows(), 1], scalar)], axis=1)
178
+ else:
179
+ ndims = tf.rank(tensor)
180
+ paddings = tf.concat(
181
+ [tf.zeros((ndims - 1, 2), dtype=tf.int32), tf.constant([[0, 1]])],
182
+ axis=0,
183
+ )
184
+ return tf.pad(tensor, paddings=paddings, constant_values=scalar)
185
+
186
+
187
+ def _shift_right_by_one(tensor: tf.Tensor, bos_id: int = 0) -> tf.Tensor:
188
+ """Shift the input tensor to the right by one position without wrapping."""
189
+
190
+ if not (tensor.dtype.is_integer or tensor.dtype.is_floating):
191
+ raise ValueError(f"Only numeric types are supported. Got: {tensor.dtype}")
192
+ # tf.roll wraps around the axis.
193
+ rolled = tf.roll(tensor, shift=1, axis=0)
194
+
195
+ # Zero out the first position by multiplying with [0, 1, 1, ..., 1].
196
+ depth = tf.shape(tensor)[0]
197
+ mask = tf.one_hot(0, depth=depth, on_value=0, off_value=1, dtype=tensor.dtype)
198
+
199
+ # Expand dims of mask to broadcast to rolled.
200
+ dim_expansion = [slice(None, None)] + [None] * (len(rolled.shape) - 1)
201
+ mask = mask[dim_expansion]
202
+ return rolled * mask + (1 - mask) * bos_id
203
+
204
+
205
+ def make_autoregressive_inputs(
206
+ targets: tf.Tensor,
207
+ sequence_id: tf.Tensor = None,
208
+ output_dtype: Optional[tf.dtypes.DType] = None,
209
+ bos_id: int = 0,
210
+ ) -> tf.Tensor:
211
+ """Generate inputs for an autoregressive model, by shifting the targets.
212
+
213
+ Modified from mesh_tensorflow.transformer.transformer.autoregressive_inputs.
214
+
215
+ For the first element of each sequence, the returned input id is 0.
216
+
217
+ For a "packed" dataset, also pass the sequence_id tensor, which aligns
218
+ with the targets tensor and contains different values for different
219
+ concatenated examples.
220
+
221
+ Example for a packed dataset:
222
+
223
+ ```
224
+ targets = [3, 8, 2, 9, 2, 5, 4, 2, -1, -1]
225
+ sequence_id = [1, 1, 1, 2, 2, 3, 3, 3, 0, 0]
226
+ inputs = [1, 3, 8, 1, 9, 1, 5, 4, -1, -1]
227
+ | | |
228
+ These positions are set to 0 if sequence_id is not
229
+ None.
230
+ ```
231
+
232
+ Args:
233
+ targets: a tf.int32 tensor with shape [length].
234
+ sequence_id: an optional tensor with the same shape as targets.
235
+ output_dtype: an optional output data type.
236
+ bos_id: bos id.
237
+
238
+ Returns:
239
+ a tensor with dtype tf.int32 and the same shape as targets.
240
+ """
241
+ output_dtype = output_dtype or targets.dtype
242
+ if sequence_id is not None and not sequence_id.dtype.is_integer:
243
+ raise ValueError(
244
+ "The sequence_id should be integer-valued tensors for a packed dataset."
245
+ )
246
+ if sequence_id is not None and len(targets.shape) > 1:
247
+ raise ValueError(
248
+ "Only 1-D sequences are supported with packing. Got a "
249
+ f"packed {len(targets.shape)}-D sequence."
250
+ )
251
+
252
+ inputs = _shift_right_by_one(targets, bos_id)
253
+ if inputs.dtype != output_dtype:
254
+ inputs = tf.cast(inputs, output_dtype)
255
+
256
+ # We should have a 0 at the beginning of each sequence rather than the
257
+ # shifted EOS (e.g. 1) from the previous sequence.
258
+ if sequence_id is not None:
259
+ not_first_in_sequence = tf.equal(
260
+ sequence_id, _shift_right_by_one(sequence_id)
261
+ )
262
+ not_first_in_sequence = tf.cast(not_first_in_sequence, output_dtype)
263
+ first_ids = tf.cast((1 - not_first_in_sequence) * bos_id, output_dtype)
264
+ inputs = inputs * not_first_in_sequence + first_ids
265
+ return inputs
266
+
267
+
268
+ @tf.function
269
+ def sum_except_first_axis(tensor):
270
+ # Compute the sum along all axes except the first
271
+ axes_to_sum = tuple(range(1, len(tensor.shape)))
272
+ return tf.reduce_sum(tensor, axis=axes_to_sum)
273
+
274
+
275
+ @seqio.map_over_dataset()
276
+ def add_segment_ids(ex):
277
+ ex["subsegment_ids"] = tf.zeros_like(ex["target_tokens"], dtype=tf.int32)
278
+ return ex
279
+
280
+
281
+ def trim_and_pad_dataset(
282
+ dataset: tf.data.Dataset, feature_lengths: Mapping[str, int]
283
+ ) -> tf.data.Dataset:
284
+ """Trim and pad first dimension of features to `feature_lengths`.
285
+
286
+ Args:
287
+ dataset: tf.data.Dataset, the dataset to trim/pad examples in.
288
+ feature_lengths: map from feature key to final length. Other features will
289
+ be returned unchanged.
290
+
291
+ Returns:
292
+ Trimmed/padded tf.data.Dataset.
293
+ """
294
+
295
+ def _trim_and_pad(k: str, t: tf.Tensor) -> tf.Tensor:
296
+ """Trim/pad to the first axis of `t` to be of size `length`."""
297
+ if k not in feature_lengths:
298
+ return t
299
+ if isinstance(t, tf.RaggedTensor):
300
+ t = t.to_tensor()
301
+
302
+ constant_values = -1
303
+ length_k = feature_lengths[k]
304
+ if isinstance(length_k, int):
305
+ t = t[:length_k]
306
+ pad_amt = length_k - tf.shape(t)[0]
307
+ padded_t = tf.pad(t, [(0, pad_amt)] + [(0, 0)] * (len(t.shape) - 1), constant_values=constant_values)
308
+ padded_t.set_shape([length_k] + t.shape.as_list()[1:])
309
+ return padded_t
310
+
311
+ slices = tuple((slice(0, limit) for limit in length_k))
312
+ t = t[slices]
313
+ pad_amt = tf.pad((length_k - tf.shape(t))[..., None], ((0, 0), (1, 0)), constant_values=constant_values)
314
+ padded_t = tf.pad(t, pad_amt, constant_values=constant_values)
315
+ padded_t.set_shape(length_k)
316
+ return padded_t
317
+
318
+ return dataset.map(
319
+ lambda x: {k: _trim_and_pad(k, t) for k, t in x.items()},
320
+ num_parallel_calls=tf.data.experimental.AUTOTUNE,
321
+ )
322
+
323
+
324
+ def get_3d_subsegments(segmented_suffix):
325
+ q_lens, text_lens = segmented_suffix.nested_row_lengths()
326
+ text_segments = tf.range(0, tf.shape(text_lens)[0], dtype=tf.int32)
327
+ question_repeat = tf.reshape(tf.stack([tf.ones_like(q_lens), q_lens-1], 1), [-1])
328
+ question_offset = tf.range(1, tf.shape(q_lens)[0]+1, dtype=tf.int32)*200
329
+ question_offset = tf.reshape(tf.stack([question_offset, question_offset-100], 1), [-1])
330
+ text_segments = text_segments + tf.repeat(question_offset, question_repeat)
331
+ segment_ids = tf.cast(tf.repeat(text_segments, text_lens), tf.int32)
332
+ return segment_ids
333
+
334
+
335
+ def assert_not_truncated(ds, keys, max_val):
336
+ def _check(ex):
337
+ for k in keys:
338
+ tf.assert_less(tf.shape(ex[k])[0], max_val+1,
339
+ message=f"Field {k} was unexpectedly truncated max_len={max_val}")
340
+ return ex
341
+ return ds.map(_check)
342
+
343
+
344
+ def apply_with_random_selector(x, func, num_cases):
345
+ """Computes func(x, sel), with sel sampled from [0...num_cases-1].
346
+ Args:
347
+ x: input Tensor.
348
+ func: Python function to apply.
349
+ num_cases: Python int32, number of cases to sample sel from.
350
+ Returns:
351
+ The result of func(x, sel), where func receives the value of the
352
+ selector as a python integer, but sel is sampled dynamically.
353
+ """
354
+ sel = tf.random.uniform([], maxval=num_cases, dtype=tf.int32)
355
+ # Pass the real x only to one of the func calls.
356
+ return control_flow_ops.merge([
357
+ func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
358
+ for case in range(num_cases)])[0]
359
+
360
+
361
+ def denormalize_boxes(boxes, image_shape):
362
+ """Converts boxes normalized by [height, width] to pixel coordinates.
363
+ Args:
364
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
365
+ boxes in ymin, xmin, ymax, xmax order.
366
+ image_shape: a list of two integers, a two-element vector or a tensor such
367
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
368
+ dimension is 2, which represents [height, width].
369
+ Returns:
370
+ denormalized_boxes: a tensor whose shape is the same as `boxes` representing
371
+ the denormalized boxes.
372
+ Raises:
373
+ ValueError: If the last dimension of boxes is not 4.
374
+ """
375
+ with tf.name_scope('denormalize_boxes'):
376
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
377
+ height, width = image_shape
378
+ height = tf.cast(height, dtype=boxes.dtype)
379
+ width = tf.cast(width, dtype=boxes.dtype)
380
+ else:
381
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
382
+ height, width = tf.split(image_shape, 2, axis=-1)
383
+
384
+ ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=-1)
385
+ ymin = ymin * height
386
+ xmin = xmin * width
387
+ ymax = ymax * height
388
+ xmax = xmax * width
389
+
390
+ denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
391
+ return denormalized_boxes
392
+
393
+ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
394
+ target_width, value=0):
395
+
396
+ return pad_to_bounding_box_internal(
397
+ image,
398
+ offset_height,
399
+ offset_width,
400
+ target_height,
401
+ target_width,
402
+ check_dims=True,
403
+ value=value)
404
+
405
+ def pad_to_bounding_box_internal(image, offset_height, offset_width,
406
+ target_height, target_width, check_dims, value):
407
+
408
+ with ops.name_scope(None, 'pad_to_bounding_box_with_one_internal', [image]):
409
+ image = ops.convert_to_tensor(image, name='image')
410
+
411
+ is_batch = True
412
+ image_shape = image.get_shape()
413
+ if image_shape.ndims == 3:
414
+ is_batch = False
415
+ image = array_ops.expand_dims(image, 0)
416
+ elif image_shape.ndims is None:
417
+ is_batch = False
418
+ image = array_ops.expand_dims(image, 0)
419
+ image.set_shape([None] * 4)
420
+ elif image_shape.ndims != 4:
421
+ raise ValueError(
422
+ '\'image\' (shape %s) must have either 3 or 4 dimensions.' %
423
+ image_shape)
424
+
425
+ batch, height, width, depth = _ImageDimensions(image, rank=4)
426
+
427
+ after_padding_width = target_width - offset_width - width
428
+
429
+ after_padding_height = target_height - offset_height - height
430
+
431
+ if check_dims:
432
+ assert_ops = _CheckAtLeast3DImage(image, require_static=False)
433
+ assert_ops += _assert(offset_height >= 0, ValueError,
434
+ 'offset_height must be >= 0')
435
+ assert_ops += _assert(offset_width >= 0, ValueError,
436
+ 'offset_width must be >= 0')
437
+ assert_ops += _assert(after_padding_width >= 0, ValueError,
438
+ 'width must be <= target - offset')
439
+ assert_ops += _assert(after_padding_height >= 0, ValueError,
440
+ 'height must be <= target - offset')
441
+ image = control_flow_ops.with_dependencies(assert_ops, image)
442
+
443
+ # Do not pad on the depth dimensions.
444
+ paddings = array_ops.reshape(
445
+ tf.stack([
446
+ 0, 0, offset_height, after_padding_height, offset_width,
447
+ after_padding_width, 0, 0
448
+ ]), [4, 2])
449
+ padded = array_ops.pad(image, paddings, constant_values=value)
450
+
451
+ padded_shape = [
452
+ None if _is_tensor(i) else i
453
+ for i in [batch, target_height, target_width, depth]
454
+ ]
455
+ padded.set_shape(padded_shape)
456
+
457
+ if not is_batch:
458
+ padded = array_ops.squeeze(padded, axis=[0])
459
+
460
+ return padded
461
+
462
+ def resize_and_crop_boxes(boxes, image_scale, output_size, offset, paddings):
463
+ """Resizes boxes to output size with scale and offset.
464
+ Args:
465
+ boxes: `Tensor` of shape [N, 4] representing ground truth boxes.
466
+ image_scale: 2D float `Tensor` representing scale factors that apply to
467
+ [height, width] of input image.
468
+ output_size: 2D `Tensor` or `int` representing [height, width] of target
469
+ output image size.
470
+ offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
471
+ boxes.
472
+ paddings: 2D `Tensor` representing top/left paddings.
473
+ Returns:
474
+ boxes: `Tensor` of shape [N, 4] representing the scaled boxes.
475
+ """
476
+ # Adjusts box coordinates based on image_scale, offset and paddings.
477
+ boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
478
+ boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
479
+ boxes += tf.tile(tf.expand_dims(paddings, axis=0), [1, 2])
480
+ # Clips the boxes.
481
+ boxes = clip_boxes(boxes, output_size)
482
+ return boxes
483
+
484
+ def clip_boxes(boxes, image_shape):
485
+ """Clips boxes to image boundaries.
486
+ Args:
487
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
488
+ boxes in ymin, xmin, ymax, xmax order.
489
+ image_shape: a list of two integers, a two-element vector or a tensor such
490
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
491
+ dimension is 2, which represents [height, width].
492
+ Returns:
493
+ clipped_boxes: a tensor whose shape is the same as `boxes` representing the
494
+ clipped boxes.
495
+ Raises:
496
+ ValueError: If the last dimension of boxes is not 4.
497
+ """
498
+ if boxes.shape[-1] != 4:
499
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
500
+ boxes.shape[-1]))
501
+
502
+ with tf.name_scope('clip_boxes'):
503
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
504
+ height, width = image_shape
505
+ max_length = [height, width, height, width]
506
+ else:
507
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
508
+ height, width = tf.unstack(image_shape, axis=-1)
509
+ max_length = tf.stack(
510
+ [height, width, height, width], axis=-1)
511
+
512
+ clipped_boxes = tf.math.maximum(tf.math.minimum(boxes, max_length), 0.0)
513
+ return clipped_boxes
514
+
515
+
516
+ def get_non_empty_box_indices(boxes):
517
+ """Get indices for non-empty boxes."""
518
+ # Selects indices if box height or width is 0.
519
+ height = boxes[:, 2] - boxes[:, 0]
520
+ width = boxes[:, 3] - boxes[:, 1]
521
+ indices = tf.where(
522
+ tf.logical_and(tf.greater(height, 0), tf.greater(width, 0)))
523
+ return indices[:, 0]
524
+
525
+
526
+ def resize_and_pad(image, desired_output_size, masks=None, boxes=None, labels=None,
527
+ random_scale_min=0.1, random_scale_max=2.0, do_random_scale=False,
528
+ shrink_both_sides=True, boxes1=None, filter_box=True,
529
+ desired_target_size=None, random_scale_ratio=0.0,
530
+ resize_method=tf.image.ResizeMethod.BILINEAR, return_outputs=True,
531
+ pad_value=0, normalize=True):
532
+ desired_height, desired_width = desired_output_size
533
+ desired_height_f = tf.cast(desired_height, dtype=tf.float32)
534
+ desired_width_f = tf.cast(desired_width, dtype=tf.float32)
535
+
536
+ height = tf.cast(tf.shape(image)[0], tf.float32)
537
+ width = tf.cast(tf.shape(image)[1], tf.float32)
538
+
539
+ if boxes is not None:
540
+ # Converts boxes from normalized coordinates to pixel coordinates.
541
+ # Now the coordinates of boxes are w.r.t. the original image.
542
+ boxes = denormalize_boxes(boxes, [height, width])
543
+
544
+ if boxes1 is not None:
545
+ boxes1 = denormalize_boxes(boxes1, [height, width])
546
+
547
+ if do_random_scale:
548
+ random_scale_factor = tf.random.uniform([], random_scale_min, random_scale_max)
549
+ if not shrink_both_sides:
550
+ # Max random is where scale * W > W_desired
551
+ # scale * H > H_desired
552
+ rsf_max = tf.maximum(desired_width_f / width, desired_height_f / height)
553
+ random_scale_factor = tf.minimum(rsf_max, random_scale_factor)
554
+
555
+ scaled_y = tf.cast(random_scale_factor * desired_height_f, tf.int32)
556
+ scaled_x = tf.cast(random_scale_factor * desired_width_f, tf.int32)
557
+
558
+ # Recompute the accurate scale_factor using rounded scaled image size.
559
+ image_scale_y = tf.cast(scaled_y, tf.float32) / height
560
+ image_scale_x = tf.cast(scaled_x, tf.float32) / width
561
+
562
+ image_scale = tf.cond(tf.less(
563
+ tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
564
+ tf.cast(random_scale_ratio, tf.float32)),
565
+ lambda: tf.maximum(image_scale_x, image_scale_y),
566
+ lambda: tf.minimum(image_scale_x, image_scale_y))
567
+
568
+ # image_scale = tf.minimum(image_scale_x, image_scale_y)
569
+
570
+ # Conceptual captions has some REALLY WIDE images I believe
571
+ # this ensures that we won't scale any side lower than to 64
572
+ image_scale = tf.maximum(image_scale, 64.0 / tf.minimum(height, width))
573
+
574
+ # Select non-zero random offset (x, y) if scaled image is larger than
575
+ # self._output_size.
576
+ scaled_height = tf.cast(height * image_scale, tf.int32)
577
+ scaled_width = tf.cast(width * image_scale, tf.int32)
578
+ offset_y = tf.cast(scaled_height - desired_height, tf.float32)
579
+ offset_x = tf.cast(scaled_width - desired_width, tf.float32)
580
+ offset_y = tf.maximum(0.0, offset_y) * tf.random.uniform([], 0, 1)
581
+ offset_x = tf.maximum(0.0, offset_x) * tf.random.uniform([], 0, 1)
582
+ offset_y = tf.cast(offset_y, tf.int32)
583
+ offset_x = tf.cast(offset_x, tf.int32)
584
+ else:
585
+ image_scale_y = desired_height_f / height
586
+ image_scale_x = desired_width_f / width
587
+ image_scale = tf.minimum(image_scale_x, image_scale_y)
588
+ scaled_height = tf.cast(height * image_scale, tf.int32)
589
+ scaled_width = tf.cast(width * image_scale, tf.int32)
590
+ offset_y = tf.constant(0)
591
+ offset_x = tf.constant(0)
592
+
593
+ # Now resize and crop
594
+ if resize_method == 'random' and do_random_scale:
595
+ resize_methods = sorted([k for k in tf.image.ResizeMethod.__dict__.keys() if k.isupper()])
596
+ image = apply_with_random_selector(
597
+ image,
598
+ lambda x, method_idx: tf.image.resize(x, [scaled_height, scaled_width],
599
+ tf.image.ResizeMethod.__dict__[resize_methods[method_idx]],
600
+ antialias=True),
601
+ num_cases=len(resize_methods))
602
+
603
+ elif resize_method != 'random':
604
+ image = tf.image.resize(image, [scaled_height, scaled_width], method=resize_method, antialias=True)
605
+ else:
606
+ image = tf.image.resize(image, [scaled_height, scaled_width],
607
+ method=tf.image.ResizeMethod.BILINEAR, antialias=True)
608
+
609
+ image = tf.clip_by_value(image, 0.0, 1.0)
610
+
611
+ # H x W x C
612
+ image = image[offset_y:offset_y + desired_height, offset_x:offset_x + desired_width, :]
613
+
614
+ H = tf.shape(image)[0]
615
+ W = tf.shape(image)[1]
616
+
617
+ top_pad = (desired_height - H) // 2
618
+ left_pad = (desired_width - W) // 2
619
+
620
+ image_mask = pad_to_bounding_box(
621
+ tf.ones_like(image, dtype=tf.bool), top_pad, left_pad, desired_height, desired_width)[:,:,0]
622
+
623
+ image = pad_to_bounding_box(image, top_pad, left_pad, desired_height, desired_width, value=pad_value)
624
+
625
+ if isinstance(desired_height, int) and isinstance(desired_width, int):
626
+ image.set_shape([desired_height, desired_width, 3])
627
+
628
+ if masks is not None and tf.size(masks) != 0:
629
+ masks = tf.image.resize(masks, [scaled_height, scaled_width],
630
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
631
+
632
+ if len(masks.shape) == 3:
633
+ masks = masks[offset_y:offset_y + desired_height, offset_x:offset_x + desired_width]
634
+ else:
635
+ masks = masks[:, offset_y:offset_y + desired_height, offset_x:offset_x + desired_width]
636
+
637
+ masks = pad_to_bounding_box(masks, top_pad, left_pad, desired_height, desired_width)
638
+ masks = tf.image.resize(masks, desired_target_size,
639
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
640
+
641
+ indices = None
642
+ if boxes is not None:
643
+ # assert ValueError("the box need to be shift which is not tested yet.")
644
+ boxes = resize_and_crop_boxes(
645
+ boxes,
646
+ tf.stack([image_scale, image_scale]),
647
+ [desired_height, desired_width],
648
+ tf.cast(tf.stack([offset_y, offset_x]), dtype=tf.float32),
649
+ tf.cast(tf.stack([top_pad, left_pad]), dtype=tf.float32))
650
+
651
+ if filter_box:
652
+ indices = get_non_empty_box_indices(boxes)
653
+ else:
654
+ indices = tf.range(tf.shape(boxes)[0])
655
+ boxes = tf.gather(boxes, indices)
656
+
657
+ if labels is not None:
658
+ labels = tf.gather(labels, indices)
659
+
660
+ if boxes1 is not None:
661
+ boxes1 = resize_and_crop_boxes(
662
+ boxes1,
663
+ tf.stack([image_scale, image_scale]),
664
+ [desired_height, desired_width],
665
+ tf.cast(tf.stack([offset_y, offset_x]), dtype=tf.float32),
666
+ tf.cast(tf.stack([top_pad, left_pad]), dtype=tf.float32))
667
+
668
+ image_info = tf.stack([
669
+ tf.cast(top_pad, tf.float32),
670
+ tf.cast(left_pad, tf.float32),
671
+ 1.0 / image_scale,
672
+ height,
673
+ width,
674
+ tf.cast(offset_y, dtype=tf.float32) / height,
675
+ tf.cast(offset_x, dtype=tf.float32) / width,
676
+ tf.cast(offset_y, dtype=tf.float32),
677
+ tf.cast(offset_x, dtype=tf.float32),
678
+ tf.cast(scaled_height, dtype=tf.float32),
679
+ tf.cast(scaled_width, dtype=tf.float32),
680
+ ])
681
+
682
+ if boxes1 is not None:
683
+ outputs = (image_info, masks, boxes, labels, indices, boxes1)
684
+ else:
685
+ outputs = (image_info, masks, boxes, labels, indices)
686
+
687
+ if normalize:
688
+ image = normalize_image(image)
689
+
690
+ if return_outputs:
691
+ return image, image_mask, outputs
692
+ else:
693
+ return image, image_mask
694
+
695
+
696
+ def _remove_bars_from_frames(frames, black_bar=True, threshold=32, max_perc_to_trim=0.3):
697
+ """
698
+ :param frames: [num_frames, height, width, 3]
699
+ :param blackbar_threshold: Pixels must be this intense for us to not trim
700
+ :param max_perc_to_prim: Will trim x% by default of the image at most in each dimension
701
+ :return:
702
+ """
703
+ # Detect black bars####################
704
+ frames_shape = tf.shape(frames)
705
+ h, w = frames_shape[1], frames_shape[2]
706
+ if black_bar:
707
+ has_content = tf.reduce_max(frames, axis=(0, -1)) >= threshold
708
+ else:
709
+ has_content = tf.reduce_min(frames, axis=(0, -1)) <= threshold
710
+
711
+ y_frames = tf.cast(tf.reshape(tf.where(tf.reduce_any(has_content, axis=1)), [-1]), tf.int32)
712
+ nhbars = tf.shape(y_frames)[0]
713
+ y_frames = tf.cond(nhbars > 0, lambda: y_frames, lambda: tf.expand_dims(tf.cast(h // 2, tf.int32), axis=0))
714
+
715
+ y1 = tf.minimum(y_frames[0], tf.cast(tf.cast(h, tf.float32) * max_perc_to_trim, tf.int32))
716
+ y2 = tf.maximum(y_frames[-1] + 1, tf.cast(tf.cast(h, tf.float32) * (1 - max_perc_to_trim), tf.int32))
717
+
718
+ x_frames = tf.cast(tf.reshape(tf.where(tf.reduce_any(has_content, axis=0)), [-1]), tf.int32)
719
+ nvbars = tf.shape(x_frames)[0]
720
+ x_frames = tf.cond(nvbars > 0, lambda: x_frames, lambda: tf.expand_dims(tf.cast(w // 2, tf.int32), axis=0))
721
+
722
+ x1 = tf.minimum(x_frames[0], tf.cast(tf.cast(w, tf.float32) * max_perc_to_trim, tf.int32))
723
+ x2 = tf.maximum(x_frames[-1] + 1, tf.cast(tf.cast(w, tf.float32) * (1 - max_perc_to_trim), tf.int32))
724
+
725
+ frames = frames[:, y1:y2, x1:x2]
726
+ return frames
727
+
728
+ def convert_video_dtype(video,dtype):
729
+ """
730
+ Converts tensor to dtype and scales the values.
731
+ Video equivalent of tf.convert_image_dtype: https://www.tensorflow.org/api_docs/python/tf/image/convert_image_dtype
732
+ """
733
+ return tf.map_fn(
734
+ fn=functools.partial(
735
+ tf.image.convert_image_dtype,
736
+ dtype=dtype),
737
+ elems=video,
738
+ fn_output_signature=dtype)
739
+
740
+
741
+ def stateless_shuffle(x: tf.Tensor, seed):
742
+ if hasattr(tf.random.experimental, 'stateless_shuffle'):
743
+ return tf.random.experimental.stateless_shuffle(x, seed=seed)
744
+ else:
745
+ vals = tf.random.stateless_uniform(tf.shape(x)[:1], seed)
746
+ ixs = tf.argsort(vals)
747
+ return tf.gather(x, ixs)
748
+
749
+
750
+ def stateless_permutation(n: int, seed):
751
+ if hasattr(tf.random.experimental, 'stateless_shuffle'):
752
+ ix = tf.range(0, n, dtype=tf.int32)
753
+ return tf.random.experimental.stateless_shuffle(ix, seed=seed)
754
+ else:
755
+ vals = tf.random.stateless_uniform(n, seed)
756
+ return tf.argsort(vals)
757
+
758
+
759
+ @seqio.map_over_dataset
760
+ def _strip_metadata(example):
761
+ return pop_metadata(example)[0]
762
+
763
+
764
+ def sample_patches(mask, n_patches, stateless=False, seeds=None):
765
+ input_sample_valid = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)
766
+ input_sample_masked = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask == 0)
767
+ if stateless:
768
+ encoder_pos_ids = tf.concat([
769
+ stateless_shuffle(input_sample_valid, seeds[0]),
770
+ stateless_shuffle(input_sample_masked, seeds[1])], axis=0)[:n_patches]
771
+ else:
772
+ encoder_pos_ids = tf.concat([
773
+ tf.random.shuffle(input_sample_valid),
774
+ tf.random.shuffle(input_sample_masked)], axis=0)[:n_patches]
775
+ encoder_pos_ids = tf.reshape(encoder_pos_ids, (n_patches,))
776
+ encoder_pos_ids = tf.cast(encoder_pos_ids, tf.int32)
777
+ return encoder_pos_ids
778
+
779
+
780
+ @gin.configurable()
781
+ def normalize_image(image,
782
+ offset=(0.48145466, 0.4578275, 0.40821073),
783
+ scale=(0.26862954, 0.26130258, 0.27577711)):
784
+ """Normalizes the image to zero mean and unit variance."""
785
+ offset = tf.constant(offset)
786
+ offset = tf.expand_dims(offset, axis=0)
787
+ offset = tf.expand_dims(offset, axis=0)
788
+ image -= tf.cast(offset, image.dtype)
789
+
790
+ scale = tf.constant(scale)
791
+ scale = tf.expand_dims(scale, axis=0)
792
+ scale = tf.expand_dims(scale, axis=0)
793
+ image /= tf.cast(scale, image.dtype)
794
+ return image
795
+
796
+
797
+ def unnormalize_image(image,
798
+ offset=(0.48145466, 0.4578275, 0.40821073),
799
+ scale=(0.26862954, 0.26130258, 0.27577711)):
800
+ """Normalizes the image to zero mean and unit variance."""
801
+ scale = tf.cast(tf.expand_dims(tf.expand_dims(tf.constant(scale), axis=0), axis=0), image.dtype)
802
+ image *= scale
803
+
804
+ offset = tf.cast(tf.expand_dims(tf.expand_dims(tf.constant(offset), axis=0), axis=0), image.dtype)
805
+ image += offset
806
+ return image
807
+
808
+
809
+ def flatten_parts(ds: tf.data.Dataset, parts: List[str], add_index=False, dataset_size=None) -> tf.data.Dataset:
810
+ def _flatten(ex):
811
+ flat_key = {k: ex[k] for k in parts}
812
+ if add_index:
813
+ flat_key['index'] = tf.range(len(ex[parts[0]]))
814
+
815
+ flat_ds = tf.data.Dataset.from_tensor_slices(flat_key)
816
+
817
+ def _merge(_flat_ex):
818
+ for k, v in ex.items():
819
+ if k not in parts:
820
+ _flat_ex[k] = v
821
+ return _flat_ex
822
+ return flat_ds.map(_merge)
823
+
824
+ ds = ds.flat_map(_flatten)
825
+ if dataset_size is not None:
826
+ ds = tf.data.experimental.assert_cardinality(dataset_size)(ds)
827
+ return ds
dataset_sizes.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET_SIZES = {
2
+ ("cockatoo_qa_v2", "train"): 194820,
3
+ ("user_qa", "train"): 71172,
4
+
5
+ ("text_vqa", "train"): 34602,
6
+ ("chart_qa", "train"): 28299,
7
+ ("chart_qa_prompting", "train"): 28299,
8
+ ("chart_qa_weighted", "train"): 28299,
9
+ ("tally_qa", "train"): 132981,
10
+ ("doc_qa", "train"): 39463,
11
+ ("info_qa", "train"): 23946,
12
+ ("okvqa", "train"): 9009,
13
+ ("gqa", "train"): 943000,
14
+ ("gqa_multi", "train"): 72140,
15
+ ("coco_2014_vqa", "train"): 443757, # (82783, 443757)
16
+ ("coco_captioning_karpathy", "train"): 414113, # (82783, 414113)
17
+ ("coco_captioning_karpathy_multi", "train"): 82783,
18
+ ("coco_2014_vqa_multi", "train"): 82783,
19
+ ("science_qa_img", "train"): 6218,
20
+ ("ai2_diagram", "train"): 11389,
21
+ ("a_okvqa_mc", "train"): 17056,
22
+ ("a_okvqa_da", "train"): 17056,
23
+ ("ocr_vqa", "train"): 166043,
24
+ ("st_qa", "train"): 25050,
25
+ ("ocr_qa", "train"): 166043,
26
+
27
+ ("dv_qa", "train"): 200000,
28
+ ("tabwmp_da", "train"): 23059,
29
+ ("figure_qa", "train"): 100000,
30
+ ("figure_qa_zero_shot", "train"): 100000,
31
+ ("plot_qa", "train"): 157070,
32
+ ('clocks', 'train'): 800269,
33
+ ('clocks', 'validation'): 25600,
34
+
35
+ ("st_qa", "test"): 4070,
36
+ ('text_vqa', "test"): 5734,
37
+ ('okvqa', "test"): 5046,
38
+ ('chart_qa', "test"): 1250,
39
+ ('doc_qa', "test"): 5188,
40
+ ('info_qa', "test"): 3288,
41
+ ('gqa', "test"): 95336,
42
+ ('coco_captioning_karpathy', "test"): 25010,
43
+ ("science_qa_img", "test"): 2017,
44
+ ("ai2_diagram", "test"): 3088,
45
+ ("a_okvqa_mc_eval", "test"): 6702,
46
+ ("a_okvqa_da_eval", "test"): 6109,
47
+
48
+ ("ai2_diagram_v2", "train"): 10950,
49
+ ("ai2_diagram_v2", "validation"): 1463,
50
+ ("ai2_diagram_v2", "test"): 3088,
51
+ ("vqa_v2_test", "test2015"): 555187,
52
+
53
+ ("ai2_diagram_v2_transparent", "train"): 10950,
54
+ ("ai2_diagram_v2_transparent", "validation"): 1463,
55
+ ("ai2_diagram_v2_transparent", "test"): 3088,
56
+
57
+ # splits in mix_data include both transparent + opaque boxes
58
+ ("ai2_diagram_v2_mix_transparent", "train"): 15042,
59
+ ("ai2_diagram_v2_mix_transparent", "validation"): 1980,
60
+ ("ai2_diagram_v2_mix_transparent", "test"): 4272,
61
+
62
+ # vaia_qa
63
+ ('vaia_qa', 'train'): 477052,
64
+ ('vaia_qa', 'validation'): 1024,
65
+
66
+ ('vaia_qa_latex_image', 'train'): 477052,
67
+ ('vaia_qa_latex_image', 'validation'): 1024,
68
+ ('vaia_qa_latex_image_only', 'train'): 42605,
69
+ ('vaia_qa_latex_image_only', 'validation'): 1024,
70
+ ('vaia_qa_latex_all_image_only', 'train'): 154266,
71
+ ('vaia_qa_latex_all_image_only', 'validation'): 1024,
72
+
73
+ ("vaia_qa_latex_image_math_subset_short_answer", 'train'): 198161,
74
+ ("vaia_qa_latex_image_math_subset_short_answer", 'validation'): 419,
75
+ ("vaia_qa_latex_image_math_subset_mc_only_short_answer", "train"): 57568,
76
+ ("vaia_qa_latex_image_math_subset_mc_only_short_answer", "validation"): 118,
77
+ ("vaia_qa_latex_image_math_subset_mc_only_short_answer_first", "train"): 57568,
78
+ ("vaia_qa_latex_image_math_subset_mc_only_short_answer_first", "validation"): 118,
79
+
80
+ ("vaia_qa_latex_image_all_image_only_short_answer", "train"): 86752,
81
+ ("vaia_qa_latex_image_all_image_only_short_answer", "validation"): 92,
82
+ ("vaia_qa_latex_image_all_image_only_short_answer_first", "train"): 86752,
83
+ ("vaia_qa_latex_image_all_image_only_short_answer_first", "validation"): 92,
84
+ ("vaia_qa_latex_image_math_subset_image_only_short_answer", "train"): 21726,
85
+ ("vaia_qa_latex_image_math_subset_image_only_short_answer", "validation"): 48,
86
+
87
+ ('vqa_online', 'train'): 62722,
88
+ ('vqa_online', 'validation'): 1024,
89
+ ('vqa_online', 'test'): 1024,
90
+
91
+ ('vqa_online_gpt_longQ_longA', 'train'): 62722,
92
+ ('vqa_online_gpt_longQ_longA', 'validation'): 1024,
93
+ ('vqa_online_gpt_longQ_longA', 'test'): 1024,
94
+
95
+ ("tally_qa", "validation"): 38589,
96
+ ('text_vqa', "validation"): 5000,
97
+ ('okvqa', "validation"): 5046,
98
+ ('chart_qa', "validation"): 960*2,
99
+ ('chart_qa_prompting_explanation', "validation"): 960*2,
100
+ ('chart_qa_ex', "validation"): 960*2,
101
+ ('chart_qa_human', "validation"): 960,
102
+ ('chart_qa_aug', "validation"): 960,
103
+ ('doc_qa', "validation"): 5349,
104
+ ('info_qa', "validation"): 2801,
105
+ ('coco_2014_vqa', "validation"): 214354, # 40504 images
106
+ ('coco_2014_vqa_multi', "validation"): 214354,
107
+ ('coco_captioning_karpathy', "validation"): 25010,
108
+ ('gqa', "validation"): 132062,
109
+ ("science_qa_img", "validation"): 2097,
110
+ ("ai2_diagram", "validation"): 1024,
111
+ ("a_okvqa_mc", "validation"): 1145,
112
+ ("a_okvqa_da", "validation"): 1075,
113
+ ("charxiv_descriptive", "validation"): 1000,
114
+ ("charxiv_descriptive", "test"): 1320,
115
+ ("charxiv_reasoning", "validation"): 1000,
116
+ ("charxiv_reasoning", "test"): 1320,
117
+ ("fintabnetqa", "validation"): 125,
118
+ ("fintabnetqa", "test"): 250,
119
+ ("vwtq", "validation"): 125,
120
+ ("vwtq", "test"): 750,
121
+ ("vwtq_syn", "validation"): 125,
122
+ ("vwtq_syn", "test"): 250,
123
+ ("vtabfact", "validation"): 125,
124
+ ("vtabfact", "test"): 250,
125
+ ("nutrition_fact", "validation"): 100,
126
+ ("nutrition_fact", "test"): 100,
127
+
128
+ ("mmmu_test", "validation"): 900,
129
+ ("count_bench", "test"): 500,
130
+ ("mmmu_test", "test"): 10500,
131
+ ("real_world_qa_test", "test"): 765,
132
+ ("real_world_qa_no_instruction", "test"): 765,
133
+ ("real_world_qa_dbg", "test"): 765,
134
+ ("real_world_qa_as_user_qa", "test"): 765,
135
+
136
+ ("seed_bench_test", "test"): 19241,
137
+ ("pope_test", "test"): 9000,
138
+ ("mme_test", "test"): 2374,
139
+ ("math_vista_test", "validation"): 1000,
140
+ ("math_vista_demo", "validation"): 1000,
141
+ ("math_vista_v2", "validation"): 1000,
142
+
143
+ ("math_vista_test", "test"): 5141,
144
+ ("mmbench_test", "validation"): 4329,
145
+ ("mmbench_test", "test"): 6666,
146
+ ("sugar_crepe_test", "test"): 15022,
147
+ ("blink_test", "validation"): 1901,
148
+ ("dense_caption_eval_dbg", "validation"): 1,
149
+
150
+ ("refclef_unc", "train"): 17978,
151
+ ("refclef_unc", "validation"): 12029,
152
+ ("refcoco_unc", "train"): 16994,
153
+ ("refcoco_unc", "validation"): 10834,
154
+ ("refcocoplus_unc", "train"): 16992,
155
+ ("refcocoplus_unc", "validation"): 10758,
156
+ ("refcocog_umd", "train"): 21899,
157
+ ("refcocog_umd", "validation"): 4896,
158
+ ("refclef_unc", "testA"): 3449,
159
+ ("refclef_unc", "testB"): 3221,
160
+ ("refclef_unc", "testC"): 2664,
161
+ ("refclef_unc", "testAB"): 116,
162
+ ("refclef_unc", "testBC"): 86,
163
+ ("refcoco_unc", "testA"): 5657,
164
+ ("refcoco_unc", "testB"): 5095,
165
+ ("refcocoplus_unc", "testA"): 5726,
166
+ ("refcocoplus_unc", "testB"): 4889,
167
+ ("refcocog_umd", "test"): 9602,
168
+ ("countbench_qa_point_count", "huggingface"): 490,
169
+ ('countbench_qa', 'huggingface'): 490,
170
+
171
+ ('cockatoo_712k_sept6', 'train'): 712121,
172
+ ('cockatoo_712k_sept6', 'validation'): 5120,
173
+ ('user_qa', 'train'): 71172,
174
+ ('user_qa', 'validation'): 2048,
175
+
176
+ # pointing
177
+ ("pointing_test", "test"): 436,
178
+
179
+ ("fast_flickr_count_qa_point_count", "train"): 36916,
180
+ ("fast_flickr_count_qa_point_count", "validation"): 163,
181
+ ("fast_flickr_count_qa_point_count", "test"): 540,
182
+ ("fast_flickr_count_qa_pointing", "train"): 36916,
183
+ ("fast_flickr_count_qa_pointing", "validation"): 163,
184
+ ("fast_flickr_count_qa_pointing", "test"): 540,
185
+ ('pointing', 'train'): 309216,
186
+ ('point_count', 'train'): 309216,
187
+ ('pointing', 'validation'): 2054,
188
+ ('point_count', 'validation'): 2054,
189
+ ('point_count_high_freq', 'train'): 113840,
190
+ ('point_count_high_freq', 'validation'): 3969,
191
+ ('pointing_high_freq', 'train'): 113840,
192
+ ('pointing_high_freq', 'validation'): 3969,
193
+ ('point_qa', 'train'): 27856,
194
+ ('point_qa', 'validation'): 978,
195
+ ("a_okvqa_da", "test"): 6109,
196
+ ("a_okvqa_mc", "test"): 6702,
197
+ ("user_questions_for_elo", "test"): 14851,
198
+ ("user_questions_for_elo_long", "test"): 1368,
199
+ ("user_questions_for_elo_9_to_12", "test"): 3000,
200
+
201
+ ("sim_point_count_qa", "train"): 522611,
202
+ ("sim_point_count_qa", "validation"): 800,
203
+ ("sim_point_count_qa", "test"): 800,
204
+ ("sim_count_qa", "train"): 522611,
205
+ ("sim_count_qa", "validation"): 800,
206
+ ("sim_count_qa", "test"): 800,
207
+
208
+ ("scifi_charts_qa", "validation"): 1024,
209
+ ("scifi_table_qa", "validation"): 1024,
210
+ ("scifi_natural_qa", "validation"): 128,
211
+ ("scifi_nutrition_qa", "validation"): 128,
212
+ ("scifi_document_qa", "validation"): 1024,
213
+ ("scifi_diagram_qa", "validation"): 1024,
214
+ ("scifi_charts_qa", "train"): 233622,
215
+ ("scifi_table_qa", "train"): 93036,
216
+ ("scifi_document_qa", "train"): 142559,
217
+ ("scifi_diagram_qa", "train"): 33102,
218
+
219
+ ("scifi_charts_qa_split", "train"): 116814,
220
+ ("scifi_table_qa_split", "train"): 46518,
221
+ ("scifi_document_qa_split", "train"): 71282,
222
+ ("scifi_diagram_qa_split", "train"): 16551,
223
+
224
+ ("scifi_charts_qa_exp_split", "train"): 116814,
225
+ ("scifi_table_qa_exp_split", "train"): 46518,
226
+ ("scifi_document_qa_exp_split", "train"): 71282,
227
+ ("scifi_diagram_qa_exp_split", "train"): 16551,
228
+
229
+ ("android_control", "train"): 74714,
230
+ ("android_control", "validation"): 690,
231
+ ("android_control", "test"): 3897,
232
+
233
+ ("synthetic_qa_v3_multi_turn", "train"): 9824,
234
+ ("synthetic_qa_v3", "train"): 162855,
235
+ ("synthetic_qa_v3_style_tag", "train"): 162855,
236
+ ("synthetic_qa_v3_as_user_qa", "train"): 162855,
237
+ }
238
+
239
+
240
+ for (name, split), count in list(DATASET_SIZES.items()):
241
+ if name in ["chart_qa"]:
242
+ DATASET_SIZES[(name + "_scifi", split)] = count
243
+ if name in ["android_control"]:
244
+ for k in ["ll", "hl", "hl_ll", "hl_cot"]:
245
+ DATASET_SIZES[(f"{name}_{k}", split)] = count
246
+ if name in ["scifi_charts_qa" ,"scifi_table_qa", "scifi_document_qa", "scifi_diagram_qa", "scifi_datikz_qa"]:
247
+ DATASET_SIZES[(name + "_exp", split)] = count
248
+ DATASET_SIZES[(name[:-3] + "_exp", split)] = count
249
+ DATASET_SIZES[(name[:-3] + "_demo", split)] = count
250
+ if name in ["ai2_diagram_v2_mix_transparent"]:
251
+ DATASET_SIZES[("ai2_diagram_v2_mix_transparent_one_style", split)] = count
252
+ if name in ["chart_qa", "info_qa", "doc_qa", "text_vqa", "coco_2014_vqa",
253
+ "ai2_diagram_v2_mix_transparent", "countbench_qa", "chart_qa_human"]:
254
+ DATASET_SIZES[(name + "_demo", split)] = count
255
+
256
+
257
+ def get_dataset_size(name, split):
258
+ if name.endswith("_eval"):
259
+ if (name, split) in DATASET_SIZES:
260
+ return DATASET_SIZES[(name, split)]
261
+ name = name[:-len('_eval')]
262
+ return DATASET_SIZES[(name, split)]
exceptions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "OLMoError",
3
+ "OLMoConfigurationError",
4
+ "OLMoCliError",
5
+ "OLMoEnvironmentError",
6
+ "OLMoNetworkError",
7
+ "OLMoCheckpointError",
8
+ ]
9
+
10
+
11
+ class OLMoError(Exception):
12
+ """
13
+ Base class for all custom OLMo exceptions.
14
+ """
15
+
16
+
17
+ class OLMoConfigurationError(OLMoError):
18
+ """
19
+ An error with a configuration file.
20
+ """
21
+
22
+
23
+ class OLMoCliError(OLMoError):
24
+ """
25
+ An error from incorrect CLI usage.
26
+ """
27
+
28
+
29
+ class OLMoEnvironmentError(OLMoError):
30
+ """
31
+ An error from incorrect environment variables.
32
+ """
33
+
34
+
35
+ class OLMoNetworkError(OLMoError):
36
+ """
37
+ An error with a network request.
38
+ """
39
+
40
+
41
+ class OLMoCheckpointError(OLMoError):
42
+ """
43
+ An error occurred reading or writing from a checkpoint.
44
+ """
45
+
46
+
47
+ class OLMoThreadError(Exception):
48
+ """
49
+ Raised when a thread fails.
50
+ """
iterable_dataset.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import pickle
6
+ import queue
7
+ import socket
8
+ import time
9
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
10
+ from multiprocessing.managers import BaseManager
11
+ from multiprocessing.shared_memory import SharedMemory
12
+ from os.path import exists
13
+ from pathlib import Path
14
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
15
+
16
+ import psutil
17
+ import tensorflow as tf
18
+ import numpy as np
19
+ import torch
20
+ import torch.utils.data
21
+ import clu
22
+ from clu.data.dataset_iterator import Element
23
+
24
+
25
+ from .aliases import PathOrStr
26
+ from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size, get_node_rank, \
27
+ get_local_world_size, get_local_rank, move_to_device
28
+ from .util import roundrobin, threaded_generator
29
+ from .data_factory import SeqioDataset
30
+ from .multimodal_preprocessor import MultiModalPreprocessor
31
+ from .preprocesssors import rename
32
+ import torch.distributed as dist
33
+ from . import tasks
34
+
35
+ __all__ = ["MMIterableDataset"]
36
+
37
+ log = logging.getLogger(__name__)
38
+
39
+
40
+ def batch_fn(batch, for_inference):
41
+ if for_inference:
42
+ out = {}
43
+ for k, v in batch.items():
44
+ if k.startswith("metadata/"):
45
+ out[k] = v
46
+ else:
47
+ out[k] = torch.from_numpy(v)
48
+ return out
49
+ else:
50
+ out = {k: torch.from_numpy(v) for k, v in batch.items() if not k.startswith("metadata/")}
51
+ out["metadata"] = [{} for _ in out["input_ids"]]
52
+ return out
53
+
54
+
55
+ class PyTorchDatasetIterator(clu.data.dataset_iterator.TfDatasetIterator):
56
+ def __init__(self, dataset, *, checkpoint: bool, for_inference: bool):
57
+ self.for_inference = for_inference
58
+ super().__init__(dataset, checkpoint=checkpoint)
59
+
60
+ def __next__(self) -> Element:
61
+ batch = {k: v.numpy() for k, v in next(self.iterator).items()}
62
+ return batch_fn(batch, self.for_inference)
63
+
64
+ def __len__(self) -> int:
65
+ return len(self._dataset)
66
+
67
+
68
+ class MMIterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]):
69
+ def __init__(
70
+ self,
71
+ dataset: SeqioDataset,
72
+ preprocessor: MultiModalPreprocessor,
73
+ world_size: Optional[int] = None,
74
+ rank: Optional[int] = None,
75
+ ):
76
+ self.preprocessor = preprocessor
77
+ self.rank = rank if rank is not None else get_global_rank()
78
+ self.world_size = world_size if world_size is not None else get_world_size()
79
+ self.dataset_config = dataset
80
+
81
+ data_iter = dataset.build(
82
+ self.preprocessor,
83
+ self.rank,
84
+ self.world_size,
85
+ )
86
+
87
+ data_iter: tf.data.Dataset = rename(input_ids="input_tokens", labels="target_tokens")(data_iter)
88
+ self.dataset = data_iter
89
+ self.data_iter = PyTorchDatasetIterator(
90
+ data_iter, checkpoint=True, for_inference=dataset.for_inference)
91
+
92
+ def reset(self):
93
+ self.data_iter.reset()
94
+
95
+ def save(self, filename: PathOrStr):
96
+ self.data_iter.save(filename)
97
+
98
+ def restore(self, filename: PathOrStr):
99
+ self.data_iter.restore(filename)
100
+
101
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
102
+ return self.data_iter
103
+
104
+
105
+ def _split_batch(batch, n):
106
+ subbatches = [{} for _ in range(n)]
107
+ for k, v in batch.items():
108
+ assert len(v) % n == 0, f"n={n} but {k} has {len(v)}"
109
+ subatch_dim = len(v) // n
110
+ for i, subbatch in enumerate(subbatches):
111
+ subbatch[k] = v[i * subatch_dim:(i + 1) * subatch_dim]
112
+ return subbatches
113
+
114
+
115
+ def tf_to_torch_dtype(tf_dtype):
116
+ dtype_mapping = {
117
+ tf.float16: torch.float16,
118
+ tf.float32: torch.float32,
119
+ tf.float64: torch.float64,
120
+ tf.int8: torch.int8,
121
+ tf.uint8: torch.uint8,
122
+ tf.int16: torch.int16,
123
+ tf.int32: torch.int32,
124
+ tf.int64: torch.int64,
125
+ tf.bool: torch.bool,
126
+ }
127
+ return dtype_mapping[tf_dtype]
128
+
129
+
130
+ class PeerToPeer(torch.utils.data.IterableDataset[Dict[str, Any]]):
131
+ """
132
+ This dataloader runs the tf.data.Dataset on one processes per a node, and then
133
+ transfers the batch to the other processes. For 7B model about a 10% performance
134
+ despite my attempts to make it asynchronous
135
+
136
+ The advantage is that it avoids the overhead of running multiple tf.data.Dataset
137
+ in one node
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ dataset: SeqioDataset,
143
+ preprocessor: MultiModalPreprocessor,
144
+ world_size: Optional[int] = None,
145
+ rank: Optional[int] = None,
146
+ device=None
147
+ ):
148
+ assert get_world_size() % get_local_world_size() == 0
149
+ self.device = device
150
+ self.device_batch_size = dataset.global_batch_size // get_world_size()
151
+
152
+ self.preprocessor = preprocessor
153
+ self.seqio_dataset = dataset
154
+
155
+ lws = get_local_world_size()
156
+
157
+ if get_local_rank() == 0:
158
+ tf_dataset = dataset.build(
159
+ self.preprocessor,
160
+ get_node_rank(),
161
+ get_world_size() // lws,
162
+ )
163
+
164
+ tf_dataset = rename(input_ids="input_tokens", labels="target_tokens")(tf_dataset)
165
+ self.dataset = tf_dataset
166
+ device_spec = {k: ((v.shape[0]//lws,) + tuple(v.shape[1:]), tf_to_torch_dtype(v.dtype))
167
+ for k, v in tf_dataset.element_spec.items()}
168
+ else:
169
+ self.dataset = None
170
+ device_spec = None
171
+
172
+ broadcast = [device_spec]
173
+ torch.distributed.broadcast_object_list(broadcast)
174
+ self.device_spec = broadcast[0]
175
+
176
+ self._node_group_ranks = ranks = [(i + get_node_rank()*lws) for i in range(lws)]
177
+ if get_local_rank() == 0:
178
+ assert get_global_rank() == self._node_group_ranks[0]
179
+ self._keys = sorted(self.device_spec)
180
+ self.multithread_pin = False
181
+
182
+ def _pin(self, it, on):
183
+ batch = next(it)
184
+ batch = {k: torch.from_numpy(v) for k, v in batch.items()}
185
+ batch = _split_batch(batch, len(self._node_group_ranks))
186
+ return [{k: v.pin_memory() for k, v in subbatch.items()} for subbatch in batch]
187
+
188
+ def _send_pinned(self, batch):
189
+ requests = []
190
+ for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1):
191
+ for k in self._keys:
192
+ batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True)
193
+ requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank))
194
+ ops = dist.batch_isend_irecv(requests)
195
+ return batch[0], ops
196
+
197
+ def _send(self, it, on):
198
+ if get_local_rank() == 0:
199
+ try:
200
+ batch = next(it)
201
+ batch = {k: torch.from_numpy(v) for k, v in batch.items()}
202
+ batch = _split_batch(batch, len(self._node_group_ranks))
203
+ except StopIteration:
204
+ # Special batch to indicate iteration is done
205
+ batch = [
206
+ {k: torch.full(sh, -10, dtype=dtype, device=self.device)
207
+ for k, (sh, dtype) in self.device_spec.items()}
208
+ for _ in range(len(self._node_group_ranks))
209
+ ]
210
+
211
+ # pin_memory so the device transfer can be non_blocking
212
+ batch = [{k: v.pin_memory() for k, v in subbatch.items()}
213
+ for subbatch in batch]
214
+
215
+ requests = []
216
+ for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1):
217
+ for k in self._keys:
218
+ batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True)
219
+ requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank))
220
+ ops = dist.batch_isend_irecv(requests)
221
+ batch = batch[0]
222
+ else:
223
+ batch = {k: torch.zeros(sh, dtype=dtype, device=self.device)
224
+ for k, (sh, dtype) in self.device_spec.items()}
225
+ requests = []
226
+ for k in self._keys:
227
+ requests.append(dist.P2POp(dist.irecv, batch[k], self._node_group_ranks[0]))
228
+ ops = dist.batch_isend_irecv(requests)
229
+ return batch, ops
230
+
231
+ def __iter__(self):
232
+ on = 0
233
+ if get_local_rank() == 0:
234
+ it = iter(self.dataset.as_numpy_iterator())
235
+ else:
236
+ it = None
237
+
238
+ if get_local_rank() == 0 and self.multithread_pin:
239
+ # Try to be clever and do memory pinning in a seperate thread, in practice
240
+ # didn't seem to help much so off by default for now
241
+ # Currently does not support finite dataset
242
+ with ThreadPoolExecutor(max_workers=1) as pool:
243
+ _is_sending = self._send_pinned(self._pin(it, on))
244
+ _is_pinning = pool.submit(self._pin, it, on)
245
+ on += 1
246
+ while True:
247
+ result = _is_sending
248
+ _is_sending = self._send_pinned(_is_pinning.result())
249
+ _is_pinning = pool.submit(self._pin, it, on)
250
+ on += 1
251
+ for op in result[1]:
252
+ op.wait()
253
+ yield result[0]
254
+ else:
255
+ _in_flight = self._send(it, on)
256
+ on += 1
257
+ while True:
258
+ on += 1
259
+ next_batch = self._send(it, on) # queue up the next batch
260
+ for op in _in_flight[1]: # wait for the current batch
261
+ op.wait()
262
+ if _in_flight["input_ids"][0] != -10: # indicates no more data
263
+ return
264
+ yield _in_flight[0]
265
+ _in_flight = next_batch
266
+
modeling_molmoe.py CHANGED
@@ -39,14 +39,14 @@ import einops
39
  from transformers import PreTrainedModel
40
  from transformers.modeling_outputs import CausalLMOutputWithPast
41
 
42
- from olmo.aliases import PathOrStr
43
- from olmo.beam_search import (
44
  BeamSearch,
45
  Constraint,
46
  FinalSequenceScorer,
47
  Sampler
48
  )
49
- from olmo.config import (
50
  ActivationType,
51
  BlockType,
52
  LayerNormType,
@@ -56,7 +56,7 @@ from olmo.config import (
56
  AttentionType,
57
  )
58
 
59
- from olmo.util import resource_path
60
  from .config_molmoe import (
61
  MolmoConfig,
62
  VisionBackboneConfig
 
39
  from transformers import PreTrainedModel
40
  from transformers.modeling_outputs import CausalLMOutputWithPast
41
 
42
+ from .aliases import PathOrStr
43
+ from .beam_search import (
44
  BeamSearch,
45
  Constraint,
46
  FinalSequenceScorer,
47
  Sampler
48
  )
49
+ from .config import (
50
  ActivationType,
51
  BlockType,
52
  LayerNormType,
 
56
  AttentionType,
57
  )
58
 
59
+ from .util import resource_path
60
  from .config_molmoe import (
61
  MolmoConfig,
62
  VisionBackboneConfig
multimodal_preprocessor.py ADDED
@@ -0,0 +1,1549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import re
4
+ from collections import defaultdict
5
+ from typing import Tuple, Optional, Any, Dict, List, Union, Mapping
6
+
7
+ import einops
8
+ import seqio
9
+ import numpy as np
10
+ import tensorflow as tf
11
+
12
+ from .mm_data import seqio_tokenizer
13
+ from .data_utils import pad_to_bounding_box, \
14
+ get_3d_subsegments, _append_to_innermost_axis, resize_and_pad, \
15
+ apply_with_random_selector, get_special_token_ids, make_autoregressive_inputs, \
16
+ trim_and_pad_dataset, assert_not_truncated
17
+ from .prompts import apply_keyword_prompt, STYLE_TO_GENERAL_PROMPT, GENERAL_PROMPTS_V1
18
+ import .constants as config
19
+
20
+
21
+ def siglip_resize(src, imgsize, truncate):
22
+ """Resize and preprocess for SigLIP ViT in the offical jax implementation"""
23
+ assert src.dtype == tf.uint8
24
+ # SigCLIP removes aspect ratio by default
25
+ resized = tf.image.resize(src, imgsize, method=tf.image.ResizeMethod.BILINEAR, antialias=False)
26
+ dtype = src.dtype
27
+ tf_dtype = tf.type_spec_from_value(src).dtype
28
+ resized = tf.cast(tf.clip_by_value(resized, tf_dtype.min, tf_dtype.max), dtype)
29
+
30
+ # Normalize between -1 and 1 without using imagenet standard mean/std
31
+ vmin=-1; vmax=1; in_min=0; in_max=255.0
32
+ in_min_t = tf.constant(in_min, tf.float32)
33
+ in_max_t = tf.constant(in_max, tf.float32)
34
+ image = tf.cast(resized, tf.float32)
35
+ image = (image - in_min_t) / (in_max_t - in_min_t)
36
+ image = vmin + image * (vmax - vmin)
37
+ if truncate:
38
+ image = image[:truncate, :truncate]
39
+ return image
40
+
41
+
42
+ def extract_bboxes(text, image_w, image_h):
43
+ points = extract_points(text, image_w, image_h)
44
+ boxes = []
45
+ for i in range(len(points)//2):
46
+ x1, y1 = points[i*2]
47
+ x2, y2 = points[i*2 + 1]
48
+ boxes.append([x1, y1, x2, y2])
49
+ return boxes
50
+
51
+
52
+ def extract_annotated_points(caption, image_w, image_h):
53
+ points = []
54
+ for match in re.finditer("<point x=\"([0-9\\.]*)\" y=\"([0-9\\.]*)\" alt=\"([^\"]*)\">", caption):
55
+ x = float(match.group(1))
56
+ y = float(match.group(2))
57
+ points.append(([[x, y]], match.group(3)))
58
+ for match in re.finditer("<points ([^<]*) alt=\"([^\"]*)\">", caption):
59
+ loc_str = match.group(1)
60
+ locations = defaultdict(dict)
61
+ if loc_str.startswith("points="):
62
+ point_grp = []
63
+ for point_match in re.finditer(r"([0-9]+\.[0-9]),? ([0-9]+\.[0-9])", loc_str):
64
+ try:
65
+ point = [float(point_match.group(i)) for i in range(1, 3)]
66
+ point_grp.append(point)
67
+ except ValueError:
68
+ pass
69
+ else:
70
+ for val in loc_str.split():
71
+ try:
72
+ key, val = val.split("=")
73
+ locations[key[1:]][key[:1]] = float(val.strip("\""))
74
+ except ValueError:
75
+ import pdb; pdb.set_trace()
76
+ logging.warning(f"Failed to parse {val} from {match.group(0)}")
77
+ point_grp = []
78
+ for key, coords in locations.items():
79
+ if sorted(coords) == ["x", "y"]:
80
+ point_grp.append([coords["x"], coords["y"]])
81
+ if point_grp:
82
+ points.append((point_grp, match.group(2)))
83
+
84
+ normalized = []
85
+ for point_grp, point_text in points:
86
+ normalized.append((
87
+ np.array(point_grp) / 100.0 * np.array([image_w, image_h]),
88
+ point_text,
89
+ ))
90
+ return normalized
91
+
92
+
93
+ def extract_points(text, image_w, image_h):
94
+ all_points = []
95
+ for match in re.finditer(r"Click\(([0-9]+\.[0-9]), ?([0-9]+\.[0-9])\)", text):
96
+ try:
97
+ point = [float(match.group(i)) for i in range(1, 3)]
98
+ except ValueError:
99
+ pass
100
+ else:
101
+ point = np.array(point)
102
+ if np.max(point) > 100:
103
+ # Treat as an invalid output
104
+ continue
105
+ point /= 100.0
106
+ point = point * np.array([image_w, image_h])
107
+ all_points.append(point)
108
+
109
+ for match in re.finditer(r"\(([0-9]+\.[0-9]),? ?([0-9]+\.[0-9])\)", text):
110
+ try:
111
+ point = [float(match.group(i)) for i in range(1, 3)]
112
+ except ValueError:
113
+ pass
114
+ else:
115
+ point = np.array(point)
116
+ if np.max(point) > 100:
117
+ # Treat as an invalid output
118
+ continue
119
+ point /= 100.0
120
+ point = point * np.array([image_w, image_h])
121
+ all_points.append(point)
122
+ for match in re.finditer(r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"', text):
123
+ try:
124
+ point = [float(match.group(i)) for i in range(1, 3)]
125
+ except ValueError:
126
+ pass
127
+ else:
128
+ point = np.array(point)
129
+ if np.max(point) > 100:
130
+ # Treat as an invalid output
131
+ continue
132
+ point /= 100.0
133
+ point = point * np.array([image_w, image_h])
134
+ all_points.append(point)
135
+ for match in re.finditer(r'(?:\d+|p)\s*=\s*([0-9]{3})\s*,\s*([0-9]{3})', text):
136
+ try:
137
+ point = [int(match.group(i)) / 10.0 for i in range(1, 3)]
138
+ except ValueError:
139
+ pass
140
+ else:
141
+ point = np.array(point)
142
+ if np.max(point) > 100:
143
+ # Treat as an invalid output
144
+ continue
145
+ point /= 100.0
146
+ point = point * np.array([image_w, image_h])
147
+ all_points.append(point)
148
+ return all_points
149
+
150
+
151
+ def extract_points_from_point_count(text, image_w, image_h):
152
+ all_points = []
153
+ points = re.findall(r"(\d+\.\d+),\s*(\d+\.\d+)", text)
154
+
155
+ for match in points:
156
+ try:
157
+ point = [float(match[0]), float(match[1])]
158
+ except ValueError:
159
+ pass
160
+ else:
161
+ point = np.array(point)
162
+ if np.max(point) > 100:
163
+ # Treat as an invalid output
164
+ continue
165
+ point = point * np.array([image_w, image_h])
166
+ all_points.append(point)
167
+ return all_points
168
+
169
+
170
+ def select_tiling(h, w, patch_size, max_num_patches):
171
+ """Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
172
+ original_size = tf.stack([h, w]) # [1, 2]
173
+ original_res = h * w
174
+ tilings = []
175
+ for i in range(1, max_num_patches+1):
176
+ for j in range(1, max_num_patches+1):
177
+ if i*j <= max_num_patches:
178
+ tilings.append((i, j))
179
+ # sort so argmin and argmax favour smaller tilings in the event of a tie
180
+ tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
181
+ candidate_tilings = tf.constant(tilings, dtype=tf.int32) # [n_resolutions, 2]
182
+ candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
183
+
184
+ # How much we would need to scale the image to fit exactly in each tiling
185
+ required_scale_d = tf.cast(candidate_resolutions, tf.float32) / tf.cast(original_size[None, :], tf.float32)
186
+ required_scale = tf.reduce_min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
187
+ if tf.reduce_all(required_scale < 1):
188
+ # We are forced to downscale, so try to minimize the amount of downscaling
189
+ ix = tf.argmax(required_scale)[0]
190
+ else:
191
+ # Pick the resolution that required the least upscaling so that it most closely fits the image
192
+ required_scale = tf.where(required_scale < 1.0, 10e9, required_scale)
193
+ ix = tf.argmin(required_scale)[0]
194
+ return candidate_tilings[ix]
195
+
196
+
197
+ DEMO_STYLES = [
198
+ "point_count",
199
+ "pointing",
200
+ "user_qa",
201
+ "scifi_charts_exp",
202
+ "scifi_charts_exp",
203
+ "scifi_charts_exp",
204
+ "scifi_charts_exp",
205
+ "long_caption",
206
+ "named_entity"
207
+ ]
208
+
209
+
210
+ @dataclasses.dataclass
211
+ class MultiModalPreprocessor:
212
+ """Turns text/image inputs into tensors that can be input to the model"""
213
+ tokenizer: Any
214
+
215
+ # How to prompt the model
216
+ prompt_templates: str = "none" # How to template prompts for examples
217
+ message_format: str = "none" # How to format messages
218
+ system_prompt: Optional[str] = None # How to generate system prompts
219
+ prompt_override: Optional[str] = None # Used for setting prompt manually
220
+ always_start_with_space: bool = False # Always include a leading space for the first bit of text
221
+ default_inference_len: int = 65 # Inference len for length-conditioned prompting
222
+
223
+ # How to crops/resize images
224
+ crop_mode: str = "resize"
225
+ max_crops: int = 6
226
+ overlap_margins: Tuple[int, int] = (4, 4)
227
+ do_random_scale: Optional[bool] = False
228
+ resize: str = "default"
229
+ random_scale_max: float = 1.1
230
+ random_scale_min: float = 0.9
231
+ random_scale_ratio: float = 0.5
232
+ use_col_tokens: bool = True
233
+
234
+ # Data about the ViT and connector we need when deciding the crops
235
+ base_image_input_size: Tuple[int, int] = (336, 336)
236
+ image_token_length_w: int = 12
237
+ image_token_length_h: int = 12
238
+ image_patch_size: int = 14
239
+ image_padding_mask: bool = False
240
+
241
+ # Other settings
242
+ loss_token_weighting: Optional[str] = None
243
+ unconditioned: Union[bool, float] = False # Ignore images
244
+ fix_image_input_idx: int = 2 # backwards compatibility fix
245
+ pad_to: Optional[int] = None # experimental feature
246
+
247
+ _special_tokens: Dict[str, int] = None
248
+ split_at: Optional[int] = None
249
+
250
+ def get_max_total_crops(self):
251
+ if self.crop_mode == "resize":
252
+ return 1
253
+ elif "resize" in self.crop_mode:
254
+ return 1 + self.max_crops
255
+ else:
256
+ return self.max_crops
257
+
258
+ @property
259
+ def image_num_patch(self):
260
+ h, w = self.base_image_input_size
261
+ return h//self.image_patch_size, w//self.image_patch_size
262
+
263
+ @property
264
+ def special_token_ids(self):
265
+ if self._special_tokens is None:
266
+ self._special_tokens = get_special_token_ids(self.tokenizer)
267
+ return self._special_tokens
268
+
269
+ def image_to_patches_and_tokens(self, image, is_training):
270
+ """Preprocesses an image
271
+
272
+ Args:
273
+ image: [h, w, 3] image to preprocessing
274
+ Returns:
275
+ crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
276
+ change between images but the other dimension are fixed
277
+ tokens: (n_tokens,) tf.int32 tokens, pad tokens indicate where to insert the
278
+ patch features, might include other special tokens as well
279
+ patch_ordering: (n_crops, n_tokens_per_crop) order image features should be inserted
280
+ into the `tokens`, negative values indicates patches features to exclude
281
+ padding_mask: (n_crops, h, w) mask of what pixels are padding, can be None
282
+ """
283
+ do_random_scale = self.do_random_scale
284
+ if do_random_scale:
285
+ do_random_scale = is_training
286
+
287
+ base_image_input_size = self.base_image_input_size
288
+ if isinstance(base_image_input_size, int):
289
+ base_image_input_size = (base_image_input_size, base_image_input_size)
290
+
291
+ image_token_length_w, image_token_length_h = self.image_token_length_w, self.image_token_length_h
292
+ base_image_input_d = self.image_patch_size
293
+ tokens_per_image = image_token_length_w * image_token_length_h
294
+ image_base_patch_w = base_image_input_size[1] // base_image_input_d
295
+ image_base_patch_h = base_image_input_size[0] // base_image_input_d
296
+ extra_image = False
297
+ patch_ordering = None
298
+
299
+ if self.resize == "default":
300
+ image = tf.image.convert_image_dtype(image, dtype=tf.float32)
301
+ def _resize(_image, sz):
302
+ return resize_and_pad(
303
+ _image, sz,
304
+ do_random_scale=do_random_scale,
305
+ random_scale_max=self.random_scale_max,
306
+ random_scale_min=self.random_scale_min,
307
+ random_scale_ratio=self.random_scale_ratio,
308
+ return_outputs=False,
309
+ resize_method='random' if is_training else tf.image.ResizeMethod.BILINEAR)
310
+ elif self.resize == "stretch":
311
+ image = tf.image.convert_image_dtype(image, dtype=tf.float32)
312
+ assert not do_random_scale
313
+
314
+ def _resize(_image, sz):
315
+ if not is_training:
316
+ img = tf.image.resize(_image, sz, antialias=True, method=tf.image.ResizeMethod.BILINEAR)
317
+ else:
318
+ resize_methods = sorted([k for k in tf.image.ResizeMethod.__dict__.keys() if k.isupper()])
319
+ img = apply_with_random_selector(
320
+ _image,
321
+ lambda x, method_idx: tf.image.resize(x, sz,
322
+ tf.image.ResizeMethod.__dict__[resize_methods[method_idx]],
323
+ antialias=True),
324
+ num_cases=len(resize_methods))
325
+ return img, tf.ones(tf.shape(img)[:2], dtype=tf.bool)
326
+ elif self.resize in "siglip":
327
+ assert not do_random_scale
328
+
329
+ def _resize(_image, sz):
330
+ img = siglip_resize(_image, sz, truncate=None)
331
+ return img, tf.ones(tf.shape(img)[:2], dtype=tf.bool)
332
+ else:
333
+ raise NotImplementedError(self.resize)
334
+
335
+ def _img_to_patches(_img, _img_mask, dy=1, dx=1):
336
+ _img = einops.rearrange(
337
+ _img, '(dy h dh) (dx w dw) c -> (dy dx) (h w) (dh dw c)',
338
+ dh=base_image_input_d,
339
+ dw=base_image_input_d,
340
+ dy=dy,
341
+ dx=dx,
342
+ h=image_base_patch_h,
343
+ w=image_base_patch_w
344
+ )
345
+ _img_mask = einops.rearrange(
346
+ _img_mask, '(dy h dh) (dx w dw) -> (dy dx) (h w) (dh dw)',
347
+ dh=base_image_input_d,
348
+ dw=base_image_input_d,
349
+ dy=dy,
350
+ dx=dx,
351
+ h=image_base_patch_h,
352
+ w=image_base_patch_w
353
+ )
354
+ return _img, tf.reduce_mean(tf.cast(_img_mask, tf.float32), -1)
355
+
356
+ mode = self.crop_mode
357
+ if mode == "resize":
358
+ patches, img_mask = _resize(image, base_image_input_size)
359
+ patches, img_mask = _img_to_patches(patches, img_mask)
360
+ image_layout_impatch_w = 1
361
+ image_layout_impatch_h = 1
362
+ patch_ordering = tf.range(tokens_per_image)[None, :]
363
+
364
+ elif mode in ["overlap", "overlap-and-resize-c2"]:
365
+ original_image_h = tf.shape(image, out_type=tf.int32)[0]
366
+ original_image_w = tf.shape(image, out_type=tf.int32)[1]
367
+ crop_size = base_image_input_size[0]
368
+
369
+ # Discard this many patches from the (left/top, right/bottom) of crops
370
+ left_margin, right_margin = self.overlap_margins
371
+ # left_margin, right_margin = 2, 2
372
+ assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
373
+ total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
374
+ crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
375
+ crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
376
+ crop_window_size = crop_window_patches * base_image_input_d
377
+ tiling = select_tiling(original_image_h - total_margin_pixels, original_image_w - total_margin_pixels,
378
+ crop_window_size, self.max_crops)
379
+ src, img_mask = _resize(
380
+ image, [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels])
381
+
382
+ n_crops = tiling[0]*tiling[1]
383
+ patches_arr = tf.TensorArray(
384
+ tf.float32, n_crops, element_shape=[crop_size, crop_size, 3])
385
+ mask_arr = tf.TensorArray(
386
+ tf.bool, n_crops, element_shape=[crop_size, crop_size])
387
+ # We assume 2x2 pooling, but can allow padding the right/bottom with extra
388
+ # patches if the number of patches per side is not even
389
+ assert (crop_patches+1)//2 == image_token_length_h
390
+ assert (crop_patches+1)//2 == image_token_length_w
391
+ patch_ordering_arr = tf.TensorArray(
392
+ tf.int32, n_crops, element_shape=[image_token_length_h, image_token_length_w])
393
+ on = 0
394
+ on_patch = 0
395
+ for i in range(tiling[0]):
396
+ y0 = i*crop_window_size
397
+ if i == 0:
398
+ crop_y0 = 0
399
+ else:
400
+ crop_y0 = left_margin // 2
401
+
402
+ crop_h = image_base_patch_h - (right_margin + left_margin)
403
+ if i == 0:
404
+ crop_h += left_margin
405
+ if i == (tiling[0]-1):
406
+ crop_h += right_margin
407
+ for j in range(tiling[1]):
408
+ x0 = j*crop_window_size
409
+ if j == 0:
410
+ crop_x0 = 0
411
+ else:
412
+ crop_x0 = left_margin // 2
413
+
414
+ crop_w = image_base_patch_w - (right_margin + left_margin)
415
+ if j == 0:
416
+ crop_w += left_margin
417
+ if j == (tiling[1]-1):
418
+ crop_w += right_margin
419
+
420
+ pooled_w = (crop_w + 1) // 2
421
+ pooled_h = (crop_h + 1) // 2
422
+ patch_ordering_arr = patch_ordering_arr.write(
423
+ on_patch,
424
+ pad_to_bounding_box(
425
+ tf.reshape(tf.range(on, on+pooled_h*pooled_w, dtype=tf.int32), (pooled_h, pooled_w, 1)),
426
+ crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
427
+ )[:, :, 0]
428
+ )
429
+ patches_arr = patches_arr.write(on_patch, src[y0:y0+crop_size, x0:x0+crop_size])
430
+ mask_arr = mask_arr.write(on_patch, img_mask[y0:y0+crop_size, x0:x0+crop_size])
431
+
432
+ on += pooled_h*pooled_w
433
+ on_patch += 1
434
+ patches = patches_arr.stack()
435
+ patch_ordering = patch_ordering_arr.stack()
436
+ img_mask = mask_arr.stack()
437
+
438
+ image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
439
+ patches = einops.rearrange(
440
+ patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
441
+ dh=base_image_input_d,
442
+ dw=base_image_input_d,
443
+ h=image_base_patch_h,
444
+ w=image_base_patch_w
445
+ )
446
+ img_mask = einops.rearrange(
447
+ img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
448
+ dh=base_image_input_d,
449
+ dw=base_image_input_d,
450
+ h=image_base_patch_h,
451
+ w=image_base_patch_w
452
+ )
453
+ img_mask = tf.reduce_mean(tf.cast(img_mask, tf.float32), -1)
454
+ patch_ordering = tf.reshape(patch_ordering, [-1])
455
+ valid = patch_ordering >= 0
456
+
457
+ # Transpose, to get left-to-right order
458
+ patch_ordering_rh = tf.reshape(patch_ordering,
459
+ [tiling[0], tiling[1], image_token_length_h, image_token_length_w])
460
+ patch_ordering_rh = tf.transpose(patch_ordering_rh, [0, 2, 1, 3])
461
+ patch_ordering_rh = tf.reshape(patch_ordering_rh, [-1])
462
+
463
+ # The tranpose will screw up which patches are masked, project the
464
+ # new order into sparse structure of `patch_ordering` to fix this
465
+ patch_ordering = tf.tensor_scatter_nd_update(
466
+ patch_ordering,
467
+ tf.where(valid),
468
+ tf.boolean_mask(patch_ordering_rh, patch_ordering_rh >= 0),
469
+ name="patch_order_transpose_Scatter"
470
+ )
471
+
472
+ h = tiling[0]*crop_window_patches + (right_margin+left_margin)
473
+ w = tiling[1]*crop_window_patches + (right_margin+left_margin)
474
+ special_token_ids = self.special_token_ids
475
+ per_row = tf.fill(((w+1)//2,),
476
+ special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
477
+ if self.use_col_tokens:
478
+ per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
479
+
480
+ joint = tf.tile(per_row, [(h+1)//2])
481
+ joint = [
482
+ [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
483
+ joint,
484
+ [special_token_ids[config.DEFAULT_IM_END_TOKEN]]
485
+ ]
486
+
487
+ if "resize" in mode:
488
+ resized, resized_mask = _resize(image, base_image_input_size)
489
+ resized, resized_mask = _img_to_patches(resized, resized_mask)
490
+ if 'c2' in mode:
491
+ patches = tf.concat([resized, patches], 0)
492
+ image_mask = tf.concat([resized_mask, img_mask], 0)
493
+ else:
494
+ patches = tf.concat([patches, resized], 0)
495
+ image_mask = tf.concat([img_mask, resized_mask], 0)
496
+
497
+ if patch_ordering is not None:
498
+ if 'c2' in mode:
499
+ patch_ordering = tf.where(
500
+ patch_ordering >= 0,
501
+ patch_ordering + tokens_per_image,
502
+ -1
503
+ )
504
+ patch_ordering = tf.concat([tf.range(0, tokens_per_image), patch_ordering], 0)
505
+ else:
506
+ raise ValueError()
507
+ per_row = tf.fill((image_token_length_w,), special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
508
+ if self.use_col_tokens:
509
+ per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
510
+ extra_tokens = tf.tile(per_row, [image_token_length_h])
511
+ joint = [
512
+ [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
513
+ extra_tokens,
514
+ [special_token_ids[config.DEFAULT_IM_END_TOKEN]],
515
+ ] + joint
516
+
517
+ joint = tf.concat(joint, 0)
518
+ return patches, joint, patch_ordering, img_mask
519
+
520
+ elif mode in ["patchify", "patchify-and-resize", "patchify-v2", "patchify-v2-and-resize", "patchify-v2-and-resize-c2"]:
521
+ original_image_w = tf.shape(image, out_type=tf.int32)[0]
522
+ original_image_h = tf.shape(image, out_type=tf.int32)[1]
523
+ assert base_image_input_size[0] == base_image_input_size[1]
524
+ base_patch_size = base_image_input_size[0]
525
+ tiling = select_tiling(original_image_w, original_image_h, base_patch_size, self.max_crops)
526
+
527
+ patches, img_mask = _resize(
528
+ image, [tiling[0]*base_patch_size, tiling[1]*base_patch_size])
529
+ patches, img_mask = _img_to_patches(patches, img_mask, tiling[0], tiling[1])
530
+ if 'v2' in mode:
531
+ # Order patches left-to-right not crop-by-crop
532
+ patch_ordering = tf.reshape(
533
+ tf.range(tokens_per_image*tiling[0]*tiling[1]),
534
+ [tiling[0], tiling[1], image_token_length_w, image_token_length_h])
535
+ patch_ordering = tf.transpose(patch_ordering, [0, 2, 1, 3])
536
+ patch_ordering = tf.reshape(patch_ordering, (-1, tokens_per_image))
537
+ else:
538
+ patch_ordering = None
539
+
540
+ # given image size, determine the number of patch size.
541
+ image_layout_impatch_w = tiling[0]
542
+ image_layout_impatch_h = tiling[1]
543
+
544
+ if "resize" in mode:
545
+ extra_image = True
546
+ resized, resized_mask = _resize(image, base_image_input_size)
547
+ resized, resized_mask = _img_to_patches(resized, resized_mask)
548
+ if 'c2' in mode:
549
+ patches = tf.concat([resized, patches], 0)
550
+ image_mask = tf.concat([resized_mask, img_mask], 0)
551
+ else:
552
+ patches = tf.concat([patches, resized], 0)
553
+ image_mask = tf.concat([img_mask, resized_mask], 0)
554
+
555
+ if patch_ordering is not None:
556
+ if 'c2' in mode:
557
+ patch_ordering = tf.concat(
558
+ [tf.range(0, tokens_per_image)[None, :], patch_ordering+tokens_per_image], 0)
559
+ else:
560
+ n = tf.shape(patch_ordering)[0]
561
+ patch_ordering = tf.concat(patch_ordering, [tf.range(n, n+tokens_per_image)[None, :]], 0)
562
+ else:
563
+ raise NotImplementedError(mode)
564
+
565
+ special_token_ids = self.special_token_ids
566
+
567
+ per_row = tf.fill((image_token_length_w*image_layout_impatch_w,),
568
+ special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
569
+ if self.use_col_tokens:
570
+ per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
571
+
572
+ joint = tf.tile(per_row, [image_token_length_h * image_layout_impatch_h])
573
+ joint = [
574
+ [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
575
+ joint,
576
+ [special_token_ids[config.DEFAULT_IM_END_TOKEN]]
577
+ ]
578
+ if extra_image:
579
+ assert not self.image_padding_mask
580
+ per_row = tf.fill((image_token_length_w,), special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
581
+ if self.use_col_tokens:
582
+ per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
583
+ extra_tokens = tf.tile(per_row, [image_token_length_h])
584
+ if 'c2' in mode:
585
+ joint = [
586
+ [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
587
+ extra_tokens,
588
+ [special_token_ids[config.DEFAULT_IM_END_TOKEN]],
589
+ ] + joint
590
+ else:
591
+ joint += [
592
+ [special_token_ids[config.DEFAULT_IM_START_TOKEN]],
593
+ extra_tokens,
594
+ [special_token_ids[config.DEFAULT_IM_END_TOKEN]]
595
+ ]
596
+ if self.pad_to is not None:
597
+ n = [tf.shape(x)[0] for x in joint]
598
+ assert len(joint[-1]) == 1
599
+ to_pad = self.pad_to - tf.reduce_sum(tf.stack(n))
600
+ joint = tf.concat(joint[:-1] + [
601
+ tf.zeros(to_pad, dtype=tf.int32) - 1,
602
+ joint[-1]
603
+ ], axis=0)
604
+ else:
605
+ joint = tf.concat(joint, 0)
606
+ return patches, tf.concat(joint, 0), patch_ordering, img_mask
607
+
608
+ def build_image_input_idx(self, input_tokens, patch_order, no_image=None):
609
+ """Builds the index used to insert patch features into `input_tokens`"""
610
+ tokens_per_image = self.image_token_length_w * self.image_token_length_h
611
+ if no_image is not None and no_image:
612
+ return tf.zeros((0, tokens_per_image), tf.int32)
613
+
614
+ image_input_idx = input_tokens == self.special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN]
615
+ image_input_idx = tf.experimental.numpy.nonzero(image_input_idx)[0]
616
+ image_input_idx = tf.cast(image_input_idx, tf.int32)
617
+
618
+ if patch_order is not None:
619
+ n_tokens = tf.shape(image_input_idx)[0]
620
+ # Item N should have the value of image_input_index[where(patch_order == n)] if >= 0 else -1
621
+ patch_order = tf.reshape(patch_order, [-1])
622
+ n_patches = tf.shape(patch_order)[0]
623
+ if n_tokens != n_patches:
624
+ # Most complex case where some patches are dropped
625
+ # First invert the valid tokens
626
+ valid = patch_order >= 0
627
+ sorted_patch_ixs = tf.scatter_nd(
628
+ tf.boolean_mask(patch_order, valid)[:, None],
629
+ tf.range(tf.reduce_sum(tf.cast(valid, tf.int32)), dtype=tf.int32),
630
+ [n_tokens],
631
+ name="valid_order_scatter"
632
+ )
633
+
634
+ # Project the inverted mapping into same sparse structure
635
+ tmp = tf.fill(tf.shape(patch_order), -1)
636
+ sorted_patch_ixs_ex = tf.tensor_scatter_nd_update(
637
+ tmp,
638
+ tf.where(valid),
639
+ sorted_patch_ixs,
640
+ name="order_with_padding_scatter"
641
+ )
642
+
643
+ # Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
644
+ valid = tf.cast(sorted_patch_ixs_ex >= 0, tf.int32)
645
+ image_input_idx = tf.gather(image_input_idx, sorted_patch_ixs_ex*valid)
646
+ image_input_idx = image_input_idx*valid - 100*(1 - valid)
647
+ else:
648
+ sorted_patch_ixs = tf.scatter_nd(patch_order[:, None], tf.range(n_patches), [n_patches])
649
+ image_input_idx = tf.gather(tf.reshape(image_input_idx, [-1]), sorted_patch_ixs)
650
+ image_input_idx = tf.reshape(image_input_idx, [-1, tokens_per_image])
651
+ return image_input_idx
652
+
653
+ def build_multimodel_features(self, tokens, mask, subsegments, images, is_training):
654
+ """Builds input features by pre-processing `images` and modifying `tokens`
655
+ to include image col/pad/start/end tokens instead image placeholder tokens
656
+ """
657
+ image_token_id = self.special_token_ids[config.IMAGE_PROMPT]
658
+ image_idx = tf.experimental.numpy.nonzero(tokens == image_token_id)[0]
659
+ if images is None or tf.shape(images)[0] == 0:
660
+ tf.debugging.assert_equal(image_idx, tf.cast(0, tf.int64),
661
+ "Image placeholders in input, but no images given!")
662
+ tokens_per_image = self.image_token_length_w * self.image_token_length_h
663
+ n_pixels = self.image_patch_size ** 2 * 3
664
+ image_num_patch = np.prod(self.image_num_patch)
665
+ crops = tf.zeros((0, image_num_patch, n_pixels), dtype=tf.float32)
666
+ image_idx = tf.zeros((0, tokens_per_image), tf.int32)
667
+ out = dict(
668
+ target_tokens=tokens,
669
+ images=crops,
670
+ image_input_idx=image_idx,
671
+ loss_masks=mask
672
+ )
673
+ if self.image_padding_mask:
674
+ out["image_masks"] = tf.zeros((0, image_num_patch), dtype=tf.float32)
675
+ if subsegments is not None:
676
+ out["subsegment_ids"] = subsegments
677
+ return out
678
+ elif tf.shape(image_idx)[0] == 0 and tf.shape(images)[0] > 0:
679
+ # As a special case, no image prompt means the images are all at the start
680
+ image_idx = tf.zeros([tf.shape(images)[0]], tf.int64) - 1
681
+ else:
682
+ tf.debugging.assert_equal(
683
+ tf.shape(images)[0], tf.shape(image_idx)[0],
684
+ message="Different number of images and image placeholders")
685
+
686
+ # Each image will produce a variable number of crops/tokens, so we aggregate things
687
+ # the results tensor arrays and the concat them
688
+ tokens_per_image = self.image_token_length_w * self.image_token_length_h
689
+ n_pixels = self.image_patch_size*self.image_patch_size*3
690
+ n_patches = self.image_num_patch[0]*self.image_num_patch[1]
691
+
692
+ n = tf.shape(images)[0]
693
+ all_crops = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
694
+ element_shape=[None, n_patches, n_pixels])
695
+ all_image_idx = tf.TensorArray(dtype=tf.int32, size=n, infer_shape=False,
696
+ element_shape=[None, tokens_per_image])
697
+ out_tokens = tf.TensorArray(dtype=tf.int32, size=n, infer_shape=False,
698
+ element_shape=[None])
699
+ out_masks = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
700
+ element_shape=[None])
701
+ if self.image_padding_mask:
702
+ all_crop_masks = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
703
+ element_shape=[None, None])
704
+ else:
705
+ # Dummy array to keep tensorflow's control analysis happy
706
+ all_crop_masks = tf.TensorArray(dtype=tf.float32, size=0, infer_shape=False,
707
+ element_shape=[None, None])
708
+ if subsegments is not None:
709
+ out_subsegments = tf.TensorArray(dtype=tf.int32, size=n, element_shape=[None])
710
+ else:
711
+ out_subsegments = tf.TensorArray(dtype=tf.int32, size=0, element_shape=[None])
712
+
713
+ image_idx = tf.cast(image_idx, tf.int32)
714
+ for ix in range(tf.shape(image_idx)[0]):
715
+ token_ix = image_idx[ix]
716
+ crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(images[ix], is_training)
717
+ patch_idx = self.build_image_input_idx(image_tokens, patch_ordering)
718
+
719
+ if token_ix == -1: # -1 is an image inserted at the very start
720
+ start = 0
721
+ token_ix = 0
722
+ end = 0
723
+ else:
724
+ start = 0 if ix == 0 else image_idx[ix-1] + 1
725
+ end = token_ix + 1
726
+
727
+ all_image_idx = all_image_idx.write(ix, patch_idx + token_ix)
728
+ all_crops = all_crops.write(ix, crops)
729
+ image_token_mask = tf.zeros_like(image_tokens, dtype=tf.float32)
730
+
731
+ if ix == (tf.shape(images)[0] - 1):
732
+ tokens_part = tf.concat([tokens[start:token_ix], image_tokens, tokens[end:]], 0)
733
+ mask_part = tf.concat([mask[start:token_ix], image_token_mask, mask[end:]], 0)
734
+ else:
735
+ tokens_part = tf.concat([tokens[start:token_ix], image_tokens], 0)
736
+ mask_part = tf.concat([mask[start:token_ix], image_token_mask], 0)
737
+
738
+ out_tokens = out_tokens.write(ix, tokens_part)
739
+ out_masks = out_masks.write(ix, mask_part)
740
+ if self.image_padding_mask:
741
+ all_crop_masks = all_crop_masks.write(ix, img_mask)
742
+ if subsegments is not None:
743
+ parts = tf.fill([tf.shape(image_tokens)[0]], subsegments[token_ix])
744
+ if ix == (tf.shape(images)[0] - 1):
745
+ seg = tf.concat([subsegments[start:token_ix], parts, subsegments[end:]], 0)
746
+ else:
747
+ seg = tf.concat([subsegments[start:token_ix], parts], 0)
748
+ out_subsegments = out_subsegments.write(ix, seg)
749
+
750
+ out = dict(
751
+ target_tokens=out_tokens.concat(),
752
+ images=all_crops.concat(),
753
+ image_input_idx=all_image_idx.concat(),
754
+ loss_masks=out_masks.concat()
755
+ )
756
+ if self.image_padding_mask:
757
+ out["image_masks"] = all_crop_masks.concat()
758
+ if subsegments is not None:
759
+ out["subsegment_ids"] = out_subsegments.concat()
760
+ return out
761
+
762
+ def _format_message(self, args):
763
+ message, ix = args
764
+ return self.format_message(message, ix)
765
+
766
+ def format_message(self, message, ix):
767
+ """Applies system formatting to ith message from a sequence of messages"""
768
+ # If the image placeholder text is not preceded by space it will not get tokenized
769
+ # correctly by some tokenizers, so double check it here
770
+ assert config.IMAGE_PROMPT == "<|image|>"
771
+ tf.debugging.assert_equal(
772
+ tf.strings.regex_full_match(message, r".*[^ ]<\|image\|>.*"),
773
+ False,
774
+ message="Image token must always be preceded by a space"
775
+ )
776
+ is_user = ix % 2 == 0
777
+ if self.message_format == "none" or self.message_format is None:
778
+ pass
779
+ elif self.message_format == "role":
780
+ if is_user:
781
+ # We put the "System:" prefix here since it doesn't need a loss
782
+ message = tf.strings.join(["User: ", message, " Assistant:"])
783
+ elif self.message_format == "cleanup":
784
+ if is_user:
785
+ # We put the "System:" prefix here since it doesn't need a loss
786
+ message = tf.strings.join(
787
+ [
788
+ "[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
789
+ message,
790
+ "\n[[Assistant]]: {after}"
791
+ ]
792
+ )
793
+ elif self.message_format == "mistral":
794
+ if is_user:
795
+ message = tf.strings.join(["[INST] ", message, " [/INST]"])
796
+ else:
797
+ raise NotImplementedError(self.message_format)
798
+
799
+ # For now assume a space will be used to separate the messages
800
+ if not self.tokenizer.adds_space:
801
+ if ix != 0 or self.always_start_with_space:
802
+ message = tf.strings.join([" ", message])
803
+ # Else space added automatically by the tokenizer
804
+
805
+ return message
806
+
807
+ def get_multi_message_token_input(self, conversations, text_weights=None):
808
+ """Build inputs for a ragged tensor of conversations, where each row of the tensor,
809
+ is a different conversation"""
810
+ tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(
811
+ conversations.values, re.escape(config.IMAGE_PROMPT))), False, "Segmented prompts must start with the image")
812
+
813
+ n_conversation = tf.shape(conversations)[0]
814
+ ar = tf.TensorArray(dtype=tf.int32, infer_shape=False, element_shape=[None],
815
+ size=n_conversation)
816
+ n_messages_per_conversation = conversations.row_lengths()
817
+ for ix in range(n_conversation):
818
+ ar = ar.write(ix, tf.range(n_messages_per_conversation[ix], dtype=tf.int32))
819
+ message_ix = ar.concat()
820
+ messages = tf.map_fn(
821
+ self._format_message, elems=(conversations.values, message_ix), fn_output_signature=tf.string)
822
+ messages = self.tokenizer.encode_tf(messages)
823
+
824
+ # Append EOS
825
+ is_response = message_ix % 2 == 1
826
+ is_response_int = tf.cast(is_response, tf.int32)
827
+ eos = tf.RaggedTensor.from_row_lengths(
828
+ tf.fill([tf.reduce_sum(is_response_int)], self.tokenizer.eos_token_id),
829
+ tf.cast(is_response_int, messages.row_splits.dtype)
830
+ )
831
+ messages = tf.concat([messages, eos], axis=1)
832
+
833
+ # Build mask over system responses
834
+ mask = tf.ones_like(messages) * tf.cast(tf.expand_dims(is_response, axis=1), tf.int32)
835
+ decoder_loss_weights = tf.cast(mask.values, tf.float32)
836
+
837
+ # Build subsegment ids for each conversation
838
+ tokens_per_message = tf.RaggedTensor.from_row_splits(
839
+ row_splits=conversations.row_splits,
840
+ values=messages.row_lengths()
841
+ )
842
+ token_per_conversation = tf.reduce_sum(tokens_per_message, axis=1)
843
+ subsegment_ids = tf.repeat(tf.range(n_conversation, dtype=tf.int32)+1, token_per_conversation)
844
+
845
+ image_ix = self.special_token_ids[config.IMAGE_PROMPT]
846
+ messages = tf.concat([[image_ix], messages.values], axis=0)
847
+ decoder_loss_weights = tf.concat([[0], decoder_loss_weights], axis=0)
848
+ subsegment_ids = tf.concat([[10000], subsegment_ids], axis=0)
849
+ return messages, decoder_loss_weights, subsegment_ids
850
+
851
+ def get_multi_response_token_input(self, user_prompt, text, text_weights=None):
852
+ """Build tokens for a multi-response-per-image example"""
853
+ # FIXME this could be relaxed to just having the same prefix
854
+ tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(
855
+ user_prompt, re.escape(config.IMAGE_PROMPT))), False, "Segmented prompts must start with the image")
856
+ user_prompt = self.format_message(user_prompt, 0)
857
+ vocab = self.tokenizer
858
+ prompts = vocab.encode_tf(user_prompt)
859
+ response = self.format_message(text, 1)
860
+ responses = vocab.encode_tf(response)
861
+ responses = _append_to_innermost_axis(responses, vocab.eos_token_id)
862
+ response_mask = tf.ones_like(responses, dtype=tf.float32)
863
+ if text_weights is not None:
864
+ response_mask *= text_weights
865
+ image_tokens = tf.constant([self.special_token_ids[config.IMAGE_PROMPT]])
866
+
867
+ if len(responses.shape) == 3:
868
+ # Tricky case where we have multiple questions, each of which has multiple answers
869
+ assert len(prompts.shape) == 2
870
+
871
+ # Also shift the last tokens to the response segment since that tokens will
872
+ # have multiple possible target tokens to predict
873
+ last_prompt_tokens = prompts[:, -1:]
874
+ last_prompt_tokens = tf.repeat(last_prompt_tokens, responses.row_lengths())
875
+ last_prompt_tokens = tf.RaggedTensor.from_row_splits(
876
+ values=tf.RaggedTensor.from_row_lengths(
877
+ values=last_prompt_tokens,
878
+ row_lengths=tf.ones_like(last_prompt_tokens, dtype=responses.row_splits.dtype)
879
+ ),
880
+ row_splits=responses.row_splits
881
+ )
882
+ responses = tf.concat([last_prompt_tokens, responses], 2)
883
+ prompts = prompts[:, :-1]
884
+
885
+ shared_prefix = image_tokens
886
+ segmented_suffix = tf.concat([tf.expand_dims(prompts, 1), responses], 1)
887
+ targets = tf.concat([shared_prefix, segmented_suffix.values.values], 0)
888
+
889
+ segmented_mask = tf.concat([
890
+ tf.zeros_like(tf.expand_dims(prompts, 1), dtype=tf.float32),
891
+ tf.concat([
892
+ tf.zeros_like(last_prompt_tokens, dtype=tf.float32),
893
+ response_mask
894
+ ], 2)
895
+ ], 1).values.values
896
+ decoder_loss_weights = tf.concat(
897
+ [tf.zeros_like(shared_prefix, dtype=tf.float32), segmented_mask], 0)
898
+
899
+ text_segment_ids = get_3d_subsegments(segmented_suffix)
900
+ subsegment_ids = tf.concat([
901
+ tf.zeros_like(shared_prefix) + tf.reduce_max(text_segment_ids)+1,
902
+ text_segment_ids], 0)
903
+ subsegment_ids = tf.cast(subsegment_ids, tf.int32)
904
+ else:
905
+ if len(prompts.shape) == 1:
906
+ # One prompt for all responses, we use the last token of the prompt as the
907
+ # first token of each response segment since there will be multiple targets
908
+ # for that token, the remaining targets are part of the prefix
909
+ shared_prefix = tf.concat([image_tokens, prompts[:-1]], 0)
910
+ prompts = prompts[-1:]
911
+ prompts = tf.tile(tf.expand_dims(prompts, axis=0), [tf.shape(text)[0], 1])
912
+ else:
913
+ shared_prefix = image_tokens
914
+
915
+ # Separate prompt for each response
916
+ segmented_suffix = tf.concat([prompts, responses], 1)
917
+ segmented_mask = tf.concat([tf.zeros_like(prompts, dtype=tf.float32), response_mask], 1).values
918
+
919
+ targets = tf.concat([shared_prefix, segmented_suffix.values], 0)
920
+ decoder_loss_weights = tf.concat(
921
+ [tf.zeros_like(shared_prefix, dtype=tf.float32), segmented_mask], 0)
922
+ subsegments = tf.ragged.row_splits_to_segment_ids(segmented_suffix.row_splits) + 1
923
+ subsegment_ids = tf.concat([tf.zeros_like(shared_prefix)+10000,
924
+ tf.cast(subsegments, tf.int32)], 0)
925
+ return targets, decoder_loss_weights, subsegment_ids
926
+
927
+ def get_tokens_input(self, messages, for_inference=False, text_weights=None):
928
+ """Gets the token input for an example, using image placeholder tokens to
929
+ indicate where images features should be inserted
930
+
931
+ inputs
932
+ messages: List or tensor users/system text messages, can have image placeholder tokens
933
+ for_inference: bool, if true truncate the messages if it is a system message
934
+ text_weights: Weights per a system message
935
+
936
+ returns
937
+ tokens: [n_tokens] tf.int32 token inputs with image placeholder tokens
938
+ loss_mask: [n_tokens] tf.float32 token weights for loss
939
+ subsegment: [n_tokens] tf.int32 or None, subsegment ids used to build more complex
940
+ attention masks if needed
941
+ """
942
+ if isinstance(messages, tf.RaggedTensor):
943
+ assert not for_inference, "Cannot have multiple target messages for inference"
944
+ return self.get_multi_message_token_input(messages, text_weights)
945
+ elif len(tf.shape(messages[-1])) > 0:
946
+ assert not for_inference, "Cannot have multiple target messages for inference"
947
+ assert len(messages) == 2
948
+ prompt = messages[0]
949
+ response = messages[1]
950
+ return self.get_multi_response_token_input(prompt, response, text_weights)
951
+ else:
952
+ messages = tf.convert_to_tensor(messages)
953
+ if for_inference:
954
+ if tf.shape(messages) % 2 == 0:
955
+ # Remove the last message since the model should predict it
956
+ messages = messages[:-1]
957
+
958
+ # Apply system formatting
959
+ ix = tf.range(tf.shape(messages)[0])
960
+ is_response = ix % 2 == 1
961
+ messages = tf.map_fn(
962
+ self._format_message, elems=(messages, ix), fn_output_signature=tf.string)
963
+
964
+ # Tokenize
965
+ messages = self.tokenizer.encode_tf(messages)
966
+
967
+ # Add EOS to system messages
968
+ is_response_int = tf.cast(is_response, tf.int32)
969
+ eos = tf.RaggedTensor.from_row_lengths(
970
+ tf.fill([tf.reduce_sum(is_response_int)], self.tokenizer.eos_token_id),
971
+ tf.cast(is_response_int, messages.row_splits.dtype)
972
+ )
973
+ messages = tf.concat([messages, eos], axis=1)
974
+ targets = messages.values
975
+
976
+ # Build mask over system responses
977
+ mask = tf.ones_like(messages) * tf.cast(tf.expand_dims(is_response, axis=1), tf.int32)
978
+ decoder_loss_weights = tf.cast(mask.values, tf.float32)
979
+ if text_weights is not None:
980
+ decoder_loss_weights = decoder_loss_weights * text_weights
981
+ return messages.values, decoder_loss_weights, None
982
+
983
+ def preprocess(self, image, input_text, is_training=False,
984
+ seq_len=None, pad_images=1, style=None, for_inference=True):
985
+ """Get input tensors for the given image/text data
986
+
987
+ image: [h, w, 3] numpy uint8 array of image pixels
988
+ input_text: string input text, a list of text for a multi-turn conversation or dictionary
989
+ of inputs to use to build the prompt from a template
990
+ is_training: allow training-time preprocessing (e.g., image augmentation)
991
+ seq_len: pad input tokens to `seq_len`
992
+ pad_images: pad input images to `self.get_max_total_crops()`
993
+ style: Style to use for prompt templating
994
+ """
995
+ if image is not None and len(tf.shape(image)) == 3:
996
+ image = tf.expand_dims(image, axis=0)
997
+
998
+ messages = self.get_messages(input_text, style, is_training, for_inference=for_inference, user_prompt_seed=None, system_prompt_seed=None)
999
+ targets, loss_masks, subsegments = self.get_tokens_input(messages, for_inference=for_inference)
1000
+ batch = self.build_multimodel_features(
1001
+ targets, loss_masks, subsegments, image, is_training)
1002
+
1003
+ # Optionally padding to get constant sized arrays
1004
+ if pad_images:
1005
+ max_crops = self.get_max_total_crops() * pad_images
1006
+ image = batch["images"]
1007
+ n = max_crops - tf.shape(batch["images"])[0]
1008
+ batch["images"] = tf.pad(image, [[0, n], [0, 0], [0, 0]], constant_values=-1)
1009
+ if self.image_padding_mask:
1010
+ m = max_crops - tf.shape(batch["image_masks"])[0]
1011
+ batch["image_masks"] = tf.pad(batch["image_masks"], [[0, m], [0, 0]], constant_values=-1)
1012
+ batch["image_input_idx"] = tf.pad(batch["image_input_idx"], [[0, n], [0, 0]], constant_values=-1)
1013
+
1014
+ if seq_len is not None:
1015
+ targets = batch["target_tokens"]
1016
+ if seq_len < len(targets):
1017
+ raise ValueError("Sequence length too short")
1018
+ n = seq_len - len(targets)
1019
+ batch["target_tokens"] = tf.pad(targets, [[0, n]], constant_values=-1)
1020
+ batch["loss_masks"] = tf.pad(batch["loss_masks"], [[0, n]], constant_values=-1)
1021
+
1022
+ batch = self.get_post_mixing_preprocessor(pack=False)._convert_example(batch)
1023
+ return batch
1024
+
1025
+ def get_user_prompt(self, style, example, is_training=True, for_inference=False, seed=None):
1026
+ """Build a list of strings of what a user might type in to the model for the given example,
1027
+ and its responses, by applying a prompt template to the fields in `example`
1028
+
1029
+ Can return multiple strings for one message for multi-response examples
1030
+ """
1031
+ if "style" in example:
1032
+ style = example["style"]
1033
+
1034
+ if "prompt" in example:
1035
+ # Examples have a complete user prompt pre-specified, usually for eval sets
1036
+ prompt = example["prompt"]
1037
+
1038
+ elif self.prompt_templates == "none":
1039
+ # Bare-bone prompt with not templating of instructions
1040
+ if "prompt" in example:
1041
+ prompt = example["prompt"]
1042
+ elif "refexp" in example:
1043
+ prompt = example["refexp"]
1044
+ elif "question" in example and "options" in example:
1045
+ prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1046
+ elif "question" in example:
1047
+ prompt = example["question"]
1048
+ else:
1049
+ prompt = ""
1050
+
1051
+ elif self.prompt_templates == "uber_model":
1052
+ if not isinstance(style, str):
1053
+ tf.debugging.assert_equal(tf.logical_or(
1054
+ style == "ai2_diagram_no_letter",
1055
+ style == "ai2_diagram",
1056
+ ), True)
1057
+ prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1058
+ else:
1059
+ # We template long captions and pointing since they are "demo" tasks, and use
1060
+ # plain text for everything else
1061
+ if style == "long_caption":
1062
+ prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], example, seed)
1063
+ elif style == "pointing":
1064
+ prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["pointing"], example, seed)
1065
+ elif style == "point_count":
1066
+ prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["point_count"], example, seed)
1067
+ elif "prompt" in example:
1068
+ prompt = example["prompt"]
1069
+ elif "refexp" in example:
1070
+ prompt = example["refexp"]
1071
+ elif "question" in example and "options" in example:
1072
+ prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1073
+ elif "question" in example:
1074
+ prompt = example["question"]
1075
+ else:
1076
+ prompt = ""
1077
+
1078
+ elif self.prompt_templates == "uber_model_pointing":
1079
+ if style == "long_caption":
1080
+ long_captions = GENERAL_PROMPTS_V1["long_caption_no_pointing"]
1081
+ prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], example, seed)
1082
+ elif style == "pointing":
1083
+ prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["pointing"], example, seed)
1084
+ elif style in [
1085
+ "scifi_charts_explanation",
1086
+ "scifi_table_explanation",
1087
+ "scifi_document_explanation",
1088
+ "scifi_diagram_explanation",
1089
+ "user_qa",
1090
+ "long_caption",
1091
+ ]:
1092
+ raise NotImplementedError()
1093
+ if style == "long_caption":
1094
+ prompts = GENERAL_PROMPTS_V1["long_caption"]
1095
+ elif "prompt" in example:
1096
+ prompts = tf.expand_dims(example["prompt"], axis=0)
1097
+ else:
1098
+ prompts = tf.expand_dims(example["question"], axis=0)
1099
+ suffixes = []
1100
+ for suffix in GENERAL_PROMPTS_V1["no_pointing_suffix"]:
1101
+ if not suffix[0].isspace():
1102
+ suffix = " " + suffix
1103
+ suffixes.append(suffix)
1104
+ no_point_prompts = tf.reshape(tf.strings.join([
1105
+ tf.tile(tf.expand_dims(suffixes, 1), [1, tf.shape(prompts)[1]]),
1106
+ tf.tile(prompts, [len(suffixes), 1]),
1107
+ ]), [-1])
1108
+ # prefixes = []
1109
+ # for prefix in GENERAL_PROMPTS_V1["no_pointing_prefix"]:
1110
+ # if not prefix[0].isspace():
1111
+ # prefix = prefix + " "
1112
+ # prefixes.append(prompts + prefix)
1113
+ prompt = apply_keyword_prompt(no_point_prompts, example, seed, keywords=[])
1114
+ elif "prompt" in example:
1115
+ prompt = example["prompt"]
1116
+ elif "refexp" in example:
1117
+ prompt = example["refexp"]
1118
+ elif "question" in example and "options" in example:
1119
+ prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
1120
+ elif "question" in example:
1121
+ prompt = example["question"]
1122
+ else:
1123
+ prompt = ""
1124
+
1125
+ elif self.prompt_templates == "general_instructions_v1":
1126
+ if isinstance(style, str):
1127
+ prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1[STYLE_TO_GENERAL_PROMPT[style]], example, seed)
1128
+ elif isinstance(style, list):
1129
+ # This ia bit of hack to allow apply prompts to joint caption/transcript data
1130
+ # FIXME ideally we can apply the templating to multiple styles more generally
1131
+ def _apply(_style, ix):
1132
+ tmp = dict(example)
1133
+ # prevent apply_keyword_prompt for generating multiple templates
1134
+ tmp["text"] = tmp["text"][0]
1135
+ if _style == "long_caption":
1136
+ return apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], tmp, seed)
1137
+ elif _style == "transcript":
1138
+ return apply_keyword_prompt(GENERAL_PROMPTS_V1["transcript"], tmp, seed)
1139
+ else:
1140
+ raise NotImplementedError(_style)
1141
+ prompt = [_apply(x, ix) for ix, x in enumerate(style)]
1142
+ else:
1143
+ raise NotImplementedError()
1144
+
1145
+ elif self.prompt_templates == "zero_shot_v1":
1146
+ assert style is not None
1147
+ if not isinstance(style, str):
1148
+ # FIXME can we handle tensor style's in a better way?
1149
+ if style == "ai2_diagram":
1150
+ prompt = "Question: {question}\nAnswer with correct answer option letter only\nOptions: {options}\nAnswer:"
1151
+ prompt = apply_keyword_prompt([prompt], example, seed)
1152
+ elif style == "ai2_diagram_no_letter":
1153
+ prompt = "Question: {question}\nAnswer with correct answer option only\nOptions: {options}\nAnswer:"
1154
+ prompt = apply_keyword_prompt([prompt], example, seed)
1155
+ else:
1156
+ prompt = ""
1157
+ tf.debugging.assert_equal(prompt != "", True)
1158
+ else:
1159
+ general_style = STYLE_TO_GENERAL_PROMPT[style]
1160
+ if general_style == "short_answer":
1161
+ prompt = apply_keyword_prompt(["Question: {question} Answer with as few words as possible. Answer:"], example, seed)
1162
+ elif general_style == "multiple_choice":
1163
+ prompt = apply_keyword_prompt(["Question: {question}\nAnswer with correct answer option letter only\nOptions: {options}\nAnswer:"], example, seed)
1164
+ elif general_style == "count_bench":
1165
+ prompt = apply_keyword_prompt(["Question: How many {object} are there?\nRespond with only a number.\nAnswer:"], example, seed)
1166
+ else:
1167
+ raise NotImplementedError(general_style)
1168
+
1169
+ elif self.prompt_templates == "zero_shot_v2":
1170
+ assert style is not None
1171
+
1172
+ if self.prompt_override:
1173
+ prompt = apply_keyword_prompt([self.prompt_override], example, seed)
1174
+ elif not isinstance(style, str):
1175
+ if style == "ai2_diagram":
1176
+ prompt = "{question} Answer with correct answer option letter only. Options: {options}"
1177
+ prompt = apply_keyword_prompt([prompt], example, seed)
1178
+ elif style == "ai2_diagram_no_letter":
1179
+ prompt = "{question} Answer with correct answer option only. Options: {options}"
1180
+ prompt = apply_keyword_prompt([prompt], example, seed)
1181
+ else:
1182
+ prompt = ""
1183
+ tf.debugging.assert_equal(prompt != "", True)
1184
+ else:
1185
+ if style in ["vqa2", "gqa", "tally_qa", "okvqa", "a_okvqa_da"]:
1186
+ prompt = "Answer with a single word. {question}"
1187
+ elif style in ["text_vqa", "doc_qa", "info_qa", "chart_qa", "st_qa", "ocr_vqa", "dv_qa", "tabwmp_da", "figure_qa", "figure_qa_zero_shot", "plot_qa"]:
1188
+ prompt = "{question}\nRespond as concisely as possible, do not output anything other than the answer."
1189
+ elif STYLE_TO_GENERAL_PROMPT[style] == "multiple_choice":
1190
+ prompt = "{question} Answer with correct answer option letter only. Options: {options}"
1191
+ elif STYLE_TO_GENERAL_PROMPT[style] == "short_answer":
1192
+ prompt = "{question} Answer with as few words as possible."
1193
+ elif style == "vtabfact":
1194
+ prompt = "{question}"
1195
+ elif style == "count_bench":
1196
+ prompt = "How many {object} are there?\nRespond with only a number."
1197
+ else:
1198
+ raise NotImplementedError(style)
1199
+ prompt = apply_keyword_prompt([prompt], example, seed)
1200
+ else:
1201
+ raise NotImplementedError(self.prompt_templates)
1202
+
1203
+ if for_inference:
1204
+ return [prompt]
1205
+ else:
1206
+ return [prompt, example["text"]]
1207
+
1208
+ def get_system_prompt(self, style, example, for_inference,
1209
+ messages, seed=None):
1210
+ if isinstance(style, str) and style == "count_bench":
1211
+ style = "ok_vqa"
1212
+
1213
+ if self.system_prompt == "style":
1214
+ if isinstance(style, str):
1215
+ prefix = style + ":"
1216
+ else:
1217
+ prefix = tf.strings.join([style, ":"])
1218
+
1219
+ elif self.system_prompt == "demo_or_style":
1220
+ if isinstance(style, str):
1221
+ if style == "android_control" or style == "demo":
1222
+ # android is a special case since I hacked in prefix in the preprocessor
1223
+ prefix = ""
1224
+ elif style in ["scifi_demo", "synthetic_qa"] or style in DEMO_STYLES:
1225
+ if style == "scifi_demo":
1226
+ p_no_prompt = 0.2
1227
+ elif style == "synthetic_qa":
1228
+ p_no_prompt = 0.25
1229
+ else:
1230
+ p_no_prompt = 0.9
1231
+ if len(tf.shape(messages)) > 1:
1232
+ n_messages = tf.shape(messages)[1]
1233
+ style = tf.tile(tf.expand_dims(style, axis=0), [n_messages])
1234
+ r = tf.random.stateless_uniform([n_messages], seed, 0, 1)
1235
+ else:
1236
+ r = tf.random.stateless_uniform((), seed, 0, 1)
1237
+ prefix = tf.where(r < p_no_prompt, "", tf.strings.join([style + ":"]))
1238
+ else:
1239
+ prefix = style + ":"
1240
+ else:
1241
+ if tf.reduce_any(style == tf.constant(DEMO_STYLES + ["scifi_demo", "android_control", "demo"])):
1242
+ prefix = ""
1243
+ else:
1244
+ prefix = tf.strings.join([style, ":"])
1245
+
1246
+ elif self.system_prompt in ["long_caption_length_hint", "style_long_caption_length_hint"]:
1247
+ if seed is not None:
1248
+ raise NotImplementedError("Determinism")
1249
+ std = 25
1250
+ use_hint = tf.logical_or(
1251
+ tf.equal(style, "long_caption"), tf.equal(style, "transcript"))
1252
+ if self.system_prompt == "style_long_caption_length_hint":
1253
+ default = tf.strings.join([style, ": "])
1254
+ else:
1255
+ default = ""
1256
+ if for_inference:
1257
+ assert len(tf.shape(use_hint)) == 0
1258
+ if self.default_inference_len and use_hint:
1259
+ prefix = tf.strings.join([style, " ", str(self.default_inference_len), ": "])
1260
+ else:
1261
+ prefix = default
1262
+ else:
1263
+ std = 25
1264
+ n = tf.strings.length(messages[-1])
1265
+ n += tf.cast(tf.random.normal(n.shape)*std, tf.int32)
1266
+ hint = tf.strings.join([style, " ", tf.strings.as_string(n//15), ": "])
1267
+ use_hint = tf.logical_and(use_hint, tf.random.uniform(tf.shape(hint)) > 0.1)
1268
+ prefix = tf.where(use_hint, hint, default)
1269
+
1270
+ elif for_inference and self.system_prompt in ["style_and_length", "style_and_length_v2"]:
1271
+ v2 = self.system_prompt == "style_and_length_v2"
1272
+ if example.get("length_cond") is not None:
1273
+ # Examples have individual length conditioning
1274
+ n = tf.strings.as_string(example["length_cond"])
1275
+ else:
1276
+ inference_len = self.default_inference_len
1277
+ n = None if inference_len is None else str(inference_len)
1278
+ logging.warning(f"eval len: {n}")
1279
+ if n is not None and tf.strings.length(n) > 0: # allow empty string to signal unconditioned
1280
+ prefix = tf.strings.join([style, " ", n, ":"])
1281
+ else:
1282
+ prefix = tf.strings.join([style, ":" if v2 else " :"])
1283
+ elif self.system_prompt in ["style_and_length", "style_and_length_v2"]:
1284
+ v2 = self.system_prompt == "style_and_length_v2"
1285
+ std = 25
1286
+ logging.info(f"style prompt std={std}, percent=10")
1287
+ if seed is not None:
1288
+ seeds = tf.random.split(seed)
1289
+ p = tf.random.stateless_uniform((), seed=seeds[0])
1290
+ else:
1291
+ p = tf.random.uniform(())
1292
+ if p > 0.10:
1293
+ n = tf.strings.length(messages[-1])
1294
+ if seed is not None:
1295
+ n += tf.cast(tf.random.stateless_normal(n.shape, seed=seeds[1])*std, tf.int32)
1296
+ else:
1297
+ n += tf.cast(tf.random.normal(n.shape)*std, tf.int32)
1298
+ n = tf.strings.as_string(n//15)
1299
+ prefix = tf.strings.join([style, " ", n, ":"])
1300
+ else:
1301
+ prefix = tf.strings.join([style, ":" if v2 else " :"])
1302
+ else:
1303
+ raise NotImplementedError(self.system_prompt)
1304
+
1305
+ return prefix
1306
+
1307
+ def preprend_system_prompt(self, style, example, for_inference, messages, seed=None):
1308
+ prefix = self.get_system_prompt(style, example, for_inference, messages, seed=seed)
1309
+ separator = tf.where(tf.logical_and(
1310
+ tf.strings.length(prefix) > 0, tf.strings.length(messages[0]) > 0), " ", "")
1311
+ with_system_prompt = tf.strings.join([prefix, separator, messages[0]])
1312
+ if isinstance(messages, list):
1313
+ messages = [with_system_prompt] + messages[1:]
1314
+ else:
1315
+ messages = tf.concat([tf.expand_dims(with_system_prompt, 0), messages[1:]], axis=0)
1316
+ return messages
1317
+
1318
+ def get_messages(self, ex, style, is_training, for_inference, user_prompt_seed, system_prompt_seed):
1319
+ if isinstance(ex, list):
1320
+ messages = ex
1321
+ elif isinstance(ex, str):
1322
+ messages = [ex]
1323
+ elif "messages" in ex:
1324
+ messages = ex["messages"]
1325
+ else:
1326
+ # Apply a prompt template
1327
+ messages = self.get_user_prompt(style, ex, is_training, for_inference=for_inference, seed=user_prompt_seed)
1328
+
1329
+ # Maybe add a system prompt. The system prompt gets concatenated with the first user input
1330
+ if self.system_prompt and self.system_prompt != "none":
1331
+ if isinstance(ex, dict):
1332
+ style = ex.get("style", style)
1333
+
1334
+ if isinstance(messages, tf.RaggedTensor):
1335
+ n = tf.shape(messages)[0]
1336
+ message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
1337
+ seeds = tf.random.split(system_prompt_seed, n)
1338
+ for i in range(n):
1339
+ message_arr = message_arr.write(i, self.preprend_system_prompt(style, None, for_inference, messages[i], seed=seeds[i]))
1340
+ messages = tf.RaggedTensor.from_row_splits(
1341
+ values=message_arr.concat(), row_splits=messages.row_splits)
1342
+ else:
1343
+ messages = self.preprend_system_prompt(style, ex, for_inference, messages, seed=system_prompt_seed)
1344
+
1345
+ return messages
1346
+
1347
+ def get_preprocessor(self, is_training, for_inference, style=None, include_metadata=None):
1348
+ """Build a preprocessing function that can be applied ot a tf.data.Dataset"""
1349
+ vocab = self.tokenizer
1350
+ include_response = not for_inference
1351
+ if include_metadata is None:
1352
+ include_metadata = for_inference
1353
+
1354
+ @seqio.map_over_dataset(num_seeds=2)
1355
+ def to_inputs_and_targets(ex, seeds):
1356
+ if "unconditioned" in ex:
1357
+ raise NotImplementedError()
1358
+ if "image" not in ex:
1359
+ image = None
1360
+ elif ex['image'].dtype == tf.string:
1361
+ image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1362
+ else:
1363
+ image = ex['image']
1364
+ raw_image = image
1365
+ if image is not None and len(tf.shape(image)) == 3:
1366
+ image = tf.expand_dims(image, axis=0)
1367
+
1368
+ unconditioned = self.unconditioned
1369
+ if unconditioned and isinstance(unconditioned, float):
1370
+ assert image is not None
1371
+ if is_training and tf.random.uniform((), 0, 1, dtype=tf.float32) < unconditioned:
1372
+ image = image[:0]
1373
+ elif unconditioned:
1374
+ image = None
1375
+
1376
+ messages = self.get_messages(ex, style, is_training, for_inference, seeds[0], seeds[1])
1377
+ targets, loss_masks, subsegments = self.get_tokens_input(
1378
+ messages, for_inference, ex.get("text_weights"))
1379
+ # if "scifi" in style and style.endswith("_explanation"):
1380
+ # logging.warning(f"No loss on EOS for {style}")
1381
+ # loss_masks = tf.where(targets == self.tokenizer.eos_token_id, tf.zeros_like(loss_masks), loss_masks)
1382
+ out = self.build_multimodel_features(targets, loss_masks, subsegments, image, is_training)
1383
+
1384
+ if include_metadata:
1385
+ # FIXME remove these special cases
1386
+ if "text" in ex:
1387
+ if len(ex["text"].shape) > 0:
1388
+ # FIXME can this be variable lengths after all?
1389
+ out["metadata/captions"] = tf.strings.reduce_join(
1390
+ tf.strings.regex_replace(ex['text'], "\\s+", " "),
1391
+ separator="\n"
1392
+ )
1393
+ else:
1394
+ out["metadata/captions"] = ex["text"]
1395
+
1396
+ if "image_url" in ex:
1397
+ out["metadata/image_url"] = ex["image_url"]
1398
+ elif "url" in ex:
1399
+ out["metadata/image_url"] = ex["url"]
1400
+ if "image_id" in ex:
1401
+ out["metadata/image_id"] = ex["image_id"]
1402
+ for k, v in ex.items():
1403
+ if k.startswith("metadata"):
1404
+ out[k] = v
1405
+ if raw_image is not None and "metadata/image_size" not in out:
1406
+ img_h = tf.shape(raw_image)[0]
1407
+ img_w = tf.shape(raw_image)[1]
1408
+ out["metadata/image_size"] = [img_w, img_h]
1409
+ if "metadata/image_url" not in out and raw_image is not None:
1410
+ if len(ex["image"].shape) < 4:
1411
+ # For visualizations FIXME can we make this variable length
1412
+ out["metadata/image"] = tf.io.encode_jpeg(
1413
+ tf.image.convert_image_dtype(raw_image, tf.uint8))
1414
+ return out
1415
+ return to_inputs_and_targets
1416
+
1417
+ def get_post_mixing_preprocessor(self, pack=False):
1418
+ """Build a feature conversion function that can be applied ot a tf.data.Dataset
1419
+
1420
+ This function applies a second stage of pre-processing, but unlike `self.get_preprocessor`
1421
+ this stage can be applied after mixing tf.data.Datasets into a mixture
1422
+ """
1423
+ return MultiModalLMFeatureConverter(
1424
+ loss_token_weighting=self.loss_token_weighting,
1425
+ bos_id=self.tokenizer.bos_token_id,
1426
+ fix_image_input_idx=self.fix_image_input_idx,
1427
+ pack=pack,
1428
+ special_tokens=list(self.special_token_ids.values()),
1429
+ )
1430
+
1431
+
1432
+ class MultiModalLMFeatureConverter:
1433
+
1434
+ def __init__(
1435
+ self, pack: bool = False, loss_token_weighting: str=None, bos_id: int = 1,
1436
+ special_tokens=None, fix_image_input_idx=2
1437
+ ):
1438
+ self.pack = pack
1439
+ self.bos_id = bos_id
1440
+ self.fix_image_input_idx = fix_image_input_idx
1441
+ self.special_tokens = tf.constant(special_tokens) if special_tokens else None
1442
+ self.loss_token_weighting = loss_token_weighting
1443
+
1444
+ def _convert_example(
1445
+ self, features: Mapping[str, tf.Tensor]
1446
+ ) -> Mapping[str, tf.Tensor]:
1447
+ """Convert an LM example into an example with model features."""
1448
+ # targets_segment_id is present only for a packed dataset.
1449
+ decoder_input_tokens = make_autoregressive_inputs(
1450
+ features["target_tokens"],
1451
+ sequence_id=features.get("targets_segment_ids", None),
1452
+ bos_id=self.bos_id,
1453
+ )
1454
+
1455
+ tf.assert_equal(
1456
+ True,
1457
+ tf.reduce_all(decoder_input_tokens[-1] != self.special_tokens),
1458
+ message="An input ends with an image special token",
1459
+ )
1460
+
1461
+ image_input_idx = features["image_input_idx"]
1462
+ if self.fix_image_input_idx == 2:
1463
+ # plus one sine we have added BOS to the inputs
1464
+ image_input_idx = tf.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
1465
+ else:
1466
+ # Some old models trained like this, sometimes image_input_idx will go from -1 -> 0 didn't
1467
+ # effect performance but keep this code path for backwards compatiblity with those checkpoints
1468
+ image_input_idx = image_input_idx + 1
1469
+
1470
+ d = {
1471
+ "target_tokens": features["target_tokens"],
1472
+ "input_tokens": decoder_input_tokens,
1473
+ "loss_masks": features["loss_masks"],
1474
+ "images": features["images"],
1475
+ "image_input_idx": image_input_idx
1476
+ }
1477
+ if "image_masks" in features:
1478
+ d["image_masks"] = features["image_masks"]
1479
+
1480
+ has_custom_text_weight = features.get("has_custom_loss_weight", False)
1481
+
1482
+ if "subsegment_ids" in features:
1483
+ subsegment_ids = make_autoregressive_inputs(
1484
+ features["subsegment_ids"],
1485
+ sequence_id=features.get("targets_segment_ids", None),
1486
+ bos_id=features["subsegment_ids"][0],
1487
+ )
1488
+
1489
+ # Subsegment have a position based on the sum of previous positions they can attend to
1490
+ position_ids = tf.zeros_like(subsegment_ids)
1491
+ unique_segments = tf.unique(subsegment_ids)[0]
1492
+ for i in unique_segments:
1493
+ segment_position_ids = tf.cumsum(tf.cast(subsegment_ids >= i, tf.int32)) - 1
1494
+ position_ids = tf.where(subsegment_ids == i, segment_position_ids, position_ids)
1495
+
1496
+ # Apply loss weighting, this is done here so it occurs after truncation
1497
+ if has_custom_text_weight:
1498
+ pass
1499
+ elif self.loss_token_weighting in ["subsegments", "root_subsegments"]:
1500
+ n_loss_segments = tf.shape(tf.unique(tf.boolean_mask(subsegment_ids, d["loss_masks"] > 0))[0])[0]
1501
+ n_loss_segments = tf.maximum(tf.cast(n_loss_segments, tf.float32), 1)
1502
+ weight = 1/n_loss_segments if self.loss_token_weighting == "subsegments" else tf.math.rsqrt(n_loss_segments)
1503
+ d["loss_masks"] = tf.where(d["loss_masks"] > 0, d["loss_masks"]*weight, d["loss_masks"])
1504
+ elif self.loss_token_weighting is not None:
1505
+ raise NotImplementedError(self.loss_token_weighting)
1506
+
1507
+ d["subsegment_ids"] = subsegment_ids
1508
+ d["position_ids"] = position_ids
1509
+ else:
1510
+ if self.loss_token_weighting not in [None, "subsegments", "root_subsegments"] and not has_custom_text_weight:
1511
+ raise NotImplementedError(self.loss_token_weighting)
1512
+ if self.pack:
1513
+ d["decoder_segment_ids"] = features["targets_segment_ids"]
1514
+ d["decoder_positions"] = features["targets_positions"]
1515
+
1516
+ for k in features:
1517
+ if k.startswith("metadata/"):
1518
+ d[k] = features[k]
1519
+ return d
1520
+
1521
+ def _pack_or_pad(self, ds, task_feature_lengths):
1522
+ if self.pack:
1523
+ raise NotImplementedError()
1524
+ else:
1525
+ return trim_and_pad_dataset(ds, task_feature_lengths)
1526
+
1527
+ def __call__(self, ds: tf.data.Dataset, task_feature_lengths: Mapping[str, int]) -> tf.data.Dataset:
1528
+ """Convert the dataset to be fed to a language model."""
1529
+ task_feature_lengths = dict(task_feature_lengths)
1530
+
1531
+ if "images" in ds.element_spec and "images" in task_feature_lengths:
1532
+ # Images should never be truncated
1533
+ ds = assert_not_truncated(ds, ["images", "image_input_idx"], task_feature_lengths["images"])
1534
+
1535
+ if any(x.startswith("metadata/") for x in ds.element_spec):
1536
+ # Metadata indicates the dataset is being used for inference, inference datasets
1537
+ # should not be truncated
1538
+ ds = assert_not_truncated(ds, ["target_tokens"], task_feature_lengths["target_tokens"])
1539
+
1540
+ if "image_masks" in ds.element_spec and "images" in task_feature_lengths:
1541
+ task_feature_lengths["image_masks"] = task_feature_lengths["images"]
1542
+ if "subsegment_ids" in ds.element_spec and "target_tokens" in task_feature_lengths:
1543
+ task_feature_lengths["subsegment_ids"] = task_feature_lengths["target_tokens"]
1544
+ if "loss_masks" not in task_feature_lengths and "target_tokens" in task_feature_lengths:
1545
+ task_feature_lengths["loss_masks"] = task_feature_lengths["target_tokens"]
1546
+ ds = self._pack_or_pad(ds, task_feature_lengths)
1547
+
1548
+ return ds.map(
1549
+ self._convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
preprocesssors.py ADDED
@@ -0,0 +1,2472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import math
4
+ from functools import reduce
5
+ from typing import Mapping, Optional, Sequence
6
+
7
+ import numpy as np
8
+ import tensorflow as tf
9
+ import seqio
10
+ import gin
11
+
12
+ from .data_utils import flatten_parts, stateless_permutation, stateless_shuffle
13
+ from .. import config
14
+
15
+
16
+ def get_from_dict(data, keys):
17
+ """Iterate nested dictionary"""
18
+ return reduce(dict.get, keys, data)
19
+
20
+ def get_blank_image():
21
+ image = tf.zeros([224, 224, 3], dtype=tf.uint8)
22
+ image = tf.expand_dims(image, 0)[:1]
23
+ return image
24
+
25
+
26
+ @seqio.utils.map_over_dataset
27
+ def rekey(x, key_map=None):
28
+ """Replace the feature keys according to the mapping in `key_map`.
29
+ For example, if the dataset returns examples of the format:
30
+ {'foo': 'something', 'bar': 'something else'}
31
+ and key_map = {'boo': 'foo', 'spar': 'bar'} then this function will return
32
+ examples with the format
33
+ {'boo': 'something', 'spar': 'something else'}
34
+ If a mapping is to an empty key or None, set the new key to an empty string.
35
+ Args:
36
+ x: an example to process.
37
+ key_map: dictionary mapping new keys to original keys
38
+ Returns:
39
+ A preprocessed example with the format listed above.
40
+ """
41
+ if key_map:
42
+ out = {}
43
+ for new_key, old_key in key_map.items():
44
+ if isinstance(old_key, list):
45
+ out[new_key] = get_from_dict(x, old_key)
46
+ else:
47
+ out[new_key] = x[old_key]
48
+ return out
49
+ return x
50
+
51
+
52
+ def rename(**kwargs):
53
+ @seqio.map_over_dataset
54
+ def _fn(x):
55
+ updates = {}
56
+ for new_key, old_key in kwargs.items():
57
+ if isinstance(old_key, list):
58
+ val = x[old_key[0]]
59
+ for k in old_key[1:-1]:
60
+ val = val[k]
61
+ updates[new_key] = val.pop(old_key[-1])
62
+ else:
63
+ updates[new_key] = x.pop(old_key)
64
+ x.update(updates)
65
+ return x
66
+ return _fn
67
+
68
+
69
+ def extract_transcripts(ds):
70
+ ds = flatten_parts(ds, ["transcripts"])
71
+ def _map(ex):
72
+ return dict(
73
+ image=ex["image"],
74
+ text=ex["transcripts"],
75
+ url=ex["url"]
76
+ )
77
+ return ds.map(_map)
78
+
79
+
80
+ @seqio.map_over_dataset
81
+ def extract_caption_and_all_transcripts(ex):
82
+ transcripts = tf.random.shuffle(ex["transcripts"])[:3]
83
+ weight = 1.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
84
+ return dict(
85
+ image=ex["image"],
86
+ text=tf.concat([tf.expand_dims(ex["caption"], 0), transcripts], 0),
87
+ url=ex["url"],
88
+ text_weights=tf.pad(
89
+ tf.ones((1,), dtype=tf.float32), [[0, tf.shape(transcripts)[0]]],
90
+ constant_values=weight),
91
+ )
92
+
93
+
94
+ @seqio.map_over_dataset
95
+ def extract_all_transcripts(ex):
96
+ transcripts = tf.random.shuffle(ex["transcripts"])[:3]
97
+ weight = 3.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
98
+ return dict(
99
+ image=ex["image"],
100
+ text=transcripts,
101
+ url=ex["url"],
102
+ text_weights=tf.fill((tf.shape(transcripts)[0],), weight),
103
+ )
104
+
105
+
106
+ @seqio.map_over_dataset
107
+ def extract_transcript(ex):
108
+ transcripts = tf.random.shuffle(ex["transcripts"])
109
+ return dict(
110
+ image=ex["image"],
111
+ text=transcripts[0],
112
+ url=ex["url"],
113
+ )
114
+
115
+
116
+ @seqio.map_over_dataset
117
+ def extract_caption(ex):
118
+ caption = ex["caption"]
119
+ if len(caption.shape) > 0:
120
+ ex["text"] = caption[0]
121
+ else:
122
+ ex["text"] = caption
123
+ return ex
124
+
125
+
126
+ @seqio.map_over_dataset
127
+ def extract_joint_captions(ex):
128
+ caption = ex["caption"]
129
+ if len(caption.shape) > 0:
130
+ caption = caption[0]
131
+ _ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
132
+ _ix = _ix % tf.shape(ex["transcripts"])[0]
133
+ return dict(
134
+ image=ex["image"],
135
+ text=tf.stack([caption, ex["mistral_caption"], ex["transcripts"][_ix]], 0),
136
+ url=ex["url"]
137
+ )
138
+
139
+
140
+ @seqio.map_over_dataset(num_seeds=1)
141
+ def extract_caption_and_transcript(ex, seed):
142
+ caption = ex["caption"]
143
+ if len(caption.shape) > 0:
144
+ caption = caption[0]
145
+ _ix = tf.random.stateless_uniform((), seed, 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
146
+ return dict(
147
+ image=ex["image"],
148
+ text=tf.stack([caption, ex["transcripts"][_ix]], 0),
149
+ url=ex["url"]
150
+ )
151
+
152
+
153
+ @seqio.map_over_dataset
154
+ def caption_transcript_augmented(ex, sequence_length):
155
+ caption = ex["caption"]
156
+ if len(caption.shape) > 0:
157
+ caption = caption[0]
158
+ image = ex["image"]
159
+ properties = []
160
+
161
+ do_augmentation = sequence_length["is_training"]
162
+ # do_augmentation = False
163
+
164
+ # Keep this off, it screws up OCR
165
+ # do_hflip = (tf.random.uniform(()) > 0.2 and do_augmentation)
166
+ do_hflip = False
167
+ if do_hflip:
168
+ image = image[:, ::-1]
169
+
170
+ # Mild color jitter
171
+ do_color = (tf.random.uniform(()) > 0.5 and do_augmentation)
172
+ if do_color:
173
+ image = tf.image.random_hue(image, max_delta=0.05)
174
+ image = tf.image.random_brightness(image, max_delta=0.2)
175
+ image = tf.image.random_saturation(image, 0.7, 1.3)
176
+ image = tf.image.random_contrast(image, 0.7, 1.3)
177
+
178
+ # Mild affine transformation
179
+ do_affine = (tf.random.uniform(()) > 0.5 and do_augmentation)
180
+ if do_affine and do_augmentation:
181
+ shift_x = tf.random.uniform((), -10, 10) * 0
182
+ shift_y = tf.random.uniform((), -10, 10) * 0
183
+ shear_x = tf.random.uniform((), -2, 2)
184
+ shear_y = tf.random.uniform((), -2, 2)
185
+ rotation = tf.random.uniform((), -6, 6)
186
+ max_scale = 1.1
187
+ scale = tf.random.uniform((), 0.8, max_scale)
188
+ center = tf.cast(tf.shape(image), tf.float32)/2
189
+
190
+ image = tf.keras.ops.image.affine_transform(
191
+ image,
192
+ tf.stack(get_affine_matrix(
193
+ [center[0], center[1]],
194
+ rotation,
195
+ [shift_x, shift_y],
196
+ 1/scale,
197
+ [shear_x, shear_y]
198
+ ) + [0., 0.]),
199
+ interpolation='bilinear',
200
+ fill_mode='constant',
201
+ fill_value=1.,
202
+ data_format='channels_last'
203
+ )
204
+
205
+ properties = tf.stack([
206
+ ("[hflip]" if do_hflip else ""),
207
+ ("[color]" if do_color else ""),
208
+ ("[affine]" if do_affine else "")
209
+ ])
210
+ properties = tf.boolean_mask(properties, tf.strings.length(properties) > 0)
211
+ prompt = tf.strings.reduce_join(properties, separator=" ")
212
+ ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
213
+ out = dict(
214
+ image=image,
215
+ text=tf.stack([caption, ex["transcripts"][ix]], 0),
216
+ url=ex["url"],
217
+ prompt=prompt,
218
+ )
219
+ # out["metadata/unaugmented_image"] = image
220
+ return out
221
+
222
+
223
+ def extract_caption_and_transcript_hflip(ds):
224
+
225
+ # Just in case they are ordered somehow in Matt's data
226
+ @seqio.map_over_dataset
227
+ def _shuffle_transcripts(_ex):
228
+ _ex["transcripts"] = tf.random.shuffle(_ex["transcripts"])
229
+ _ex["hflip"] = tf.random.uniform((), 0, 3, dtype=tf.int32)
230
+ return _ex
231
+
232
+ ds = _shuffle_transcripts(ds)
233
+
234
+ # Build a 3x long dataset with each individual transcript so we iterate through
235
+ # each transcript
236
+ @seqio.map_over_dataset
237
+ def _with_transcript(ex, _ix):
238
+ caption = ex["caption"]
239
+ if len(caption.shape) > 0:
240
+ caption = caption[0]
241
+ hflip = ex["hflip"] == _ix
242
+ if hflip:
243
+ ex["image"] = ex["image"][:, ::-1]
244
+ style = ["long_caption_flipped", "transcript_flipped"]
245
+ else:
246
+ style = ["long_caption", "transcript"]
247
+ return dict(
248
+ image=ex["image"],
249
+ text=tf.stack([caption, ex["transcripts"][_ix]], 0),
250
+ url=ex["url"],
251
+ style=style
252
+ )
253
+
254
+ joint_ds = _with_transcript(ds, 0)
255
+ for i in range(1, 3):
256
+ joint_ds = joint_ds.concatenate(_with_transcript(ds, i))
257
+
258
+ return joint_ds
259
+
260
+
261
+ @seqio.map_over_dataset
262
+ def extract_llava(ex, sequence_length, output_features):
263
+ tf.assert_equal(tf.shape(ex['conversations']['value'])[0], 2)
264
+ prompt = ex['conversations']['value'][0]
265
+ text = ex['conversations']['value'][1]
266
+ ex.pop('conversations')
267
+ ex["text"] = text
268
+ ex["prompt"] = prompt
269
+ return ex
270
+
271
+
272
+ def extract_localized_narrative(ds):
273
+ ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
274
+ def _map(ex):
275
+ return dict(
276
+ image=ex["image"],
277
+ text=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
278
+ )
279
+ return ds.map(_map)
280
+
281
+
282
+ def float_to_text(val):
283
+ return tf.strings.as_string(tf.cast(val * 100, tf.int32))
284
+
285
+
286
+ @seqio.map_over_dataset
287
+ def extract_vqa(ex):
288
+ questions = ex["vqa"]["questions"]
289
+ answers = ex["vqa"]["answers"]
290
+ answers = tf.strings.reduce_join(answers, 1, separator="; ")
291
+ qas = tf.strings.reduce_join(tf.stack([questions, answers], 1), separator=" ")
292
+ return dict(
293
+ image=ex["image"],
294
+ text=tf.strings.reduce_join(qas, separator="\n")
295
+ )
296
+
297
+
298
+ @seqio.map_over_dataset
299
+ def coco_image_id_from_path(ex):
300
+ image_id = tf.strings.substr(ex["image/filename"], 0, tf.strings.length(ex["image/filename"])-4)
301
+ ex["image_id"] = tf.strings.to_number(image_id)
302
+ return ex
303
+
304
+
305
+ @seqio.map_over_dataset
306
+ def add_coco_url(ex):
307
+ """Turns a COCO path into a URL, which can then be used in visualizations"""
308
+ path = ex["image/filename"]
309
+ if not tf.strings.regex_full_match(path, ".*/.*"):
310
+ prefix = tf.strings.regex_replace(path, "COCO_", "")
311
+ prefix = tf.strings.regex_replace(prefix, "_[0-9]+.jpg", "")
312
+ path = tf.strings.join([prefix, path], separator="/")
313
+
314
+ # images are hosted by the COCO website here
315
+ url = tf.strings.join(["https://s3.us-east-1.amazonaws.com/images.cocodataset.org/", path])
316
+ ex["metadata/image_url"] = url
317
+ return ex
318
+
319
+
320
+ def flatten_vqa(ds):
321
+ parts = ["questions", "answers"]
322
+ for k in ["id", "question_id"]:
323
+ if k in ds.element_spec:
324
+ parts.append(k)
325
+ return flatten_parts(ds, parts)
326
+
327
+
328
+ def format_gqa(ds, is_balanced=True, flatten=True):
329
+ if is_balanced:
330
+ ds = ds.filter(lambda x: tf.reduce_any(x["questions"]["is_balanced"]))
331
+ def _filter_qs(ex):
332
+ qs = ex["questions"]
333
+ mask = qs["is_balanced"]
334
+ qs = {k: tf.boolean_mask(v, mask) for k, v in qs.items()}
335
+ ex["questions"] = qs
336
+ return ex
337
+ ds = ds.map(_filter_qs)
338
+
339
+ if flatten:
340
+ ds = flatten_parts(ds, ["questions"])
341
+
342
+ def _rename(ex):
343
+ out = ex["questions"]
344
+ out["image"] = ex["image"]
345
+ out["image_id"] = ex["image_id"]
346
+ return out
347
+ return ds.map(_rename)
348
+
349
+
350
+ @seqio.map_over_dataset
351
+ def fix_doqa_url(x):
352
+ x["image_url"] = tf.strings.regex_replace(x["image_url"], "gs://", "")
353
+ return x
354
+
355
+
356
+ def _add_metadata(ex):
357
+ out = {}
358
+ if "id" in ex:
359
+ out["metadata/example_id"] = ex["id"]
360
+ elif "example_id" in ex:
361
+ out["metadata/example_id"] = ex["example_id"]
362
+ elif "question_id" in ex:
363
+ out["metadata/example_id"] = ex["question_id"]
364
+ if "image_url" in ex:
365
+ out["metadata/image_url"] = ex["image_url"]
366
+ for k, v in ex.items():
367
+ if k.startswith("metadata/"):
368
+ out[k] = v
369
+ return out
370
+
371
+
372
+ def image_only(ds):
373
+ return ds.filter(lambda x: x["has_image"])
374
+
375
+
376
+ def filter_difficult_direct_answer(ds):
377
+ return ds.filter(lambda x: not x["difficult_direct_answer"])
378
+
379
+
380
+ @seqio.map_over_dataset()
381
+ def format_ai2d(ex, variable_style=True):
382
+ abc = tf.constant(list("abcdefg".upper()))
383
+ out = dict(image=ex["image"])
384
+ out.update(_add_metadata(ex))
385
+
386
+ options = ex["choices"]
387
+ # >= 3 in case of none of the above like answers
388
+ n_options = tf.shape(ex["option_is_abc"])[0]
389
+ if ex["abc_label"] and tf.reduce_sum(tf.cast(ex["option_is_abc"], tf.int32)) >= (n_options - 1):
390
+ # The image labels are always upper, so use upper in the answer ptions
391
+ options = tf.where(
392
+ ex["option_is_abc"],
393
+ tf.strings.upper(options),
394
+ options
395
+ )
396
+ short_options = options
397
+ style = "ai2_diagram_no_letter"
398
+ else:
399
+ short_options = abc[:tf.shape(options)[0]]
400
+ options = tf.stack([short_options, options,], 1)
401
+ options = tf.strings.reduce_join(options, axis=-1, separator=": ")
402
+ style = "ai2_diagram"
403
+
404
+ options = tf.strings.reduce_join(options, separator="\n")
405
+ out["question"] = ex["question"]
406
+ out["options"] = options
407
+ if variable_style:
408
+ out["style"] = style
409
+ if ex["answer_idx"] < 0:
410
+ out["text"] = "?"
411
+ else:
412
+ out["text"] = short_options[ex["answer_idx"]]
413
+ out["metadata/answer_idx"] = ex["answer_idx"]
414
+ tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
415
+ out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
416
+ out["metadata/has_transparent_box"] = ex.get("has_transparent_box", tf.constant(False))
417
+ out["metadata/abc_label"] = ex["abc_label"]
418
+ return out
419
+
420
+
421
+ @gin.configurable()
422
+ @seqio.map_over_dataset()
423
+ def format_multiple_choice_qa(ex, option_format="abc"):
424
+ assert option_format == "abc"
425
+ abc = tf.constant(list("abcdefg".upper()))
426
+ out = dict(image=ex["image"])
427
+ out.update(_add_metadata(ex))
428
+ options = ex["choices"]
429
+ short_options = abc[:tf.shape(options)[0]]
430
+ options = tf.stack([short_options, options,], 1)
431
+ options = tf.strings.reduce_join(options, axis=-1, separator=": ")
432
+ options = tf.strings.reduce_join(options, separator="\n")
433
+ out["question"] = ex["question"]
434
+ out["options"] = options
435
+ if ex["answer_idx"] < 0:
436
+ out["text"] = "?"
437
+ else:
438
+ out["text"] = short_options[ex["answer_idx"]]
439
+ out["metadata/answer_idx"] = ex["answer_idx"]
440
+ tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
441
+ out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
442
+ # out["metadata/option_names"] = tf.RaggedTensor.from_row_lengths(short_options, tf.shape(short_options))
443
+ # out["metadata/option_names"] = short_options
444
+ return out
445
+
446
+
447
+ @seqio.map_over_dataset()
448
+ def output_options(ex):
449
+ ex["metadata/options"] = ex["options"]
450
+ return ex
451
+
452
+
453
+ @seqio.map_over_dataset()
454
+ def extract_tally_qa(ex):
455
+ questions = ex.pop("questions")
456
+ ex["questions"] = questions["question"]
457
+ ex["answers"] = tf.strings.as_string(questions["answer"])
458
+ ex["question_id"] = questions["question_id"]
459
+ return ex
460
+
461
+
462
+ @seqio.map_over_dataset()
463
+ def count_bench_preprocessor(ex):
464
+ return {
465
+ "image": ex["image"],
466
+ "text": tf.strings.as_string(ex["number"]),
467
+ "object": ex["noun"],
468
+ "question": tf.strings.join([
469
+ "How many ", ex["noun"], " are there?"
470
+ ]),
471
+ "metadata/count": ex["number"],
472
+ }
473
+
474
+
475
+ def filter_human(ds):
476
+ return ds.filter(lambda x: x["is_human"])
477
+
478
+
479
+ def filter_aug(ds):
480
+ return ds.filter(lambda x: not x["is_human"])
481
+
482
+
483
+ @seqio.map_over_dataset()
484
+ def reweight_chartqa(ex, human, aug):
485
+ is_human = ex["metadata/is_human"]
486
+ ex["text_weights"] = human if is_human else aug
487
+ return ex
488
+
489
+
490
+ @seqio.map_over_dataset()
491
+ def chartqa_prompting(ex):
492
+ question = tf.strings.join([ex["question"], " Answer:"])
493
+ return dict(
494
+ image=ex["image"],
495
+ question=question,
496
+ answer=ex["answer"]
497
+ )
498
+
499
+
500
+ @seqio.map_over_dataset()
501
+ def chartqa_explanation(ex):
502
+ question = tf.strings.join([ex["question"], " Explanation:"])
503
+ out = {
504
+ "image": ex["image"],
505
+ "question": question,
506
+ "answer": ex["answer"],
507
+ }
508
+ out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
509
+ return out
510
+
511
+
512
+ @seqio.map_over_dataset(num_seeds=1)
513
+ def _preprocess_scifi(ex, seed):
514
+ if "qa_pairs" in ex:
515
+ q = ex["qa_pairs"]
516
+ else:
517
+ q = ex["qa"]
518
+ ix = stateless_permutation(tf.shape(q["question"])[0], seed)
519
+ return dict(
520
+ image=ex["image"],
521
+ question=tf.gather(q["question"], ix),
522
+ explanation=tf.gather(q["explanation"], ix),
523
+ answer=tf.gather(q["answer"], ix),
524
+ )
525
+
526
+ @seqio.map_over_dataset
527
+ def scifi_explanation_only(ex):
528
+ return dict(
529
+ image=ex["image"],
530
+ question=ex["question"],
531
+ answer=ex["explanation"],
532
+ )
533
+
534
+
535
+ def filter_named_entity(ds):
536
+ @seqio.map_over_dataset
537
+ def _load_image(ex):
538
+ ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
539
+ return ex
540
+
541
+ ds = _load_image(ds)
542
+ return ds.filter(lambda x: tf.reduce_min(tf.shape(x["image"])[:2]) >= 32)
543
+
544
+
545
+ @seqio.map_over_dataset()
546
+ def extract_named_entity(ex):
547
+ qs = ex["questions"]
548
+ return {
549
+ "image": ex["image"],
550
+ "metadata/image_url": ex["url"],
551
+ "metadata/entity": ex["entity"],
552
+ "questions": qs["question"],
553
+ "answers": qs["answer"],
554
+ }
555
+
556
+ @gin.configurable()
557
+ def extract_individual_vqa(ds, test=False, answer_mode="best"):
558
+
559
+ @seqio.map_over_dataset(num_seeds=1)
560
+ def _extract(ex, seed):
561
+ if "questions" in ex:
562
+ question = ex["questions"]
563
+ else:
564
+ question = ex["question"]
565
+ out = dict(
566
+ image=ex["image"],
567
+ question=question,
568
+ )
569
+ out.update(_add_metadata(ex))
570
+ out["metadata/question"] = question
571
+ if ex.get("answers") is not None:
572
+ out["metadata/references"] = tf.strings.reduce_join(ex["answers"], separator="\n")
573
+ elif ex.get("answer") is not None:
574
+ out["metadata/references"] = ex["answer"]
575
+
576
+ if not test:
577
+ if "answer" in ex:
578
+ answer = ex["answer"]
579
+ else:
580
+ answer = ex["answers"]
581
+ if answer.dtype in [tf.int32, tf.int64]:
582
+ answer = tf.strings.as_string(answer)
583
+ if len(answer.shape) == 1 and tf.shape(answer)[0] == 0:
584
+ answer = tf.expand_dims("", 0)
585
+ if len(answer.shape) == len(question.shape):
586
+ pass
587
+ # Handle questions with multiple answers
588
+ elif answer_mode == "random":
589
+ assert len(answer.shape) == 1
590
+ answer = answer[tf.random.stateless_uniform((), seed, 0, tf.shape(answer)[0], dtype=tf.int32)]
591
+ elif answer_mode == "best":
592
+ def _get_best(_answer):
593
+ vals, _, counts = tf.unique_with_counts(_answer)
594
+ count_thresh = tf.reduce_max(counts)
595
+ vals = tf.boolean_mask(vals, counts >= count_thresh)
596
+ return vals[tf.random.stateless_uniform((), seed, 0, tf.shape(vals)[0], dtype=tf.int32)]
597
+ if len(answer.shape) == 1:
598
+ answer = _get_best(answer)
599
+ elif isinstance(answer, tf.RaggedTensor):
600
+ n = tf.shape(answer)[0]
601
+ answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
602
+ for i in range(n):
603
+ answer_arr = answer_arr.write(i, _get_best(answer[i]))
604
+ answer = answer_arr.stack()
605
+ else:
606
+ answer = tf.map_fn(_get_best, answer)
607
+ elif answer_mode == "all_segments":
608
+ out["text"] = answer
609
+ elif answer_mode == "all_segments_weighted":
610
+ out["text"] = answer
611
+ out["text_weights"] = 1.0 / tf.cast(tf.shape(answer)[-1], tf.float32)
612
+ elif answer_mode == "all":
613
+ if len(answer.shape) == 1:
614
+ answer = stateless_shuffle(answer, seed)
615
+ answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
616
+ elif isinstance(answer, tf.RaggedTensor):
617
+ n = tf.shape(answer)[0]
618
+ answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
619
+ for i in range(n):
620
+ answer_arr = answer_arr.write(i, tf.strings.reduce_join(tf.random.shuffle(answer[i]), separator="\n", axis=-1))
621
+ answer = answer_arr.stack()
622
+ else:
623
+ answer = tf.map_fn(tf.random.shuffle, answer)
624
+ answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
625
+ else:
626
+ raise NotImplementedError()
627
+ out["text"] = answer
628
+ return out
629
+ return _extract(ds)
630
+
631
+
632
+ @seqio.map_over_dataset()
633
+ def extract_khan_academy(ex):
634
+ return dict(
635
+ image=ex["image"],
636
+ image_url=ex["image_url"],
637
+ prompt="Answer this question",
638
+ text=ex["gptResponse"]
639
+ )
640
+
641
+ @seqio.map_over_dataset()
642
+ def extract_vaia_qa_latex_image(ex, add_short_answer=False, set_short_answer_first=False):
643
+ if ex["has_image"]:
644
+ image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
645
+ image = tf.expand_dims(image, 0)[:1]
646
+ else:
647
+ # image = get_blank_image() # blank image
648
+ image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
649
+ image = tf.expand_dims(image, 0)[:0]
650
+ img_h = tf.shape(image)[1]
651
+ img_w = tf.shape(image)[2]
652
+
653
+ if add_short_answer:
654
+ if set_short_answer_first:
655
+ answer = tf.strings.join(["Answer: ", ex["short_answer"], "\n\n", ex["answer"]])
656
+ else:
657
+ answer = tf.strings.join([ex["answer"], "\n\n", "Answer: ", ex["short_answer"]])
658
+ else:
659
+ answer = ex["answer"]
660
+ out = dict(
661
+ image=image, # 4-d tensor
662
+ text=answer,
663
+ prompt=tf.strings.join([ex["latex_question"], "\n"]),
664
+ )
665
+ out["metadata/images"] = image
666
+ out.update(_add_metadata(ex))
667
+ out["metadata/batch_id"] = ex["batch_id"]
668
+ out["metadata/image_size"] = [img_w, img_h]
669
+ return out
670
+
671
+ @seqio.map_over_dataset()
672
+ def extract_vqa_online(ex):
673
+ out = dict(
674
+ image=ex["image"],
675
+ prompt=tf.strings.join([ex["question"], "\n"]),
676
+ text=ex["answer"]
677
+ )
678
+ out.update(_add_metadata(ex))
679
+ out["metadata/row_id"] = ex["row_id"]
680
+ return out
681
+
682
+
683
+ @seqio.map_over_dataset()
684
+ def extract_scifi_joint(ex):
685
+ if "qa_pairs" in ex:
686
+ q = ex["qa_pairs"]
687
+ else:
688
+ q = ex["qa"]
689
+ prompts = tf.concat([["Describe this image in detail."], q["question"]], 0)
690
+ responses = tf.concat([ex["summary"][None], q["answer"]], 0)
691
+ return dict(
692
+ image=ex["image"],
693
+ prompt=prompts,
694
+ text=responses,
695
+ )
696
+
697
+
698
+ def remove_no_qa(ds):
699
+ def _filter(ex):
700
+ if "qa_pairs" in ex:
701
+ q = ex["qa_pairs"]
702
+ else:
703
+ q = ex["qa"]
704
+ return tf.shape(q["question"])[0] > 0
705
+ return ds.filter(_filter)
706
+
707
+
708
+ @seqio.map_over_dataset()
709
+ def extract_scifi_qa_exp(ex):
710
+ return dict(
711
+ image=ex["image"],
712
+ question=ex["question"], # Array of questions
713
+ answer=tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]]),
714
+ )
715
+
716
+
717
+ @seqio.map_over_dataset(num_seeds=1)
718
+ def extract_scifi_qa_demo(ex, seed):
719
+ # if tf.random.stateless_uniform((), 0, 1) > 0.5:
720
+ answer = tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]])
721
+ # else:
722
+ # answer = ex["explanation"]
723
+ return dict(
724
+ image=ex["image"],
725
+ question=ex["question"], # Array of questions
726
+ answer=answer,
727
+ )
728
+
729
+
730
+ @seqio.map_over_dataset()
731
+ def clock_bench_preprocessor(ex):
732
+ out = dict(
733
+ image=ex["image"],
734
+ prompt="What time is being shown?",
735
+ )
736
+ for k in ["hour", "minute", "second", "answerable"]:
737
+ out[f"metadata/{k}"] = ex[k]
738
+ return out
739
+
740
+
741
+ def deg2rad(x):
742
+ return x*math.pi/180.0
743
+
744
+
745
+ def get_affine_matrix(center, angle, translate, scale, shear):
746
+ # From https://github.com/pytorch/vision/blob/f96c42fca53230057b16941b078a0a9eee06e20f/torchvision/transforms/functional.py#L1006
747
+ rot = deg2rad(angle)
748
+ sx = deg2rad(shear[0])
749
+ sy = deg2rad(shear[1])
750
+
751
+ cx, cy = center
752
+ tx, ty = translate
753
+
754
+ # RSS without scaling
755
+ a = tf.cos(rot - sy) / tf.cos(sy)
756
+ b = -tf.cos(rot - sy) * tf.tan(sx) / tf.cos(sy) - tf.sin(rot)
757
+ c = tf.sin(rot - sy) / tf.cos(sy)
758
+ d = -tf.sin(rot - sy) * tf.tan(sx) / tf.cos(sy) + tf.cos(rot)
759
+
760
+ matrix = [a, b, 0.0, c, d, 0.0]
761
+ matrix = [x * scale for x in matrix]
762
+ # Apply inverse of center translation: RSS * C^-1
763
+ matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
764
+ matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
765
+ # Apply translation and center : T * C * RSS * C^-1
766
+ matrix[2] += cx + tx
767
+ matrix[5] += cy + ty
768
+ return matrix
769
+
770
+
771
+ def quantize_point(coor, max_dim, mode="percent-precision-1"):
772
+ max_dim = tf.cast(max_dim, tf.float32)
773
+ coor = tf.cast(coor, tf.float32)
774
+ x = (coor / max_dim)
775
+ if mode == "percent-precision-1":
776
+ return tf.strings.as_string(x*100, precision=1)
777
+ elif mode == "zero_to_one":
778
+ return tf.strings.as_string(x, precision=3)
779
+ elif mode == "1k":
780
+ return tf.strings.as_string(x*1000, precision=0)
781
+ else:
782
+ raise NotImplementedError(mode)
783
+
784
+
785
+ def construct_pointing_format(label_text, alt_text, x_str, y_str):
786
+ if alt_text is None:
787
+ alt_text = label_text
788
+ np = tf.shape(x_str)[0]
789
+ if np == 0:
790
+ output = ""
791
+ elif np == 1:
792
+ output = tf.strings.join([
793
+ '<point x="', x_str[0], '" y="', y_str[0], '" alt="',
794
+ alt_text, '">', label_text, '</point>'
795
+ ])
796
+ else:
797
+ ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
798
+ xs = tf.strings.join(["x", ids, '="', x_str, '"'])
799
+ ys = tf.strings.join(["y", ids, '="', y_str, '"'])
800
+ points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
801
+ output = tf.strings.join(
802
+ ["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
803
+ return output
804
+
805
+
806
+ def order_points(x, y, seed, point_order):
807
+ if point_order == "natural":
808
+ return x, y
809
+
810
+ if point_order == "random":
811
+ ix = stateless_permutation(tf.shape(x)[0], seed)
812
+ elif point_order == "xy":
813
+ x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
814
+ ix = tf.argsort(x_float*100000 + y_float)
815
+ elif point_order == "yx":
816
+ x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
817
+ ix = tf.argsort(y_float*100000 + x_float)
818
+ else:
819
+ raise NotImplementedError(point_order)
820
+ return tf.gather(x, ix), tf.gather(y, ix)
821
+
822
+
823
+ @gin.configurable()
824
+ def points_to_text(x, y, w, h, seed, label=None, alt_text=None, point_mode="percent-precision-1",
825
+ point_order="xy", point_list_mode="tag"):
826
+ """Returns a string encoding of a list of points"""
827
+ x = quantize_point(x, w, point_mode)
828
+ y = quantize_point(y, h, point_mode)
829
+ # Order the quantized points to make the order matches what was generated, this can matter
830
+ # when points have the same quantized value e.g, (10.001, 20) (10.002, 10) should be
831
+ # represented (10, 10), (10, 20), but if we sort before quantization we get (10, 20), (10, 10)
832
+ x, y = order_points(x, y, seed, point_order)
833
+ if point_list_mode == "tag":
834
+ return construct_pointing_format(label, alt_text, x, y)
835
+ elif point_list_mode == "paren":
836
+ n = tf.shape(x)[0]
837
+ return tf.strings.reduce_join(tf.strings.join([
838
+ "(", x, ", ", y, ")"
839
+ ]), separator=", ")
840
+ # if n == 0:
841
+ # output = ""
842
+ # else:
843
+ # ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
844
+ # xs = tf.strings.join(["x", ids, '="', x_str, '"'])
845
+ # ys = tf.strings.join(["y", ids, '="', y_str, '"'])
846
+ # points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
847
+ # output = tf.strings.join(
848
+ # ["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
849
+ # return output
850
+ else:
851
+ raise NotImplementedError(point_list_mode)
852
+
853
+
854
+ def points_to_answer(x, y, w, h, seed, label, is_counting, alt_text=None):
855
+ count = tf.shape(x)[0]
856
+ if is_counting:
857
+ if count == 0:
858
+ return "There are none."
859
+ else:
860
+ point_text = points_to_text(x, y, w, h, seed, label, alt_text)
861
+ return tf.strings.join([
862
+ "Counting the ", point_text,
863
+ " shows a total of ",
864
+ tf.strings.as_string(count),
865
+ "."
866
+ ])
867
+ else:
868
+ if count == 0:
869
+ return "There are none."
870
+ else:
871
+ return points_to_text(x, y, w, h, seed, label, alt_text)
872
+
873
+
874
+ @seqio.map_over_dataset(num_seeds=2)
875
+ def extract_point_qa(ex, seeds, answer_type="y_major"):
876
+ ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
877
+ img_h = tf.shape(ex["image"])[0]
878
+ img_w = tf.shape(ex["image"])[1]
879
+
880
+ questions = ex["questions"]
881
+ question = questions["question"]
882
+ n = tf.shape(question)[0]
883
+ answers = tf.TensorArray(tf.string, size=n, element_shape=())
884
+ point_text = questions["annotations"]["point_text"]
885
+ point_seeds = tf.RaggedTensor.from_row_splits(
886
+ row_splits=point_text.row_splits,
887
+ values=tf.random.split(seeds[0], num=tf.shape(point_text.values)[0])
888
+ )
889
+ for question_ix in range(n):
890
+ anno = questions["annotations"]
891
+ answer = questions["answer_with_placeholders"][question_ix]
892
+ n_anno = tf.shape(anno["point_text"][question_ix])[0]
893
+ for anno_ix in range(n_anno):
894
+ points = anno["points"][question_ix, anno_ix]
895
+ point_text = points_to_answer(
896
+ points[:, 0], points[:, 1], 100, 100,
897
+ point_seeds[question_ix, anno_ix],
898
+ anno["point_text"][question_ix, anno_ix],
899
+ False,
900
+ alt_text=anno["alt_text"][question_ix, anno_ix],
901
+ )
902
+ answer_split = tf.strings.split(answer, sep="<|POINT|>", maxsplit=1)
903
+ answer = tf.strings.join([answer_split[0], point_text, answer_split[1]])
904
+ # Make sure all placeholders where used
905
+ tf.debugging.assert_equal(tf.shape(tf.strings.split(answer, sep="<|POINT|>"))[0], 1)
906
+ answers = answers.write(question_ix, answer)
907
+
908
+ messages = tf.stack([question, answers.stack()], axis=1)
909
+ messages = tf.reshape(messages, [-1])
910
+ conversation_ids = tf.range(tf.shape(messages)[0] // 2, dtype=tf.int32)
911
+ conversation_ids = tf.repeat(conversation_ids, 2)
912
+ out = dict(
913
+ image=ex["image"],
914
+ messages=tf.RaggedTensor.from_value_rowids(messages, conversation_ids)
915
+ )
916
+ ix = stateless_permutation(tf.shape(messages)[0], seeds[1])
917
+ messages = tf.gather(messages, ix)
918
+ out.update(_add_metadata(ex))
919
+ out["metadata/image_size"] = [img_w, img_h]
920
+ return out
921
+
922
+
923
+ def select_point(mask):
924
+ bs = tf.shape(mask)[0]
925
+ valid = tf.cast(mask, tf.float32)
926
+ h, w = tf.shape(mask)[1], tf.shape(mask)[2]
927
+ ys = tf.range(h, dtype=tf.int32)
928
+ xs = tf.range(w, dtype=tf.int32)
929
+
930
+ n = tf.reduce_sum(valid, [1, 2])
931
+ cy = tf.reduce_sum(tf.cast(ys[None, :, None], tf.float32) * valid, [1, 2]) / n # [bs]
932
+ cx = tf.reduce_sum(tf.cast(xs[None, None, :], tf.float32) * valid, [1, 2]) / n # [bs]
933
+
934
+ dist_y = tf.square(tf.range(h, dtype=tf.float32)[None, :] - cy[:, None]) # [bs, h]
935
+ dist_x = tf.square(tf.range(w, dtype=tf.float32)[None, :] - cx[:, None]) # [bs, w]
936
+ dist = dist_y[:, :, None] + dist_x[:, None, :] # [batch, h, w]
937
+ dist = dist + (1 - valid) * 1e12
938
+ min_dist = tf.argmin(tf.reshape(dist, [bs, -1]), axis=-1) # [batch]
939
+ w = tf.cast(w, min_dist.dtype)
940
+ cy = tf.cast(min_dist // w, tf.float32)
941
+ cx = tf.cast(min_dist % w, tf.float32)
942
+ return cx, cy
943
+
944
+
945
+ @seqio.map_over_dataset
946
+ def refexp_pointing(ex):
947
+ img_h = tf.shape(ex["image"])[0]
948
+ img_w = tf.shape(ex["image"])[1]
949
+ objects = ex["objects"]
950
+
951
+ # Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
952
+ refexps = objects['refexp']['raw']
953
+ bbox = objects["bbox"]
954
+ mask = tf.squeeze(objects["mask"], -1)
955
+
956
+ ix = tf.range(0, tf.shape(refexps)[0], dtype=tf.int32)
957
+ ix = tf.random.shuffle(ix)
958
+ refexps = tf.gather(refexps, ix)
959
+ bbox = tf.gather(bbox, ix)
960
+ mask = tf.gather(mask, ix)
961
+
962
+ cx, cy = select_point(mask)
963
+ answers = points_to_text(img_h, img_w, cx, cy)
964
+
965
+ out = {
966
+ "image": ex["image"],
967
+ "refexp": refexps.values,
968
+ "metadata/image_size": tf.stack([img_w, img_h,]),
969
+ "text": tf.repeat(answers, refexps.row_lengths()),
970
+ }
971
+ if "image_url" in ex:
972
+ out["metadata/image_url"] = ex["image_url"]
973
+ return out
974
+
975
+
976
+ @seqio.map_over_dataset
977
+ def refexp_pointing_inf(ex):
978
+ img_h = tf.shape(ex["image"])[0]
979
+ img_w = tf.shape(ex["image"])[1]
980
+
981
+ objects = ex["objects"]
982
+ mask = tf.squeeze(objects["mask"], -1)
983
+ cx, cy = select_point(mask)
984
+ answers = points_to_text(img_h, img_w, cx, cy)
985
+
986
+ refexps = objects["refexp"]["raw"]
987
+
988
+ # We can't use `mask` directly since it is variable size, and thus it
989
+ # will break batching. Here we serialize it instead
990
+ serialized_masks = tf.map_fn(tf.io.serialize_tensor, mask, fn_output_signature=tf.string)
991
+ out = {
992
+ "image": ex["image"],
993
+ "refexp": refexps,
994
+ "metadata/bbox": objects["bbox"],
995
+ "metadata/answer": answers,
996
+ "metadata/mask": serialized_masks,
997
+ "metadata/image_size": tf.stack([img_w, img_h]),
998
+ }
999
+ out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
1000
+ return out
1001
+
1002
+ @seqio.map_over_dataset
1003
+ def extract_andriod_control_inf(ex, mode):
1004
+ if mode == "ll":
1005
+ prompt = tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]])
1006
+ elif mode == "hl_ll":
1007
+ prompt = tf.strings.join([
1008
+ "high_level: ", ex["metadata/hl_instruction"],
1009
+ " low_level: ", ex["metadata/ll_instruction"]
1010
+ ])
1011
+ elif mode == "hl":
1012
+ prompt = tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]])
1013
+ elif mode == "hl_cot":
1014
+ prompt = tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]])
1015
+ else:
1016
+ raise NotImplementedError()
1017
+
1018
+ out = dict(
1019
+ image=ex["image"],
1020
+ prompt=prompt,
1021
+ text=ex["metadata/target_action"]
1022
+ )
1023
+ out.update(_add_metadata(ex))
1024
+ return out
1025
+
1026
+ @seqio.map_over_dataset
1027
+ def extract_android_control(ex):
1028
+ # Each image has three tasks:
1029
+ # low level -> action
1030
+ # high+low level -> action
1031
+ # high level -> action
1032
+ # high level -> low level + action (CoT)
1033
+ out = dict(
1034
+ image=ex["image"],
1035
+ prompt=tf.stack([
1036
+ tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]]),
1037
+ tf.strings.join([
1038
+ "high_level: ", ex["metadata/hl_instruction"],
1039
+ " low_level: ", ex["metadata/ll_instruction"]
1040
+ ]),
1041
+ tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]]),
1042
+ tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]]),
1043
+ ]),
1044
+ text=tf.stack([
1045
+ ex["metadata/target_action"],
1046
+ ex["metadata/target_action"],
1047
+ ex["metadata/target_action"],
1048
+ tf.strings.join(["Plan: ", ex["metadata/ll_instruction"], " Action: ", ex["metadata/target_action"]]),
1049
+ ])
1050
+ )
1051
+ # Only needed if visualizing
1052
+ # ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1053
+ # img_h = tf.shape(ex["image"])[0]
1054
+ # img_w = tf.shape(ex["image"])[1]
1055
+ # out["metadata/image_size"] = tf.stack([img_w, img_h,])
1056
+ out.update(_add_metadata(ex))
1057
+ return out
1058
+
1059
+
1060
+ @seqio.map_over_dataset(num_seeds=1)
1061
+ def refexp(ex, seed):
1062
+ img_h = tf.shape(ex["image"])[0]
1063
+ img_w = tf.shape(ex["image"])[1]
1064
+ objects = ex["objects"]
1065
+
1066
+ # Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
1067
+ refexps = objects['refexp']['raw']
1068
+ bbox = objects["bbox"]
1069
+ ix = stateless_permutation(tf.shape(refexps)[0], seed)
1070
+ refexps = tf.gather(refexps, ix)
1071
+ bbox = tf.gather(bbox, ix)
1072
+
1073
+ x2 = bbox[:, 0] + bbox[:, 2]
1074
+ y2 = bbox[:, 1] + bbox[:, 3]
1075
+ with tf.control_dependencies([
1076
+ tf.debugging.assert_equal(tf.reduce_any(x2 <= tf.cast(img_w, tf.float32)), True),
1077
+ tf.debugging.assert_equal(tf.reduce_any(y2 <= tf.cast(img_h, tf.float32)), True)
1078
+ ]):
1079
+ answers = points_to_text(
1080
+ img_h, img_w,
1081
+ tf.reshape(tf.stack([bbox[:, 0], x2], 1), [-1]),
1082
+ tf.reshape(tf.stack([bbox[:, 1], y2], 1), [-1]))
1083
+ answers = tf.strings.reduce_join(tf.reshape(answers, [-1, 2]), separator=" ", axis=1)
1084
+
1085
+ out = {
1086
+ "image": ex["image"],
1087
+ "refexp": refexps.values,
1088
+ "metadata/bbox": bbox,
1089
+ "metadata/image_size": tf.stack([img_w, img_h,]),
1090
+ "text": tf.repeat(answers, refexps.row_lengths()),
1091
+ }
1092
+
1093
+ if "image_url" in ex:
1094
+ out["image_url"] = ex["image_url"]
1095
+ return out
1096
+
1097
+
1098
+ @seqio.map_over_dataset
1099
+ def refexp_inf(ex):
1100
+ img_h = tf.shape(ex["image"])[0]
1101
+ img_w = tf.shape(ex["image"])[1]
1102
+ out = {
1103
+ "image": ex["image"],
1104
+ "refexp": ex["objects"]["refexp"]["raw"],
1105
+ "metadata/bbox": ex["objects"]["bbox"],
1106
+ "metadata/image_size": tf.stack([img_w, img_h,]),
1107
+ }
1108
+ out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
1109
+ return out
1110
+
1111
+
1112
+ def point_text_interleaved(*args):
1113
+ raise NotImplementedError()
1114
+
1115
+
1116
+ @seqio.map_over_dataset
1117
+ def web_pointing_preprocessor(ex):
1118
+ img_h = tf.shape(ex["image"])[0]
1119
+ img_w = tf.shape(ex["image"])[1]
1120
+
1121
+ question = point_text_interleaved(
1122
+ img_h, img_w, ex["question"], ex["question_points"]["x"], ex["question_points"]["y"])
1123
+ answer = point_text_interleaved(
1124
+ img_h, img_w, ex["answer"], ex["answer_points"]["x"], ex["answer_points"]["y"])
1125
+ answer_points = tf.stack([ex["answer_points"]["x"], ex["answer_points"]["y"]], axis=1)
1126
+ return {
1127
+ "question": question,
1128
+ "answer": answer,
1129
+ "image": ex["image"],
1130
+ "metadata/image_size": [img_w, img_h],
1131
+ "metadata/question_type": ex["question_type"],
1132
+ "metadata/answer_points": tf.io.serialize_tensor(answer_points),
1133
+ "metadata/answer": answer,
1134
+ }
1135
+
1136
+
1137
+ def filter_pointing(ds):
1138
+ return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] >= 1)
1139
+
1140
+
1141
+ def filter_qa(ds):
1142
+ return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] == 0)
1143
+
1144
+ # vaia filtering
1145
+ def filter_image_only(ds):
1146
+ return ds.filter(lambda ex: ex["has_image"])
1147
+
1148
+ def filter_mc(ds):
1149
+ return ds.filter(lambda ex: ex["is_mc"])
1150
+
1151
+ def remove_is_long(ds):
1152
+ return ds.filter(lambda ex: not ex["is_long"])
1153
+
1154
+ def remove_has_multiple_parts(ds):
1155
+ return ds.filter(lambda ex: not ex["has_multiple_parts"])
1156
+
1157
+
1158
+ def _split(ds: tf.data.Dataset, keys, n_splits=2):
1159
+ def _map(ex):
1160
+ n = tf.shape(ex[keys[0]])[0]
1161
+ if n < n_splits:
1162
+ return tf.data.Dataset.from_tensors(ex)
1163
+ else:
1164
+ # import pdb; pdb.set_trace()
1165
+ bs = n // n_splits
1166
+ remainder = n - bs*n_splits
1167
+ lens = tf.concat([
1168
+ tf.ones([remainder], dtype=tf.int32),
1169
+ tf.zeros([n_splits-remainder], dtype=tf.int32),
1170
+ ], axis=0) + bs
1171
+ tf.debugging.assert_equal(tf.reduce_sum(lens), n)
1172
+ ends = tf.cumsum(lens)
1173
+
1174
+ parts = []
1175
+ for split_ix in range(n_splits):
1176
+ part_ex = dict(ex)
1177
+ e = ends[split_ix]
1178
+ s = e - lens[split_ix]
1179
+ for k in keys:
1180
+ if isinstance(k, tuple):
1181
+ assert len(k) == 2
1182
+ part_ex[k[0]][k[1]] = ex[k[0]][k[1]][s:e]
1183
+ else:
1184
+ part_ex[k] = ex[k][s:e]
1185
+ parts.append(part_ex)
1186
+
1187
+ ds = tf.data.Dataset.from_tensors(parts[0])
1188
+ for sub_ds in parts[1:]:
1189
+ sub_ds = tf.data.Dataset.from_tensors(sub_ds)
1190
+ ds = ds.concatenate(sub_ds)
1191
+ return ds
1192
+
1193
+ return ds.flat_map(_map)
1194
+
1195
+
1196
+
1197
+ def split(ds, n=2):
1198
+ # return ds
1199
+ return _split(ds, [k for k in [
1200
+ "question",
1201
+ "label",
1202
+ "text",
1203
+ "entity",
1204
+ "messages"
1205
+ ] if k in ds.element_spec], n_splits=n)
1206
+
1207
+
1208
+ def split_points(ds, max_points=50):
1209
+ label = "question" if "question" in ds.element_spec else "label"
1210
+ return _split(ds, [
1211
+ "question", label, "notInImage",
1212
+ ("answer_points", "x"),
1213
+ ("answer_points", "y"),
1214
+ ])
1215
+
1216
+
1217
+ @seqio.map_over_dataset
1218
+ def fix_count_qa(ex):
1219
+ ex["label"] = ex["label"][::2]
1220
+ tf.debugging.assert_equal(tf.shape(ex["answer_points"]["x"])[0], tf.shape(ex["label"])[0])
1221
+ return ex
1222
+
1223
+
1224
+ def filter_points(ds, max_number=40):
1225
+
1226
+ def _add_valid(ex):
1227
+ valid = (
1228
+ tf.reduce_all(ex["answer_points"]["x"] >= 0.0, axis=-1) &
1229
+ tf.reduce_all(ex["answer_points"]["x"] <= 100.0, axis=-1) &
1230
+ tf.reduce_all(ex["answer_points"]["y"] >= 0.0, axis=-1) &
1231
+ tf.reduce_all(ex["answer_points"]["y"] <= 100.0, axis=-1) &
1232
+ (ex["answer_points"]["y"].row_lengths() <= max_number)
1233
+ )
1234
+ ex["valid"] = valid
1235
+ return ex
1236
+ ds = ds.map(_add_valid)
1237
+ ds = ds.filter(lambda ex: tf.reduce_any(ex["valid"]))
1238
+ return ds
1239
+
1240
+
1241
+ # def filter_points(ds, max_number=30):
1242
+ # n_points = ds["answer_points"]["x"].row_lengths()
1243
+ # parts = tf.TensorArray(tf.int32, size=tf.shape(n_points[0]), element_shape=tf.TensorShape([None]))
1244
+ # total = 0
1245
+ # on_row = 0
1246
+ # for i in range(n_points):
1247
+ # n = n_points[i]
1248
+ # if n > max_number:
1249
+ # continue
1250
+ # if n + total > max_number:
1251
+ #
1252
+ # return ds
1253
+
1254
+
1255
+ @seqio.map_over_dataset(num_seeds=2)
1256
+ def pointing_preprocessor(ex, sequence_length, seeds, with_count=False):
1257
+ image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1258
+ img_h = tf.shape(image)[0]
1259
+ img_w = tf.shape(image)[1]
1260
+
1261
+ ix = tf.where(ex["valid"])[:, 0]
1262
+ ix = stateless_shuffle(ix, seeds[0])
1263
+ if "label" in ex:
1264
+ question = tf.strings.lower(ex["label"])
1265
+ else:
1266
+ question = ex["question"]
1267
+ question = tf.gather(question, ix) # [n_question]
1268
+ points_x = tf.gather(ex["answer_points"]["x"], ix) # [n_question, n_points[ragged]]]
1269
+ points_y = tf.gather(ex["answer_points"]["y"], ix)
1270
+ not_in_image = tf.gather(ex["notInImage"], ix) # [n_question]
1271
+
1272
+ n = tf.shape(points_x)[0]
1273
+ point_text = tf.TensorArray(dtype=tf.string, size=n, element_shape=()) # [n_question]
1274
+ point_seeds = tf.random.split(seeds[1], n)
1275
+ for i in range(n):
1276
+ answer = points_to_answer(points_x[i], points_y[i], 100, 100, point_seeds[i], question[i], with_count)
1277
+ point_text = point_text.write(i, answer)
1278
+ return {
1279
+ "image": image,
1280
+ "metadata/image_size": [img_w, img_h],
1281
+ "entity": question,
1282
+ "question": question,
1283
+ "text": point_text.stack(),
1284
+ }
1285
+
1286
+
1287
+ @seqio.map_over_dataset
1288
+ def pointing_inf_preprocessor(ex):
1289
+ ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1290
+ img_h = tf.shape(ex["image"])[0]
1291
+ img_w = tf.shape(ex["image"])[1]
1292
+
1293
+ question = ex["question"]
1294
+ not_in_image = tf.shape(ex["answer_points"]["x"])[0] == 0
1295
+
1296
+ # points are stored in normalized format, de-normalize here
1297
+ points_x = ex["answer_points"]["x"] * tf.cast(img_w, tf.float32) / 100.0
1298
+ points_y = ex["answer_points"]["y"] * tf.cast(img_h, tf.float32) / 100.0
1299
+
1300
+ out = dict(
1301
+ image=ex["image"],
1302
+ question=question,
1303
+ entity=question,
1304
+ )
1305
+ out.update(_add_metadata(ex))
1306
+ out["metadata/not_in_image"] = not_in_image
1307
+ # We can't use `mask` directly since it is variable size, and thus it
1308
+ # will break batching. Here we serialize it instead
1309
+ serialized_masks = tf.map_fn(tf.io.serialize_tensor, ex["masks"], fn_output_signature=tf.string)
1310
+ serialized_masks = tf.strings.reduce_join(serialized_masks, separator="|||")
1311
+ out["metadata/mask"] = serialized_masks
1312
+ out["metadata/question"] = question
1313
+ out["metadata/answer_points"] = tf.io.serialize_tensor(tf.stack([points_x, points_y], 1))
1314
+ out["metadata/image_size"] = [img_w, img_h]
1315
+
1316
+ return out
1317
+
1318
+
1319
+ @seqio.map_over_dataset(num_seeds=1)
1320
+ def count_qa_preprocessor_inf(ex, sequence_length, seed):
1321
+ image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1322
+ img_h = tf.shape(image)[0]
1323
+ img_w = tf.shape(image)[1]
1324
+
1325
+ entity = tf.strings.substr(
1326
+ ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
1327
+ entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
1328
+ entity = tf.strings.lower(entity)
1329
+ tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
1330
+
1331
+ return {
1332
+ "image": image,
1333
+ "metadata/image_size": [img_w, img_h],
1334
+ "metadata/count": tf.strings.to_number(ex["answer"]),
1335
+ "question": ex["question"],
1336
+ "entity": entity,
1337
+ }
1338
+
1339
+
1340
+ @seqio.map_over_dataset(num_seeds=1)
1341
+ def count_qa_preprocessor(ex, sequence_length, seed, with_count=False,
1342
+ for_inference=False):
1343
+ point_answer = ex["point_answer"]
1344
+ numbers_str = tf.strings.regex_replace(point_answer, r'\.$', '')
1345
+ numbers_str = tf.strings.regex_replace(numbers_str, r'[^\d\.\s]+', '')
1346
+ numbers_str = tf.strings.strip(numbers_str)
1347
+ numbers = tf.strings.split(numbers_str)
1348
+ float_numbers = tf.strings.to_number(numbers, out_type=tf.float32)
1349
+ coordinates = tf.reshape(float_numbers, (-1, 3))
1350
+ points_x = coordinates[:, 1]
1351
+ points_y = coordinates[:, 2]
1352
+
1353
+ image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1354
+ img_h = tf.shape(image)[0]
1355
+ img_w = tf.shape(image)[1]
1356
+ entity = tf.strings.substr(
1357
+ ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
1358
+ entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
1359
+ entity = tf.strings.lower(entity)
1360
+ tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
1361
+ count = tf.strings.to_number(ex["answer"], out_type=tf.int32)
1362
+ if for_inference:
1363
+ return {
1364
+ "image": image,
1365
+ "metadata/image_size": [img_w, img_h],
1366
+ "metadata/count": count,
1367
+ "question": ex["question"],
1368
+ "entity": entity,
1369
+ }
1370
+ else:
1371
+ tf.debugging.assert_equal(count, tf.shape(points_x)[0])
1372
+ # points are already normalized so use w=1, h=1
1373
+ answer = points_to_answer(points_x, points_y, 1, 1, seed, entity, with_count)
1374
+ return {
1375
+ "image": image,
1376
+ "metadata/image_size": [img_w, img_h],
1377
+ "metadata/count": count,
1378
+ "question": ex["question"],
1379
+ "entity": entity,
1380
+ "text": answer,
1381
+ }
1382
+
1383
+
1384
+ @gin.configurable()
1385
+ @seqio.map_over_dataset
1386
+ def cleanup_preprocessor(ex, preprocess=False):
1387
+ if preprocess:
1388
+ ex["prompt"] = tf.strings.join(
1389
+ [
1390
+ "[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
1391
+ ex["prompt"],
1392
+ "\n[[Assistant]]: {after}"
1393
+ ]
1394
+ )
1395
+ return ex
1396
+ else:
1397
+ return ex
1398
+
1399
+
1400
+ @gin.configurable()
1401
+ @seqio.map_over_dataset
1402
+ def random_text_preprocessor(ex, preprocess=False):
1403
+ ex["prompt"] = "What does the text say in this image?"
1404
+ if preprocess:
1405
+ ex["prompt"] = tf.strings.join(["[[User]]: ", ex["prompt"], "\n[[Assistant]]:"])
1406
+ return ex
1407
+ else:
1408
+ return ex
1409
+
1410
+
1411
+ @seqio.map_over_dataset(num_seeds=25)
1412
+ def clock_augmentation(ex, seeds):
1413
+ seeds = list(seeds)
1414
+ image = ex["image"]
1415
+
1416
+ # Apply shear, rotation, and scale through one affine matrix
1417
+ height = tf.cast(tf.shape(image)[0], tf.float32)
1418
+ width = tf.cast(tf.shape(image)[1], tf.float32)
1419
+
1420
+ _call_id = [0]
1421
+
1422
+ def _rng(_minval=0, _maxval=1, shape=(), dtype=tf.float32):
1423
+ return tf.random.stateless_uniform(shape, seeds.pop(), _minval, _maxval, dtype=dtype)
1424
+
1425
+ sel = _rng(0, 1)
1426
+ if sel < 0.1:
1427
+ # Straight on
1428
+ shear_x = 0.
1429
+ shear_y = 0.
1430
+ rotation = 0.
1431
+ elif sel < 0.5:
1432
+ # Normal looking
1433
+ shear_x = _rng(-10, 10)
1434
+ shear_y = _rng(-10, 10)
1435
+ rotation = _rng(-25, 25)
1436
+ else:
1437
+ # Allowed to be very wonky
1438
+ # if tf.random.stateless_uniform((), seeds.pop(), 0, 1) > 0.8:
1439
+ # image = image[:, ::-1]
1440
+
1441
+ if _rng() > 0.5:
1442
+ shear_x = _rng( -30, 30)
1443
+ shear_y = _rng( -30, 30)
1444
+ else:
1445
+ shear_x = _rng( -10, 10)
1446
+ shear_y = _rng( -10, 10)
1447
+ rng = _rng( 0, 1)
1448
+ if rng < 0.2:
1449
+ rotation = _rng( -25, 25)
1450
+ elif rng < 0.6:
1451
+ rotation = _rng( -80, 80)
1452
+ else:
1453
+ rotation = _rng( -180, 180)
1454
+
1455
+ if _rng() > 0.5:
1456
+ scale = _rng( 0.3, 2)
1457
+ else:
1458
+ scale = _rng( 0.3, 1)
1459
+ # Pad so upscaling/rotation will not move the image out of bounds
1460
+ pad = tf.cast(tf.maximum(height, width)*0.5, tf.int32)
1461
+ image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
1462
+ height = tf.cast(tf.shape(image)[0], tf.float32)
1463
+ width = tf.cast(tf.shape(image)[1], tf.float32)
1464
+
1465
+ image = tf.keras.ops.image.affine_transform(
1466
+ image,
1467
+ tf.stack(get_affine_matrix(
1468
+ [height/2, width/2],
1469
+ rotation,
1470
+ [0, 0],
1471
+ 1/scale,
1472
+ [shear_x, shear_y]
1473
+ ) + [0., 0.]),
1474
+ interpolation='bilinear',
1475
+ fill_mode='constant',
1476
+ fill_value=1.,
1477
+ data_format='channels_last'
1478
+ )
1479
+
1480
+ # Crop, otherwise it would be impossible to put the image at the corner of the image
1481
+ not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
1482
+ no_white_ix = tf.where(not_white)
1483
+ top_left = tf.reduce_min(no_white_ix, axis=0)
1484
+ bottom_right = tf.reduce_max(no_white_ix, axis=0)
1485
+ image = tf.image.crop_to_bounding_box(
1486
+ image,
1487
+ offset_height=tf.cast(top_left[0], tf.int32),
1488
+ offset_width=tf.cast(top_left[1], tf.int32),
1489
+ target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
1490
+ target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
1491
+ )
1492
+
1493
+ # Translate
1494
+ height, width = tf.shape(image)[0], tf.shape(image)[1]
1495
+ translation_seed = _rng(0, 1)
1496
+ if translation_seed < 0.2:
1497
+ h_pad = _rng(0, height//2, (2,), dtype=tf.int32)
1498
+ w_pad = _rng(0, width//2, (2,), dtype=tf.int32)
1499
+ else:
1500
+ h_pad = _rng(0, height*2, (2,), dtype=tf.int32)
1501
+ w_pad = _rng(0, width*2, (2,), dtype=tf.int32)
1502
+ image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
1503
+ constant_values=1)
1504
+
1505
+ # Random background color
1506
+ # color_rng = tf.random.stateless_uniform((4,), seeds.pop(), 0, 1)
1507
+ # random_color = color_rng[:3]
1508
+ # valid = tf.reduce_all(tf.reduce_sum(tf.abs(random_color[None, None, :] - image), -1) > 0.03)
1509
+ # if color_rng[0] < 0.2 and valid:
1510
+ # image = tf.where(tf.reduce_all(image < 0.99, axis=-1, keepdims=True),
1511
+ # image, image * 0 + random_color[None, None, :])
1512
+
1513
+ # Mild color hitter
1514
+ image = tf.image.stateless_random_hue(image, max_delta=0.05, seed=seeds.pop())
1515
+ image = tf.image.stateless_random_brightness(image, max_delta=0.15, seed=seeds.pop())
1516
+ image = tf.image.stateless_random_saturation(image, 0.8, 1.2, seed=seeds.pop())
1517
+ image = tf.image.stateless_random_contrast(image, 0.8, 1.2, seed=seeds.pop())
1518
+
1519
+ # ex["metadata/unaugmented_image"] = ex["image"]
1520
+ ex["image"] = image
1521
+ return ex
1522
+
1523
+
1524
+ @seqio.map_over_dataset
1525
+ def clocks_preprocessor(ex):
1526
+ time_format = ex["time_format"]
1527
+ shows_seconds = ex["shows_seconds"]
1528
+ hour, minute, second = [tf.cast(ex[k], tf.int32) for k in ["hour", "minute", "second"]]
1529
+ if hour == 0: # Midnight of the previous day
1530
+ am_pm = "PM"
1531
+ hour_str = 12
1532
+ hour = 24
1533
+ elif hour > 12:
1534
+ am_pm = "PM"
1535
+ hour_str = hour - 12
1536
+ else:
1537
+ hour_str = hour
1538
+ am_pm = "AM"
1539
+ hour_str = tf.strings.as_string(hour_str)
1540
+ minute_str = tf.strings.as_string(minute)
1541
+ if tf.strings.length(minute_str) == 1:
1542
+ minute_str = tf.strings.join(["0", minute_str])
1543
+
1544
+ second_str = tf.strings.as_string(second)
1545
+ if tf.strings.length(second_str) == 1:
1546
+ second_str = tf.strings.join(["0", second_str])
1547
+
1548
+ prefix = "The time shown is "
1549
+
1550
+ if time_format == "The time is not shown":
1551
+ text = "The time is not shown in the image."
1552
+ hour, minute, second = -1, -1, -1
1553
+ else:
1554
+ if not shows_seconds:
1555
+ second = -1
1556
+ if time_format == "12 hour clock (without AM/PM)" and shows_seconds:
1557
+ if hour > 12:
1558
+ hour = hour - 12
1559
+ time = tf.strings.join([hour_str, ":", minute_str, ":", second_str])
1560
+ elif time_format == "12 hour clock (with AM/PM)" and shows_seconds:
1561
+ time = tf.strings.join([hour_str, ":", minute_str, ":", second_str, " ", am_pm])
1562
+ elif time_format == "12 hour clock (with AM/PM)" and not shows_seconds:
1563
+ time = tf.strings.join([hour_str, ":", minute_str, " ", am_pm])
1564
+ elif time_format == "12 hour clock (without AM/PM)" and not shows_seconds:
1565
+ if hour > 12:
1566
+ hour = hour - 12
1567
+ time = tf.strings.join([hour_str, ":", minute_str])
1568
+ else:
1569
+ time = "" # Should never occur, but needed for tf analysis
1570
+ tf.debugging.assert_equal(tf.strings.length(time) > 0, True)
1571
+ text = tf.strings.join(["The time shown is ", time])
1572
+ image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1573
+ image = tf.image.convert_image_dtype(image, tf.float32)[:-120] # remove the black shadow at the bottom
1574
+ return {
1575
+ "image": image,
1576
+ "prompt": "What time is being shown?",
1577
+ "text": text,
1578
+ "metadata/time_format": time_format,
1579
+ "metadata/hour": hour,
1580
+ "metadata/minute": minute,
1581
+ "metadata/text": text,
1582
+ "metadata/second": second,
1583
+ }
1584
+
1585
+
1586
+ @seqio.map_over_dataset()
1587
+ def atlas_obscura_preprocessor(ex):
1588
+ out = dict(
1589
+ image=ex["image"],
1590
+ prompt="Where was this picture taken?",
1591
+ text=tf.strings.join([
1592
+ ex["place"],
1593
+ " in ",
1594
+ ex["city"]
1595
+ ])
1596
+ )
1597
+ out["metadata/image_url"] = ex["image_url"]
1598
+ out["metadata/references"] = out["text"]
1599
+ return out
1600
+
1601
+
1602
+ @seqio.map_over_dataset()
1603
+ def famous_birthdays_preprocessor(ex):
1604
+ out = dict(
1605
+ image=ex["image"],
1606
+ image_url=ex["image_url"],
1607
+ prompt="Who is this?",
1608
+ text=ex["name"]
1609
+ )
1610
+ out["metadata/references"] = out["text"]
1611
+ return out
1612
+
1613
+
1614
+ @seqio.map_over_dataset()
1615
+ def mild_color_aug_preprocessor(ex):
1616
+ if "image_url" in ex: # URL won't show the augmentations
1617
+ del ex["image_url"]
1618
+ # ex["metadata/unaugmented_image"] = ex["image"]
1619
+ ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1620
+ ex["image"] = mild_color_aug(ex["image"])
1621
+ return ex
1622
+
1623
+
1624
+ def build_text_with_points(text, points, img_h, img_w):
1625
+ points = points_to_text(img_h, img_w, points[:, 0], points[:, 1])
1626
+ parts = tf.strings.split(text, sep="<ANS>")
1627
+ with_points = tf.strings.reduce_join(tf.reshape(tf.stack([
1628
+ parts,
1629
+ tf.pad(points, [[0, 1]], constant_values=""),
1630
+ ], 1), [-1]), separator="")
1631
+ return tf.strings.split(with_points, "\n\n")
1632
+
1633
+
1634
+ @seqio.map_over_dataset()
1635
+ def synth_count_preprocessor(example):
1636
+ image_shape = tf.shape(example["image"])
1637
+ h, w = image_shape[0], image_shape[1]
1638
+ questions = build_text_with_points(example["questions"], example["question_points"], h, w)
1639
+ answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
1640
+ keep_q = tf.strings.regex_full_match(questions, "How many.*")
1641
+ keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
1642
+ keep = tf.logical_and(keep_q, keep_ans)
1643
+ questions = tf.boolean_mask(questions, keep)
1644
+ answers = tf.boolean_mask(answers, keep)
1645
+ ix = tf.range(0, tf.shape(answers)[0], dtype=tf.int32)
1646
+ ix = tf.random.shuffle(ix)
1647
+ return dict(
1648
+ image=example["image"],
1649
+ prompt=tf.gather(questions, ix),
1650
+ text=tf.gather(answers, ix),
1651
+ )
1652
+
1653
+
1654
+ def synth_count_inf_preprocessor(ds):
1655
+
1656
+ @seqio.map_over_dataset(num_seeds=1)
1657
+ def get_two(example, seed):
1658
+ image_shape = tf.shape(example["image"])
1659
+ h, w = image_shape[0], image_shape[1]
1660
+ questions = build_text_with_points(example["questions"], example["question_points"], h, w)
1661
+ answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
1662
+ keep_q = tf.strings.regex_full_match(questions, "How many.*")
1663
+ keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
1664
+ keep = tf.logical_and(keep_q, keep_ans)
1665
+ questions = tf.boolean_mask(questions, keep)
1666
+ answers = tf.boolean_mask(answers, keep)
1667
+
1668
+ ix = stateless_permutation(tf.shape(answers)[0], seed)[:2]
1669
+ return {
1670
+ "image": example["image"],
1671
+ "prompt": tf.gather(questions, ix),
1672
+ "metadata/references": tf.gather(answers, ix),
1673
+ }
1674
+
1675
+ ds = get_two(ds)
1676
+ return flatten_parts(ds, ["prompt", "metadata/references"])
1677
+
1678
+
1679
+ def mild_color_aug(image):
1680
+ image = tf.image.random_hue(image, max_delta=0.05)
1681
+ image = tf.image.random_brightness(image, max_delta=0.15)
1682
+ image = tf.image.random_saturation(image, 0.7, 1.3)
1683
+ image = tf.image.random_contrast(image, 0.8, 1.2)
1684
+ return image
1685
+
1686
+
1687
+ @seqio.map_over_dataset()
1688
+ def name_entity_augmentation(ex, p_high_color=0.7):
1689
+ ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
1690
+ image = ex["image"]
1691
+ image = tf.image.convert_image_dtype(image, tf.float32)
1692
+
1693
+ # Horizontal flip
1694
+ if tf.random.uniform((), 0, 1) > 0.85:
1695
+ image = image[:, ::-1]
1696
+
1697
+ # Random crop
1698
+ height = tf.cast(tf.shape(image)[0], tf.float32)
1699
+ width = tf.cast(tf.shape(image)[1], tf.float32)
1700
+ crop_rng = tf.random.uniform((), 0, 1)
1701
+ if crop_rng < 0.2:
1702
+ pass
1703
+ else:
1704
+ if crop_rng < 0.4:
1705
+ h_crop = height * 0.15
1706
+ w_crop = width * 0.15
1707
+ else:
1708
+ h_crop = height * 0.4
1709
+ w_crop = width * 0.4
1710
+ crop_h = tf.cast(tf.random.uniform((2,), 0, h_crop/2), tf.int32)
1711
+ crop_w = tf.cast(tf.random.uniform((2,), 0, w_crop/2), tf.int32)
1712
+ image = image[crop_h[0]:-crop_h[1]-1, crop_w[0]:-crop_w[1]-1]
1713
+ height = tf.cast(tf.shape(image)[0], tf.float32)
1714
+ width = tf.cast(tf.shape(image)[1], tf.float32)
1715
+
1716
+ if tf.random.uniform(()) > p_high_color:
1717
+ image = tf.image.random_hue(image, max_delta=0.05)
1718
+ image = tf.image.random_brightness(image, max_delta=0.15)
1719
+ image = tf.image.random_saturation(image, 0.7, 1.3)
1720
+ image = tf.image.random_contrast(image, 0.8, 1.2)
1721
+ else:
1722
+ image = tf.image.random_hue(image, max_delta=0.1)
1723
+ image = tf.image.random_brightness(image, max_delta=0.3)
1724
+ image = tf.image.random_saturation(image, 0.0, 2.0)
1725
+ image = tf.image.random_contrast(image, 0.2, 1.5)
1726
+
1727
+ # Apply shear, rotation, and scale through one affine matrix
1728
+ sel = tf.random.uniform((), 0, 1)
1729
+ if sel < 0.1:
1730
+ pass
1731
+ else:
1732
+ if sel < 0.15: # Scale only
1733
+ shear_x = 0
1734
+ shear_y = 0
1735
+ rotation = 0
1736
+ if sel < 0.7: # Mild
1737
+ shear_x = tf.random.uniform((), -2, 2)
1738
+ shear_y = tf.random.uniform((), -2, 2)
1739
+ rotation = tf.random.uniform((), -5, 5)
1740
+ else: # Severe
1741
+ shear_x = tf.random.uniform((), -10, 10)
1742
+ shear_y = tf.random.uniform((), -10, 10)
1743
+ rotation = tf.random.uniform((), -20, 20)
1744
+
1745
+ max_scale = 1.2
1746
+ scale = tf.random.uniform((), 0.4, max_scale)
1747
+
1748
+ # Pad so upscaling/rotation will not move the image out of bounds
1749
+ pad = tf.cast(tf.maximum(height, width)*0.2, tf.int32)
1750
+ image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
1751
+
1752
+ image = tf.keras.ops.image.affine_transform(
1753
+ image,
1754
+ tf.stack(get_affine_matrix(
1755
+ [height/2, width/2],
1756
+ rotation,
1757
+ [0, 0],
1758
+ 1/scale,
1759
+ [shear_x, shear_y]
1760
+ ) + [0., 0.]),
1761
+ interpolation='bilinear',
1762
+ fill_mode='constant',
1763
+ fill_value=1.,
1764
+ data_format='channels_last'
1765
+ )
1766
+
1767
+ # Crop, otherwise it would be impossible to put the image at the corner of the image
1768
+ not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
1769
+ no_white_ix = tf.where(not_white)
1770
+ top_left = tf.reduce_min(no_white_ix, axis=0)
1771
+ bottom_right = tf.reduce_max(no_white_ix, axis=0)
1772
+
1773
+ # Very low chance center crop will get nothing but white space, we just skip
1774
+ if (
1775
+ (bottom_right[0] - top_left[0]) > 1 and (bottom_right[1] - top_left[1]) > 1
1776
+ ):
1777
+ image = tf.image.crop_to_bounding_box(
1778
+ image,
1779
+ offset_height=tf.cast(top_left[0], tf.int32),
1780
+ offset_width=tf.cast(top_left[1], tf.int32),
1781
+ target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
1782
+ target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
1783
+ )
1784
+
1785
+ # Translate
1786
+ height, width = tf.shape(image)[0], tf.shape(image)[1]
1787
+ if tf.random.uniform((), 0, 1) < 0.1:
1788
+ h_pad = tf.zeros((2,), dtype=tf.int32)
1789
+ w_pad = tf.zeros((2,), dtype=tf.int32)
1790
+ elif tf.random.uniform((), 0, 1) < 0.8:
1791
+ h_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
1792
+ w_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
1793
+ else:
1794
+ pad = tf.cast(tf.maximum(height, width), tf.int32)
1795
+ h_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
1796
+ w_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
1797
+ image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
1798
+ constant_values=1)
1799
+
1800
+ if "image_url" in ex: # URL won't show the augmentations
1801
+ del ex["image_url"]
1802
+ # ex["metadata/unaugmented_image"] = ex["image"]
1803
+ ex["image"] = image
1804
+ return ex
1805
+
1806
+
1807
+ @seqio.map_over_dataset()
1808
+ def wiki_art_preprocessor(ex):
1809
+ out = dict(
1810
+ image=ex["image"],
1811
+ prompt="What is this?",
1812
+ text=ex["question"]
1813
+ )
1814
+ out["metadata/title"] = ex["title"]
1815
+ out["metadata/gt"] = ex["question"]
1816
+ out["metadata/artist"] = ex["artist"]
1817
+ out["metadata/painting_url"] = ex["painting_url"]
1818
+ # if "metadata/unaugmented_image" in ex:
1819
+ # out["metadata/unaugmented_image"] = ex["metadata/unaugmented_image"]
1820
+ return out
1821
+
1822
+ @seqio.map_over_dataset()
1823
+ def oscar_preprocessor(ex):
1824
+ out = dict(
1825
+ image=ex["image"],
1826
+ prompt=ex["question"]
1827
+ )
1828
+ out.update(_add_metadata(ex))
1829
+ out["metadata/question"] = ex["question"]
1830
+ out["metadata/answer"] = ex["answer"]
1831
+ out["metadata/category"] = ex["category"]
1832
+ return out
1833
+
1834
+
1835
+ @seqio.map_over_dataset()
1836
+ def tulu_preprocessor(ex):
1837
+ return {
1838
+ "messages": ex["messages"]["content"],
1839
+ }
1840
+ # logging.info("Debugging tulue")
1841
+ # return {"messages": ex["messages"]["content"], "text_weights": 1e-6}
1842
+
1843
+
1844
+ WIKI_DATA_QUESTION = "What is this? Respond with just a proper name."
1845
+
1846
+
1847
+ @seqio.map_over_dataset()
1848
+ def extract_wiki_data(ex):
1849
+ return dict(
1850
+ image=ex["image"],
1851
+ image_url=ex["image_url"],
1852
+ prompt=[
1853
+ WIKI_DATA_QUESTION,
1854
+ "What is this? Respond with the proper name of the main focus of the image and a few details about it."
1855
+ ],
1856
+ text=[
1857
+ tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", "")),
1858
+ ex["gptResponse"],
1859
+ ]
1860
+ )
1861
+
1862
+
1863
+ @seqio.map_over_dataset()
1864
+ def extract_wiki_data_name(ex):
1865
+ target = tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", ""))
1866
+ out = dict(
1867
+ image=ex["image"],
1868
+ image_url=ex["image_url"],
1869
+ prompt=WIKI_DATA_QUESTION,
1870
+ text=target,
1871
+ )
1872
+ out["metadata/references"] = target
1873
+ return out
1874
+
1875
+
1876
+ @seqio.map_over_dataset()
1877
+ def extract_wiki_data_describe(ex):
1878
+ out = dict(
1879
+ image=ex["image"],
1880
+ image_url=ex["image_url"],
1881
+ prompt="What is this? Respond with the proper name of the main focus of the image and a few details about it.",
1882
+ )
1883
+ out["metadata/references"] = ex["gptResponse"]
1884
+ return out
1885
+
1886
+
1887
+ @gin.configurable()
1888
+ def format_multiple_style_qa(ds, types=['multiple_choice', 'short_answer'], styles=['ai2_diagram', 'vqa2'], default_style='vqa2',
1889
+ strip_instruction=False):
1890
+ def _extract(ex):
1891
+ prompt = ex["question"]
1892
+ out = dict(image=ex["image"])
1893
+ out.update(_add_metadata(ex))
1894
+
1895
+ out["text"] = ex["answer"]
1896
+ out["metadata/references"] = ex["answer"]
1897
+
1898
+ if ex["metadata/question_type"] == 'multiple_choice':
1899
+ style = styles[0]
1900
+ else:
1901
+ style = styles[1]
1902
+ if strip_instruction:
1903
+ if ex["metadata/question_type"] == "multiple_choice":
1904
+ # parts = tf.strings.split(prompt, "\n")
1905
+ # parts 1 is blank and part -1 is the instruction
1906
+ # prompt = tf.strings.reduce_join(tf.concat([parts[:1], parts[2:-1]], 0), separator="\n")
1907
+ prompt = prompt
1908
+ else:
1909
+ prompt = tf.strings.split(prompt, "\n")[0]
1910
+
1911
+ out["style"] = style
1912
+ out["prompt"] = prompt
1913
+ return out
1914
+ ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
1915
+ return ds
1916
+
1917
+
1918
+ @gin.configurable()
1919
+ def extract_mmmu(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
1920
+ assert option_format == "abc"
1921
+ keys_tensor = tf.constant(types, dtype=tf.string)
1922
+ values_tensor = tf.constant(styles, dtype=tf.string)
1923
+ table = tf.lookup.StaticHashTable(
1924
+ tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
1925
+ default_value=tf.constant(default_style, dtype=tf.string),
1926
+ )
1927
+ def _extract(ex):
1928
+ out = dict(image=tf.expand_dims(ex["image_1"], 0))
1929
+ out.update(_add_metadata(ex))
1930
+ style = table.lookup(ex["metadata/question_type"])
1931
+ out["style"] = style
1932
+ out["text"] = ex["answer"]
1933
+ out["metadata/references"] = ex["answer"]
1934
+
1935
+ if style == styles[0]:
1936
+ abc = tf.constant(list("abcdefghi".upper()))
1937
+ options = ex["options"]
1938
+ num_options = tf.shape(options)[0]
1939
+ dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
1940
+ out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
1941
+ out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
1942
+
1943
+ short_options = abc[:num_options]
1944
+ options = tf.stack([short_options, options,], 1)
1945
+ options = tf.strings.reduce_join(options, axis=-1, separator=": ")
1946
+ options = tf.strings.reduce_join(options, separator="\n")
1947
+ out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
1948
+ if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
1949
+ # Following LLaVa, don't use any images if there are multiple images paths
1950
+ # I think the rationale is that this means the image are answer-options
1951
+ out["image"] = out["image"][:0]
1952
+ else:
1953
+ out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
1954
+ out["prompt"] = ex["question"]
1955
+ out["image"] = out["image"][:0]
1956
+ return out
1957
+ ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
1958
+ return ds
1959
+
1960
+ @gin.configurable()
1961
+ def extract_mmmu_cot(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
1962
+ assert option_format == "abc"
1963
+ keys_tensor = tf.constant(types, dtype=tf.string)
1964
+ values_tensor = tf.constant(styles, dtype=tf.string)
1965
+ table = tf.lookup.StaticHashTable(
1966
+ tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
1967
+ default_value=tf.constant(default_style, dtype=tf.string),
1968
+ )
1969
+ def _extract(ex):
1970
+ # out = dict(image=tf.expand_dims(ex["image_with_question"], 0))
1971
+ out = dict(image=tf.expand_dims(ex["image_1"], 0))
1972
+ out.update(_add_metadata(ex))
1973
+ style = table.lookup(ex["metadata/question_type"])
1974
+ # out["style"] = style
1975
+ out["text"] = ex["answer"]
1976
+ out["metadata/question"] = ex["question"]
1977
+ out["metadata/references"] = ex["answer"]
1978
+
1979
+ if style == styles[0]:
1980
+ abc = tf.constant(list("abcdefghi".upper()))
1981
+ options = ex["options"]
1982
+ num_options = tf.shape(options)[0]
1983
+ dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
1984
+ out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
1985
+ out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
1986
+
1987
+ short_options = abc[:num_options]
1988
+ options = tf.stack([short_options, options,], 1)
1989
+ options = tf.strings.reduce_join(options, axis=-1, separator=": ")
1990
+ options = tf.strings.reduce_join(options, separator="\n")
1991
+ out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
1992
+ # out["prompt"] = ex["question"]
1993
+ if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
1994
+ # Following LLaVa, don't use any images if there are multiple images paths
1995
+ # I think the rationale is that this means the image are answer-options
1996
+ out["image"] = out["image"][:0]
1997
+ else:
1998
+ out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
1999
+ out["prompt"] = ex["question"]
2000
+ # out["image"] = out["image"][:0]
2001
+ return out
2002
+ ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
2003
+ return ds
2004
+
2005
+
2006
+ @seqio.map_over_dataset
2007
+ def reformat_math_vista(ex):
2008
+ query = ex["query"]
2009
+ query = tf.strings.split(query, sep="Question:")[-1]
2010
+ query = tf.strings.strip(tf.strings.split(query, sep="Hint:")[0])
2011
+ ex["query"] = query
2012
+ return ex
2013
+
2014
+
2015
+ @seqio.map_over_dataset
2016
+ def extract_math_vista(ex, styles=['ai2_diagram', 'vqa2']):
2017
+ out = dict(image=ex["image"])
2018
+ out.update(_add_metadata(ex))
2019
+
2020
+ is_mc = ex["metadata/question_type"] == 'multi_choice'
2021
+ if is_mc:
2022
+ style = styles[0]
2023
+ abc = tf.constant(list("abcdefghi".upper()))
2024
+ options = ex["choices"]
2025
+ num_options = tf.shape(options)[0]
2026
+ dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
2027
+ out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
2028
+ out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
2029
+
2030
+ if ex["metadata/split"] != "test":
2031
+ short_options = abc[:num_options]
2032
+ answer_short_option = tf.boolean_mask(short_options, options == ex["answer"])[0]
2033
+ out["text"] = answer_short_option
2034
+ else:
2035
+ out["text"] = ex["answer"]
2036
+ else:
2037
+ style = styles[1]
2038
+ out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
2039
+ out["text"] = ex["answer"]
2040
+ out["style"] = style
2041
+ out["prompt"] = ex["query"]
2042
+ out["metadata/query"] = ex["query"]
2043
+ out["metadata/references"] = ex["answer"]
2044
+ return out
2045
+
2046
+
2047
+ NO_POINT_PREFIX = [
2048
+ "No pointing: ",
2049
+ "No pointing: ",
2050
+ "no pointing:\n",
2051
+ "No pointing:\n",
2052
+ "Not pointing:\n",
2053
+ "No Points: ",
2054
+ "No Points: ",
2055
+ "NO POINTING\n",
2056
+ "No pontiing\n",
2057
+ "No Points:\n ",
2058
+ "No pointing\n",
2059
+ "Do not point. ",
2060
+ "Refrain from pointing. ",
2061
+ "Avoid generating points . ",
2062
+ "For this question, do not use points. ",
2063
+ "Refrain from using points:\n",
2064
+ "Don't include points in your response. ",
2065
+ "Don't point. ",
2066
+ "Don't use points. ",
2067
+ "Please don't use points.\n\n",
2068
+ "Please don't use points.\n\n",
2069
+ "Respond without using points. ",
2070
+ "Respond without pointing:\n",
2071
+ "Do not generate ponits: ",
2072
+ "Do not point. ",
2073
+ "Do not point\n",
2074
+ "no pointing\n\n",
2075
+ "Answer without points: ",
2076
+ "Answer this question without pointing: ",
2077
+ "Answer without poiints. ",
2078
+ "answer without points: ",
2079
+ "answer with text only, do not points\n"
2080
+ ]
2081
+ assert all(x[-1].isspace() for x in NO_POINT_PREFIX)
2082
+ NO_POINT_PREFIX_TF = tf.constant(NO_POINT_PREFIX)
2083
+
2084
+
2085
+ def prefix_how_many(messages, seed):
2086
+ question = messages[0]
2087
+ if tf.strings.regex_full_match(tf.strings.lower(question), "how many.*"):
2088
+ ix = tf.random.stateless_uniform((), seed, 0, len(NO_POINT_PREFIX), tf.int32)
2089
+ question = tf.strings.join([NO_POINT_PREFIX_TF[ix], question])
2090
+ return tf.concat([tf.expand_dims(question, 0), messages[1:]], axis=0)
2091
+ else:
2092
+ return messages
2093
+
2094
+
2095
+ @seqio.map_over_dataset(num_seeds=1)
2096
+ def prefix_how_many_messages(ex, seed):
2097
+ messages = ex["messages"]
2098
+ n = tf.shape(messages)[0]
2099
+ seeds = tf.random.split(seed, n)
2100
+ message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
2101
+ for i in range(n):
2102
+ message_arr = message_arr.write(i, prefix_how_many(messages[i], seeds[i]))
2103
+ ex["messages"] = tf.RaggedTensor.from_row_splits(
2104
+ values=message_arr.concat(), row_splits=messages.row_splits)
2105
+ return ex
2106
+
2107
+
2108
+ def filter_single_turn(ds):
2109
+ @seqio.map_over_dataset
2110
+ def _filter(ex):
2111
+ multi_turn = ex["messages"].row_lengths() > 2
2112
+ ex["messages"] = tf.ragged.boolean_mask(ex["messages"], multi_turn)
2113
+ return ex
2114
+
2115
+ ds = _filter(ds)
2116
+ ds = ds.filter(lambda x: tf.shape(x["messages"])[0] > 0)
2117
+ return ds
2118
+
2119
+
2120
+ @seqio.map_over_dataset(num_seeds=1)
2121
+ def extract_cockatoo_qa_v2(ex, seed):
2122
+ messages = tf.RaggedTensor.from_value_rowids(ex["messages"], ex["conversation_ids"])
2123
+ ix = stateless_permutation(tf.shape(messages)[0], seed)
2124
+ messages = tf.gather(messages, ix)
2125
+ out = dict(
2126
+ image=ex["image"],
2127
+ messages=messages
2128
+ )
2129
+ out.update(_add_metadata(ex))
2130
+ return out
2131
+
2132
+
2133
+ def format_mmbench(ds):
2134
+
2135
+ def _trim(ex):
2136
+ num_passes = tf.shape(ex["id"])[0]
2137
+ ex["choices"] = ex["choices"][:num_passes, :num_passes]
2138
+ ex["answer"] = ex["answer"][:num_passes]
2139
+ return ex
2140
+
2141
+ ds = ds.map(_trim)
2142
+ ds = flatten_parts(ds, ["id", "query", "choices", "answer"])
2143
+
2144
+ def _extract(ex):
2145
+ out = dict(image=ex["image"])
2146
+ out.update(_add_metadata(ex))
2147
+ out["prompt"] = ex["query"]
2148
+ out["text"] = ex["answer"]
2149
+ options = ex["choices"]
2150
+ tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
2151
+ out["metadata/options"] = tf.strings.reduce_join(options, separator="|||")
2152
+ out["metadata/question"] = ex["question"]
2153
+ out["metadata/references"] = ex["answer"]
2154
+ return out
2155
+
2156
+ ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
2157
+ return ds
2158
+
2159
+
2160
+ @seqio.map_over_dataset
2161
+ def extract_lvis(ex, class_name_file="gs://oe-training-chrisc/cockatoo/data/lvis_class_names.json"):
2162
+ with tf.io.gfile.GFile(class_name_file) as f:
2163
+ class_names = json.load(f)
2164
+ class_names_arr = [None]*len(class_names)
2165
+ for k, v in class_names.items():
2166
+ class_names_arr[int(k)] = v
2167
+ assert all(x is not None for x in class_names_arr)
2168
+ class_names_arr = tf.constant(class_names_arr)
2169
+
2170
+ return dict(
2171
+ image=ex["image"],
2172
+ bbox=ex["objects"]["bbox"],
2173
+ label=tf.gather(class_names_arr, ex["objects"]["label"]),
2174
+ )
2175
+
2176
+
2177
+ def extract_open_images_boxes(ds):
2178
+ # ds = ds.filter(lambda ex: tf.logical_or(
2179
+ # tf.shape(ex["cap/cap_caption"])[0] > 0,
2180
+ # tf.shape(ex["detection/bbox"])[0] > 0
2181
+ # ))
2182
+ ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
2183
+
2184
+ @seqio.map_over_dataset
2185
+ def _map(ex):
2186
+ bbox = tf.reshape(ex["detection/bbox"], (-1, 4))
2187
+ bbox = tf.stack([
2188
+ bbox[:, 2],
2189
+ bbox[:, 0],
2190
+ bbox[:, 3],
2191
+ bbox[:, 1]
2192
+ ], 1)
2193
+ return dict(
2194
+ image=tf.image.decode_jpeg(ex["image"]),
2195
+ bbox=bbox,
2196
+ label=ex["detection/label"],
2197
+ caption=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
2198
+ )
2199
+
2200
+ return _map(ds)
2201
+
2202
+
2203
+ @seqio.map_over_dataset
2204
+ def region_captions_to_dense(ex):
2205
+ if "captions" in ex:
2206
+ captions = ex["captions"]["text"]
2207
+ boxes = ex["captions"]["bbox"]
2208
+ else:
2209
+ captions = ex["label"]
2210
+ boxes = ex["bbox"]
2211
+
2212
+
2213
+ sh = tf.cast(tf.shape(ex["image"])[:2], tf.float32)
2214
+ # image_h, image_w = sh[0], sh[1]
2215
+ w = boxes[:, 2] - boxes[:, 0]
2216
+ h = boxes[:, 3] - boxes[:, 1]
2217
+
2218
+ cx = tf.cast(boxes[:, 0] + w/2, tf.float32)
2219
+ cy = tf.cast(boxes[:, 1] + h/2, tf.float32)
2220
+ # w = w / image_w
2221
+ # h = h / image_h
2222
+ coor = tf.strings.reduce_join(
2223
+ float_to_text(tf.stack([cx, cy, w, h], 1)), separator=",", axis=1)
2224
+
2225
+ area = w*h
2226
+ if tf.random.uniform(()) < 0.5:
2227
+ coor_text = "before"
2228
+ captions = tf.strings.join([coor, captions], separator=": ")
2229
+ else:
2230
+ coor_text = "after"
2231
+ captions = tf.strings.join([captions, coor], separator=": ")
2232
+
2233
+ ix = tf.random.uniform((), 0, 6, tf.int32)
2234
+ center = boxes
2235
+ if ix == 0:
2236
+ order_text = "left"
2237
+ sort_by = boxes[:, 0]
2238
+ elif ix == 1:
2239
+ order_text = "right"
2240
+ sort_by = -boxes[:, 2]
2241
+ elif ix == 2:
2242
+ order_text = "top"
2243
+ sort_by = boxes[:, 1]
2244
+ elif ix == 3:
2245
+ order_text = "bottom"
2246
+ sort_by = -boxes[:, 3]
2247
+ elif ix == 4:
2248
+ order_text = "largest"
2249
+ sort_by = area
2250
+ else:
2251
+ order_text = "smallest"
2252
+ sort_by = -area
2253
+ ixs = tf.argsort(sort_by)
2254
+ captions = tf.gather(captions, ixs)
2255
+ text = tf.strings.join([
2256
+ order_text,
2257
+ coor_text,
2258
+ tf.strings.reduce_join(captions, separator="\n")
2259
+ ], separator="; ")
2260
+
2261
+ if "caption" in ex:
2262
+ if tf.random.uniform(()) > 0.5:
2263
+ text = tf.strings.join([text, "\ncaption: ", ex["caption"]])
2264
+ else:
2265
+ text = tf.strings.join(["caption: ", ex["caption"], "\n", text])
2266
+
2267
+ return dict(
2268
+ image=ex["image"],
2269
+ text=text
2270
+ )
2271
+
2272
+
2273
+ @seqio.map_over_dataset()
2274
+ def join_captions(ex):
2275
+ text = tf.random.shuffle(ex['text'])
2276
+ ex["text"] = tf.strings.reduce_join(text, separator="\n")
2277
+ return ex
2278
+
2279
+
2280
+ @seqio.map_over_dataset(num_seeds=1)
2281
+ def extract_figureqa(ex, seed):
2282
+ questions = ex["questions"]
2283
+ n = stateless_permutation(tf.shape(questions["question"])[0], seed)
2284
+ return dict(
2285
+ image=ex["image"],
2286
+ questions=tf.gather(questions["question"], n),
2287
+ question_id=tf.gather(questions["question_id"], n),
2288
+ answer=tf.gather(tf.strings.as_string(questions["answer"]), n)
2289
+ )
2290
+
2291
+
2292
+ @seqio.map_over_dataset
2293
+ def convert_figureqa_answer(ex):
2294
+ keys_tensor = tf.constant(["0", "1"])
2295
+ values_tensor = tf.constant(["no", "yes"])
2296
+ table = tf.lookup.StaticHashTable(
2297
+ tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
2298
+ default_value=tf.constant("nan", dtype=tf.string),
2299
+ )
2300
+ answer = table.lookup(ex["answer"])
2301
+ ex["answer"] = answer
2302
+ return ex
2303
+
2304
+
2305
+ @seqio.map_over_dataset()
2306
+ def build_question_with_hint(ex):
2307
+ hint = ex["hint"]
2308
+ if tf.strings.length(hint) > 0:
2309
+ ex["question"] = tf.strings.join([hint, ex["question"]], separator="\n")
2310
+ return ex
2311
+
2312
+ @seqio.map_over_dataset()
2313
+ def build_question_with_context(ex):
2314
+ context = ex["context"]
2315
+ if tf.strings.length(context) > 0:
2316
+ ex["question"] = tf.strings.join([context, ex["question"]], separator="\n")
2317
+ return ex
2318
+
2319
+
2320
+ def max_words(ds, max_words):
2321
+ return ds.filter(lambda x: x["n_words"] <= max_words)
2322
+
2323
+
2324
+ @seqio.map_over_dataset
2325
+ def format_pdfa_eng_wds(example):
2326
+ return dict(
2327
+ image=example["image"],
2328
+ text=tf.strings.reduce_join(example["lines"]["text"], separator="\n"),
2329
+ )
2330
+
2331
+
2332
+ @gin.configurable()
2333
+ def accuracy_conditioned_joint(ds, sequence_length, is_eval=False, eval_quality=17,
2334
+ transcript_quality=None):
2335
+ # v2: Transcripts no longer get a quality score
2336
+ is_training = sequence_length.get('is_training', True)
2337
+ if not is_training:
2338
+ if is_eval:
2339
+ prompt = f"quality {eval_quality}:"
2340
+ else:
2341
+ prompt = f"quality 17:"
2342
+
2343
+ @seqio.map_over_dataset
2344
+ def _with_prompt(ex):
2345
+ out = dict(
2346
+ image=ex["image"],
2347
+ url=ex["url"],
2348
+ prompt=prompt,
2349
+ )
2350
+ if "text" in ex:
2351
+ out["text"] = ex["text"]
2352
+ elif "caption" in ex:
2353
+ out["text"] = ex["caption"]
2354
+ return out
2355
+ return _with_prompt(ds)
2356
+
2357
+ elif is_eval:
2358
+ raise ValueError("is_eval=True and is_training=False")
2359
+
2360
+ # each transcript
2361
+ @seqio.map_over_dataset
2362
+ def _with_transcript(ex):
2363
+ if tf.shape(ex["edited_captions"]["caption"])[0] > 0:
2364
+ edited_caption = ex["edited_captions"]["caption"][0]
2365
+ n = ex["edited_captions"]["n_edits"][0]
2366
+ else:
2367
+ edited_caption = ""
2368
+ n = 0
2369
+ text = [
2370
+ ex["caption"],
2371
+ ex["transcripts"][tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)],
2372
+ edited_caption
2373
+ ]
2374
+ edit_quality = 17 - n
2375
+ prompt = [
2376
+ "quality 17:",
2377
+ "" if transcript_quality is None else f"quality: {edit_quality}:",
2378
+ tf.strings.join(["quality ", tf.strings.as_string(edit_quality), ":"])
2379
+ ]
2380
+ return dict(
2381
+ image=ex["image"],
2382
+ text=tf.stack(text, 0),
2383
+ url=ex["url"],
2384
+ prompt=tf.stack(prompt, 0),
2385
+ style=["long_caption", "transcript", "long_caption"]
2386
+ )
2387
+ return _with_transcript(ds)
2388
+
2389
+
2390
+ def select_dense_caption_sample(ds, samples=200):
2391
+ def compute_hash(string: str) -> str:
2392
+ return hashlib.sha256(string.encode("utf-8")).hexdigest()
2393
+
2394
+ with tf.io.gfile.GFile("gs://oe-training-chrisc/cockatoo/data/dense-caption-eval-v0-final-data.json") as f:
2395
+ data = json.load(f)
2396
+ for ex in data:
2397
+ ex["image_id"] = compute_hash(ex["image"])
2398
+ data.sort(key=lambda x: x["image_id"])
2399
+ np.random.RandomState(12312).shuffle(data)
2400
+ keep = tf.constant([x["image"] for x in data[:samples]])
2401
+
2402
+ def _keep(ex):
2403
+ return tf.reduce_any(ex["url"] == keep)
2404
+ ds = ds.filter(_keep)
2405
+ ds = tf.data.experimental.assert_cardinality(samples)(ds)
2406
+ return ds
2407
+
2408
+ @seqio.map_over_dataset()
2409
+ def charxiv_preprocessor(ex):
2410
+ question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4", "reasoning_q"]
2411
+ answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4", "reasoning_a"]
2412
+
2413
+ questions = [ex[name] for name in question_names]
2414
+ answers = [ex[name] for name in answer_names]
2415
+
2416
+ return dict(
2417
+ image=ex["image"],
2418
+ question=tf.stack(questions, 0),
2419
+ answer=tf.stack(answers, 0)
2420
+ )
2421
+
2422
+ @seqio.map_over_dataset()
2423
+ def charxiv_descriptive_preprocessor(ex):
2424
+ question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4"]
2425
+ answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4"]
2426
+
2427
+ questions = [ex[name] for name in question_names]
2428
+ answers = [ex[name] for name in answer_names]
2429
+
2430
+ return dict(
2431
+ image=ex["image"],
2432
+ question=tf.stack(questions, 0),
2433
+ answer=tf.stack(answers, 0)
2434
+ )
2435
+
2436
+ @seqio.map_over_dataset()
2437
+ def charxiv_reasoning_preprocessor(ex):
2438
+ return dict(
2439
+ image=ex["image"],
2440
+ question=ex["reasoning_q"],
2441
+ answer=ex["reasoning_a"]
2442
+ )
2443
+
2444
+ @seqio.map_over_dataset()
2445
+ def tablevqa_preprocessor(ex):
2446
+ return dict(
2447
+ image=ex["image"],
2448
+ question=ex["question"],
2449
+ answer=ex["gt"]
2450
+ )
2451
+
2452
+ @seqio.map_over_dataset()
2453
+ def vtabfact_preprocessor(ex):
2454
+ return dict(
2455
+ image=ex["image"],
2456
+ question=tf.strings.join([ex["question"], "Answer with yes or no."], separator="\n"),
2457
+ answer=ex["gt"]
2458
+ )
2459
+
2460
+ @seqio.map_over_dataset()
2461
+ def nutrition_fact_preprocessor(ex):
2462
+ question_names = ["descriptive_q", "reasoning_q"]
2463
+ answer_names = ["descriptive_a", "reasoning_a"]
2464
+
2465
+ questions = [ex[name] for name in question_names]
2466
+ answers = [ex[name] for name in answer_names]
2467
+
2468
+ return dict(
2469
+ image=ex["image"],
2470
+ question=tf.stack(questions, 0),
2471
+ answer=tf.stack(answers, 0)
2472
+ )
prompts.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import tensorflow as tf
4
+
5
+ IMAGE_PROMPT = "<|image|>"
6
+
7
+
8
+ GENERAL_PROMPTS_V1 = {
9
+ "short_answer": [
10
+ "Answer this question very briefly\n{question}",
11
+ "{question} Answer with a few words",
12
+ "{question} Response very briefly",
13
+ "{question} Answer directly without any details, explanation, or elaboration",
14
+ "I have a question about this image, please answer it very briefly: {question}",
15
+ "Question: {question} Short Answer:",
16
+ "Question: {question}\nShort Answer:",
17
+ '{question}\nAnswer the question as briefly as possible.',
18
+ 'Answer very briefly:\n{question}',
19
+ 'The question "{question}" can be answered using the image. A short answer is',
20
+ "{question} Based on the image, respond to this question with a short answer:",
21
+ "{question} Short answer:",
22
+ "{question} A short answer to the question is",
23
+ "Give a short, matter-of-fact answer to this question: {question}",
24
+ "Give me a simple, direct answer to this question, do not elaborate or explain your answer:\n{question}"
25
+ ],
26
+ "short_caption": [
27
+ 'Caption the image with 1 or two sentences',
28
+ 'Write a very short description of this image.',
29
+ 'Briefly describe the image.',
30
+ 'Look and this image, and then summarize it in a sentence or two.',
31
+ 'Write a brief caption describing the image',
32
+ 'Brief Caption:'
33
+ 'A short image caption:',
34
+ 'A short image description',
35
+ 'Briefly describe the content of the image.',
36
+ 'Can you give me one sentence summary of the picture?',
37
+ 'How would you describe this image in a sentence or two?',
38
+ ],
39
+ "long_caption": [
40
+ 'Describe this image.',
41
+ 'Describe this image',
42
+ 'describe the image',
43
+ 'Write a long description of this image.',
44
+ 'caption the picture',
45
+ 'Caption',
46
+ 'caption',
47
+ 'Construct a long caption for this image',
48
+ 'Generate a caption',
49
+ 'Create a detailed caption',
50
+ 'Write a long caption',
51
+ 'Describe this image in detail',
52
+ 'Describe this',
53
+ 'describe this',
54
+ 'Caption this',
55
+ 'What can be seen in this image?',
56
+ 'What do you see in the image?',
57
+ 'Look at this photo carefully and then tell me about it in detail',
58
+ 'Write a long description of this image',
59
+ 'Tell me about this picture.',
60
+ 'Write a paragraph about this image.',
61
+ 'Look at this image carefully and then describe it in detail',
62
+ 'Generate a long caption about this image.'
63
+ ],
64
+ "long_caption_no_pointing": [
65
+ 'Describe this image in detail, but without any pointing.',
66
+ 'Write a long description of this image, do not produce any points.',
67
+ 'Tell me about this picture, use plain text only.',
68
+ 'Generate a plain text description of this caption',
69
+ "What is in this image?\nNo pointing\nGive lots of detail"
70
+ "Write a long caption.\nDo not use image coordinates\nOutput a full paragraph"
71
+ ],
72
+ "transcript": [
73
+ 'Describe this image as if you are a person speaking',
74
+ 'Imagine you are a person talking about this image. Generate a transcript of what you would say.',
75
+ "Generate an audio transcript of a person describing this image",
76
+ "Create a transcript of a human describing this image out load",
77
+ "Describe this in this style of a human talking",
78
+ ],
79
+ "refexp": [
80
+ 'What region does \"{refexp}\" refer to?',
81
+ ],
82
+ "count_bench": [
83
+ 'How many {object} are there?',
84
+ ],
85
+ "refexp_pointing": [
86
+ 'Where is the \"{refexp}\"?',
87
+ 'Point to {refexp}',
88
+ 'point at {refexp}',
89
+ 'Find the {refexp}.',
90
+ 'Which object in the image does \"{refexp}\" refer to?',
91
+ 'Locate the object \"{refexp}\" refers to.',
92
+ 'Point to the object that best matches the expression:\n{refexp}\n',
93
+ 'What object could be described as: {refexp}.\nPoint:',
94
+ 'Referring Expression: {refexp}.\nPoint:',
95
+ 'Expression: {refexp}\nPoint to the refexp',
96
+ 'Task: Point to the object that best matches the expression.\nExpression: {refexp}\nPoint:',
97
+ 'Instruction: Locate the object that matches the expression by returning a point.\nReferring Expression: {refexp}\n',
98
+ 'Help me find an object in this image by pointing to the {refexp}',
99
+ 'What point of the image might the expression \'{refexp}\' refer to?',
100
+ ],
101
+ "plain": ["{question}"],
102
+ "multiple_choice": [
103
+ "{question}\n{options}\nReturn only the letter of the best answer option",
104
+ "Answer this question by naming one of the provided options:\n{question}\n{options}",
105
+ "{question}\n{options}\nWhat option best answers the question?",
106
+ "{question}\n{options}\nReturn the best answer option",
107
+ "Look at the options, then return the letter of the option that best answers the question.\nQuesiton: {question}\nOptions: {options}",
108
+ "{question}? Select an answer option from:\n{options}",
109
+ "{question}\nSelect an answer option from:\n{options}\n\n",
110
+ "Question: {question}? Options: {options} Answer:",
111
+ "Answer the question by selecting an answer options\nQuestion: {question}\nOptions: {options}",
112
+ "{question}?\n{options}\nReturn only the letter of the correct answer",
113
+ "Help me answer this question: \"{question}\", by stating which of the following options is correct\n{options}."
114
+ ],
115
+ "binary": ["{question}\nAnswer with 'yes' or 'no'"],
116
+ "pointing": [
117
+ "Point to {entity}\nPlease say 'This isn't in the image.' if it is not in the image.",
118
+ "Point to all occurrences of \"{entity}\"",
119
+ "Point to any {entity} in the image",
120
+ "Point to any {entity} in the image.",
121
+ "Point: Where are the {entity}",
122
+ "Show me where the {entity} are",
123
+ "Can you show me where the {entity} are?",
124
+ "Show me where the {entity} are",
125
+ "Show me where a {entity} is",
126
+ "Show me where a {entity} is.",
127
+ "If there are any {entity} in the image? Show me where they are.",
128
+ "Where are the {entity}?",
129
+ "Generate a list of points showing where the {entity} are.",
130
+ "Find the \"{entity}\".",
131
+ "Find a \"{entity}\".",
132
+ "Locate all {entity}.",
133
+ "Locate an {entity}.",
134
+ "Locate a {entity}.",
135
+ "Locate every {entity}.",
136
+ "Locate {entity}.",
137
+ "Locate the {entity}.",
138
+ "Object: {entity}\nInstruction: Point to the object.",
139
+ "find {entity}",
140
+ "find {entity}.",
141
+ "Point to every {entity}",
142
+ "find any {entity} in the picture",
143
+ "Find the {entity}",
144
+ "Find any {entity}",
145
+ "Point to a {entity}",
146
+ "Point to an {entity}",
147
+ "Look for {entity} in the image and show me where they are.",
148
+ "Help me find an object in the image by pointing to them.\nObject: {entity}.",
149
+ "I am looking for {entity}, where can they be found in the image?",
150
+ "Can you see any {entity} in the image? Point to them.",
151
+ "Point out each {entity} in the image.",
152
+ "Point out every {entity} in the image.",
153
+ "Point to the {entity} in the image.",
154
+ "Locate each {entity} in the image.",
155
+ "Can you point out all {entity} in this image?",
156
+ "Please find {entity} and show me where they are.",
157
+ "If there are any {entity} present, indicate their positions.",
158
+ "If there is a {entity} present, indicate its positions.",
159
+ "show me all visible {entity}",
160
+ ],
161
+ "point_count": [
162
+ "How many {entity} are there?",
163
+ "How many {entity}?",
164
+ "How many {entity}.",
165
+ "how many {entity}.",
166
+ "how many {entity}?",
167
+ "How many {entity} are there in the image?",
168
+ "Tell me how many {entity} there are",
169
+ "Tell me how many {entity} there are and point to them.",
170
+ "how many {entity}",
171
+ "Tell me where each {entity} is.",
172
+ "Tell me how many {entity} are in the image",
173
+ "count {entity}",
174
+ "count every {entity}",
175
+ "count each {entity}",
176
+ "count {entity}.",
177
+ "Count the {entity}.",
178
+ "How many {entity} do you see?",
179
+ "How many {entity} are visible?",
180
+ "Count all the {entity}",
181
+ "how mmny {entity}?",
182
+ "Count every {entity} in the picture.",
183
+ "Count all the {entity}",
184
+ "Count each {entity}",
185
+ "Point to and count the {entity} in the picture.",
186
+ "Point and count {entity}",
187
+ "Point to every {entity}",
188
+ "Locate the {entity} and count them",
189
+ "Locate every {entity} and count them",
190
+ "Find all the {entity}. How many are there?",
191
+ "Find each {entity}. How many are there?",
192
+ "Point at {entity} and then tell me the count.",
193
+ "What is the total number of {entity} in the image?",
194
+ "In all the picture, how many {entity} are there?",
195
+ "Point at the {entity} and then count them.",
196
+ "Point to all the visible {entity} output the total count.",
197
+ "Point to all the {entity} visible and output the total count. \nPlease say 'This isn't in the image.' if it is not in the image.",
198
+ "Point to all occurrences of \"{entity}\" and output the total count.",
199
+ "Show me where the {entity} are and output the total count.",
200
+ "Where are the {entity}? How many are there?",
201
+ "Generate list of points showing where the {entity} are and output the total count.",
202
+ "Object: {entity}\nInstruction: Point to the object and output the total count.",
203
+ "find any {entity} in the picture and output the total count.",
204
+ "Can you see any {entity} in the image? Point to them and output the total count.",
205
+ "Can you point out all {entity} in this image? How many are there?",
206
+ "If there are any {entity} present, indicate their positions and output the total count.",
207
+ "How many {entity} are there in the image? Point to them and output the total count.",
208
+ "How many {entity} are there in the image?",
209
+ "Give me the count of {entity} in the image.",
210
+ "How many {entity} are visible in the image?",
211
+ "How many {entity} are there?",
212
+ "In the image, how many {entity} are there?",
213
+ "Can you count the number of {entity} in the image?",
214
+ "Can you count every {entity} in the picture?",
215
+ "Can you see any {entity} in the image? How many are there?",
216
+ "Are there any {entity} in the image? How many are there?",
217
+ "If you see any {entity} in the image, give me the count. Otherwise, say 'This isn't in the image.'",
218
+ "Object: {entity}\nInstruction: How many are there?",
219
+ ],
220
+
221
+ # vaia
222
+ "detailed_solution": [
223
+ "Answer the question providing a step by step solution and answer in the end.\n"
224
+ "Provide a step-by-step solution to the question, ending with your final answer.\n",
225
+ "Please provide a step-by-step solution to the question shown in the image.\n",
226
+ "Give a detailed explanation for the question, concluding with your final answer.\n",
227
+ "Solve the problem presented in the question with a thorough explanation. Give me your final answer at the end.\n",
228
+ "Please analyze the question and provide a complete solution, finishing with your final answer.\n",
229
+ "Work through the problem, offering detailed reasoning before stating your final answer.\n",
230
+ "Interpret the question and guide me through the solution, concluding with your answer.\n",
231
+ "Review the question and deliver a well-explained solution, making sure to include your final answer.\n",
232
+ "Examine the question: provide a detailed explanation followed by your final answer.\n"
233
+ ],
234
+
235
+ # vaia first answer with short_answer
236
+ "detailed_solution_answer_first": [
237
+ "Answer the question directly, then provide a step-by-step solution.\n",
238
+ "Please provide the answer first, followed by a step-by-step solution to the question shown in the image.\n",
239
+ "Give the final answer first, then provide a detailed explanation for the question.\n",
240
+ "Provide the final answer, then solve the problem presented in the question with a thorough explanation.\n",
241
+ "First, give the final answer, then analyze the question and provide a complete solution.\n",
242
+ "State the final answer first, then work through the problem, offering detailed reasoning.\n",
243
+ "Provide the final answer, then interpret the question and guide me through the solution.\n",
244
+ "Give the final answer first, then review the question and deliver a well-explained solution.\n",
245
+ "First, provide the final answer, then examine the question and give a detailed explanation.\n"
246
+ ],
247
+
248
+ # vqa_online
249
+ "detailed_answer": [
250
+ "Answer the question providing a step-by-step explanation and answer in the end.\n",
251
+ "Provide a step-by-step explanation to the question, ending with your final answer.\n",
252
+ "Please provide a step-by-step explanation to the question shown in the image.\n",
253
+ "Give a detailed explanation for the question, concluding with your final answer.\n",
254
+ "Address the problem presented in the question with a thorough explanation. Give me your final answer at the end.\n",
255
+ "Please analyze the question and provide a complete explanation, finishing with your final answer.\n",
256
+ "Work through the problem, offering detailed reasoning before stating your final answer.\n",
257
+ "Interpret the question and guide me through the explanation, concluding with your answer.\n",
258
+ "Review the question and deliver a well-explained answer, making sure to include your final answer.\n",
259
+ "Examine the question: provide a detailed explanation followed by your final answer.\n"
260
+ ],
261
+ }
262
+
263
+ GENERAL_PROMPTS_V1["pointing_tag"] = [txt + " Make the alt text and the inside of the tag the target label." for txt in GENERAL_PROMPTS_V1["pointing"]]
264
+
265
+ STYLE_TO_GENERAL_PROMPT = {
266
+ "vqa2": "short_answer",
267
+ "coco_captioning": "short_caption",
268
+ "gqa": "short_answer",
269
+ "ocr_vqa": "short_answer",
270
+ "tally_qa": "short_answer",
271
+ "text_vqa": "short_answer",
272
+ "okvqa": "short_answer",
273
+ "chart_qa": "short_answer",
274
+ "doc_qa": "short_answer",
275
+ "info_qa": "short_answer",
276
+ "science_qa": "multiple_choice",
277
+ "ai2_diagram": "multiple_choice",
278
+ "a_okvqa_mc": "multiple_choice",
279
+ "a_okvqa_da": "short_answer",
280
+ "long_caption": "long_caption",
281
+ "web_pointing": "plain",
282
+ "count_bench": "count_bench",
283
+ "refexp": "refexp",
284
+ "refexp_pointing": "refexp_pointing",
285
+ "vtabfact": "binary",
286
+ "vwtq": "short_answer",
287
+ "vwtq_syn": "short_answer",
288
+ "fintabnetqa": "short_answer",
289
+ "scifi_charts": "short_answer",
290
+ "scifi_charts_qa": "short_answer",
291
+ "charxiv_descriptive": "short_answer",
292
+ "charxiv_reasoning": "short_answer",
293
+ "pointing": "pointing",
294
+ "pointing_tag": "pointing_tag",
295
+ "point_count": "point_count",
296
+ "plain": "plain",
297
+ }
298
+
299
+
300
+ # def maybe_format_options(example, option_style="basic"):
301
+ # abc = tf.constant(list("abcdefg".upper()))
302
+ # if option_style == "random-v1":
303
+ # letter_option_sep = [": ", ". ", ")"]
304
+ # option_sep = ["\n", "\n", "\n", " ", ". ", ".\n", "; ", ", "]
305
+ # option_sep = tf.constant(option_sep)[tf.random.uniform((), 0, len(option_sep), tf.int32)]
306
+ # elif option_style == "basic":
307
+ # letter_option_sep = ": "
308
+ # option_sep = "\n"
309
+ # else:
310
+ # raise NotImplementedError(option_style)
311
+ #
312
+ # options = example["options"]
313
+ # short_options = abc[:tf.shape(options)[0]]
314
+ # sep = tf.constant(letter_option_sep)[tf.random.uniform((), 0, len(letter_option_sep), tf.int32)]
315
+ #
316
+ # options = tf.stack([short_options, options,], 1)
317
+ #
318
+ # options = tf.strings.reduce_join(options, axis=-1, separator=sep)
319
+ #
320
+ # options = tf.strings.reduce_join(options, separator=option_sep)
321
+ # example["options"] = options
322
+ # tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
323
+ # example["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
324
+ #
325
+ # if "answer_idx" in example:
326
+ # if example["answer_idx"] < 0:
327
+ # example["text"] = "?"
328
+ # else:
329
+ # example["text"] = short_options[example["answer_idx"]]
330
+ # example["metadata/answer_idx"] = example["answer_idx"]
331
+ # return example
332
+
333
+
334
+ def apply_keyword_prompt(prompts, example, seed=None, weights=None, keywords=None):
335
+ if isinstance(prompts, list):
336
+ assert keywords is None
337
+ all_keywords = [sorted(re.findall("{([^{}]+)}", x)) for x in prompts]
338
+ keywords = all_keywords[0]
339
+ assert len(keywords) == len(set(keywords)), f"Repeated keywords in {keywords}"
340
+ assert all(keywords == x for x in all_keywords), f"Inconsistent keywords in prompts {all_keywords}"
341
+ assert not any("{" not in word[1:-1] and "}" in word[1:-1] for word in keywords)
342
+
343
+ for k in keywords:
344
+ assert k in example, f"Example missing expected field {k}, example={example}"
345
+ prompts = tf.constant(prompts)
346
+
347
+ multiple = False
348
+ if "text" in example and len(example["text"].shape) > 0:
349
+ multiple = True
350
+
351
+ if weights is not None:
352
+ weights = tf.expand_dims(tf.math.log(weights), 0)
353
+
354
+ if seed is None:
355
+ raise ValueError()
356
+
357
+ if not multiple:
358
+ if weights is None:
359
+ prompt = prompts[tf.random.stateless_uniform((), seed, 0, len(prompts), dtype=tf.int32)]
360
+ else:
361
+ prompt = prompts[tf.random.stateless_categorical(weights, 1, seed, 0, len(prompts), dtype=tf.int32)][0, 0]
362
+ for keyword in keywords:
363
+ # We use split not regex_replace because regex_replace has issues with
364
+ # value strings with backslashes
365
+ res = tf.strings.split(prompt, "{"+keyword+"}", maxsplit=2)
366
+ prompt = tf.strings.join([res[0], example[keyword], res[1]])
367
+ return prompt
368
+ else:
369
+ n_prompts = tf.shape(example["text"])[0]
370
+ if weights is None:
371
+ ix = tf.random.stateless_uniform(
372
+ (n_prompts,), seed, 0, tf.shape(prompts)[0], dtype=tf.int32)
373
+ else:
374
+ ix = tf.random.stateless_categorical(
375
+ weights, tf.shape(prompts)[0], seed, 0, len(prompts), dtype=tf.int32)[0]
376
+ prompt = tf.gather(prompts, ix)
377
+ out = tf.TensorArray(dtype=tf.string, size=n_prompts, element_shape=())
378
+ for i in range(n_prompts):
379
+ modified = prompt[i]
380
+ for keyword in keywords:
381
+ res = tf.strings.split(modified, "{"+keyword+"}", maxsplit=2)
382
+ modified = tf.strings.join([res[0], example[keyword][i], res[1]])
383
+ out = out.write(i, modified)
384
+ return out.stack()
385
+
seqio_tokenizer.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The SeqIO Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Vocabularies."""
16
+
17
+ import abc
18
+ import dataclasses
19
+ import functools
20
+ import hashlib
21
+ import threading
22
+ from typing import Any, ClassVar, Dict, Iterable, Optional, Sequence, Union, List, Tuple
23
+
24
+ import numpy as np
25
+ from absl import logging
26
+ import tensorflow.compat.v2 as tf
27
+
28
+ from sentencepiece import sentencepiece_model_pb2
29
+ import sentencepiece as sentencepiece_processor
30
+
31
+ PAD_ID = -1 # -1 for llama tokenizer
32
+
33
+
34
+ class Vocabulary(metaclass=abc.ABCMeta):
35
+ """Abstract class for all vocabularies.
36
+
37
+ Subclasses must implement methods for converting between strings and tokens
38
+ both in pure python (`_encode`/`_decode`) and in TensorFlow
39
+ (`_encode_tf`/`_decode_tf`).
40
+
41
+ Subclasses are responsible for reserving PAD_ID=0 as well as optionally
42
+ reserving EOS_ID and UNK_ID
43
+
44
+ `_base_vocab_size` should account for PAD, EOS, and UNK but not `extra_ids`.
45
+ """
46
+
47
+ def __init__(self, extra_ids: int = 0):
48
+ """Vocabulary constructor.
49
+
50
+ Args:
51
+ extra_ids: The number of extra IDs to reserve.
52
+ """
53
+ self._extra_ids = extra_ids or 0
54
+
55
+ @property
56
+ def bos_token_id(self) -> Optional[int]:
57
+ raise NotImplementedError("need to implement bos_id")
58
+
59
+ @property
60
+ @abc.abstractmethod
61
+ def eos_token_id(self) -> Optional[int]:
62
+ raise NotImplementedError("need to implement eos_id")
63
+
64
+ @property
65
+ def pad_id(self) -> int:
66
+ return PAD_ID
67
+
68
+ @property
69
+ @abc.abstractmethod
70
+ def unk_id(self) -> Optional[int]:
71
+ raise NotImplementedError("need to implement unk_id")
72
+
73
+ @property
74
+ def extra_ids(self) -> int:
75
+ return self._extra_ids
76
+
77
+ @property
78
+ def vocab_size(self) -> int:
79
+ """Vocabulary size, including extra ids."""
80
+ return self._base_vocab_size + self.extra_ids
81
+
82
+ @property
83
+ @abc.abstractmethod
84
+ def _base_vocab_size(self) -> int:
85
+ """Vocabulary size, excluding extra ids but including PAD/EOS/UNK."""
86
+ # TODO(fjord): add a check that pad_id and unk_id (if present)
87
+ # are less than _base_vocab_size.
88
+ raise NotImplementedError
89
+
90
+ @abc.abstractmethod
91
+ def _encode(self, s: str) -> Sequence[int]:
92
+ raise NotImplementedError
93
+
94
+ def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]:
95
+ """Tokenizes string to an int sequence, without adding EOS."""
96
+ return self._encode(s)
97
+
98
+ @abc.abstractmethod
99
+ def _decode(self, ids):
100
+ raise NotImplementedError
101
+
102
+ def decode(self, ids: Iterable[int], truncate_at_eos=True):
103
+ """Detokenizes int32 iterable to a string, up through first EOS."""
104
+ clean_ids = list(ids)
105
+
106
+ if self.unk_id is not None:
107
+ vocab_size = self._base_vocab_size
108
+ clean_ids = [self.unk_id if i >= vocab_size else i for i in clean_ids]
109
+
110
+ if truncate_at_eos and (self.eos_token_id is not None and self.eos_token_id in clean_ids):
111
+ clean_ids = clean_ids[: clean_ids.index(self.eos_token_id) + 1]
112
+
113
+ return self._decode(clean_ids)
114
+
115
+ @abc.abstractmethod
116
+ def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
117
+ raise NotImplementedError
118
+
119
+ def encode_tf(self, s: tf.Tensor) -> tf.Tensor:
120
+ """Tokenizes string Scalar to an int32 Tensor, without adding EOS."""
121
+ return self._encode_tf(s)
122
+
123
+ @abc.abstractmethod
124
+ def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
125
+ raise NotImplementedError
126
+
127
+ def decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
128
+ """Detokenizes int32 batched Tensor through first EOS."""
129
+ clean_ids = ids
130
+
131
+ if self.unk_id is not None:
132
+ base_vocab_size = self._base_vocab_size
133
+ clean_ids = tf.where(
134
+ tf.less(clean_ids, base_vocab_size), clean_ids, self.unk_id
135
+ )
136
+
137
+ if self.eos_id is not None:
138
+ # Replace everything after the first eos_id with pad_id.
139
+ after_eos = tf.cumsum(
140
+ tf.cast(tf.equal(clean_ids, self.eos_id), tf.int32),
141
+ exclusive=True,
142
+ axis=-1,
143
+ )
144
+ clean_ids = tf.where(tf.cast(after_eos, tf.bool), self.pad_id, clean_ids)
145
+
146
+ return self._decode_tf(clean_ids)
147
+
148
+
149
+ class PassThroughVocabulary(Vocabulary):
150
+ """Vocabulary that passes through inputs unchanged."""
151
+
152
+ def __init__(self, size: int, eos_id: Optional[Any] = None):
153
+ """PassThroughVocabulary constructor.
154
+
155
+ Args:
156
+ size: the full size of the vocabulary.
157
+ eos_id: the end-of-sequence token.
158
+ """
159
+ self._size = size
160
+ self._eos_id = eos_id
161
+ super().__init__()
162
+
163
+ @property
164
+ def _base_vocab_size(self):
165
+ return self._size
166
+
167
+ def _encode(self, s: Sequence[Any]) -> Sequence[Any]:
168
+ return s
169
+
170
+ def _decode(self, ids: Sequence[Any]) -> Sequence[Any]:
171
+ return ids
172
+
173
+ def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
174
+ return s
175
+
176
+ def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
177
+ return ids
178
+
179
+ @property
180
+ def eos_id(self) -> Optional[Any]:
181
+ return self._eos_id
182
+
183
+ @property
184
+ def unk_id(self) -> Optional[Any]:
185
+ return None
186
+
187
+ def __eq__(self, other):
188
+ if not isinstance(other, PassThroughVocabulary):
189
+ return False
190
+ return self._size == other._size and self.eos_id == other.eos_id
191
+
192
+ def __str__(self) -> str:
193
+ return f"PassThroughVocabulary(size={self._size}, eos_id={self.eos_id})"
194
+
195
+
196
+ class UnigramVocabulary(Vocabulary):
197
+ """Vocabulary that does table-lookup of unigrams."""
198
+
199
+ def __init__(self, unigrams: Sequence[str]):
200
+ """UnigramVocabulary constructor.
201
+
202
+ Args:
203
+ unigrams: the collection of in-vocabulary tokens. This collection should
204
+ not include PAD or UNK, which are automatically assigned ids and managed
205
+ as possible decode tokens.
206
+ """
207
+
208
+ super().__init__()
209
+ unigrams_as_list = list(unigrams)
210
+ self._unigram_by_id = ["PAD"] + unigrams_as_list + ["UNK"]
211
+ self._id_by_unigram = {u: i for i, u in enumerate(self._unigram_by_id)}
212
+ initializer = tf.lookup.KeyValueTensorInitializer(
213
+ keys=tf.constant(["PAD"] + unigrams_as_list),
214
+ # One extra value because the leading 0 corresponds to PAD
215
+ values=tf.constant(range(len(unigrams) + 1), dtype=tf.int64),
216
+ )
217
+ self._id_by_unigram_tf = tf.lookup.StaticVocabularyTable(
218
+ initializer, num_oov_buckets=1
219
+ )
220
+ self._unigram_by_id_tf = tf.constant(self._unigram_by_id)
221
+
222
+ def _encode(self, s: str) -> Sequence[int]:
223
+ return [self._id_by_unigram.get(s, self.unk_id)]
224
+
225
+ def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
226
+ tf_ids = self._id_by_unigram_tf.lookup(s)
227
+ return tf.expand_dims(tf.dtypes.cast(tf_ids, tf.int32), -1)
228
+
229
+ def _decode(self, ids: Sequence[int]) -> str:
230
+ return " ".join(self._unigram_by_id[id] for id in ids)
231
+
232
+ def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
233
+ return self._unigram_by_id_tf[ids[0]]
234
+
235
+ @property
236
+ def _base_vocab_size(self):
237
+ return len(self._unigram_by_id)
238
+
239
+ @property
240
+ def eos_id(self):
241
+ return None
242
+
243
+ @property
244
+ def unk_id(self):
245
+ return self._base_vocab_size - 1
246
+
247
+
248
+ class SentencePieceVocabulary(Vocabulary):
249
+ """Wrapper for nlp/sentencepiece encoder.
250
+
251
+ Assumes the model was built using flags to reserve ID=0 for padding, ID=1 for
252
+ EOS, and ID=2 for UNK.
253
+
254
+ If using extra ids, you can represent them in string-form as `<extra_id_0>`,
255
+ `<extra_id_1>`, etc. They will be indexed starting from the end of the
256
+ vocabulary to match how the masking preprocessors are set up.
257
+
258
+ IMPORTANT NOTE: these placeholders only work properly when they are used at
259
+ word starts (e.g., "I like peanut butter and <extra_id_0> sandwiches." or
260
+ "I like peanut butter and <extra_id_0>ly sandwiches" are both okay, but
261
+ "I like peanut butter and jel<extra_id_0> sandwiches" is not.).
262
+ """
263
+
264
+ @dataclasses.dataclass
265
+ class _ModelContext:
266
+ tokenizer: sentencepiece_processor.SentencePieceProcessor
267
+ sp_model: bytes
268
+
269
+ _load_model_lock: ClassVar[threading.Lock] = threading.Lock()
270
+
271
+ def __init__(
272
+ self,
273
+ sentencepiece_model_file: str,
274
+ extra_ids: int = 0,
275
+ normalizer_spec_overrides: Optional[
276
+ sentencepiece_model_pb2.NormalizerSpec
277
+ ] = None,
278
+ reverse_extra_ids: bool = False,
279
+ extra_tokens: Tuple[str] = None,
280
+ hack_to_t5_start_tokens: bool = False,
281
+ ):
282
+ """Create a SentencePieceVocabulary.
283
+
284
+ Optionally, specify a number of extra ids to add to the end of the
285
+ vocabulary for use as sentinels.
286
+
287
+ Args:
288
+ sentencepiece_model_file: path of the sentence piece model.
289
+ extra_ids: number of extra ids to include.
290
+ normalizer_spec_overrides: If not None, this proto will be merged into the
291
+ model's normalizer and denormalizer specs. Thus, any options set on this
292
+ object will override the values of those options in the loaded model.
293
+ reverse_extra_ids: if True, extra_ids are numbered in descending order, so
294
+ the first extra_id has the highest number. This is done for
295
+ compatibility with span_corruption mask generation in T5.
296
+ """
297
+ self._sentencepiece_model_file = sentencepiece_model_file
298
+ self._normalizer_spec_overrides = normalizer_spec_overrides
299
+ self._reverse_extra_ids = reverse_extra_ids
300
+ self._model: Optional[SentencePieceVocabulary._ModelContext] = None
301
+ self._extra_tokens = extra_tokens
302
+ self._hack_to_t5_start_tokens = hack_to_t5_start_tokens
303
+ super().__init__(extra_ids=extra_ids)
304
+
305
+ def __getstate__(self):
306
+ state = self.__dict__.copy()
307
+ # Gin config makes a deep copy of the keyword arguments of configurables.
308
+ # When a SentencePieceVocabulary vocabulary is used as a keyword argument
309
+ # in a Gin configurable, it must be picklable. We therefore remove
310
+ # _model; will be initialized lazily as needed.
311
+ del state["_model"]
312
+ return state
313
+
314
+ def __setstate__(self, state):
315
+ self.__dict__.update(state)
316
+ self._model = None
317
+
318
+ def load_model(self) -> None:
319
+ _ = self._model_context()
320
+
321
+ def _model_context(
322
+ self,
323
+ ) -> _ModelContext:
324
+ """Loads model if not yet loaded and returns the model context.
325
+
326
+ Returns:
327
+ The model context as a tuple of (tokenizer, sp_model).
328
+ """
329
+ if self._model:
330
+ return self._model
331
+
332
+ normalizer_spec_overrides_serialized = (
333
+ self._normalizer_spec_overrides.SerializeToString(deterministic=True)
334
+ if self._normalizer_spec_overrides
335
+ else None
336
+ )
337
+
338
+ self._model = self._load_model(
339
+ self._sentencepiece_model_file,
340
+ self._extra_ids,
341
+ normalizer_spec_overrides_serialized,
342
+ self._reverse_extra_ids,
343
+ extra_tokens=self._extra_tokens,
344
+ hack_to_t5_start_tokens=self._hack_to_t5_start_tokens,
345
+ )
346
+ return self._model
347
+
348
+ @classmethod
349
+ @functools.lru_cache(maxsize=None)
350
+ def _load_model(
351
+ cls,
352
+ sentencepiece_model_file: str,
353
+ extra_ids: int,
354
+ normalizer_spec_overrides_serialized: Optional[bytes] = None,
355
+ reverse_extra_ids: bool = True,
356
+ extra_tokens: Tuple[str] = None,
357
+ hack_to_t5_start_tokens=False,
358
+ ) -> _ModelContext:
359
+ """Load SPM, Python tokenizer, and cache results to the class definition."""
360
+ # SentencePieceProcessor::LoadFromSerializedProto is not thread-safe.
361
+ # Without a lock, users may randomly see SIGSEGV on
362
+ # sentencepiece::ModelInterface::pad_piece when using the vocabulary in
363
+ # SeqIO preprocessors.
364
+ with cls._load_model_lock:
365
+ # Handle cases where SP can't load the file, but gfile can.
366
+ with tf.io.gfile.GFile(sentencepiece_model_file, "rb") as f:
367
+ sp_model = f.read()
368
+ model = sentencepiece_model_pb2.ModelProto.FromString(sp_model)
369
+
370
+ if hack_to_t5_start_tokens:
371
+ # PAD token would still be 0 same as BOS for consistency as previous!
372
+ unk = model.pieces[0]
373
+ bos = model.pieces[1]
374
+ eos = model.pieces[2]
375
+ model.pieces.remove(unk)
376
+ model.pieces.remove(bos)
377
+ model.pieces.remove(eos)
378
+ model.pieces.insert(0, bos) # BOS is token 0
379
+ model.pieces.insert(1, eos) # EOS is token 1
380
+ model.pieces.insert(2, unk) # UNK is token 2
381
+
382
+ # Add placeholder strings for extra IDs.
383
+ if extra_ids:
384
+ # By default, we them in reverse order to match span corruption.
385
+ if reverse_extra_ids:
386
+ extra_id_tokens = reversed(range(extra_ids))
387
+ else:
388
+ extra_id_tokens = range(extra_ids)
389
+
390
+ for i in extra_id_tokens:
391
+ model.pieces.add(
392
+ piece=f"▁<extra_id_{i}>",
393
+ score=0.0,
394
+ type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED,
395
+ )
396
+
397
+ if extra_tokens:
398
+ for s in extra_tokens:
399
+ model.pieces.add(
400
+ piece=f"▁"+s,
401
+ score=0.0,
402
+ type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED,
403
+ )
404
+
405
+ if normalizer_spec_overrides_serialized is not None:
406
+ normalizer_spec_overrides = (
407
+ sentencepiece_model_pb2.NormalizerSpec.FromString(
408
+ normalizer_spec_overrides_serialized
409
+ )
410
+ )
411
+
412
+ model.normalizer_spec.MergeFrom(normalizer_spec_overrides)
413
+ model.denormalizer_spec.MergeFrom(normalizer_spec_overrides)
414
+ sp_model = model.SerializeToString()
415
+ # Load Python tokenizer and ensure the EOS and PAD IDs are correct.
416
+ tokenizer = sentencepiece_processor.SentencePieceProcessor()
417
+ tokenizer.LoadFromSerializedProto(sp_model)
418
+ if tokenizer.pad_id() != PAD_ID:
419
+ logging.warning(
420
+ (
421
+ "T5 library uses PAD_ID=%s, which is different from the "
422
+ "sentencepiece vocabulary, which defines pad_id=%s"
423
+ ),
424
+ PAD_ID,
425
+ tokenizer.pad_id(),
426
+ )
427
+
428
+ return cls._ModelContext(tokenizer=tokenizer, sp_model=sp_model)
429
+
430
+ @property
431
+ def num_extra_tokens(self):
432
+ if self._extra_tokens:
433
+ return len(self._extra_tokens)
434
+ return 0
435
+
436
+ @property
437
+ def bos_id(self) -> Optional[int]:
438
+ return self.tokenizer.bos_id()
439
+
440
+ @property
441
+ def bos_token_id(self) -> Optional[int]:
442
+ return self.tokenizer.bos_id()
443
+
444
+ @property
445
+ def eos_token_id(self) -> Optional[int]:
446
+ return self.tokenizer.eos_id()
447
+
448
+ @property
449
+ def eos_id(self) -> Optional[int]:
450
+ return self.tokenizer.eos_id()
451
+
452
+ @property
453
+ def unk_id(self) -> Optional[int]:
454
+ return self.tokenizer.unk_id()
455
+
456
+ @property
457
+ def sp_model(self) -> Optional[bytes]:
458
+ """Retrieve the SPM."""
459
+ return self._model_context().sp_model
460
+
461
+ @property
462
+ def sentencepiece_model_file(self) -> str:
463
+ return self._sentencepiece_model_file
464
+
465
+ @property
466
+ def tokenizer(self) -> sentencepiece_processor.SentencePieceProcessor:
467
+ """Returns the Python tokenizer."""
468
+ return self._model_context().tokenizer
469
+
470
+ @property
471
+ def tf_tokenizer(self):
472
+ """Instantiate and return a TF tokenizer."""
473
+ import tensorflow_text as tf_text # import here to keep the dependency optional
474
+ return tf_text.SentencepieceTokenizer(model=self.sp_model)
475
+
476
+ @property
477
+ def vocab_size(self):
478
+ return self._base_vocab_size
479
+
480
+ @property
481
+ def _base_vocab_size(self):
482
+ """Number of ids (including 0=PAD, 1=EOS, and 2=UNK).
483
+
484
+ Returns:
485
+ an integer, the vocabulary size
486
+ """
487
+ return self.tokenizer.GetPieceSize()
488
+
489
+ def _encode(self, s):
490
+ """Encode a python string as a list of integers.
491
+
492
+ Args:
493
+ s: a string
494
+
495
+ Returns:
496
+ a list of integers (not terminated by EOS)
497
+ """
498
+ return self.tokenizer.EncodeAsIds(s)
499
+
500
+ def _decode(self, ids):
501
+ """Decode a list of integers to a python string.
502
+
503
+ Args:
504
+ ids: a list of integers (not terminated by EOS)
505
+
506
+ Returns:
507
+ a string
508
+ """
509
+ # convert all the extra ids (sentinels) to UNK=2
510
+ unk_id = self.tokenizer.unk_id()
511
+ piece_size = self.tokenizer.GetPieceSize()
512
+ ids = [unk_id if i >= piece_size else int(i) for i in ids]
513
+ return self.tokenizer.DecodeIds(ids)
514
+
515
+ def _encode_tf(self, s):
516
+ """Encode a tf.Scalar string to a tf.Tensor.
517
+
518
+ This will be necessary for on-the-fly tokenization.
519
+
520
+ Args:
521
+ s: a tf.Scalar with dtype tf.string
522
+
523
+ Returns:
524
+ a 1d tf.Tensor with dtype tf.int32
525
+ """
526
+ return self.tf_tokenizer.tokenize(s)
527
+
528
+ def _decode_tf(self, ids):
529
+ """Decode in TensorFlow.
530
+
531
+ Args:
532
+ ids: a 1d or 2d tf.Tensor with dtype tf.int32
533
+
534
+ Returns:
535
+ a 1d or 2d tf.Tensor with dtype tf.string
536
+ """
537
+ return self.tf_tokenizer.detokenize(ids)
538
+
539
+ def __eq__(self, other):
540
+ if not isinstance(other, SentencePieceVocabulary):
541
+ return False
542
+ try:
543
+ their_md5 = hashlib.md5(other.sp_model).hexdigest()
544
+ # If other has no sp_model attribute, we can't test for equality
545
+ except AttributeError:
546
+ return False
547
+ if self.sp_model is None:
548
+ return False
549
+ our_md5 = hashlib.md5(self.sp_model).hexdigest()
550
+ return our_md5 == their_md5
551
+
552
+ def __str__(self) -> str:
553
+ return (
554
+ f"SentencePieceVocabulary(file={self.sentencepiece_model_file}, "
555
+ f"extra_ids={self._extra_ids}, "
556
+ f"spm_md5={hashlib.md5(self.sp_model).hexdigest()})"
557
+ )
558
+
559
+ @property
560
+ def adds_space(self):
561
+ return True
562
+
563
+
564
+ class HfTokenizerWrapper:
565
+ def __init__(self, tokenizer, bos_token_id=None, adds_space=False):
566
+ """
567
+ tokenizer: Tokenizer to wrap
568
+ bos_token_id: BOS token id to use if not `tokenizer.bos_token_id`
569
+ adds_space: If concatenating interdependently tokenized pieces of text, will the tokens
570
+ already including a seerating space?
571
+ """
572
+ self.adds_space = adds_space
573
+ self.tokenizer = tokenizer
574
+ if bos_token_id is None:
575
+ self.bos_token_id = tokenizer.bos_token_id
576
+ else:
577
+ self.bos_token_id = bos_token_id
578
+ self.eos_token_id = self.tokenizer.eos_token_id
579
+ self.pad_id = -1
580
+
581
+ def encode(self, x: str):
582
+ return self.tokenizer.encode(x, add_special_tokens=False)
583
+
584
+ def decode(self, x: List[int], truncate_at_eos=True):
585
+ x = [int(t) for t in x]
586
+
587
+ if self.eos_token_id == self.bos_token_id and (len(x) > 0 and x[0] == self.eos_token_id):
588
+ # Assume an EOS at the start is functioning as BOS
589
+ x = x[1:]
590
+
591
+ if truncate_at_eos:
592
+ # Follow seqio and automatically cut off at EOS
593
+ try:
594
+ eos_ix = x.index(self.eos_token_id)
595
+ x = x[:eos_ix]
596
+ except ValueError:
597
+ pass
598
+ return self.tokenizer.decode(x, skip_special_tokens=True)
599
+
600
+
601
+ def vocab_size(self):
602
+ return len(self.tokenizer)
603
+
604
+ def encode_tf(self, x):
605
+ if isinstance(x, str) or len(x.shape) == 0:
606
+ def _enc(_data):
607
+ _data = _data.item() if isinstance(_data, np.ndarray) else _data
608
+ return self.tokenizer.encode(_data.decode("utf-8"), add_special_tokens=False, return_tensors="np")[0].astype(np.int32)
609
+ return tf.ensure_shape(tf.numpy_function(_enc, [x], tf.int32, stateful=False), [None])
610
+
611
+ flattened = tf.reshape(x, [-1])
612
+
613
+ def _enc(_data):
614
+ tokens = [self.tokenizer.encode(x.decode("utf-8"), add_special_tokens=False, return_tensors="np")[0].astype(np.int32)
615
+ for x in _data]
616
+ if len(tokens) == 0:
617
+ return np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32)
618
+ else:
619
+ return np.concatenate(tokens, 0), np.array([len(x) for x in tokens]).astype(np.int32)
620
+ if not (isinstance(x, str) or x.dtype == tf.string):
621
+ raise ValueError("Input be a string or a string numpy array")
622
+ text, lens = tf.numpy_function(_enc, [flattened], (tf.int32, tf.int32), stateful=False)
623
+ lens = tf.ensure_shape(lens, [None])
624
+ text = tf.ensure_shape(text, [None])
625
+ if len(x.shape) == 2:
626
+ n = x.shape[1]
627
+ assert n is not None
628
+ return tf.RaggedTensor.from_nested_row_lengths(
629
+ text,
630
+ [tf.ones(tf.shape(x)[0], dtype=lens.dtype)*n, lens]
631
+ )
632
+ else:
633
+ return tf.RaggedTensor.from_row_lengths(text, lens)
634
+
635
+
636
+ class OLMoTokenizerWrapper(HfTokenizerWrapper):
637
+
638
+ def encode(self, x: str):
639
+ return self.tokenizer.encode(x, add_special_tokens=False)
640
+
641
+ def encode_tf(self, x):
642
+ if isinstance(x, str) or len(x.shape) == 0:
643
+ def _enc(_data):
644
+ return np.asarray(self.tokenizer.encode(_data.numpy().decode("utf-8"), add_special_tokens=False), dtype=np.int32)
645
+ out = tf.py_function(_enc, (x,), tf.int32)
646
+ return tf.ensure_shape(out, [None])
647
+ else:
648
+ def _enc(_data):
649
+ tokens = [self.tokenizer.encode(x.decode("utf-8"), add_special_tokens=False)
650
+ for x in _data.numpy()]
651
+ if len(tokens) == 0:
652
+ return np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32)
653
+ else:
654
+ return np.concatenate(tokens, 0), np.array([len(x) for x in tokens])
655
+ text, lens = tf.py_function(_enc, (x,), (tf.int32, tf.int32))
656
+ lens = tf.ensure_shape(lens, [None])
657
+ text = tf.ensure_shape(text, [None])
658
+ return tf.RaggedTensor.from_row_lengths(text, lens)
659
+
tasks.py ADDED
@@ -0,0 +1,2548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Module that can be imported to register all tasks
2
+ import dataclasses
3
+ import functools
4
+ import logging
5
+ import os
6
+ from collections import OrderedDict
7
+ from typing import List, Dict, Any
8
+
9
+ import seqio
10
+ from seqio import dataset_providers
11
+ import tensorflow_datasets as tfds
12
+
13
+ from .data_utils import _strip_metadata, build_tokenizer
14
+ from .preprocesssors import *
15
+ from .preprocesssors import _preprocess_scifi
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class TaskSpec:
20
+ name: str
21
+ source: seqio.DataSourceInterface
22
+ preprocessors: List
23
+ style: str
24
+ inference_preprocessors: List = None
25
+ inference_only: bool = False
26
+ decode_image: bool = False
27
+ shuffle_after: Optional[int] = None
28
+ ignore_errors: bool = False
29
+
30
+
31
+ MULTITASK_TFDS_DATA_DIR = "/weka/oe-training-default/mm-olmo/tensorflow_datasets"
32
+
33
+ TASKS: Dict[str, TaskSpec] = {}
34
+
35
+
36
+ def add_task(
37
+ name,
38
+ source: seqio.DataSourceInterface,
39
+ preprocessors: List,
40
+ style: str,
41
+ inf_preprocessor=None,
42
+ inf_only=False,
43
+ decode_image=False,
44
+ shuffle_after=None,
45
+ ignore_errors=False
46
+ ):
47
+ TASKS[name] = TaskSpec(
48
+ name, source, preprocessors, style, inf_preprocessor, inf_only, decode_image,
49
+ shuffle_after, ignore_errors)
50
+
51
+
52
+ @seqio.map_over_dataset
53
+ def add_image_size(ex):
54
+ if ex["image"].dtype == tf.string:
55
+ ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
56
+ img_h = tf.shape(ex["image"])[0]
57
+ img_w = tf.shape(ex["image"])[1]
58
+ ex["metadata/image_size"] = [img_w, img_h]
59
+
60
+
61
+ @dataclasses.dataclass
62
+ class TaskDatasetBuilder:
63
+ """tf.data.Dataset builder for task after shuffling, sharding, and initial model pre-processing
64
+ have been applied"""
65
+ # This class is a simplified and customized version of seqio.Task
66
+ #
67
+ # The main differences are:
68
+ # 1: Does not prefetch by default, which wastes a small amount of RAM if we are using the
69
+ # dataset in a mixture which can just have its own top-level prefetch
70
+ # 2: Reduce threshold for memory caching which is way too high for image datasets by default
71
+ # 3: Can customize when shuffling occurs to help minimizes RAM usage, in general shuffling
72
+ # should happen before building image crops and tokenization so the shuffle and
73
+ # dataset checkpoint take less memory
74
+ # 4: Don't decoding images until after shuffling for the same reason
75
+ # 5: Support splitting with tfds.map_split so we never have to fall back to example sharding
76
+ # not default at the moment since its not well tested
77
+ # 6: Removes caching/output feature spec stuff from seqio that we don't need
78
+
79
+ name: str
80
+ source: Any
81
+ preprocessors: List
82
+ keep_metadata: bool
83
+ shuffle_after: int
84
+ sharding: str = "tfds_split"
85
+ decode_image: bool = False
86
+ ignore_errors: bool = False
87
+
88
+ def get_dataset(
89
+ self, # pytype: disable=signature-mismatch # overriding-default-value-checks
90
+ sequence_length: Optional[Mapping[str, int]] = None,
91
+ split: str = tfds.Split.TRAIN,
92
+ shuffle: bool = True,
93
+ shuffle_buffer_size: Optional[int] = 1000,
94
+ seed: Optional[int] = None,
95
+ shard_info: Optional[seqio.ShardInfo] = None,
96
+ num_epochs: Optional[int] = 1,
97
+ try_in_mem_cache: bool = True,
98
+ trim_output_features: bool=True
99
+ ) -> tf.data.Dataset:
100
+ source = self.source
101
+
102
+ if self.sharding == "seqio":
103
+ if source.supports_arbitrary_sharding:
104
+ shard_data_source = True
105
+ elif shard_info:
106
+ # Whether we should shard at source or on the examples from the source.
107
+ shard_data_source = (
108
+ len(source.list_shards(split=split)) >= shard_info.num_shards
109
+ )
110
+ logging.info(
111
+ "Sharding at the %s: %d of %d",
112
+ "data source" if shard_data_source else "examples",
113
+ shard_info.index + 1,
114
+ shard_info.num_shards,
115
+ )
116
+ else:
117
+ # Call get_dataset on the source without a shard_info.
118
+ shard_data_source = True
119
+ shard_info = None
120
+
121
+ if "image" in source.tfds_dataset.info.features:
122
+ if not self.decode_image:
123
+ source.tfds_dataset._decoders = dict(image=tfds.decode.SkipDecoding())
124
+
125
+ if shard_data_source:
126
+ ds = source.get_dataset(
127
+ split=split, shuffle=shuffle, seed=seed, shard_info=shard_info)
128
+ else:
129
+ ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
130
+ ds = ds.shard(shard_info.num_shards, shard_info.index)
131
+ elif self.sharding == "tfds_split":
132
+ # Shard with `tfds.even_splits`, which is seems to be recommended for mult-host training
133
+ # https://github.com/tensorflow/datasets/blob/master/docs/splits.md#tfdseven_splits--multi-host-training
134
+ assert isinstance(self.source, seqio.TfdsDataSource)
135
+ loader: seqio.LazyTfdsLoader = self.source.tfds_dataset
136
+ dataset, data_dir = loader.get_split_params(split)
137
+ shard_split = loader._map_split(split)
138
+ if shard_info and shard_info.num_shards > 1:
139
+ shard_split = tfds.even_splits(shard_split, n=shard_info.num_shards, drop_remainder=False)[shard_info.index]
140
+ else:
141
+ shard_split = shard_split
142
+ read_config = loader.read_config
143
+ read_config.shuffle_seed = seed
144
+ read_config.skip_prefetch = True
145
+ read_config.input_context = None
146
+ # Don't decode images until after shuffling to save RAM
147
+ if "image" in loader.info.features:
148
+ decoders = dict(image=tfds.decode.SkipDecoding())
149
+ else:
150
+ decoders = None
151
+ ds = tfds.load(
152
+ dataset,
153
+ split=shard_split,
154
+ data_dir=data_dir,
155
+ shuffle_files=shuffle,
156
+ download=True,
157
+ try_gcs=True,
158
+ read_config=read_config,
159
+ decoders=decoders
160
+ )
161
+ else:
162
+ raise NotImplementedError(self.sharding)
163
+
164
+ num_shards = shard_info.num_shards if shard_info else 1
165
+ if try_in_mem_cache and (
166
+ source.num_input_examples(split)
167
+ and source.num_input_examples(split)
168
+ < 10000 * num_shards
169
+ ):
170
+ logging.info(f"Automatically caching small dataset in memory: {self.name}:{split}")
171
+ ds = ds.cache()
172
+
173
+ # We repeat before calling any (potentially) stochastic
174
+ # preprocessors in order to take new samples each epoch.
175
+ if num_epochs != 1:
176
+ ds = ds.repeat(num_epochs)
177
+
178
+ preprocessors = [
179
+ seqio.add_kwargs_to_transform(
180
+ _fn,
181
+ sequence_length=sequence_length,
182
+ output_features=None,
183
+ ) for _fn in self.preprocessors
184
+ ]
185
+
186
+ with seqio.utils.map_seed_manager(seed):
187
+ for fn in preprocessors[:self.shuffle_after]:
188
+ ds = fn(ds)
189
+
190
+ # Strip metadata before shuffling if possible so its doesn't waste space
191
+ if not self.keep_metadata:
192
+ ds = _strip_metadata(ds)
193
+
194
+ if shuffle:
195
+ if shuffle_buffer_size is None:
196
+ raise ValueError("Shuffle is true, but shuffle_buffer_size is None")
197
+ else:
198
+ ds = ds.shuffle(shuffle_buffer_size, seed=seed)
199
+
200
+ for fn in preprocessors[self.shuffle_after:]:
201
+ ds = fn(ds)
202
+
203
+ if self.ignore_errors:
204
+ ds = ds.ignore_errors(log_warning=True)
205
+
206
+ if trim_output_features:
207
+ ds = seqio.trim_dataset(ds, sequence_length, sequence_length)
208
+
209
+ return ds
210
+
211
+
212
+ def get_task(preprocessor, name, is_training, for_inference,
213
+ include_metadata=None, style_override=None) -> TaskDatasetBuilder:
214
+ """Get a builder for task `name` that is pre-processed by `preprocessor`"""
215
+
216
+ task_spec = TASKS[name]
217
+ if for_inference is None:
218
+ for_inference = task_spec.inference_only
219
+ elif task_spec.inference_only and not for_inference:
220
+ raise ValueError(f"Inference=only task {task_spec.name} can only be used in inference mode")
221
+
222
+ if include_metadata is None:
223
+ include_metadata = for_inference
224
+
225
+ if preprocessor is not None:
226
+ style = style_override if style_override else task_spec.style
227
+ preprocessor = preprocessor.get_preprocessor(
228
+ is_training, for_inference, style, include_metadata)
229
+ preprocessor = [preprocessor]
230
+ else:
231
+ preprocessor = []
232
+ task_preprocessors = task_spec.preprocessors
233
+ if for_inference and task_spec.inference_preprocessors is not None:
234
+ task_preprocessors = task_spec.inference_preprocessors
235
+ if isinstance(task_spec.source, seqio.TfdsDataSource):
236
+ from seqio.utils import _TFDS_DATA_DIR_OVERRIDE
237
+ if _TFDS_DATA_DIR_OVERRIDE:
238
+ # Stop annoying override warnings flooding the log
239
+ task_spec.source.tfds_dataset._data_dir = None
240
+
241
+ return TaskDatasetBuilder(
242
+ task_spec.name,
243
+ task_spec.source,
244
+ task_preprocessors + preprocessor,
245
+ keep_metadata=include_metadata,
246
+ shuffle_after=(task_spec.shuffle_after if task_spec.shuffle_after
247
+ else len(task_spec.preprocessors)),
248
+ sharding="seqio",
249
+ decode_image=task_spec.decode_image,
250
+ ignore_errors=task_spec.ignore_errors,
251
+ )
252
+
253
+
254
+ add_task(
255
+ "coco_caption_2017",
256
+ source=seqio.TfdsDataSource(
257
+ tfds_name="coco_all:1.0.1",
258
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
259
+ ),
260
+ preprocessors=[
261
+ functools.partial(rekey, key_map={
262
+ "image/filename": ["image/filename"],
263
+ "image": ["image"],
264
+ "text": ["captions", "text"]
265
+ }),
266
+ functools.partial(flatten_parts, parts=["text"]),
267
+ ],
268
+ inf_preprocessor=[
269
+ functools.partial(rekey, key_map={
270
+ "image/filename": ["image/filename"],
271
+ "image": ["image"],
272
+ "text": ["captions", "text"]
273
+ })
274
+ ],
275
+ style="coco_captioning",
276
+ )
277
+
278
+
279
+ add_task(
280
+ "coco_captioning_karpathy",
281
+ source=seqio.TfdsDataSource(
282
+ tfds_name="coco_captioning_karpathy:1.0.2",
283
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
284
+ splits={"train": "train", "validation": "val", "test": "test"}
285
+ ),
286
+ preprocessors=[
287
+ rename(text="captions"),
288
+ functools.partial(flatten_parts, parts=["text"]),
289
+ ],
290
+ inf_preprocessor=[add_coco_url],
291
+ style="coco_captioning",
292
+ )
293
+
294
+
295
+ add_task(
296
+ "synth_counting",
297
+ source=seqio.TfdsDataSource(
298
+ tfds_name="synth_counting:0.0.3",
299
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
300
+ splits={"train": "train[5120:]", "validation": "train[:5120]"}
301
+ ),
302
+ preprocessors=[synth_count_preprocessor],
303
+ inf_preprocessor=[synth_count_inf_preprocessor],
304
+ style="synth_counting",
305
+ )
306
+
307
+
308
+ add_task(
309
+ "khan_academy",
310
+ source=seqio.TfdsDataSource(
311
+ tfds_name="khan_academy:1.0.0",
312
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
313
+ splits={"train": "train[1024:]", "validation": "train[:1024]"}
314
+ ),
315
+ preprocessors=[extract_khan_academy],
316
+ style="khan_academy",
317
+ )
318
+
319
+ for name, src in [
320
+ ("vaia_qa_latex_image_math_subset", seqio.TfdsDataSource(
321
+ tfds_name=f"vaia_qa_latex_image_short_answer:0.1.2",
322
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
323
+ splits={"train": "train", "validation": "validation"}
324
+ )),
325
+ ("vaia_qa_latex_image_all", seqio.TfdsDataSource(
326
+ tfds_name=f"vaia_qa_latex_image_short_answer:0.1.3",
327
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
328
+ splits={"train": "train", "validation": "validation"}
329
+ )),
330
+ ]:
331
+ add_task(
332
+ f"{name}_short_answer",
333
+ source=src,
334
+ preprocessors=[
335
+ remove_is_long,
336
+ remove_has_multiple_parts,
337
+ functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
338
+ ],
339
+ style="vaia_qa",
340
+ )
341
+ add_task(
342
+ f"{name}_short_answer_first",
343
+ source=src,
344
+ preprocessors=[
345
+ remove_is_long,
346
+ remove_has_multiple_parts,
347
+ functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
348
+ ],
349
+ style="vaia_qa_short_answer_first",
350
+ )
351
+ add_task(
352
+ f"{name}_mc_only_short_answer",
353
+ source=src,
354
+ preprocessors=[
355
+ remove_is_long,
356
+ remove_has_multiple_parts,
357
+ filter_mc,
358
+ functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
359
+ ],
360
+ style="vaia_qa_short_answer",
361
+ )
362
+ add_task(
363
+ f"{name}_mc_only_short_answer_first",
364
+ source=src,
365
+ preprocessors=[
366
+ remove_is_long,
367
+ remove_has_multiple_parts,
368
+ filter_mc,
369
+ functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
370
+ ],
371
+ style="vaia_qa_short_answer_first",
372
+ )
373
+ add_task(
374
+ f"{name}_image_only_short_answer",
375
+ source=src,
376
+ preprocessors=[
377
+ image_only,
378
+ remove_is_long,
379
+ remove_has_multiple_parts,
380
+ functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
381
+ ],
382
+ style="vaia_qa_short_answer",
383
+ )
384
+ add_task(
385
+ f"{name}_image_only_short_answer_first",
386
+ source=src,
387
+ preprocessors=[
388
+ image_only,
389
+ remove_is_long,
390
+ remove_has_multiple_parts,
391
+ functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
392
+ ],
393
+ style="vaia_qa_short_answer_first",
394
+ )
395
+
396
+ add_task(
397
+ "vqa_online",
398
+ source=seqio.TfdsDataSource(
399
+ tfds_name="vqa_online:1.0.1",
400
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
401
+ splits={"train": "train", "validation": "validation", "test": "validation"}
402
+ ),
403
+ preprocessors=[
404
+ build_question_with_context,
405
+ extract_vqa_online,
406
+ ],
407
+ style="vqa_online",
408
+ )
409
+
410
+ add_task(
411
+ "vqa_online_gpt_longQ_longA",
412
+ source=seqio.TfdsDataSource(
413
+ tfds_name="vqa_online_gpt_parsed:1.1.0",
414
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
415
+ splits={"train": "train", "validation": "validation", "test": "validation"}
416
+ ),
417
+ preprocessors=[
418
+ rename(question="question_long", answer="answer_long"),
419
+ extract_vqa_online,
420
+ ],
421
+ style="vqa_online",
422
+ )
423
+
424
+
425
+ add_task(
426
+ "famous_birthdays",
427
+ source=seqio.TfdsDataSource(
428
+ tfds_name="famous_birth_days:1.0.0",
429
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
430
+ splits={"train": "train[5120:]", "validation": "train[:5120]"}
431
+ ),
432
+ preprocessors=[
433
+ famous_birthdays_preprocessor,
434
+ functools.partial(name_entity_augmentation, p_high_color=0.0),
435
+ ],
436
+ style="famous_birthdays",
437
+ )
438
+
439
+
440
+ add_task(
441
+ "wiki_art",
442
+ source=seqio.TfdsDataSource(
443
+ tfds_name="wiki_art:1.0.0",
444
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
445
+ splits={"train": "train[5120:]", "validation": "train[:5120]"}
446
+ ),
447
+ preprocessors=[name_entity_augmentation, wiki_art_preprocessor],
448
+ style="wiki_art",
449
+ )
450
+
451
+ add_task(
452
+ "wiki_art_no_aug",
453
+ source=seqio.TfdsDataSource(
454
+ tfds_name="wiki_art:1.0.0",
455
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
456
+ splits={"train": "train[5120:]", "validation": "train[:5120]"}
457
+ ),
458
+ preprocessors=[wiki_art_preprocessor],
459
+ style="wiki_art",
460
+ )
461
+
462
+ add_task(
463
+ "atlas_obscura",
464
+ source=seqio.TfdsDataSource(
465
+ tfds_name="atlas_obscura:1.0.0",
466
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
467
+ splits={"train": "train[5120:]", "validation": "train[:5120]"}
468
+ ),
469
+ preprocessors=[
470
+ atlas_obscura_preprocessor,
471
+ mild_color_aug_preprocessor
472
+ ],
473
+ style="atlas_obscura",
474
+ )
475
+
476
+
477
+ add_task(
478
+ "clocks",
479
+ source=seqio.TfdsDataSource(
480
+ tfds_name="clocks:1.0.1",
481
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
482
+ ),
483
+ preprocessors=[
484
+ clocks_preprocessor,
485
+ clock_augmentation
486
+ ],
487
+ style="clocks",
488
+ shuffle_after=0
489
+ )
490
+
491
+
492
+ add_task(
493
+ "count_bench",
494
+ source=seqio.TfdsDataSource(
495
+ tfds_name="count_bench:1.0.0",
496
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
497
+ ),
498
+ preprocessors=[
499
+ count_bench_preprocessor,
500
+ ],
501
+ style="count_bench",
502
+ )
503
+
504
+
505
+ add_task(
506
+ "tulu_v2_sft",
507
+ source=seqio.TfdsDataSource(
508
+ tfds_name="allenai__tulu_v2_sft_mixture:1.0.0",
509
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
510
+ ),
511
+ preprocessors=[tulu_preprocessor],
512
+ style="tulu_v2",
513
+ )
514
+
515
+
516
+ # Pointing / Point+Count datasets
517
+ for is_count in [True, False]:
518
+ if is_count:
519
+ task = "point_count"
520
+ else:
521
+ task = "pointing"
522
+ add_task(
523
+ task,
524
+ source=seqio.TfdsDataSource(
525
+ tfds_name="pointing:1.0.1",
526
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
527
+ splits={"train": "train", "validation": "validation"}
528
+ ),
529
+ preprocessors=[
530
+ filter_points,
531
+ functools.partial(pointing_preprocessor, with_count=is_count),
532
+ split
533
+ ],
534
+ style=task,
535
+ )
536
+ add_task(
537
+ task + "_eval", # pointing validation set
538
+ source=seqio.TfdsDataSource(
539
+ tfds_name="pointing:1.0.2",
540
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
541
+ ),
542
+ preprocessors=[
543
+ filter_points,
544
+ functools.partial(pointing_preprocessor, with_count=is_count),
545
+ split
546
+ ],
547
+ style=task,
548
+ )
549
+ add_task(
550
+ task + "_high_freq",
551
+ source=seqio.TfdsDataSource(
552
+ tfds_name="count_qa:0.0.2",
553
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
554
+ splits=dict(
555
+ train="train[2048:]",
556
+ validation="train[:2048]"
557
+ )
558
+ ),
559
+ preprocessors=[
560
+ filter_points,
561
+ fix_count_qa, # Fix a tfrecord bug TODO fix the underlying records
562
+ functools.partial(pointing_preprocessor, with_count=is_count),
563
+ split,
564
+ ],
565
+ style=task,
566
+ )
567
+ add_task(
568
+ "fast_flickr_count_qa_" + task,
569
+ source=seqio.TfdsDataSource(
570
+ tfds_name="fast_flickr_count_qa:1.0.4",
571
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
572
+ ),
573
+ preprocessors=[
574
+ functools.partial(count_qa_preprocessor, with_count=is_count),
575
+ ],
576
+ inf_preprocessor=[
577
+ functools.partial(count_qa_preprocessor, with_count=is_count, for_inference=True),
578
+ ],
579
+ style=task,
580
+ )
581
+
582
+
583
+ add_task(
584
+ "countbench_qa",
585
+ source=seqio.TfdsDataSource(
586
+ tfds_name="countbench_qa:1.2.0",
587
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
588
+ ),
589
+ inf_only=True,
590
+ preprocessors=[
591
+ count_qa_preprocessor_inf,
592
+ ],
593
+ style="point_count",
594
+ )
595
+
596
+
597
+ add_task(
598
+ f"pointing_test", # pointing set with segmentation ground truths
599
+ source=seqio.TfdsDataSource(
600
+ tfds_name="pointing:1.0.3",
601
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
602
+ ),
603
+ preprocessors=[
604
+ pointing_inf_preprocessor
605
+ ],
606
+ style=task,
607
+ inf_only=True,
608
+ )
609
+
610
+
611
+ add_task(
612
+ "point_qa",
613
+ source=seqio.TfdsDataSource(
614
+ tfds_name="point_qa:0.0.5",
615
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
616
+ splits=dict(
617
+ train="train[512:]",
618
+ validation="train[:512]"
619
+ )
620
+ ),
621
+ preprocessors=[extract_point_qa, split],
622
+ style="point_qa",
623
+ )
624
+
625
+ add_task(
626
+ "clocks_no_aug",
627
+ source=seqio.TfdsDataSource(
628
+ tfds_name="clocks:1.0.1",
629
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
630
+ ),
631
+ preprocessors=[
632
+ clocks_preprocessor
633
+ ],
634
+ style="clocks",
635
+ )
636
+
637
+
638
+ add_task(
639
+ "clock_bench",
640
+ source=seqio.TfdsDataSource(
641
+ tfds_name="clock_bench:1.0.0",
642
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
643
+ ),
644
+ preprocessors=[
645
+ clock_bench_preprocessor
646
+ ],
647
+ inf_only=True,
648
+ style="clocks",
649
+ )
650
+
651
+ add_task(
652
+ "wiki_data",
653
+ source=seqio.TfdsDataSource(
654
+ tfds_name="cockatoo_wiki:1.0.0",
655
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
656
+ splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
657
+ ),
658
+ preprocessors=[extract_wiki_data],
659
+ style="wiki_data",
660
+ )
661
+
662
+
663
+ add_task(
664
+ "wiki_data_name",
665
+ source=seqio.TfdsDataSource(
666
+ tfds_name="cockatoo_wiki:1.0.0",
667
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
668
+ splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
669
+ ),
670
+ preprocessors=[
671
+ extract_wiki_data_name,
672
+ mild_color_aug_preprocessor
673
+ ],
674
+ style="wiki_data",
675
+ )
676
+
677
+ add_task(
678
+ "wiki_data_describe",
679
+ source=seqio.TfdsDataSource(
680
+ tfds_name="cockatoo_wiki:1.0.0",
681
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
682
+ splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
683
+ ),
684
+ preprocessors=[extract_wiki_data_describe],
685
+ inf_only=True,
686
+ style="wiki_data",
687
+ )
688
+
689
+ add_task(
690
+ "wiki_data_describe",
691
+ source=seqio.TfdsDataSource(
692
+ tfds_name="cockatoo_wiki:1.0.0",
693
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
694
+ splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
695
+ ),
696
+ preprocessors=[extract_wiki_data_describe],
697
+ inf_only=True,
698
+ style="wiki_data",
699
+ )
700
+
701
+
702
+ for name, src in [
703
+ ("scifi_charts", seqio.TfdsDataSource(
704
+ tfds_name="sci_fi_charts:1.0.6",
705
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
706
+ splits={"train": "train[1024:]", "validation": "train[:1024]"}
707
+ )),
708
+ ("scifi_table", seqio.TfdsDataSource(
709
+ tfds_name="sci_fi_table:1.0.3",
710
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
711
+ splits={"train": "train[1024:]", "validation": "train[:1024]"}
712
+ )),
713
+ ("scifi_document", seqio.TfdsDataSource(
714
+ tfds_name="sci_fi_document:1.0.3",
715
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
716
+ splits={"train": "train[1024:]", "validation": "train[:1024]"}
717
+ )),
718
+ ("scifi_diagram", seqio.TfdsDataSource(
719
+ tfds_name="sci_fi_diagram:1.0.0",
720
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
721
+ splits={"train": "train[1024:]", "validation": "train[:1024]"}
722
+ )),
723
+ ("scifi_natural", seqio.TfdsDataSource(
724
+ tfds_name="sci_fi_natural:1.0.1",
725
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
726
+ splits={"train": "train[128:]", "validation": "train[:128]"}
727
+ )),
728
+ ("scifi_nutrition", seqio.TfdsDataSource(
729
+ tfds_name="sci_fi_nutrition:1.0.0",
730
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
731
+ splits={"train": "train[128:]", "validation": "train[:128]"}
732
+ ))
733
+ ]:
734
+ add_task(
735
+ name + "_qa",
736
+ source=src,
737
+ preprocessors=[
738
+ remove_no_qa,
739
+ _preprocess_scifi,
740
+ extract_individual_vqa,
741
+ ],
742
+ inf_preprocessor=[
743
+ remove_no_qa, _preprocess_scifi,
744
+ functools.partial(flatten_parts, parts=["question", "answer"]),
745
+ extract_individual_vqa,
746
+ ],
747
+ style=name,
748
+ )
749
+ add_task(
750
+ name + "_qa_split",
751
+ source=src,
752
+ preprocessors=[
753
+ remove_no_qa,
754
+ _preprocess_scifi,
755
+ extract_individual_vqa,
756
+ split
757
+ ],
758
+ inf_preprocessor=[
759
+ remove_no_qa, _preprocess_scifi,
760
+ functools.partial(flatten_parts, parts=["question", "answer"]),
761
+ extract_individual_vqa,
762
+ ],
763
+ style=name,
764
+ )
765
+ add_task(
766
+ name + "_qa_exp",
767
+ source=src,
768
+ preprocessors=[
769
+ remove_no_qa,
770
+ _preprocess_scifi,
771
+ extract_scifi_qa_exp,
772
+ extract_individual_vqa,
773
+ ],
774
+ inf_preprocessor=[
775
+ remove_no_qa, _preprocess_scifi,
776
+ extract_scifi_qa_exp,
777
+ functools.partial(flatten_parts, parts=["question", "answer"]),
778
+ extract_individual_vqa,
779
+ ],
780
+ style=name + "_qa_exp",
781
+ )
782
+ add_task(
783
+ name + "_qa_exp_split",
784
+ source=src,
785
+ preprocessors=[
786
+ remove_no_qa,
787
+ _preprocess_scifi,
788
+ extract_scifi_qa_exp,
789
+ extract_individual_vqa,
790
+ split,
791
+ ],
792
+ inf_preprocessor=[
793
+ remove_no_qa, _preprocess_scifi,
794
+ extract_scifi_qa_exp,
795
+ functools.partial(flatten_parts, parts=["question", "answer"]),
796
+ extract_individual_vqa,
797
+ ],
798
+ style=name + "_qa_exp",
799
+ )
800
+ add_task(
801
+ name + "_exp",
802
+ source=src,
803
+ preprocessors=[
804
+ remove_no_qa,
805
+ _preprocess_scifi,
806
+ scifi_explanation_only,
807
+ extract_individual_vqa,
808
+ split
809
+ ],
810
+ style=name + "_exp"
811
+ )
812
+ add_task(
813
+ name + "_demo",
814
+ source=src,
815
+ preprocessors=[
816
+ remove_no_qa,
817
+ _preprocess_scifi,
818
+ extract_scifi_qa_demo,
819
+ extract_individual_vqa,
820
+ split
821
+ ],
822
+ style="scifi_demo"
823
+ )
824
+
825
+
826
+ add_task(
827
+ "chart_qa_scifi",
828
+ source=seqio.TfdsDataSource(
829
+ tfds_name="chart_qa:1.0.2",
830
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
831
+ splits={"train": "train", "validation": "val", "test": "test"}
832
+ ),
833
+ preprocessors=[
834
+ rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
835
+ extract_individual_vqa,
836
+ ],
837
+ style="scifi_charts_qa_exp",
838
+ )
839
+
840
+
841
+ add_task(
842
+ "chart_qa_prompting",
843
+ source=seqio.TfdsDataSource(
844
+ tfds_name="chart_qa:1.0.2",
845
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
846
+ splits={"train": "train", "validation": "val", "test": "test"}
847
+ ),
848
+ preprocessors=[
849
+ rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
850
+ chartqa_prompting,
851
+ extract_individual_vqa,
852
+ ],
853
+ style="chart_qa",
854
+ )
855
+
856
+
857
+ add_task(
858
+ "chart_qa_prompting_explanation",
859
+ source=seqio.TfdsDataSource(
860
+ tfds_name="chart_qa:1.0.2",
861
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
862
+ splits={"train": "train", "validation": "val", "test": "test"}
863
+ ),
864
+ preprocessors=[
865
+ rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
866
+ chartqa_explanation,
867
+ extract_individual_vqa,
868
+ ],
869
+ style="chart_qa",
870
+ )
871
+
872
+
873
+
874
+ add_task(
875
+ "coco_captioning_karpathy_multi",
876
+ source=seqio.TfdsDataSource(
877
+ tfds_name="coco_captioning_karpathy:1.0.2",
878
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
879
+ splits={"train": "train", "validation": "val", "test": "test"}
880
+ ),
881
+ preprocessors=[rename(text="captions")],
882
+ inf_preprocessor=[add_coco_url],
883
+ style="coco_captioning",
884
+ )
885
+
886
+
887
+ add_task(
888
+ "coco_caption_2017_grouped",
889
+ source=seqio.TfdsDataSource(
890
+ tfds_name="coco_all:1.0.1",
891
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
892
+ ),
893
+ preprocessors=[
894
+ functools.partial(
895
+ rekey, key_map={
896
+ "image/filename": ["image/filename"],
897
+ "image": ["image"],
898
+ "text": ["captions", "text"]
899
+ }),
900
+ join_captions
901
+ ],
902
+ style="coco_captioning_multiple",
903
+ )
904
+
905
+
906
+ add_task(
907
+ "llava_pretrain",
908
+ source=seqio.TfdsDataSource(
909
+ tfds_name="llava_pretrain:1.0.0",
910
+ tfds_data_dir="gs://mm-olmo-datasets/",
911
+ splits=dict(
912
+ train="train[4096:]",
913
+ validation="train[:4096]"
914
+ )
915
+ ),
916
+ preprocessors=[extract_llava],
917
+ style="web_caption"
918
+ )
919
+
920
+
921
+ add_task(
922
+ "rohun_images",
923
+ source=seqio.TfdsDataSource(
924
+ tfds_name="rohun_images:1.0.0",
925
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
926
+ ),
927
+ preprocessors=[],
928
+ style="long_caption",
929
+ inf_only=True
930
+ )
931
+
932
+
933
+ add_task(
934
+ "dense_caption_eval",
935
+ source=seqio.TfdsDataSource(
936
+ tfds_name="dense_captioning_eval:1.0.0",
937
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
938
+ splits=dict(validation="train")
939
+ ),
940
+ preprocessors=[],
941
+ style="long_caption",
942
+ inf_only=True
943
+ )
944
+
945
+
946
+ add_task(
947
+ "dense_caption_eval_dbg",
948
+ source=seqio.TfdsDataSource(
949
+ tfds_name="dense_captioning_eval:1.0.0",
950
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
951
+ splits=dict(validation="train")
952
+ ),
953
+ preprocessors=[
954
+ lambda ds: ds.filter(lambda x: x["url"] == "https://explore-multimodal-datasets.s3.us-west-2.amazonaws.com/eval-set/v0/eval-set/a211be07e2c9c722ef75093026a608856bd07ad935ebdedea6f2944b1f2d2b0e.jpg")
955
+ ],
956
+ style="long_caption",
957
+ inf_only=True
958
+ )
959
+
960
+
961
+ add_task(
962
+ "dense_caption_sample",
963
+ source=seqio.TfdsDataSource(
964
+ tfds_name="dense_captioning_eval:1.0.0",
965
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
966
+ splits=dict(
967
+ validation="train"
968
+ )
969
+ ),
970
+ preprocessors=[select_dense_caption_sample],
971
+ style="long_caption",
972
+ )
973
+
974
+
975
+ add_task(
976
+ "cockatoo_1per_caption_287k",
977
+ source=seqio.TfdsDataSource(
978
+ tfds_name="cockatoo_1per_caption_287k:1.0.5",
979
+ tfds_data_dir="gs://mm-olmo-data/",
980
+ splits=dict(
981
+ train="train[5120:]",
982
+ validation="train[:5120]"
983
+ )
984
+ ),
985
+ preprocessors=[
986
+ rename(text="caption"),
987
+ ],
988
+ style="long_caption"
989
+ )
990
+
991
+
992
+ def _filter_large_ratio(ds):
993
+ return ds.filter(
994
+ lambda x: tf.shape(x["image"])[0] > tf.shape(x["image"])[1]*2
995
+ )
996
+
997
+
998
+ add_task(
999
+ f"cockatoo_dbg",
1000
+ source= seqio.TfdsDataSource(
1001
+ tfds_name="cockatoo_476k:1.0.5",
1002
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1003
+ splits=dict(
1004
+ train="train[5120:]",
1005
+ validation="train[:5120]"
1006
+ )
1007
+ )
1008
+ ,
1009
+ preprocessors=[
1010
+ _filter_large_ratio,
1011
+ extract_caption_and_transcript
1012
+ ],
1013
+ style=["long_caption", "transcript"]
1014
+ )
1015
+
1016
+
1017
+ for name, src in [
1018
+ ("712k_sept6", seqio.TfdsDataSource(
1019
+ tfds_name="cockatoo_712k_sept6:1.0.5",
1020
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1021
+ splits=dict(
1022
+ train="train[5120:]",
1023
+ validation="train[:5120]"
1024
+ )
1025
+ )),
1026
+ ("476k", seqio.TfdsDataSource(
1027
+ tfds_name="cockatoo_476k:1.0.5",
1028
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1029
+ splits=dict(
1030
+ train="train[5120:]",
1031
+ validation="train[:5120]"
1032
+ )
1033
+ )),
1034
+ ("476k_gpt_captions", seqio.TfdsDataSource(
1035
+ tfds_name="cockatoo_476k_gpt_captions:1.0.5",
1036
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1037
+ splits=dict(
1038
+ train="train[5120:]",
1039
+ validation="train[:5120]"
1040
+ )
1041
+ )),
1042
+ ("100k_of_476k_gpt_captions", seqio.TfdsDataSource(
1043
+ tfds_name="cockatoo_476k_gpt_captions:1.0.5",
1044
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1045
+ splits=dict(
1046
+ train="train[5120:105120]",
1047
+ validation="train[:5120]"
1048
+ )
1049
+ )),
1050
+ ("200k_of_476k_gpt_captions", seqio.TfdsDataSource(
1051
+ tfds_name="cockatoo_476k_gpt_captions:1.0.5",
1052
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1053
+ splits=dict(
1054
+ train="train[5120:205120]",
1055
+ validation="train[:5120]"
1056
+ )
1057
+ )),
1058
+ ("300k_of_476k_gpt_captions", seqio.TfdsDataSource(
1059
+ tfds_name="cockatoo_476k_gpt_captions:1.0.5",
1060
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1061
+ splits=dict(
1062
+ train="train[5120:305120]",
1063
+ validation="train[:5120]"
1064
+ )
1065
+ )),
1066
+ ("400k_of_476k_gpt_captions", seqio.TfdsDataSource(
1067
+ tfds_name="cockatoo_476k_gpt_captions:1.0.5",
1068
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1069
+ splits=dict(
1070
+ train="train[5120:405120]",
1071
+ validation="train[:5120]"
1072
+ )
1073
+ )),
1074
+ ("400k_of_476k", seqio.TfdsDataSource(
1075
+ tfds_name="cockatoo_476k:1.0.5",
1076
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1077
+ splits=dict(
1078
+ train="train[5120:405120]",
1079
+ validation="train[:5120]"
1080
+ )
1081
+ )),
1082
+ ("300k_of_476k", seqio.TfdsDataSource(
1083
+ tfds_name="cockatoo_476k:1.0.5",
1084
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1085
+ splits=dict(
1086
+ train="train[5120:305120]",
1087
+ validation="train[:5120]"
1088
+ )
1089
+ )),
1090
+ ("200k_of_476k", seqio.TfdsDataSource(
1091
+ tfds_name="cockatoo_476k:1.0.5",
1092
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1093
+ splits=dict(
1094
+ train="train[5120:205120]",
1095
+ validation="train[:5120]"
1096
+ )
1097
+ )),
1098
+ ("100k_of_476k", seqio.TfdsDataSource(
1099
+ tfds_name="cockatoo_476k:1.0.5",
1100
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1101
+ splits=dict(
1102
+ train="train[5120:105120]",
1103
+ validation="train[:5120]"
1104
+ )
1105
+ )),
1106
+ ("276k", seqio.TfdsDataSource(
1107
+ tfds_name="cockatoo:1.0.5",
1108
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1109
+ splits=dict(
1110
+ train="train[5120:]",
1111
+ validation="train[:5120]"
1112
+ )
1113
+ )),
1114
+ ("180k", seqio.TfdsDataSource(
1115
+ tfds_name="cockatoo:1.0.3",
1116
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1117
+ splits=dict(
1118
+ train="train[4096:]",
1119
+ validation="train[:4096]"
1120
+ )
1121
+ )),
1122
+ ("84k_claude_captions", seqio.TfdsDataSource(
1123
+ tfds_name="cockatoo_84k_claude_captions:1.0.0",
1124
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1125
+ splits=dict(
1126
+ train="train[1000:]",
1127
+ validation="train[:1000]"
1128
+ )
1129
+ )),
1130
+ ]:
1131
+ add_task(
1132
+ f"cockatoo_{name}",
1133
+ source=src,
1134
+ preprocessors=[extract_caption],
1135
+ style="long_caption"
1136
+ )
1137
+
1138
+ add_task(
1139
+ f"cockatoo_and_transcript_{name}",
1140
+ source=src,
1141
+ preprocessors=[extract_caption_and_transcript],
1142
+ style=["long_caption", "transcript"]
1143
+ )
1144
+
1145
+ add_task(
1146
+ f"cockatoo_and_transcript_stratified_{name}",
1147
+ source=src,
1148
+ preprocessors=[
1149
+ extract_caption_and_transcript,
1150
+ # put this here to hack seqio into repeating the dataset after
1151
+ # `extract_caption_and_transcript` which will properly stratify the transcripts
1152
+ seqio.CacheDatasetPlaceholder(),
1153
+ ],
1154
+ style=["long_caption", "transcript"]
1155
+ )
1156
+ add_task(
1157
+ f"cockatoo_and_all_transcripts_{name}",
1158
+ source=src,
1159
+ preprocessors=[extract_caption_and_all_transcripts],
1160
+ style=["long_caption", "transcript", "transcript", "transcript"]
1161
+ )
1162
+
1163
+ add_task(
1164
+ f"cockatoo_all_transcripts_{name}",
1165
+ source=src,
1166
+ preprocessors=[extract_all_transcripts],
1167
+ style="transcript"
1168
+ )
1169
+ add_task(
1170
+ f"cockatoo_transcripts_{name}",
1171
+ source=src,
1172
+ preprocessors=[extract_transcript],
1173
+ style="transcript"
1174
+ )
1175
+
1176
+
1177
+ TFRECORD_IMAGE_TEXT_FEATURES = {
1178
+ 'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
1179
+ 'text':tf.io.FixedLenFeature(shape=(), dtype=tf.string),
1180
+ }
1181
+
1182
+
1183
+ add_task(
1184
+ "laion400m",
1185
+ source=seqio.TFExampleDataSource(
1186
+ split_to_filepattern={
1187
+ "train": os.path.join("gs://unified-io-2-us-east/", "pretrain-datasets", "laion400m", "1.0.0", "laion400m-train*"),
1188
+ },
1189
+ feature_description=TFRECORD_IMAGE_TEXT_FEATURES,
1190
+ ),
1191
+ preprocessors=[
1192
+ functools.partial(rekey, key_map={
1193
+ "image": ["image"],
1194
+ "text": ["text"]
1195
+ }),
1196
+ ],
1197
+ style="laion",
1198
+ )
1199
+
1200
+
1201
+ add_task(
1202
+ "laion_2B",
1203
+ source=seqio.TFExampleDataSource(
1204
+ split_to_filepattern={
1205
+ "train": os.path.join(MULTITASK_TFDS_DATA_DIR, "laion2b_en", "1.0.0", "laion2b_en-train*"),
1206
+ },
1207
+ feature_description=TFRECORD_IMAGE_TEXT_FEATURES,
1208
+ ),
1209
+ preprocessors=[
1210
+ functools.partial(rekey, key_map={
1211
+ "image": ["image"],
1212
+ "text": ["text"]
1213
+ }),
1214
+ ],
1215
+ style="laion",
1216
+ )
1217
+
1218
+
1219
+ add_task(
1220
+ "region_caption_vg",
1221
+ source=seqio.TfdsDataSource(
1222
+ tfds_name="vg:1.0.1",
1223
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1224
+ ),
1225
+ preprocessors=[region_captions_to_dense],
1226
+ style="region_captions",
1227
+ )
1228
+
1229
+
1230
+ add_task(
1231
+ "pdfa_eng_wds",
1232
+ source=seqio.TfdsDataSource(
1233
+ tfds_name="pdfa_eng_wds:1.0.0",
1234
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1235
+ ),
1236
+ preprocessors=[
1237
+ functools.partial(max_words, max_words=400),
1238
+ format_pdfa_eng_wds
1239
+ ],
1240
+ style="pdfa_eng_wds",
1241
+ )
1242
+
1243
+
1244
+ add_task(
1245
+ "idl_words",
1246
+ source=seqio.TfdsDataSource(
1247
+ tfds_name="idl_words:1.0.0",
1248
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1249
+ ),
1250
+ preprocessors=[],
1251
+ style="idl_words",
1252
+ )
1253
+
1254
+
1255
+
1256
+ open_image_v6_keys_to_features = {
1257
+ 'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
1258
+ 'image_id': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
1259
+ 'detection/label':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
1260
+ 'detection/bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
1261
+ 'detection/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
1262
+ 'vrd/sub_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
1263
+ 'vrd/obj_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
1264
+ 'vrd/sub_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
1265
+ 'vrd/obj_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
1266
+ 'vrd/relation': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
1267
+ 'vrd/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
1268
+ 'cap/cap_caption': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
1269
+ 'cap/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
1270
+ 'seg/masks': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
1271
+ 'seg/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
1272
+ 'seg/label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
1273
+ 'seg/bbox': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
1274
+ }
1275
+
1276
+
1277
+ add_task(
1278
+ "localized_narratives_v6",
1279
+ source=seqio.TFExampleDataSource(
1280
+ split_to_filepattern={
1281
+ "train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"),
1282
+ },
1283
+ feature_description=open_image_v6_keys_to_features,
1284
+ ),
1285
+ preprocessors=[extract_localized_narrative],
1286
+ style="localized_narratives",
1287
+ )
1288
+
1289
+
1290
+ add_task(
1291
+ "lvis_objects",
1292
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
1293
+ source=seqio.TfdsDataSource(
1294
+ tfds_name="lvis:1.2.0",
1295
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1296
+ ),
1297
+ preprocessors=[
1298
+ extract_lvis,
1299
+ region_captions_to_dense,
1300
+ ],
1301
+ style="lvis_objects",
1302
+ )
1303
+
1304
+
1305
+ add_task(
1306
+ "open_images_with_objects",
1307
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
1308
+ source=seqio.TFExampleDataSource(
1309
+ split_to_filepattern={
1310
+ "train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"),
1311
+ },
1312
+ feature_description=open_image_v6_keys_to_features,
1313
+ ),
1314
+ preprocessors=[
1315
+ extract_open_images_boxes,
1316
+ region_captions_to_dense,
1317
+ ],
1318
+ style="visual_narratives_with_objects",
1319
+ )
1320
+
1321
+
1322
+ add_task(
1323
+ "cockatoo_with_acc_476k_gpt_captions",
1324
+ source=seqio.TfdsDataSource(
1325
+ tfds_name="cockatoo_with_acc_476k_gpt_captions:1.0.0",
1326
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1327
+ splits=dict(
1328
+ train="train[5120:]",
1329
+ validation="train[:5120]"
1330
+ )
1331
+ ),
1332
+ preprocessors=[accuracy_conditioned_joint],
1333
+ inf_preprocessor=[functools.partial(accuracy_conditioned_joint, is_eval=True)],
1334
+ style=None
1335
+ )
1336
+
1337
+
1338
+ add_task(
1339
+ "dense_caption_eval_with_acc",
1340
+ source=seqio.TfdsDataSource(
1341
+ tfds_name="dense_captioning_eval:1.0.0",
1342
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1343
+ splits=dict(validation="train")
1344
+ ),
1345
+ preprocessors=[functools.partial(accuracy_conditioned_joint, is_eval=True)],
1346
+ style="long_caption",
1347
+ inf_only=True
1348
+ )
1349
+
1350
+ # ************************
1351
+ # VQA Datasets
1352
+ # ************************
1353
+
1354
+ add_task(
1355
+ "science_qa_img",
1356
+ source=seqio.TfdsDataSource(
1357
+ tfds_name="science_qa:1.0.0",
1358
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1359
+ splits={"train": "train", "validation": "val", "test": "test"}
1360
+ ),
1361
+ preprocessors=[
1362
+ image_only,
1363
+ rename(answer_idx="answer"),
1364
+ build_question_with_hint,
1365
+ format_multiple_choice_qa
1366
+ ],
1367
+ style="science_qa",
1368
+ )
1369
+
1370
+
1371
+ add_task(
1372
+ "tabwmp_da",
1373
+ source=seqio.TfdsDataSource(
1374
+ tfds_name="tab_mwp:1.0.0",
1375
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1376
+ splits={"train": "train", "validation": "dev", "test": "test"}
1377
+ ),
1378
+ preprocessors=[
1379
+ rename(text="answer")
1380
+ ],
1381
+ style="tabwmp_da",
1382
+ )
1383
+
1384
+
1385
+ add_task(
1386
+ "figure_qa",
1387
+ source=seqio.TfdsDataSource(
1388
+ tfds_name="figure_qa:1.0.2",
1389
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1390
+ splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"}
1391
+ ),
1392
+ preprocessors=[extract_figureqa, extract_individual_vqa],
1393
+ style="figure_qa",
1394
+ )
1395
+
1396
+ add_task(
1397
+ "figure_qa_zero_shot",
1398
+ source=seqio.TfdsDataSource(
1399
+ tfds_name="figure_qa:1.0.2",
1400
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1401
+ splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"}
1402
+ ),
1403
+ preprocessors=[extract_figureqa, convert_figureqa_answer, extract_individual_vqa],
1404
+ style="figure_qa",
1405
+ )
1406
+
1407
+
1408
+ add_task(
1409
+ "plot_qa",
1410
+ source=seqio.TfdsDataSource(
1411
+ tfds_name="plot_qa:1.0.0",
1412
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1413
+ ),
1414
+ preprocessors=[extract_figureqa, extract_individual_vqa],
1415
+ inf_preprocessor=[
1416
+ extract_figureqa,
1417
+ functools.partial(flatten_parts, parts=["questions", "answer", "question_id"]),
1418
+ extract_individual_vqa
1419
+ ],
1420
+ style="plot_qa",
1421
+ )
1422
+
1423
+
1424
+ add_task(
1425
+ "ai2_diagram",
1426
+ source=seqio.TfdsDataSource(
1427
+ tfds_name="ai2_diagram:1.0.2",
1428
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1429
+ splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
1430
+ ),
1431
+ preprocessors=[
1432
+ rename(choices="answer_texts", answer_idx="correct_answer"),
1433
+ format_multiple_choice_qa
1434
+ ],
1435
+ style="ai2_diagram",
1436
+ )
1437
+
1438
+
1439
+ add_task(
1440
+ "ai2_diagram_v2",
1441
+ source=seqio.TfdsDataSource(
1442
+ tfds_name="ai2_diagram_v2:1.0.1",
1443
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1444
+ ),
1445
+ preprocessors=[
1446
+ rename(choices="answer_texts", answer_idx="correct_answer"),
1447
+ format_ai2d
1448
+ ],
1449
+ style="ai2_diagram",
1450
+ )
1451
+
1452
+
1453
+ add_task(
1454
+ "ai2_diagram_v2_transparent",
1455
+ source=seqio.TfdsDataSource(
1456
+ tfds_name="ai2_diagram_v2_transparent:1.0.5",
1457
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1458
+ ),
1459
+ preprocessors=[
1460
+ rename(choices="answer_texts", answer_idx="correct_answer"),
1461
+ format_ai2d
1462
+ ],
1463
+ style="ai2_diagram",
1464
+ )
1465
+
1466
+ # ai2_diagram_v2 mixed with addiitonal abc label questions with transparent box.
1467
+ # Shares the same image split as ai2_diagram_v2.
1468
+ add_task(
1469
+ "ai2_diagram_v2_mix_transparent",
1470
+ source=seqio.TfdsDataSource(
1471
+ tfds_name="ai2_diagram_v2_mix_transparent:1.0.6",
1472
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1473
+ splits={
1474
+ "train": "train_mix",
1475
+ "validation": "validation_mix",
1476
+ "test": "test_mix", # test should only use either transparent or opaque
1477
+ # "test": "test_opaque",
1478
+ }
1479
+ ),
1480
+ preprocessors=[
1481
+ rename(choices="answer_texts", answer_idx="correct_answer"),
1482
+ format_ai2d
1483
+ ],
1484
+ style="ai2_diagram",
1485
+ )
1486
+
1487
+ add_task(
1488
+ "ai2_diagram_v2_mix_transparent_one_style",
1489
+ source=seqio.TfdsDataSource(
1490
+ tfds_name="ai2_diagram_v2_mix_transparent:1.0.6",
1491
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1492
+ splits={
1493
+ "train": "train_mix",
1494
+ "validation": "validation_mix",
1495
+ "test": "test_mix", # test should only use either transparent or opaque
1496
+ # "test": "test_opaque",
1497
+ }
1498
+ ),
1499
+ preprocessors=[
1500
+ rename(choices="answer_texts", answer_idx="correct_answer"),
1501
+ functools.partial(format_ai2d, variable_style=False),
1502
+ ],
1503
+ style="ai2_diagram",
1504
+ )
1505
+
1506
+
1507
+ for src, test_sets in [
1508
+ ["refclef_unc", ["testA", "testB", "testC", "testAB", "testBC"]],
1509
+ ["refcoco_unc", ["testA", "testB"]],
1510
+ ["refcocoplus_unc", ["testA", "testB"]],
1511
+ ["refcocog_umd", ["test"]],
1512
+ ]:
1513
+ if "coco" in src:
1514
+ add_url = [add_coco_url]
1515
+ else:
1516
+ add_url = []
1517
+ splits = {x: x for x in test_sets}
1518
+ splits.update({"train": "train", "validation": "val"})
1519
+ add_task(
1520
+ src,
1521
+ source=seqio.TfdsDataSource(
1522
+ tfds_name=f"{src}:1.0.2",
1523
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1524
+ splits=splits
1525
+ ),
1526
+ preprocessors=[refexp],
1527
+ inf_preprocessor=add_url + [
1528
+ refexp_inf,
1529
+ # Flatten objects
1530
+ functools.partial(flatten_parts, parts=["refexp", "metadata/bbox"]),
1531
+ # Flatten expressions
1532
+ functools.partial(flatten_parts, parts=["refexp"])
1533
+ ],
1534
+ style="refexp",
1535
+ decode_image=True,
1536
+ )
1537
+ add_task(
1538
+ src + "_pointing",
1539
+ source=seqio.TfdsDataSource(
1540
+ tfds_name=f"{src}:1.0.2",
1541
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1542
+ splits=splits
1543
+ ),
1544
+ preprocessors=[refexp_pointing],
1545
+ inf_preprocessor=add_url + [
1546
+ refexp_pointing_inf,
1547
+ functools.partial(flatten_parts, parts=["refexp", "metadata/bbox", "metadata/mask", "metadata/answer"]),
1548
+ functools.partial(flatten_parts, parts=["refexp"])
1549
+ ],
1550
+ decode_image=True,
1551
+ style="refexp_pointing",
1552
+ )
1553
+
1554
+
1555
+ # FIXME
1556
+ add_task(
1557
+ "ai2_diagram_test",
1558
+ source=seqio.TfdsDataSource(
1559
+ tfds_name="ai2_diagram:1.0.2",
1560
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1561
+ splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
1562
+ ),
1563
+ preprocessors=[
1564
+ rename(choices="answer_texts", answer_idx="correct_answer"),
1565
+ format_multiple_choice_qa
1566
+ ],
1567
+ style="ai2_diagram",
1568
+ )
1569
+
1570
+
1571
+ add_task(
1572
+ "gqa",
1573
+ source=seqio.TfdsDataSource(
1574
+ tfds_name="gqa:1.0.1",
1575
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1576
+ splits={"train": "train", "validation": "val", "test": "test"}
1577
+ ),
1578
+ preprocessors=[
1579
+ functools.partial(format_gqa, is_balanced=True),
1580
+ extract_individual_vqa,
1581
+ ],
1582
+ inf_preprocessor=[
1583
+ functools.partial(format_gqa, is_balanced=True),
1584
+ extract_individual_vqa,
1585
+ ],
1586
+ style="gqa",
1587
+ )
1588
+
1589
+
1590
+ add_task(
1591
+ "gqa_multi",
1592
+ source=seqio.TfdsDataSource(
1593
+ tfds_name="gqa:1.0.1",
1594
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1595
+ splits={"train": "train", "validation": "val", "test": "test"}
1596
+ ),
1597
+ preprocessors=[
1598
+ functools.partial(format_gqa, is_balanced=True, flatten=False),
1599
+ extract_individual_vqa,
1600
+ ],
1601
+ inf_preprocessor=[
1602
+ functools.partial(format_gqa, is_balanced=True, flatten=False),
1603
+ extract_individual_vqa,
1604
+ ],
1605
+ style="gqa",
1606
+ )
1607
+
1608
+
1609
+ add_task(
1610
+ "text_vqa",
1611
+ source=seqio.TfdsDataSource(
1612
+ tfds_name="text_vqa:1.0.3",
1613
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1614
+ ),
1615
+ preprocessors=[
1616
+ functools.partial(
1617
+ rekey, key_map={
1618
+ "image": ["image"],
1619
+ "questions": ["question"],
1620
+ "answers": ["answers"],
1621
+ "id": ["question_id"]
1622
+ }),
1623
+ extract_individual_vqa,
1624
+ ],
1625
+ style="text_vqa",
1626
+ )
1627
+
1628
+
1629
+ add_task(
1630
+ "okvqa",
1631
+ source=seqio.TfdsDataSource(
1632
+ tfds_name="ok_vqa:1.0.2",
1633
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1634
+ ),
1635
+ preprocessors=[
1636
+ rename(example_id="question_id"),
1637
+ add_coco_url,
1638
+ extract_individual_vqa,
1639
+ ],
1640
+ style="okvqa",
1641
+ )
1642
+
1643
+ add_task(
1644
+ "a_okvqa_da",
1645
+ source=seqio.TfdsDataSource(
1646
+ tfds_name="a_ok_vqa:1.0.2",
1647
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1648
+ splits={"train": "train", "validation": "val", "test": "test"}
1649
+ ),
1650
+ preprocessors=[
1651
+ rename(**{
1652
+ "example_id": "question_id",
1653
+ "answers": "direct_answers",
1654
+ "metadata/difficult_direct_answer": "difficult_direct_answer"
1655
+ }),
1656
+ extract_individual_vqa,
1657
+ ],
1658
+ inf_preprocessor=[
1659
+ filter_difficult_direct_answer,
1660
+ rename(**{
1661
+ "example_id": "question_id",
1662
+ "answers": "direct_answers",
1663
+ "metadata/difficult_direct_answer": "difficult_direct_answer"
1664
+ }),
1665
+ add_coco_url,
1666
+ extract_individual_vqa,
1667
+ ],
1668
+ style="a_okvqa_da",
1669
+ )
1670
+
1671
+
1672
+ add_task(
1673
+ "a_okvqa_mc",
1674
+ source=seqio.TfdsDataSource(
1675
+ tfds_name="a_ok_vqa:1.0.2",
1676
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1677
+ splits={"train": "train", "validation": "val", "test": "test"}
1678
+ ),
1679
+ preprocessors=[
1680
+ rename(**{
1681
+ "example_id": "question_id",
1682
+ "metadata/difficult_direct_answer": "difficult_direct_answer",
1683
+ "answer_idx": "correct_choice_idx"
1684
+ }),
1685
+ add_coco_url,
1686
+ format_multiple_choice_qa,
1687
+ ],
1688
+ style="a_okvqa_mc",
1689
+ )
1690
+
1691
+
1692
+ add_task(
1693
+ "dv_qa",
1694
+ source=seqio.TfdsDataSource(
1695
+ tfds_name="dv_qa:1.0.0",
1696
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1697
+ splits={"train": "train", "validation": "val_easy"}
1698
+ ),
1699
+ preprocessors=[
1700
+ extract_figureqa,
1701
+ extract_individual_vqa,
1702
+ ],
1703
+ inf_preprocessor=[
1704
+ extract_figureqa,
1705
+ flatten_vqa,
1706
+ extract_individual_vqa
1707
+ ],
1708
+ style="dv_qa",
1709
+ )
1710
+
1711
+
1712
+ @seqio.map_over_dataset
1713
+ def add_image_question_example_id(ex):
1714
+ key = tf.strings.join([ex["question"], "\n\n", ex["image"]])
1715
+ ex["metadata/example_id"] = tf.strings.to_hash_bucket(key, 2**30)
1716
+ return ex
1717
+
1718
+
1719
+ add_task(
1720
+ "chart_qa",
1721
+ source=seqio.TfdsDataSource(
1722
+ tfds_name="chart_qa:1.0.2",
1723
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1724
+ splits={"train": "train", "validation": "val", "test": "test"}
1725
+ ),
1726
+ preprocessors=[
1727
+ rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
1728
+ add_image_question_example_id,
1729
+ extract_individual_vqa,
1730
+ ],
1731
+ style="chart_qa",
1732
+ )
1733
+
1734
+
1735
+ add_task(
1736
+ "chart_qa_ex",
1737
+ source=seqio.TfdsDataSource(
1738
+ tfds_name="chart_qa:1.0.2",
1739
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1740
+ splits={"train": "train", "validation": "val", "test": "test"}
1741
+ ),
1742
+ preprocessors=[
1743
+ rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
1744
+ extract_individual_vqa,
1745
+ ],
1746
+ style="scifi_charts_qa_exp",
1747
+ )
1748
+
1749
+
1750
+ add_task(
1751
+ "chart_qa_weighted",
1752
+ source=seqio.TfdsDataSource(
1753
+ tfds_name="chart_qa:1.0.2",
1754
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1755
+ splits={"train": "train", "validation": "val", "test": "test"}
1756
+ ),
1757
+ preprocessors=[
1758
+ rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
1759
+ extract_individual_vqa,
1760
+ functools.partial(reweight_chartqa, human=2*20901/(20901+7398), aug=2*7398/(20901+7398)),
1761
+ ],
1762
+ style="chart_qa",
1763
+ )
1764
+
1765
+
1766
+ add_task(
1767
+ "chart_qa_human",
1768
+ source=seqio.TfdsDataSource(
1769
+ tfds_name="chart_qa:1.0.2",
1770
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1771
+ splits={"train": "train", "validation": "val", "test": "test"}
1772
+ ),
1773
+ preprocessors=[
1774
+ rename(question="query", answer="label"),
1775
+ add_image_question_example_id,
1776
+ filter_human,
1777
+ extract_individual_vqa,
1778
+ ],
1779
+ style="chart_qa",
1780
+ )
1781
+
1782
+
1783
+ add_task(
1784
+ "chart_qa_aug",
1785
+ source=seqio.TfdsDataSource(
1786
+ tfds_name="chart_qa:1.0.2",
1787
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1788
+ splits={"train": "train", "validation": "val", "test": "test"}
1789
+ ),
1790
+ preprocessors=[
1791
+ rename(question="query", answer="label"),
1792
+ filter_aug,
1793
+ extract_individual_vqa,
1794
+ ],
1795
+ style="chart_qa",
1796
+ )
1797
+
1798
+
1799
+ add_task(
1800
+ "doc_qa",
1801
+ source=seqio.TfdsDataSource(
1802
+ tfds_name="doc_qa:1.0.1",
1803
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1804
+ splits={"train": "train", "validation": "val", "test": "test"}
1805
+ ),
1806
+ preprocessors=[fix_doqa_url, extract_individual_vqa],
1807
+ style="doc_qa",
1808
+ )
1809
+
1810
+
1811
+ add_task(
1812
+ "ocr_qa",
1813
+ source=seqio.TfdsDataSource(
1814
+ tfds_name="ocr_vqa:1.0.0",
1815
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1816
+ ),
1817
+ preprocessors=[extract_individual_vqa],
1818
+ inf_preprocessor=[flatten_vqa, extract_individual_vqa],
1819
+ style="ocr_vqa",
1820
+ )
1821
+
1822
+
1823
+ add_task(
1824
+ "st_qa",
1825
+ source=seqio.TfdsDataSource(
1826
+ tfds_name="st_vqa:1.0.2",
1827
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1828
+ splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
1829
+ ),
1830
+ preprocessors=[extract_individual_vqa],
1831
+ inf_preprocessor=[extract_individual_vqa],
1832
+ style="st_qa",
1833
+ )
1834
+
1835
+
1836
+ add_task(
1837
+ "tally_qa",
1838
+ source=seqio.TfdsDataSource(
1839
+ tfds_name="tally_qa:1.0.2",
1840
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1841
+ splits={"train": "train", "validation": "test"}
1842
+ ),
1843
+ preprocessors=[
1844
+ extract_tally_qa,
1845
+ extract_individual_vqa
1846
+ ],
1847
+ inf_preprocessor=[
1848
+ extract_tally_qa,
1849
+ flatten_vqa,
1850
+ extract_individual_vqa
1851
+ ],
1852
+ style="tally_qa",
1853
+ )
1854
+
1855
+
1856
+ add_task(
1857
+ "info_qa",
1858
+ source=seqio.TfdsDataSource(
1859
+ tfds_name="info_qa:1.0.0",
1860
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1861
+ splits={"train": "train", "validation": "val", "test": "test"}
1862
+ ),
1863
+ preprocessors=[extract_individual_vqa],
1864
+ style="info_qa",
1865
+ )
1866
+
1867
+ add_task(
1868
+ "android_control",
1869
+ source=seqio.TfdsDataSource(
1870
+ tfds_name="android_control:2.0.0",
1871
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1872
+ splits={"train": "train", "validation": "val", "test": "test"}
1873
+ ),
1874
+ preprocessors=[extract_android_control],
1875
+ style="android_control",
1876
+ )
1877
+
1878
+ for mode in ["ll", "hl", "hl_ll", "hl_cot"]:
1879
+ add_task(
1880
+ f"android_control_{mode}",
1881
+ source=seqio.TfdsDataSource(
1882
+ tfds_name="android_control:2.0.0",
1883
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1884
+ splits={"train": "train", "validation": "val", "test": "test"}
1885
+ ),
1886
+ preprocessors=[functools.partial(extract_andriod_control_inf, mode=mode)],
1887
+ style="android_control",
1888
+ )
1889
+
1890
+
1891
+ map_coco_vqa = functools.partial(rekey, key_map={
1892
+ "image": ["image"],
1893
+ "questions": ["vqa", "questions"],
1894
+ "answers": ["vqa", "answers"],
1895
+ "id": ["vqa", "id"],
1896
+ "metadata/image_url": ["metadata/image_url"],
1897
+ })
1898
+
1899
+
1900
+ add_task(
1901
+ "coco_2017_vqa",
1902
+ source=seqio.TfdsDataSource(
1903
+ tfds_name="coco_all:1.0.1",
1904
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1905
+ ),
1906
+ preprocessors=[
1907
+ add_coco_url,
1908
+ map_coco_vqa,
1909
+ flatten_vqa,
1910
+ extract_individual_vqa
1911
+ ],
1912
+ style="vqa2",
1913
+ )
1914
+
1915
+
1916
+ add_task(
1917
+ "cockatoo_qa",
1918
+ source=seqio.TfdsDataSource(
1919
+ tfds_name="cockatoo_qa:1.0.0",
1920
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1921
+ splits=dict(
1922
+ train="train[5120:]",
1923
+ validation="train[:5120]"
1924
+ )
1925
+ ),
1926
+ preprocessors=[rename(text="answer")],
1927
+ style=None,
1928
+ )
1929
+
1930
+
1931
+ add_task(
1932
+ "synthetic_qa_v3",
1933
+ source=seqio.TfdsDataSource(
1934
+ tfds_name="synthetic_qa_v3:0.0.4",
1935
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1936
+ splits=dict(
1937
+ train="train[2048:]",
1938
+ validation="train[:2048]"
1939
+ )
1940
+ ),
1941
+ preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
1942
+ style="synthetic_qa",
1943
+ )
1944
+
1945
+
1946
+ add_task(
1947
+ "synthetic_qa_v3_style_tag",
1948
+ source=seqio.TfdsDataSource(
1949
+ tfds_name="synthetic_qa_v3:0.0.4",
1950
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1951
+ splits=dict(
1952
+ train="train[2048:]",
1953
+ validation="train[:2048]"
1954
+ )
1955
+ ),
1956
+ preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
1957
+ style="llm_qa",
1958
+ )
1959
+
1960
+
1961
+ add_task(
1962
+ "synthetic_qa_v3_as_user_qa",
1963
+ source=seqio.TfdsDataSource(
1964
+ tfds_name="synthetic_qa_v3:0.0.4",
1965
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1966
+ splits=dict(
1967
+ train="train[2048:]",
1968
+ validation="train[:2048]"
1969
+ )
1970
+ ),
1971
+ preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
1972
+ style="user_qa",
1973
+ )
1974
+
1975
+
1976
+ add_task(
1977
+ "synthetic_qa_v3_multi_turn",
1978
+ source=seqio.TfdsDataSource(
1979
+ tfds_name="synthetic_qa_v3:0.0.4",
1980
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1981
+ splits=dict(
1982
+ train="train[2048:]",
1983
+ validation="train[:2048]"
1984
+ )
1985
+ ),
1986
+ preprocessors=[extract_cockatoo_qa_v2, filter_single_turn, prefix_how_many_messages],
1987
+ style="synthetic_qa",
1988
+ )
1989
+
1990
+
1991
+ NE_SHARDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
1992
+
1993
+ for i in NE_SHARDS:
1994
+ add_task(
1995
+ f"named_entity{i}",
1996
+ source=seqio.TfdsDataSource(
1997
+ tfds_name=f"named_entities_qa_{i}_of_18:1.0.0",
1998
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
1999
+ splits=dict(
2000
+ train="train[1024:]",
2001
+ validation="train[:1024]"
2002
+ )
2003
+ ),
2004
+ preprocessors=[filter_named_entity, extract_named_entity, extract_individual_vqa],
2005
+ inf_preprocessor=[
2006
+ filter_named_entity,
2007
+ extract_named_entity,
2008
+ flatten_vqa,
2009
+ extract_individual_vqa
2010
+ ],
2011
+ style="named_entity",
2012
+ ignore_errors=True
2013
+ )
2014
+
2015
+
2016
+ add_task(
2017
+ "user_qa",
2018
+ source=seqio.TfdsDataSource(
2019
+ tfds_name="user_qa:0.0.1",
2020
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2021
+ splits=dict(
2022
+ train="train[2048:]",
2023
+ validation="train[:2048]"
2024
+ )
2025
+ ),
2026
+ preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
2027
+ style="user_qa",
2028
+ )
2029
+
2030
+ add_task(
2031
+ "user_questions_for_elo",
2032
+ source=seqio.TfdsDataSource(
2033
+ tfds_name="user_questions_for_elo:0.0.3",
2034
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2035
+ ),
2036
+ preprocessors=[functools.partial(extract_individual_vqa, test=True)],
2037
+ inf_only=True,
2038
+ style="demo",
2039
+ )
2040
+
2041
+
2042
+ def _filter_by_id(ds, prediction_file, max_seq_len):
2043
+ with open(prediction_file) as f:
2044
+ predictions = json.load(f)
2045
+ is_long = []
2046
+ lens = []
2047
+ tokenizer = build_tokenizer("hf-Qwen/Qwen2-7B")
2048
+ for pred in predictions:
2049
+ n_tokens = len(tokenizer.encode(pred["prediction"]))
2050
+ lens.append(n_tokens)
2051
+ if n_tokens >= max_seq_len:
2052
+ is_long.append(pred["example_id"])
2053
+ is_long = tf.constant(is_long)
2054
+ logging.info(f"Filtering for {len(is_long)} ids")
2055
+ return ds.filter(lambda ex: tf.reduce_any(ex["example_id"] == is_long))
2056
+
2057
+
2058
+
2059
+ add_task(
2060
+ "user_questions_for_elo",
2061
+ source=seqio.TfdsDataSource(
2062
+ tfds_name="user_questions_for_elo:0.0.3",
2063
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2064
+ ),
2065
+ preprocessors=[functools.partial(extract_individual_vqa, test=True)],
2066
+ inf_only=True,
2067
+ style="demo",
2068
+ )
2069
+
2070
+
2071
+ add_task(
2072
+ "user_questions_for_elo_long",
2073
+ source=seqio.TfdsDataSource(
2074
+ tfds_name="user_questions_for_elo:0.0.3",
2075
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2076
+ ),
2077
+ preprocessors=[
2078
+ functools.partial(_filter_by_id, prediction_file="/weka/oe-training-default/chrisc/cockatoo/models/uber-model-v11/70b-335-30k-3.2-resume8k-noopt/predictions-ck20000-user_questions_for_elo-test/predictions.json", max_seq_len=230),
2079
+ functools.partial(extract_individual_vqa, test=True)
2080
+ ],
2081
+ inf_only=True,
2082
+ style="demo",
2083
+ )
2084
+
2085
+
2086
+ add_task(
2087
+ "coco_2014_vqa",
2088
+ source=seqio.TfdsDataSource(
2089
+ tfds_name="coco_2014_all:1.0.1",
2090
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2091
+ ),
2092
+ preprocessors=[
2093
+ add_coco_url,
2094
+ map_coco_vqa,
2095
+ flatten_vqa,
2096
+ extract_individual_vqa
2097
+ ],
2098
+ inf_preprocessor=[
2099
+ add_coco_url,
2100
+ map_coco_vqa,
2101
+ flatten_vqa,
2102
+ extract_individual_vqa
2103
+ ],
2104
+ style="vqa2",
2105
+ )
2106
+
2107
+
2108
+ add_task(
2109
+ "coco_2014_vqa_multi",
2110
+ source=seqio.TfdsDataSource(
2111
+ tfds_name="coco_2014_all:1.0.1",
2112
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2113
+ ),
2114
+ preprocessors=[
2115
+ add_coco_url,
2116
+ map_coco_vqa,
2117
+ extract_individual_vqa
2118
+ ],
2119
+ inf_preprocessor=[
2120
+ add_coco_url,
2121
+ map_coco_vqa,
2122
+ flatten_vqa,
2123
+ extract_individual_vqa
2124
+ ],
2125
+ style="vqa2",
2126
+ )
2127
+
2128
+
2129
+ add_task(
2130
+ "coco_2017_vqa_multi",
2131
+ source=seqio.TfdsDataSource(
2132
+ tfds_name="coco_all:1.0.1",
2133
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2134
+ ),
2135
+ preprocessors=[
2136
+ add_coco_url,
2137
+ map_coco_vqa,
2138
+ extract_individual_vqa
2139
+ ],
2140
+ inf_preprocessor=[
2141
+ add_coco_url,
2142
+ map_coco_vqa,
2143
+ flatten_vqa,
2144
+ extract_individual_vqa
2145
+ ],
2146
+ style="vqa2",
2147
+ )
2148
+
2149
+
2150
+ add_task(
2151
+ "vqa_v2_test",
2152
+ source=seqio.TfdsDataSource(
2153
+ tfds_name="coco_test_all:1.0.1",
2154
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2155
+ ),
2156
+ preprocessors=[
2157
+ functools.partial(rekey, key_map={
2158
+ "image": ["image"],
2159
+ "questions": ["vqa", "questions"],
2160
+ "answers": ["vqa", "answers"],
2161
+ "id": ["vqa", "id"],
2162
+ }),
2163
+ flatten_vqa,
2164
+ functools.partial(extract_individual_vqa, test=True)
2165
+ ],
2166
+ style="vqa2",
2167
+ inf_only=True
2168
+ )
2169
+
2170
+ # ************************
2171
+ # Eval-only Datasets
2172
+ # ************************
2173
+
2174
+ add_task(
2175
+ "seed_bench_test",
2176
+ source=seqio.TfdsDataSource(
2177
+ tfds_name="seed_bench:1.0.0",
2178
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2179
+ ),
2180
+ preprocessors=[
2181
+ format_multiple_choice_qa,
2182
+ ],
2183
+ style="a_okvqa_mc",
2184
+ inf_only=True
2185
+ )
2186
+
2187
+
2188
+ add_task(
2189
+ "pope_test",
2190
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2191
+ source=seqio.TfdsDataSource(
2192
+ tfds_name="pope:1.0.0",
2193
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2194
+ ),
2195
+ preprocessors=[
2196
+ add_coco_url,
2197
+ extract_individual_vqa
2198
+ ],
2199
+ style="vqa2",
2200
+ inf_only=True
2201
+ )
2202
+
2203
+
2204
+ MME_SOURCE = seqio.TfdsDataSource(
2205
+ tfds_name="mme:1.0.0",
2206
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2207
+ )
2208
+
2209
+
2210
+ add_task(
2211
+ "mme_test",
2212
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2213
+ source=MME_SOURCE,
2214
+ preprocessors=[
2215
+ functools.partial(flatten_parts, parts=["questions", "answers"]),
2216
+ rename(question="questions", answer="answers"),
2217
+ extract_individual_vqa,
2218
+ ],
2219
+ style="vqa2",
2220
+ inf_only=True
2221
+ )
2222
+
2223
+ add_task(
2224
+ "real_world_qa_test",
2225
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2226
+ source=seqio.TfdsDataSource(
2227
+ tfds_name="real_world_qa:1.0.0",
2228
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2229
+ ),
2230
+ preprocessors=[
2231
+ functools.partial(
2232
+ format_multiple_style_qa,
2233
+ types=['multiple_choice', 'short_answer'],
2234
+ styles=['a_okvqa_mc', 'vqa2'],
2235
+ default_style="a_okvqa_mc",
2236
+ ),
2237
+ ],
2238
+ style=None,
2239
+ inf_only=True
2240
+ )
2241
+
2242
+ add_task(
2243
+ "real_world_qa_no_instruction",
2244
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2245
+ source=seqio.TfdsDataSource(
2246
+ tfds_name="real_world_qa:1.0.0",
2247
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2248
+ ),
2249
+ preprocessors=[
2250
+ functools.partial(
2251
+ functools.partial(format_multiple_style_qa, strip_instruction=True),
2252
+ types=['multiple_choice', 'short_answer'],
2253
+ styles=['a_okvqa_mc', 'vqa2'],
2254
+ default_style="a_okvqa_mc",
2255
+ ),
2256
+ ],
2257
+ style=None,
2258
+ inf_only=True
2259
+ )
2260
+
2261
+ add_task(
2262
+ "real_world_qa_dbg",
2263
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2264
+ source=seqio.TfdsDataSource(
2265
+ tfds_name="real_world_qa:1.0.0",
2266
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2267
+ ),
2268
+ preprocessors=[
2269
+ functools.partial(
2270
+ format_multiple_style_qa,
2271
+ types=['multiple_choice', 'short_answer'],
2272
+ styles=['user_qa', 'user_qa'],
2273
+ default_style="user_qa",
2274
+ ),
2275
+ ],
2276
+ style=None,
2277
+ inf_only=True
2278
+ )
2279
+
2280
+
2281
+ add_task(
2282
+ "mmmu",
2283
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2284
+ source=seqio.TfdsDataSource(
2285
+ tfds_name="mmmu:1.0.0",
2286
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2287
+ splits={"train": "dev"},
2288
+ ),
2289
+ preprocessors=[
2290
+ rename(img_type="metadata/img_type"),
2291
+ functools.partial(
2292
+ extract_mmmu,
2293
+ types=['multiple-choice', 'open'],
2294
+ styles=['a_okvqa_mc', 'vqa2'],
2295
+ default_style="a_okvqa_mc",
2296
+ ),
2297
+ ],
2298
+ style=None,
2299
+ )
2300
+
2301
+
2302
+ add_task(
2303
+ "mmmu_test",
2304
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2305
+ source=seqio.TfdsDataSource(
2306
+ tfds_name="mmmu:1.0.0",
2307
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2308
+ splits={"validation": "validation", "test": "test"},
2309
+ ),
2310
+ preprocessors=[
2311
+ rename(img_type="metadata/img_type"),
2312
+ extract_mmmu,
2313
+ ],
2314
+ style=None,
2315
+ inf_only=True
2316
+ )
2317
+
2318
+ for style in ["vaia_qa", "vaia_qa_short_answer_first", "vqa_online", ]:
2319
+ add_task(
2320
+ f"mmmu_test_{style}",
2321
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2322
+ source=seqio.TfdsDataSource(
2323
+ tfds_name="mmmu:1.0.0",
2324
+ # tfds_name="mmmu_khan_academy:1.0.1",
2325
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2326
+ splits={"validation": "validation", "test": "test", "dev": "dev"},
2327
+ ),
2328
+ preprocessors=[
2329
+ rename(img_type="metadata/img_type"),
2330
+ extract_mmmu_cot,
2331
+ ],
2332
+ style=style,
2333
+ inf_only=True
2334
+ )
2335
+
2336
+
2337
+ add_task(
2338
+ "math_vista_test",
2339
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2340
+ source=seqio.TfdsDataSource(
2341
+ tfds_name="math_vista:1.0.0",
2342
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2343
+ splits={"validation": "testmini", "test": "test"},
2344
+ ),
2345
+ preprocessors=[
2346
+ functools.partial(rekey, key_map={
2347
+ "id": ["id"],
2348
+ "query": ["query"],
2349
+ "image": ["image"],
2350
+ "choices": ["choices"],
2351
+ "answer": ["answer"],
2352
+ "metadata/question_type": ["question_type"],
2353
+ "metadata/answer_type": ["answer_type"],
2354
+ "metadata/precision": ["precision"],
2355
+ "metadata/split": ["metadata/split"],
2356
+ }),
2357
+ functools.partial(extract_math_vista, styles=['a_okvqa_mc', 'vqa2']),
2358
+ ],
2359
+ style=None,
2360
+ inf_only=True
2361
+ )
2362
+
2363
+
2364
+ add_task(
2365
+ "math_vista_v2",
2366
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2367
+ source=seqio.TfdsDataSource(
2368
+ tfds_name="math_vista:1.0.0",
2369
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2370
+ splits={"validation": "testmini", "test": "test"},
2371
+ ),
2372
+ preprocessors=[
2373
+ functools.partial(rekey, key_map={
2374
+ "id": ["id"],
2375
+ "query": ["query"],
2376
+ "image": ["image"],
2377
+ "choices": ["choices"],
2378
+ "answer": ["answer"],
2379
+ "metadata/question_type": ["question_type"],
2380
+ "metadata/answer_type": ["answer_type"],
2381
+ "metadata/precision": ["precision"],
2382
+ "metadata/split": ["metadata/split"],
2383
+ }),
2384
+ reformat_math_vista,
2385
+ functools.partial(
2386
+ extract_math_vista,
2387
+ styles=['a_okvqa_mc', 'vqa2'],
2388
+ ),
2389
+ ],
2390
+ style=None,
2391
+ inf_only=True
2392
+ )
2393
+
2394
+
2395
+ MM_BENCH_SRC = seqio.TfdsDataSource(
2396
+ tfds_name="mmbench:1.0.0",
2397
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2398
+ splits={"validation": "dev", "test": "test"},
2399
+ )
2400
+
2401
+ add_task(
2402
+ "mmbench_test",
2403
+ source=MM_BENCH_SRC,
2404
+ preprocessors=[format_mmbench],
2405
+ style="a_okvqa_mc",
2406
+ inf_only=True
2407
+ )
2408
+
2409
+
2410
+ add_task(
2411
+ "sugar_crepe_test",
2412
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2413
+ source=seqio.TfdsDataSource(
2414
+ tfds_name="sugar_crepe:1.0.0",
2415
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2416
+ ),
2417
+ preprocessors=[
2418
+ add_coco_url,
2419
+ functools.partial(flatten_parts, parts=["choices", "answer_idx", "metadata/answer_type"]),
2420
+ format_multiple_choice_qa,
2421
+ ],
2422
+ style="a_okvqa_mc",
2423
+ inf_only=True
2424
+ )
2425
+
2426
+
2427
+ add_task(
2428
+ "blink_test",
2429
+ # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
2430
+ source=seqio.TfdsDataSource(
2431
+ tfds_name="blink:1.0.0",
2432
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2433
+ ),
2434
+ preprocessors=[
2435
+ functools.partial(rekey, key_map={
2436
+ "id": ["id"],
2437
+ "question": ["prompt"],
2438
+ "image": ["image_concat"],
2439
+ "choices": ["choices"],
2440
+ "answer_idx": ["answer_idx"],
2441
+ "metadata/subtask": ["metadata/subtask"],
2442
+ "metadata/question": ["question"],
2443
+ }),
2444
+ format_multiple_choice_qa,
2445
+ output_options,
2446
+ ],
2447
+ style="a_okvqa_mc",
2448
+ inf_only=True
2449
+ )
2450
+
2451
+ add_task(
2452
+ "oscarbench_qa",
2453
+ source=seqio.TfdsDataSource(
2454
+ tfds_name="oscarbench_qa:1.0.0",
2455
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2456
+ splits={"validation": "val"}
2457
+ ),
2458
+ preprocessors=[oscar_preprocessor],
2459
+ style="oscarbench_qa"
2460
+
2461
+ )
2462
+
2463
+ add_task(
2464
+ "charxiv",
2465
+ source=seqio.TfdsDataSource(
2466
+ tfds_name="charxiv:1.0.0",
2467
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2468
+ splits={"validation": "validation", "test": "test"}
2469
+ ),
2470
+ preprocessors=[charxiv_preprocessor, extract_individual_vqa],
2471
+ inf_preprocessor=[
2472
+ charxiv_preprocessor,
2473
+ functools.partial(flatten_parts, parts=["question", "answer"]),
2474
+ extract_individual_vqa,
2475
+ ],
2476
+ style="charxiv",
2477
+ )
2478
+
2479
+ add_task(
2480
+ "charxiv_descriptive",
2481
+ source=seqio.TfdsDataSource(
2482
+ tfds_name="charxiv:1.0.0",
2483
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2484
+ splits={"validation": "validation", "test": "test"}
2485
+ ),
2486
+ preprocessors=[charxiv_descriptive_preprocessor, extract_individual_vqa],
2487
+ inf_preprocessor=[
2488
+ charxiv_descriptive_preprocessor,
2489
+ functools.partial(flatten_parts, parts=["question", "answer"]),
2490
+ extract_individual_vqa,
2491
+ ],
2492
+ style="charxiv_descriptive",
2493
+ )
2494
+
2495
+ add_task(
2496
+ "charxiv_reasoning",
2497
+ source=seqio.TfdsDataSource(
2498
+ tfds_name="charxiv:1.0.0",
2499
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2500
+ splits={"validation": "validation", "test": "test"}
2501
+ ),
2502
+ preprocessors=[charxiv_reasoning_preprocessor, extract_individual_vqa],
2503
+ style="charxiv_reasoning",
2504
+ )
2505
+
2506
+ for tablevqa_name in ["fintabnetqa", "vwtq", "vwtq_syn"]:
2507
+ add_task(
2508
+ tablevqa_name,
2509
+ source=seqio.TfdsDataSource(
2510
+ tfds_name=f"{tablevqa_name}:1.0.0",
2511
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2512
+ splits={"validation": "test[:125]", "test": "test"}
2513
+ ),
2514
+ preprocessors=[tablevqa_preprocessor, extract_individual_vqa],
2515
+ style=tablevqa_name,
2516
+ )
2517
+
2518
+ add_task(
2519
+ "vtabfact",
2520
+ source=seqio.TfdsDataSource(
2521
+ tfds_name="vtabfact:1.0.0",
2522
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2523
+ splits={"validation": "test[:125]", "test": "test"}
2524
+ ),
2525
+ preprocessors=[vtabfact_preprocessor, extract_individual_vqa],
2526
+ style="vtabfact",
2527
+ )
2528
+
2529
+ add_task(
2530
+ "nutrition_fact",
2531
+ source=seqio.TfdsDataSource(
2532
+ tfds_name="nutrition_fact:1.0.0",
2533
+ tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
2534
+ splits={"validation": "test", "test": "test"}
2535
+ ),
2536
+ preprocessors=[nutrition_fact_preprocessor, extract_individual_vqa],
2537
+ inf_preprocessor=[
2538
+ nutrition_fact_preprocessor,
2539
+ functools.partial(flatten_parts, parts=["question", "answer"]),
2540
+ extract_individual_vqa,
2541
+ ],
2542
+ style="nutrition_fact",
2543
+ inf_only=True
2544
+ )
2545
+
2546
+ for k in ["chart_qa", "info_qa", "doc_qa", "text_vqa", "coco_2014_vqa",
2547
+ "ai2_diagram_v2_mix_transparent", "chart_qa_human"]:
2548
+ TASKS[k + "_demo"] = dataclasses.replace(TASKS[k], style="demo")
torch_util.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import logging
4
+ from typing import Optional, TypeVar, List, Tuple
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ def seed_all(seed: int):
16
+ """Seed all rng objects."""
17
+ import random
18
+
19
+ import numpy as np
20
+
21
+ if seed < 0 or seed > 2**32 - 1:
22
+ raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]")
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ # torch.manual_seed may call manual_seed_all but calling it again here
27
+ # to make sure it gets called at least once
28
+ torch.cuda.manual_seed_all(seed)
29
+
30
+
31
+ def is_distributed() -> bool:
32
+ return dist.is_available() and dist.is_initialized()
33
+
34
+
35
+ def get_node_rank() -> int:
36
+ return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size())
37
+
38
+
39
+ def get_world_size() -> int:
40
+ if is_distributed():
41
+ return dist.get_world_size()
42
+ else:
43
+ return 1
44
+
45
+
46
+ def get_local_world_size() -> int:
47
+ return int(os.environ.get("LOCAL_WORLD_SIZE") or 1)
48
+
49
+
50
+ def get_global_rank() -> int:
51
+ if is_distributed():
52
+ return int(os.environ.get("RANK") or dist.get_rank())
53
+ else:
54
+ return 0
55
+
56
+
57
+ def get_local_rank() -> int:
58
+ return int(os.environ.get("LOCAL_RANK") or 0)
59
+
60
+
61
+ def get_fs_local_rank() -> int:
62
+ """Get the local rank per filesystem, meaning that, regardless of the number of nodes,
63
+ if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
64
+ but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
65
+ """
66
+ if os.environ.get("OLMO_SHARED_FS"):
67
+ return int(os.environ.get("FS_LOCAL_RANK") or get_global_rank())
68
+ else:
69
+ return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
70
+
71
+
72
+ def move_to_device(o: T, device: torch.device) -> T:
73
+ if isinstance(o, torch.Tensor):
74
+ return o.to(device) # type: ignore[return-value]
75
+ elif isinstance(o, dict):
76
+ return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
77
+ elif isinstance(o, list):
78
+ return [move_to_device(x, device) for x in o] # type: ignore[return-value]
79
+ elif isinstance(o, tuple):
80
+ return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
81
+ else:
82
+ return o
83
+
84
+
85
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
86
+ """
87
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
88
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
89
+ """
90
+ if check_neg_inf:
91
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
92
+ if check_pos_inf:
93
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
94
+
95
+
96
+ def get_default_device() -> torch.device:
97
+ if torch.cuda.is_available() and torch.cuda.is_initialized():
98
+ return torch.device("cuda")
99
+ else:
100
+ return torch.device("cpu")
101
+
102
+
103
+ def barrier() -> None:
104
+ if is_distributed():
105
+ dist.barrier()
106
+
107
+
108
+ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
109
+ """
110
+ Get the peak GPU memory usage in MB across all ranks.
111
+ Only rank 0 will get the final result.
112
+ """
113
+ if not torch.cuda.is_available():
114
+ return None
115
+
116
+ device = torch.device("cuda")
117
+ peak_mb = torch.cuda.max_memory_allocated(device) / 1000000
118
+ if is_distributed():
119
+ peak_mb_tensor = torch.tensor(peak_mb, device=device)
120
+ dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX)
121
+ peak_mb = peak_mb_tensor.item()
122
+
123
+ if reset:
124
+ # Reset peak stats.
125
+ torch.cuda.reset_max_memory_allocated(device)
126
+
127
+ return peak_mb
128
+
129
+
130
+ V = TypeVar("V", bool, int, float)
131
+
132
+
133
+ def synchronize_value(value: V, device: torch.device) -> V:
134
+ if dist.is_available() and dist.is_initialized():
135
+ value_tensor = torch.tensor(value, device=device)
136
+ dist.broadcast(value_tensor, 0)
137
+ return value_tensor.item() # type: ignore
138
+ else:
139
+ return value
140
+
141
+
142
+ def synchronize_flag(flag: bool, device: torch.device) -> bool:
143
+ return synchronize_value(flag, device)
144
+
145
+
146
+ def gc_cuda():
147
+ gc.collect()
148
+ if torch.cuda.is_available():
149
+ torch.cuda.empty_cache()
150
+
151
+
152
+ def listinstr(lst, s, delimiter=None):
153
+ assert isinstance(lst, list)
154
+ for item in lst:
155
+ if delimiter:
156
+ if all(x in s for x in item.split(delimiter)):
157
+ return True
158
+ else:
159
+ if item in s:
160
+ return True
161
+ return False
162
+
163
+
164
+ def freeze_module(module: torch.nn.Module, exclude_params: Optional[List[str]] = None):
165
+ for name, param in module.named_parameters():
166
+ if exclude_params is not None and listinstr(exclude_params, name):
167
+ continue
168
+ param.requires_grad = False
169
+
170
+
171
+ def freeze_parameters_by_name(model: torch.nn.Module, freeze_names: Tuple[str]):
172
+ for name in freeze_names:
173
+ try:
174
+ module_or_param = model.get_submodule(name)
175
+ except:
176
+ try:
177
+ module_or_param = model.get_parameter(name)
178
+ except:
179
+ log.warning(f"Could not find module or parameter with name {name}")
180
+ if isinstance(module_or_param, torch.nn.Module):
181
+ freeze_module(module_or_param)
182
+ else:
183
+ module_or_param.requires_grad = False
util.py CHANGED
@@ -33,7 +33,7 @@ from .exceptions import (
33
  OLMoNetworkError,
34
  OLMoThreadError,
35
  )
36
- from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
37
 
38
  try:
39
  from functools import cache
 
33
  OLMoNetworkError,
34
  OLMoThreadError,
35
  )
36
+ # from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
37
 
38
  try:
39
  from functools import cache
utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import hashlib
3
+ import sys
4
+ import typing
5
+ import warnings
6
+ import socket
7
+ from typing import Optional, Any, Dict
8
+ import os
9
+ import logging
10
+ import absl.flags
11
+ from flax.traverse_util import flatten_dict
12
+
13
+ from ml_collections import ConfigDict, config_flags
14
+ from ml_collections.config_dict import placeholder
15
+ from mlxu import function_args_to_config
16
+
17
+ _log_extra_fields: Dict[str, Any] = {}
18
+
19
+
20
+ def is_float_printable(x):
21
+ try:
22
+ f"{x:0.2f}"
23
+ return True
24
+ except (ValueError, TypeError):
25
+ return False
26
+
27
+
28
+ def compute_hash(string: str) -> str:
29
+ """Computes the hash of a string."""
30
+ return hashlib.sha256(string.encode("utf-8")).hexdigest()
31
+
32
+
33
+ def pop_metadata(data):
34
+ meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")}
35
+ return data, meta
36
+
37
+
38
+ def setup_logging():
39
+ handler: logging.Handler
40
+ handler = logging.StreamHandler(sys.stdout)
41
+ formatter = logging.Formatter(
42
+ "[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s",
43
+ datefmt="%H:%M:%S"
44
+ )
45
+ handler.setFormatter(formatter)
46
+ logging.basicConfig(handlers=[handler], level=logging.INFO)
47
+
48
+ logging.captureWarnings(True)
49
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
50
+
51
+
52
+ def get_maybe_optional_type(field_type):
53
+ if type(None) in typing.get_args(field_type):
54
+ # Handle optional type
55
+ args = [x for x in typing.get_args(field_type) if x != type(None)]
56
+ assert len(args) == 1
57
+ field_type = args[0]
58
+ return field_type
59
+
60
+
61
+ def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict:
62
+ """Build a `ConfigDict` matching the possibly nested dataclass
63
+
64
+ dataclass: A dataclass instance or a dataclass type, if an instance defaults
65
+ will be set to the values in the class, if a class defaults will be
66
+ set to the field defaults, or None if the field is required
67
+ defaults_to_none: Make all defaults None
68
+ """
69
+ out = {}
70
+ fields = dataclasses.fields(dataclass)
71
+ for field in fields:
72
+ if not field.init:
73
+ continue
74
+
75
+ if defaults_to_none:
76
+ default = None
77
+ elif hasattr(dataclass, field.name):
78
+ default = getattr(dataclass, field.name)
79
+ elif field.default is dataclasses.MISSING:
80
+ default = None
81
+ else:
82
+ default = field.default
83
+
84
+ field_type = get_maybe_optional_type(field.type)
85
+
86
+ if hasattr(field_type, "__dataclass_fields__"):
87
+ if not defaults_to_none and default is None:
88
+ pass
89
+ else:
90
+ out[field.name] = config_from_dataclass(
91
+ default or field.type, defaults_to_none=defaults_to_none)
92
+ else:
93
+ if default is None:
94
+ assert not field_type == typing.Any
95
+ origin = getattr(field_type, "__origin__", None)
96
+ if origin is not None:
97
+ field_type = origin
98
+ out[field.name] = placeholder(field_type)
99
+ else:
100
+ out[field.name] = default
101
+ return ConfigDict(out)
102
+
103
+
104
+ def dataclass_with_none(cls):
105
+ """Build an instance of possibly nested dataclass `cls` with all attributes None"""
106
+ fields = dataclasses.fields(cls)
107
+ args = {}
108
+ for field in fields:
109
+ if not field.init:
110
+ pass
111
+ elif dataclasses.is_dataclass(field.type):
112
+ args[field.name] = dataclass_with_none(field.type)
113
+ else:
114
+ args[field.name] = None
115
+ return cls(**args)
116
+
117
+
118
+ def dataclass_from_config(cls, config: Dict):
119
+ """Build an instance of `cls` with attributes from `config``"""
120
+ fields = dataclasses.fields(cls)
121
+ args = set(x.name for x in fields)
122
+ for k in config.keys():
123
+ if k not in args:
124
+ raise ValueError(f"Config has unknown arg {k} fr {cls}")
125
+ args = {}
126
+ for field in fields:
127
+ if not field.init:
128
+ continue
129
+
130
+ field_type = get_maybe_optional_type(field.type)
131
+ if hasattr(field_type, "__dataclass_fields__"):
132
+ if config.get(field.name) is None:
133
+ args[field.name] = None
134
+ elif hasattr(field_type, "from_dict"):
135
+ src = config[field.name]
136
+ if isinstance(src, ConfigDict):
137
+ src = src.to_dict()
138
+ args[field.name] = field_type.from_dict(src)
139
+ else:
140
+ args[field.name] = dataclass_from_config(field_type, config[field.name])
141
+ elif field.name in config:
142
+ if isinstance(config[field.name], ConfigDict):
143
+ args[field.name] = config[field.name].to_dict()
144
+ else:
145
+ args[field.name] = config[field.name]
146
+ return cls(**args)
147
+
148
+
149
+ def update_dataclass(obj, updates):
150
+ """Sets attributes in `obj` to match non-None fields in `updates`"""
151
+ fields = dataclasses.fields(obj)
152
+ for field in fields:
153
+ if not field.init:
154
+ continue
155
+ update = updates.get(field.name)
156
+ if update is None:
157
+ continue
158
+ current_value = getattr(obj, field.name)
159
+ if dataclasses.is_dataclass(current_value):
160
+ update_dataclass(current_value, update)
161
+ else:
162
+ if isinstance(update, (ConfigDict, dict)):
163
+ assert all(x is None for x in flatten_dict(update).values())
164
+ else:
165
+ setattr(obj, field.name, update)
166
+
167
+
168
+ def log_metrics_to_console(prefix: str, metrics: Dict[str, float]):
169
+ # Stolen from the OLMo codebase
170
+ def format_value(value: float) -> str:
171
+ if isinstance(value, str):
172
+ return value
173
+ if value < 0.0001:
174
+ return str(value) # scientific notation
175
+ elif value > 1000:
176
+ return f"{int(value):,d}"
177
+ elif value > 100:
178
+ return f"{value:.1f}"
179
+ elif value > 10:
180
+ return f"{value:.2f}"
181
+ elif value > 1:
182
+ return f"{value:.3f}"
183
+ else:
184
+ return f"{value:.4f}"
185
+
186
+ logging.info(
187
+ f"{prefix}\n"
188
+ + "\n".join(
189
+ [
190
+ f" {name}={format_value(value)}"
191
+ for name, value in metrics.items()
192
+ if not name.startswith("optim/") # there's too many optimizer metrics
193
+ ]
194
+ )
195
+ )