Muennighoff
commited on
Commit
·
18652d8
1
Parent(s):
d13896f
Cp over files
Browse files- beam_search.py +1087 -0
- config_molmoe.py +9 -5
- constants.py +571 -0
- data_factory.py +222 -0
- data_utils.py +827 -0
- dataset_sizes.py +262 -0
- exceptions.py +50 -0
- iterable_dataset.py +266 -0
- modeling_molmoe.py +4 -4
- multimodal_preprocessor.py +1549 -0
- preprocesssors.py +2472 -0
- prompts.py +385 -0
- seqio_tokenizer.py +659 -0
- tasks.py +2548 -0
- torch_util.py +183 -0
- util.py +1 -1
- utils.py +195 -0
beam_search.py
ADDED
@@ -0,0 +1,1087 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a self-contained and flexible beam search implementation adapted from
|
3 |
+
AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import warnings
|
8 |
+
from abc import abstractmethod
|
9 |
+
from inspect import signature
|
10 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
"Sampler",
|
16 |
+
"DeterministicSampler",
|
17 |
+
"MultinomialSampler",
|
18 |
+
"TopKSampler",
|
19 |
+
"TopPSampler",
|
20 |
+
"GumbelSampler",
|
21 |
+
"FinalSequenceScorer",
|
22 |
+
"SequenceLogProbabilityScorer",
|
23 |
+
"LengthNormalizedSequenceLogProbabilityScorer",
|
24 |
+
"Constraint",
|
25 |
+
"RepeatedNGramBlockingConstraint",
|
26 |
+
"BeamSearch",
|
27 |
+
]
|
28 |
+
|
29 |
+
StateType = Dict[str, torch.Tensor]
|
30 |
+
StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
|
31 |
+
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
|
32 |
+
|
33 |
+
StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
|
34 |
+
"""
|
35 |
+
The type of step function that can be passed to [`BeamSearch.search`](#search).
|
36 |
+
|
37 |
+
This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
|
38 |
+
or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
|
39 |
+
"""
|
40 |
+
|
41 |
+
ConstraintStateType = List[List[Dict[str, Any]]]
|
42 |
+
|
43 |
+
|
44 |
+
class Sampler:
|
45 |
+
"""
|
46 |
+
An abstract class that can be used to sample candidates (either nodes or beams)
|
47 |
+
within `BeamSearch`.
|
48 |
+
|
49 |
+
A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
|
50 |
+
|
51 |
+
`init_state()` takes three arguments:
|
52 |
+
|
53 |
+
- a tensor of starting log probs with shape `(batch_size,, num_classes)`,
|
54 |
+
- the batch size, an int,
|
55 |
+
- and the number of classes, also an int.
|
56 |
+
|
57 |
+
It returns a state dictionary with any state tensors needed for subsequent
|
58 |
+
calls to `sample_nodes()` and `sample_beams()`.
|
59 |
+
|
60 |
+
By default this method just returns an empty dictionary.
|
61 |
+
|
62 |
+
Both `sample_nodes()` and `sample_beams()` should take three arguments:
|
63 |
+
|
64 |
+
- tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
|
65 |
+
- an integer representing the number of samples to take for each example in the batch,
|
66 |
+
- and a state dictionary which could contain any tensors needed for the `Sampler` to keep
|
67 |
+
track of state.
|
68 |
+
|
69 |
+
For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
|
70 |
+
`num_examples = beam_size * per_node_beam_size`.
|
71 |
+
|
72 |
+
The return value should be a tuple containing:
|
73 |
+
|
74 |
+
- a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
|
75 |
+
- a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
|
76 |
+
- and the updated state dictionary.
|
77 |
+
|
78 |
+
A default implementation of `sample_beams` is provided, which just deterministically
|
79 |
+
picks the `k` examples with highest log probability.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def init_state(
|
83 |
+
self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
|
84 |
+
) -> StateType:
|
85 |
+
del start_class_log_probabilities, batch_size, num_classes
|
86 |
+
return {}
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def sample_nodes(
|
90 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
91 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
def sample_beams(
|
95 |
+
self, log_probs: torch.Tensor, beam_size: int, state: StateType
|
96 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
97 |
+
del state
|
98 |
+
selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
|
99 |
+
return selected_log_probs, selected_indices, {}
|
100 |
+
|
101 |
+
|
102 |
+
class DeterministicSampler(Sampler):
|
103 |
+
"""
|
104 |
+
A `Sampler` that just deterministically returns the `k` nodes or beams with highest
|
105 |
+
log probability.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def sample_nodes(
|
109 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
110 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
111 |
+
del state
|
112 |
+
selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
|
113 |
+
return selected_log_probs, selected_indices, {}
|
114 |
+
|
115 |
+
|
116 |
+
class MultinomialSampler(Sampler):
|
117 |
+
"""
|
118 |
+
A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
|
119 |
+
in the default, non-deterministic way.
|
120 |
+
|
121 |
+
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
122 |
+
above 1.0 produces a flatter probability distribution.
|
123 |
+
:param with_replacement: Whether to sample with replacement.
|
124 |
+
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
temperature: float = 1.0,
|
130 |
+
with_replacement: bool = False,
|
131 |
+
) -> None:
|
132 |
+
self.temperature = temperature
|
133 |
+
self.with_replacement = with_replacement
|
134 |
+
|
135 |
+
def sample_nodes(
|
136 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
138 |
+
if self.temperature != 1.0:
|
139 |
+
_probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
|
140 |
+
else:
|
141 |
+
_probabilities = log_probs.exp()
|
142 |
+
|
143 |
+
selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
|
144 |
+
|
145 |
+
return torch.gather(log_probs, 1, selected_indices), selected_indices, state
|
146 |
+
|
147 |
+
|
148 |
+
class TopKSampler(Sampler):
|
149 |
+
"""
|
150 |
+
A `Sampler` which redistributes the probability mass function for nodes among the
|
151 |
+
top `k` choices, then samples from that subset after re-normalizing the probabilities.
|
152 |
+
|
153 |
+
Beams are sampled in the default, deterministic way.
|
154 |
+
|
155 |
+
:param k: The number of top choices to be selected from.
|
156 |
+
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
157 |
+
above 1.0 produces a flatter probability distribution.
|
158 |
+
:param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
k: int = 1,
|
164 |
+
temperature: float = 1.0,
|
165 |
+
with_replacement: bool = False,
|
166 |
+
):
|
167 |
+
self.k = k
|
168 |
+
self.temperature = temperature or 1.0
|
169 |
+
self.with_replacement = with_replacement
|
170 |
+
|
171 |
+
def sample_nodes(
|
172 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
173 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
174 |
+
if not per_node_beam_size <= self.k <= log_probs.size()[1]:
|
175 |
+
raise ValueError(
|
176 |
+
"k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
|
177 |
+
)
|
178 |
+
|
179 |
+
# shape (both): (batch_size, k)
|
180 |
+
top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
|
181 |
+
|
182 |
+
# Apply temperature if necessary.
|
183 |
+
# shape: (batch_size, k)
|
184 |
+
if self.temperature != 1.0:
|
185 |
+
top_k_log_probs = top_k_log_probs / self.temperature
|
186 |
+
|
187 |
+
# Re-normalize the subset.
|
188 |
+
# shape: (batch_size, k)
|
189 |
+
normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
|
190 |
+
|
191 |
+
# Sample from the re-normalized subset.
|
192 |
+
# NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
|
193 |
+
# shape: (batch_size, per_node_beam_size)
|
194 |
+
sampled_indices = torch.multinomial(
|
195 |
+
normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
|
196 |
+
)
|
197 |
+
|
198 |
+
# Convert `sampled_indices` back to indices in the original `log_probs` tensor.
|
199 |
+
# shape: (batch_size, per_node_beam_size)
|
200 |
+
indices = top_k_indices.gather(-1, sampled_indices)
|
201 |
+
|
202 |
+
return log_probs.gather(1, indices), indices, state
|
203 |
+
|
204 |
+
|
205 |
+
class TopPSampler(Sampler):
|
206 |
+
"""
|
207 |
+
A `Sampler` which redistributes the probability mass function for nodes among
|
208 |
+
the top choices with a cumulative probability of at least `p`, then samples from that subset
|
209 |
+
after re-normalizing the probabilities.
|
210 |
+
|
211 |
+
Beams are sampled in the default, deterministic way.
|
212 |
+
|
213 |
+
:param p:
|
214 |
+
The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
|
215 |
+
examples to sample from. If `with_replacement` is `False` and the number of possible samples is
|
216 |
+
insufficient to sample without replacement from when calling `sample_nodes`, then the top
|
217 |
+
`per_node_beam_size` examples will be chosen.
|
218 |
+
:param temperature:
|
219 |
+
A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
220 |
+
above 1.0 produces a flatter probability distribution.
|
221 |
+
:param with_replacement:
|
222 |
+
If set to `True`, samples will be selected with replacement from the top choices.
|
223 |
+
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
p: float = 0.9,
|
229 |
+
temperature: float = 1.0,
|
230 |
+
with_replacement: bool = False,
|
231 |
+
):
|
232 |
+
if p < 0.0 or p > 1.0:
|
233 |
+
raise ValueError("p must be a positive float no greater than 1.0")
|
234 |
+
self.p = p
|
235 |
+
self.temperature = temperature or 1.0
|
236 |
+
self.with_replacement = with_replacement
|
237 |
+
|
238 |
+
def sample_nodes(
|
239 |
+
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
|
240 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
241 |
+
if not per_node_beam_size <= log_probs.size()[1]:
|
242 |
+
raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
|
243 |
+
|
244 |
+
# First apply temperature coefficient:
|
245 |
+
if self.temperature != 1.0:
|
246 |
+
_log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
|
247 |
+
else:
|
248 |
+
_log_probs = log_probs
|
249 |
+
|
250 |
+
# Sort the probabilities in descending order to then find cumulative sum
|
251 |
+
log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
|
252 |
+
|
253 |
+
# shape: (batch_size, num_classes)
|
254 |
+
probabilities_descending = log_probs_descending.exp()
|
255 |
+
probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
|
256 |
+
|
257 |
+
# Create a mask for filtering out probabilities that don't make the top `p`.
|
258 |
+
# shape: (batch_size, num_classes)
|
259 |
+
exclusion_mask = probabilities_summed >= self.p
|
260 |
+
|
261 |
+
# We want to include the first index where probabilities_summed >= p, so we shift over one.
|
262 |
+
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
|
263 |
+
exclusion_mask[..., 0] = False
|
264 |
+
|
265 |
+
# Make sure there's at least `per_node_beam_size` options to be selected.
|
266 |
+
if not self.with_replacement:
|
267 |
+
exclusion_mask[..., :per_node_beam_size] = False
|
268 |
+
|
269 |
+
log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
|
270 |
+
|
271 |
+
# Now re-normalized the included log probs.
|
272 |
+
# shape: (batch_size, num_classes)
|
273 |
+
filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
|
274 |
+
|
275 |
+
# Sample from the re-normalized subset.
|
276 |
+
# NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
|
277 |
+
# shape: (batch_size, per_node_beam_size)
|
278 |
+
sampled_indices = torch.multinomial(
|
279 |
+
filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
|
280 |
+
)
|
281 |
+
|
282 |
+
# Convert `sampled_indices` back to indices in the original `log_probs` tensor.
|
283 |
+
# shape: (batch_size, per_node_beam_size)
|
284 |
+
selected_indices = sorting_indices.gather(-1, sampled_indices)
|
285 |
+
|
286 |
+
# Return (selected log probabilities, selected classes)
|
287 |
+
# shape: (len(log_probs),1) , (len(log_probs), 1)
|
288 |
+
return torch.gather(log_probs, 1, selected_indices), selected_indices, state
|
289 |
+
|
290 |
+
|
291 |
+
class GumbelSampler(Sampler):
|
292 |
+
"""
|
293 |
+
A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
|
294 |
+
[*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
|
295 |
+
Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
|
296 |
+
(https://api.semanticscholar.org/CorpusID:76662039).
|
297 |
+
|
298 |
+
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
|
299 |
+
above 1.0 produces a flatter probability distribution.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, temperature: float = 1.0):
|
303 |
+
self.temperature = temperature
|
304 |
+
|
305 |
+
def init_state(
|
306 |
+
self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
|
307 |
+
) -> StateType:
|
308 |
+
# shape: (batch_size, num_classes)
|
309 |
+
zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
|
310 |
+
|
311 |
+
# shape: (batch_size, num_classes)
|
312 |
+
G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
|
313 |
+
|
314 |
+
return {"G_phi_S": G_phi_S}
|
315 |
+
|
316 |
+
def sample_nodes(
|
317 |
+
self,
|
318 |
+
log_probs: torch.Tensor,
|
319 |
+
per_node_beam_size: int,
|
320 |
+
state: StateType,
|
321 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
322 |
+
# First apply temperature coefficient:
|
323 |
+
# shape: (batch_size * beam_size, num_classes)
|
324 |
+
if self.temperature != 1.0:
|
325 |
+
_log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
|
326 |
+
else:
|
327 |
+
_log_probs = log_probs
|
328 |
+
|
329 |
+
# shape: (group_size,)
|
330 |
+
phi_S = state["phi_S"]
|
331 |
+
|
332 |
+
# shape: (group_size, num_classes)
|
333 |
+
phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
|
334 |
+
|
335 |
+
# shape: (group_size, num_classes)
|
336 |
+
phi_S_new = phi_S + _log_probs
|
337 |
+
|
338 |
+
# shape: (group_size, 1)
|
339 |
+
G_phi_S = state["G_phi_S"].unsqueeze(-1)
|
340 |
+
|
341 |
+
# shape: (group_size, num_classes)
|
342 |
+
G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
|
343 |
+
|
344 |
+
# Replace NaNs with very negative number.
|
345 |
+
# shape: (group_size, num_classes)
|
346 |
+
# G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
|
347 |
+
|
348 |
+
# shape (both): (group_size, per_node_beam_size)
|
349 |
+
top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
|
350 |
+
|
351 |
+
# shape: (group_size, per_node_beam_size)
|
352 |
+
top_log_probs = log_probs.gather(1, top_indices)
|
353 |
+
|
354 |
+
return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
|
355 |
+
|
356 |
+
def sample_beams(
|
357 |
+
self,
|
358 |
+
log_probs: torch.Tensor,
|
359 |
+
beam_size: int,
|
360 |
+
state: StateType,
|
361 |
+
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
|
362 |
+
"""
|
363 |
+
Returns the beams with the highest perturbed log probabilities.
|
364 |
+
"""
|
365 |
+
# shape (log_probs): (batch_size, beam_size * per_node_beam_size)
|
366 |
+
|
367 |
+
batch_size = log_probs.size()[0]
|
368 |
+
|
369 |
+
# shape: (batch_size * beam_size, per_node_beam_size)
|
370 |
+
G_phi_S = state["G_phi_S"]
|
371 |
+
|
372 |
+
# shape: (batch_size, beam_size * per_node_beam_size)
|
373 |
+
G_phi_S = G_phi_S.reshape_as(log_probs)
|
374 |
+
|
375 |
+
# shape (both): (batch_size, beam_size)
|
376 |
+
G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
|
377 |
+
|
378 |
+
# shape: (batch_size, beam_size)
|
379 |
+
selected_log_probs = log_probs.gather(1, selected_indices)
|
380 |
+
|
381 |
+
# Now sort the selected beams by their true log prob.
|
382 |
+
# shape (all): (batch_size, beam_size)
|
383 |
+
selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
|
384 |
+
selected_indices = selected_indices.gather(1, sort_indices)
|
385 |
+
G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
|
386 |
+
|
387 |
+
# shape: (batch_size * beam_size,)
|
388 |
+
G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
|
389 |
+
|
390 |
+
# shape: (batch_size * beam_size,)
|
391 |
+
phi_S = selected_log_probs.reshape(batch_size * beam_size)
|
392 |
+
|
393 |
+
return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
|
394 |
+
|
395 |
+
def gumbel(self, phi) -> torch.Tensor:
|
396 |
+
"""
|
397 |
+
Sample `Gumbel(phi)`.
|
398 |
+
|
399 |
+
`phi` should have shape `(batch_size, num_classes)`.
|
400 |
+
"""
|
401 |
+
return -torch.log(-torch.log(torch.rand_like(phi))) + phi
|
402 |
+
|
403 |
+
def gumbel_with_max(self, phi, T) -> torch.Tensor:
|
404 |
+
"""
|
405 |
+
Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
|
406 |
+
|
407 |
+
`phi` should have shape `(batch_size, num_classes)` and `T` should have
|
408 |
+
shape `(batch_size, 1)`.
|
409 |
+
"""
|
410 |
+
# Shape: (batch_size, num_classes)
|
411 |
+
G_phi = self.gumbel(phi)
|
412 |
+
|
413 |
+
# Now we find the maximum from these samples.
|
414 |
+
# Shape: (batch_size, )
|
415 |
+
Z, _ = G_phi.max(dim=-1)
|
416 |
+
|
417 |
+
# Shape: (batch_size, num_classes)
|
418 |
+
v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
|
419 |
+
|
420 |
+
# Shape: (batch_size, num_classes)
|
421 |
+
return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
|
422 |
+
|
423 |
+
|
424 |
+
class FinalSequenceScorer:
|
425 |
+
"""
|
426 |
+
An abstract class that can be used to score the final generated sequences found
|
427 |
+
by beam search. Given the predicted sequences and the corresponding log probabilities of
|
428 |
+
those sequences, the class calculates and returns the final score of the sequences.
|
429 |
+
|
430 |
+
The default implementation scores the sequences using the sum of the log probabilities of
|
431 |
+
the sequence, which is passed as input.
|
432 |
+
"""
|
433 |
+
|
434 |
+
@abstractmethod
|
435 |
+
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
|
436 |
+
"""
|
437 |
+
Score the final predictions found by beam search.
|
438 |
+
Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
|
439 |
+
|
440 |
+
:param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
|
441 |
+
:param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
|
442 |
+
of the log probabilities per token, with shape `(batch_size, beam_size)`.
|
443 |
+
:param end_index: The index of the end symbol.
|
444 |
+
|
445 |
+
"""
|
446 |
+
raise NotImplementedError
|
447 |
+
|
448 |
+
|
449 |
+
class SequenceLogProbabilityScorer(FinalSequenceScorer):
|
450 |
+
"""
|
451 |
+
A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
|
452 |
+
across the sequence's tokens.
|
453 |
+
"""
|
454 |
+
|
455 |
+
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
|
456 |
+
del predictions, end_index
|
457 |
+
# The sum of the sequence log probabilities is the input parameter, so just
|
458 |
+
# return it.
|
459 |
+
return log_probabilities
|
460 |
+
|
461 |
+
|
462 |
+
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
|
463 |
+
"""
|
464 |
+
A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
|
465 |
+
tokens in the sequence. It optionally includes a length penalty which promotes
|
466 |
+
or demotes sequences based on their lengths. The final score for a sequence will
|
467 |
+
be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
|
468 |
+
here includes the end token.
|
469 |
+
|
470 |
+
:param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
|
471 |
+
A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
|
472 |
+
"""
|
473 |
+
|
474 |
+
def __init__(self, length_penalty: float = 1.0):
|
475 |
+
super().__init__()
|
476 |
+
self.length_penalty = length_penalty
|
477 |
+
|
478 |
+
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
|
479 |
+
# shape: (batch_size, beam_size)
|
480 |
+
lengths = (predictions != end_index).long().sum(dim=2)
|
481 |
+
|
482 |
+
# If the sequence ended during beam search, the `log_probabilities` will include
|
483 |
+
# the transition to the end token. Therefore, in such situations, `lengths` is
|
484 |
+
# actually off by 1. This corrects for that.
|
485 |
+
# shape: (batch_size, beam_size)
|
486 |
+
is_end_token = predictions[:, :, -1] == end_index
|
487 |
+
lengths += is_end_token.long()
|
488 |
+
|
489 |
+
# shape: (batch_size, beam_size)
|
490 |
+
average_log_probs = log_probabilities / (lengths**self.length_penalty)
|
491 |
+
return average_log_probs
|
492 |
+
|
493 |
+
|
494 |
+
class Constraint:
|
495 |
+
"""
|
496 |
+
An abstract class that can be used to enforce constraints on the output predictions
|
497 |
+
by manipulating the class log probabilities during beam search.
|
498 |
+
|
499 |
+
A `Constraint` just has three methods that need to be implemented by subclasses:
|
500 |
+
`init_state()`, `apply()` and `_update_state()`.
|
501 |
+
|
502 |
+
`init_state()` takes one argument:
|
503 |
+
|
504 |
+
- the batch size, an int
|
505 |
+
|
506 |
+
It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
|
507 |
+
calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
|
508 |
+
Each inner list should be of length 1.
|
509 |
+
|
510 |
+
`apply()` takes two arguments:
|
511 |
+
|
512 |
+
- the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
|
513 |
+
and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
|
514 |
+
- `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
|
515 |
+
log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
|
516 |
+
|
517 |
+
The `apply()` method should return new `class_log_probabilities` that enforce the constraint
|
518 |
+
for this step of beam search. For instance, it may prevent a specific class from being selected by setting
|
519 |
+
the corresponding log probability to a negligible value such as `float("-inf")` or
|
520 |
+
`torch.finfo(class_log_probabilities.dtype).min`.
|
521 |
+
|
522 |
+
`_update_state()` takes two arguments:
|
523 |
+
|
524 |
+
- the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
|
525 |
+
copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
|
526 |
+
directly edited in-place without affecting the others.
|
527 |
+
- last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
|
528 |
+
step of beam search.
|
529 |
+
|
530 |
+
The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
|
531 |
+
length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
|
532 |
+
|
533 |
+
"""
|
534 |
+
|
535 |
+
@abstractmethod
|
536 |
+
def init_state(
|
537 |
+
self,
|
538 |
+
batch_size: int,
|
539 |
+
) -> ConstraintStateType:
|
540 |
+
raise NotImplementedError
|
541 |
+
|
542 |
+
@abstractmethod
|
543 |
+
def apply(
|
544 |
+
self,
|
545 |
+
state: ConstraintStateType,
|
546 |
+
class_log_probabilities: torch.Tensor,
|
547 |
+
) -> torch.Tensor:
|
548 |
+
raise NotImplementedError
|
549 |
+
|
550 |
+
@staticmethod
|
551 |
+
def _copy_state(
|
552 |
+
state: ConstraintStateType,
|
553 |
+
batch_size: int,
|
554 |
+
beam_size: int,
|
555 |
+
last_backpointer: Optional[torch.Tensor] = None,
|
556 |
+
) -> ConstraintStateType:
|
557 |
+
"""
|
558 |
+
Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
|
559 |
+
is not appropriate for your constraint, you will need to implement the copying yourself.
|
560 |
+
"""
|
561 |
+
new_state = []
|
562 |
+
for i in range(batch_size):
|
563 |
+
batch_state = []
|
564 |
+
for j in range(beam_size):
|
565 |
+
if last_backpointer is None:
|
566 |
+
# This is the first prediction, so the backpointer is 0
|
567 |
+
backpointer = 0
|
568 |
+
else:
|
569 |
+
backpointer = last_backpointer[i, j].item()
|
570 |
+
batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
|
571 |
+
new_state.append(batch_state)
|
572 |
+
return new_state
|
573 |
+
|
574 |
+
def update_state(
|
575 |
+
self,
|
576 |
+
state: ConstraintStateType,
|
577 |
+
last_prediction: torch.Tensor,
|
578 |
+
last_backpointer: Optional[torch.Tensor] = None,
|
579 |
+
) -> ConstraintStateType:
|
580 |
+
batch_size, beam_size = last_prediction.size()
|
581 |
+
new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
|
582 |
+
return self._update_state(new_state, last_prediction)
|
583 |
+
|
584 |
+
@abstractmethod
|
585 |
+
def _update_state(
|
586 |
+
self,
|
587 |
+
state: ConstraintStateType,
|
588 |
+
last_prediction: torch.Tensor,
|
589 |
+
) -> ConstraintStateType:
|
590 |
+
raise NotImplementedError
|
591 |
+
|
592 |
+
|
593 |
+
class RepeatedNGramBlockingConstraint(Constraint):
|
594 |
+
def __init__(self, ngram_size: int, **kwargs) -> None:
|
595 |
+
super().__init__(**kwargs)
|
596 |
+
self.ngram_size = ngram_size
|
597 |
+
|
598 |
+
def init_state(
|
599 |
+
self,
|
600 |
+
batch_size: int,
|
601 |
+
) -> ConstraintStateType:
|
602 |
+
return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
|
603 |
+
|
604 |
+
def apply(
|
605 |
+
self,
|
606 |
+
state: ConstraintStateType,
|
607 |
+
class_log_probabilities: torch.Tensor,
|
608 |
+
) -> torch.Tensor:
|
609 |
+
for i, batch in enumerate(state):
|
610 |
+
for j, beam in enumerate(batch):
|
611 |
+
current_prefix = tuple(beam["current_prefix"])
|
612 |
+
seen_ngrams = beam["seen_ngrams"]
|
613 |
+
try:
|
614 |
+
disallowed_indices = seen_ngrams[current_prefix]
|
615 |
+
class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
|
616 |
+
class_log_probabilities.dtype
|
617 |
+
).min
|
618 |
+
except KeyError:
|
619 |
+
# We have not seen this prefix before, so there is no index
|
620 |
+
# that needs to be blocked
|
621 |
+
pass
|
622 |
+
return class_log_probabilities
|
623 |
+
|
624 |
+
def _update_state(
|
625 |
+
self,
|
626 |
+
state: ConstraintStateType,
|
627 |
+
last_prediction: torch.Tensor,
|
628 |
+
) -> ConstraintStateType:
|
629 |
+
for i, batch in enumerate(state):
|
630 |
+
for j, beam in enumerate(batch):
|
631 |
+
prediction = last_prediction[i, j].item()
|
632 |
+
prefix = beam["current_prefix"]
|
633 |
+
seen_ngrams = beam["seen_ngrams"]
|
634 |
+
|
635 |
+
if len(prefix) == self.ngram_size - 1:
|
636 |
+
# This is a new ngram that we have to remember
|
637 |
+
if tuple(prefix) not in seen_ngrams:
|
638 |
+
seen_ngrams[tuple(prefix)] = []
|
639 |
+
seen_ngrams[tuple(prefix)].append(prediction)
|
640 |
+
|
641 |
+
# Create the new prefix, removing the oldest index if the prefix
|
642 |
+
# is too long
|
643 |
+
prefix.append(prediction)
|
644 |
+
if len(prefix) == self.ngram_size:
|
645 |
+
prefix.pop(0)
|
646 |
+
return state
|
647 |
+
|
648 |
+
|
649 |
+
class BeamSearch:
|
650 |
+
"""
|
651 |
+
Implements the beam search algorithm for decoding the most likely sequences.
|
652 |
+
|
653 |
+
:param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
|
654 |
+
|
655 |
+
:param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
|
656 |
+
of the predicted sequences.
|
657 |
+
|
658 |
+
:param beam_size: The width of the beam used.
|
659 |
+
|
660 |
+
:param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
|
661 |
+
If not given, this just defaults to `beam_size`. Setting this parameter
|
662 |
+
to a number smaller than `beam_size` may give better results, as it can introduce
|
663 |
+
more diversity into the search. See
|
664 |
+
[*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
|
665 |
+
(https://api.semanticscholar.org/CorpusID:2229477).
|
666 |
+
|
667 |
+
:param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
|
668 |
+
If not specified, `DeterministicSampler` will be used, which just takes the
|
669 |
+
`per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
|
670 |
+
|
671 |
+
Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
|
672 |
+
[Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
|
673 |
+
|
674 |
+
:param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
|
675 |
+
the predicted sequences. This does not include the start or end tokens. If `None`,
|
676 |
+
no minimum is enforced.
|
677 |
+
|
678 |
+
:param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
|
679 |
+
The output from this module is what is returned by the `search` method. If not
|
680 |
+
specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
|
681 |
+
by the sum of the token log probabilities.
|
682 |
+
|
683 |
+
:param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
|
684 |
+
provided, no constraints will be enforced.
|
685 |
+
|
686 |
+
"""
|
687 |
+
|
688 |
+
def __init__(
|
689 |
+
self,
|
690 |
+
end_index: int,
|
691 |
+
*,
|
692 |
+
max_steps: int = 50,
|
693 |
+
beam_size: int = 10,
|
694 |
+
per_node_beam_size: Optional[int] = None,
|
695 |
+
sampler: Optional[Sampler] = None,
|
696 |
+
min_steps: Optional[int] = None,
|
697 |
+
final_sequence_scorer: Optional[FinalSequenceScorer] = None,
|
698 |
+
constraints: Optional[List[Constraint]] = None,
|
699 |
+
distributed_model: bool = False
|
700 |
+
) -> None:
|
701 |
+
if not max_steps > 0:
|
702 |
+
raise ValueError("max_steps must be positive")
|
703 |
+
if not beam_size > 0:
|
704 |
+
raise ValueError("beam_size must be positive")
|
705 |
+
if per_node_beam_size is not None and not per_node_beam_size > 0:
|
706 |
+
raise ValueError("per_node_beam_size must be positive")
|
707 |
+
if min_steps is not None:
|
708 |
+
if not min_steps >= 0:
|
709 |
+
raise ValueError("min_steps must be non-negative")
|
710 |
+
if not min_steps <= max_steps:
|
711 |
+
raise ValueError("min_steps must be less than or equal to max_steps")
|
712 |
+
|
713 |
+
self._end_index = end_index
|
714 |
+
self.max_steps = max_steps
|
715 |
+
self.beam_size = beam_size
|
716 |
+
self.per_node_beam_size = per_node_beam_size or beam_size
|
717 |
+
self.sampler = sampler or DeterministicSampler()
|
718 |
+
self.min_steps = min_steps or 0
|
719 |
+
self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
|
720 |
+
self.constraints = constraints or []
|
721 |
+
self.distributed_model = distributed_model
|
722 |
+
|
723 |
+
@staticmethod
|
724 |
+
def _reconstruct_sequences(predictions, backpointers):
|
725 |
+
# Reconstruct the sequences.
|
726 |
+
# shape: [(batch_size, beam_size, 1)]
|
727 |
+
reconstructed_predictions = [predictions[-1].unsqueeze(2)]
|
728 |
+
|
729 |
+
if not backpointers:
|
730 |
+
return reconstructed_predictions
|
731 |
+
|
732 |
+
# shape: (batch_size, beam_size)
|
733 |
+
cur_backpointers = backpointers[-1]
|
734 |
+
|
735 |
+
for timestep in range(len(predictions) - 2, 0, -1):
|
736 |
+
# shape: (batch_size, beam_size, 1)
|
737 |
+
cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
|
738 |
+
|
739 |
+
reconstructed_predictions.append(cur_preds)
|
740 |
+
|
741 |
+
# shape: (batch_size, beam_size)
|
742 |
+
cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
|
743 |
+
|
744 |
+
# shape: (batch_size, beam_size, 1)
|
745 |
+
final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
|
746 |
+
|
747 |
+
reconstructed_predictions.append(final_preds)
|
748 |
+
|
749 |
+
return reconstructed_predictions
|
750 |
+
|
751 |
+
def search(
|
752 |
+
self,
|
753 |
+
start_predictions: torch.Tensor,
|
754 |
+
start_state: StateType,
|
755 |
+
step: StepFunctionType,
|
756 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
757 |
+
"""
|
758 |
+
Given a starting state and a step function, apply beam search to find the
|
759 |
+
most likely target sequences.
|
760 |
+
|
761 |
+
Returns a tuple of `(predictions, final_scores)`, where `predictions`
|
762 |
+
has shape `(batch_size, beam_size, max_steps)` and `final_scores`
|
763 |
+
has shape `(batch_size, beam_size)`.
|
764 |
+
|
765 |
+
.. note::
|
766 |
+
If your step function returns `-inf` for some log probabilities
|
767 |
+
(like if you're using a masked log-softmax) then some of the "best"
|
768 |
+
sequences returned may also have `-inf` log probability. Specifically
|
769 |
+
this happens when the beam size is smaller than the number of actions
|
770 |
+
with finite log probability (non-zero probability) returned by the step function.
|
771 |
+
Therefore if you're using a mask you may want to check the results from `search`
|
772 |
+
and potentially discard sequences with non-finite log probability.
|
773 |
+
|
774 |
+
:param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
|
775 |
+
Usually the initial predictions are just the index of the "start" token
|
776 |
+
in the target vocabulary.
|
777 |
+
|
778 |
+
:param start_state: The initial state passed to the `step` function. Each value of the state dict
|
779 |
+
should be a tensor of shape `(batch_size, *)`, where `*` means any other
|
780 |
+
number of dimensions.
|
781 |
+
|
782 |
+
:param step: A function that is responsible for computing the next most likely tokens,
|
783 |
+
given the current state and the predictions from the last time step.
|
784 |
+
The function should accept two or three arguments:
|
785 |
+
|
786 |
+
- a tensor of shape `(group_size,)` or representing the index of the predicted
|
787 |
+
tokens from the last time step,
|
788 |
+
- the current state, a `StateType`, and
|
789 |
+
- optionally, the timestep, an `int`.
|
790 |
+
|
791 |
+
The `group_size` will be `batch_size * beam_size`, except in the initial
|
792 |
+
step, for which it will just be `batch_size`.
|
793 |
+
|
794 |
+
The function is expected to return a tuple, where the first element
|
795 |
+
is a tensor of shape `(group_size, vocab_size)` containing
|
796 |
+
the log probabilities of the tokens for the next step, and the second
|
797 |
+
element is the updated state. The tensor in the state should have shape
|
798 |
+
`(group_size, *)`, where `*` means any other number of dimensions.
|
799 |
+
|
800 |
+
"""
|
801 |
+
step_signature = signature(step)
|
802 |
+
if len(step_signature.parameters) < 3:
|
803 |
+
# If the step function we're given does not take the time step argument, wrap it
|
804 |
+
# in one that does.
|
805 |
+
old_step = cast(StepFunctionTypeNoTimestep, step)
|
806 |
+
|
807 |
+
def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
|
808 |
+
del time_step
|
809 |
+
return old_step(last_predictions, state)
|
810 |
+
|
811 |
+
return self._search(start_predictions, start_state, new_step)
|
812 |
+
else:
|
813 |
+
return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
|
814 |
+
|
815 |
+
def _search(
|
816 |
+
self,
|
817 |
+
start_predictions: torch.Tensor,
|
818 |
+
start_state: StateType,
|
819 |
+
step: StepFunctionTypeWithTimestep,
|
820 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
821 |
+
batch_size = start_predictions.size()[0]
|
822 |
+
|
823 |
+
# List of (batch_size, beam_size) tensors. One for each time step. Does not
|
824 |
+
# include the start symbols, which are implicit.
|
825 |
+
predictions: List[torch.Tensor] = []
|
826 |
+
|
827 |
+
# List of (batch_size, beam_size) tensors. One for each time step. None for
|
828 |
+
# the first. Stores the index n for the parent prediction, i.e.
|
829 |
+
# predictions[t-1][i][n], that it came from.
|
830 |
+
backpointers: List[torch.Tensor] = []
|
831 |
+
|
832 |
+
constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
|
833 |
+
|
834 |
+
# Calculate the first timestep. This is done outside the main loop
|
835 |
+
# because we are going from a single decoder input (the output from the
|
836 |
+
# encoder) to the top `beam_size` decoder outputs. On the other hand,
|
837 |
+
# within the main loop we are going from the `beam_size` elements of the
|
838 |
+
# beam to `beam_size`^2 candidates from which we will select the top
|
839 |
+
# `beam_size` elements for the next iteration.
|
840 |
+
# shape: (batch_size, num_classes)
|
841 |
+
start_class_log_probabilities, state = step(start_predictions, start_state, 0)
|
842 |
+
|
843 |
+
num_classes = start_class_log_probabilities.size()[1]
|
844 |
+
|
845 |
+
# Make sure `per_node_beam_size` is not larger than `num_classes`.
|
846 |
+
if self.per_node_beam_size > num_classes:
|
847 |
+
raise ValueError(
|
848 |
+
f"Vocab size ({num_classes:d}) too small "
|
849 |
+
f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
|
850 |
+
f"Please decrease beam_size or per_node_beam_size."
|
851 |
+
)
|
852 |
+
|
853 |
+
sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
|
854 |
+
|
855 |
+
# Apply all constraints.
|
856 |
+
if self.constraints:
|
857 |
+
# shape: (batch_size, 1, num_classes)
|
858 |
+
expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
|
859 |
+
for constraint, constraint_state in zip(self.constraints, constraint_states):
|
860 |
+
expanded_start_class_log_probabilities = constraint.apply(
|
861 |
+
constraint_state, expanded_start_class_log_probabilities
|
862 |
+
)
|
863 |
+
start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
|
864 |
+
|
865 |
+
# Prevent selecting the end symbol if there is any min_steps constraint
|
866 |
+
if self.min_steps >= 1:
|
867 |
+
start_class_log_probabilities[:, self._end_index] = torch.finfo(
|
868 |
+
start_class_log_probabilities.dtype
|
869 |
+
).min
|
870 |
+
|
871 |
+
# Get the initial predicted classed and their log probabilities.
|
872 |
+
# shape: (batch_size, beam_size), (batch_size, beam_size)
|
873 |
+
(
|
874 |
+
start_top_log_probabilities,
|
875 |
+
start_predicted_classes,
|
876 |
+
sampler_state,
|
877 |
+
) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
|
878 |
+
|
879 |
+
if (
|
880 |
+
self.beam_size == 1 and
|
881 |
+
(start_predicted_classes == self._end_index).all() and
|
882 |
+
not self.distributed_model
|
883 |
+
):
|
884 |
+
warnings.warn(
|
885 |
+
"Empty sequences predicted. You may want to increase the beam size or ensure "
|
886 |
+
"your step function is working properly.",
|
887 |
+
RuntimeWarning,
|
888 |
+
)
|
889 |
+
return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
|
890 |
+
|
891 |
+
# The log probabilities for the last time step.
|
892 |
+
# shape: (batch_size, beam_size)
|
893 |
+
last_log_probabilities = start_top_log_probabilities
|
894 |
+
|
895 |
+
# shape: [(batch_size, beam_size)]
|
896 |
+
predictions.append(start_predicted_classes)
|
897 |
+
|
898 |
+
# Log probability tensor that mandates that the end token is selected.
|
899 |
+
# shape: (batch_size * beam_size, num_classes)
|
900 |
+
log_probs_after_end = start_class_log_probabilities.new_full(
|
901 |
+
(batch_size * self.beam_size, num_classes),
|
902 |
+
torch.finfo(start_class_log_probabilities.dtype).min,
|
903 |
+
)
|
904 |
+
log_probs_after_end[:, self._end_index] = 0.0
|
905 |
+
|
906 |
+
# Set the same state for each element in the beam.
|
907 |
+
self._update_initial_state(state, batch_size)
|
908 |
+
|
909 |
+
for i, constraint in enumerate(self.constraints):
|
910 |
+
constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
|
911 |
+
|
912 |
+
for timestep in range(self.max_steps - 1):
|
913 |
+
# shape: (batch_size * beam_size,)
|
914 |
+
last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
|
915 |
+
|
916 |
+
# If every predicted token from the last step is `self._end_index`,
|
917 |
+
# then we can stop early.
|
918 |
+
# FIXME for distributed model we cannot stop early unless all devices are done,
|
919 |
+
# for now we just always run to the max limit, ideally we should check all devices
|
920 |
+
if not self.distributed_model and (last_predictions == self._end_index).all():
|
921 |
+
# finished
|
922 |
+
break
|
923 |
+
# Take a step. This get the predicted log probs of the next classes
|
924 |
+
# and updates the state.
|
925 |
+
# shape: (batch_size * beam_size, num_classes)
|
926 |
+
class_log_probabilities, state = step(last_predictions, state, timestep + 1)
|
927 |
+
|
928 |
+
# Apply all constraints.
|
929 |
+
if self.constraints:
|
930 |
+
# shape: (batch_size, beam_size, num_classes)
|
931 |
+
reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
|
932 |
+
for constraint, constraint_state in zip(self.constraints, constraint_states):
|
933 |
+
reshaped_class_log_probabilities = constraint.apply(
|
934 |
+
constraint_state, reshaped_class_log_probabilities
|
935 |
+
)
|
936 |
+
# shape: (batch_size * beam_size, num_classes)
|
937 |
+
class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
|
938 |
+
|
939 |
+
# The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
|
940 |
+
# of the sequence (because `timestep` is 0-indexed and we generated the first token
|
941 |
+
# before the for loop). Here we block the end index if the search is not allowed to
|
942 |
+
# terminate on this iteration.
|
943 |
+
if timestep + 2 <= self.min_steps:
|
944 |
+
class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
|
945 |
+
|
946 |
+
# shape: (batch_size * beam_size, num_classes)
|
947 |
+
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
|
948 |
+
batch_size * self.beam_size, num_classes
|
949 |
+
)
|
950 |
+
|
951 |
+
# Here we are finding any beams where we predicted the end token in
|
952 |
+
# the previous timestep and replacing the distribution with a
|
953 |
+
# one-hot distribution, forcing the beam to predict the end token
|
954 |
+
# this timestep as well.
|
955 |
+
# shape: (batch_size * beam_size, num_classes)
|
956 |
+
cleaned_log_probabilities = torch.where(
|
957 |
+
last_predictions_expanded == self._end_index,
|
958 |
+
log_probs_after_end,
|
959 |
+
class_log_probabilities,
|
960 |
+
)
|
961 |
+
|
962 |
+
# shape (both): (batch_size * beam_size, per_node_beam_size)
|
963 |
+
top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
|
964 |
+
cleaned_log_probabilities, self.per_node_beam_size, sampler_state
|
965 |
+
)
|
966 |
+
|
967 |
+
# Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
|
968 |
+
# so that we can add them to the current log probs for this timestep.
|
969 |
+
# This lets us maintain the log probability of each element on the beam.
|
970 |
+
# shape: (batch_size * beam_size, per_node_beam_size)
|
971 |
+
expanded_last_log_probabilities = (
|
972 |
+
last_log_probabilities.unsqueeze(2)
|
973 |
+
.expand(batch_size, self.beam_size, self.per_node_beam_size)
|
974 |
+
.reshape(batch_size * self.beam_size, self.per_node_beam_size)
|
975 |
+
)
|
976 |
+
|
977 |
+
# shape: (batch_size * beam_size, per_node_beam_size)
|
978 |
+
summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
|
979 |
+
|
980 |
+
# shape: (batch_size, beam_size * per_node_beam_size)
|
981 |
+
reshaped_summed = summed_top_log_probabilities.reshape(
|
982 |
+
batch_size, self.beam_size * self.per_node_beam_size
|
983 |
+
)
|
984 |
+
|
985 |
+
# shape: (batch_size, beam_size * per_node_beam_size)
|
986 |
+
reshaped_predicted_classes = predicted_classes.reshape(
|
987 |
+
batch_size, self.beam_size * self.per_node_beam_size
|
988 |
+
)
|
989 |
+
|
990 |
+
# Keep only the top `beam_size` beam indices.
|
991 |
+
# shape (both): (batch_size, beam_size)
|
992 |
+
(
|
993 |
+
restricted_beam_log_probs,
|
994 |
+
restricted_beam_indices,
|
995 |
+
sampler_state,
|
996 |
+
) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
|
997 |
+
|
998 |
+
# Use the beam indices to extract the corresponding classes.
|
999 |
+
# shape: (batch_size, beam_size)
|
1000 |
+
restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
|
1001 |
+
|
1002 |
+
predictions.append(restricted_predicted_classes)
|
1003 |
+
|
1004 |
+
# shape: (batch_size, beam_size)
|
1005 |
+
last_log_probabilities = restricted_beam_log_probs
|
1006 |
+
|
1007 |
+
# The beam indices come from a `beam_size * per_node_beam_size` dimension where the
|
1008 |
+
# indices with a common ancestor are grouped together. Hence
|
1009 |
+
# dividing by per_node_beam_size gives the ancestor. (Note that this is integer
|
1010 |
+
# division as the tensor is a LongTensor.)
|
1011 |
+
# shape: (batch_size, beam_size)
|
1012 |
+
backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
|
1013 |
+
backpointers.append(backpointer)
|
1014 |
+
|
1015 |
+
# Keep only the pieces of the state tensors corresponding to the
|
1016 |
+
# ancestors created this iteration.
|
1017 |
+
self._update_state(state, backpointer)
|
1018 |
+
|
1019 |
+
for i, constraint in enumerate(self.constraints):
|
1020 |
+
constraint_states[i] = constraint.update_state(
|
1021 |
+
constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
# Warn about "-inf" log probabilities if not using any constraints (negligible
|
1025 |
+
# log probabilities are expected when using constraints).
|
1026 |
+
if not self.constraints and (
|
1027 |
+
not torch.isfinite(last_log_probabilities).all()
|
1028 |
+
or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
|
1029 |
+
):
|
1030 |
+
warnings.warn(
|
1031 |
+
"Negligible log probabilities encountered ('-inf' or equivalent). "
|
1032 |
+
"Some final sequences may not make sense. "
|
1033 |
+
"This can happen when the beam size is larger than the number of valid (non-zero "
|
1034 |
+
"probability) transitions that the step function produces.",
|
1035 |
+
RuntimeWarning,
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
|
1039 |
+
|
1040 |
+
# shape: (batch_size, beam_size, max_steps)
|
1041 |
+
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
|
1042 |
+
|
1043 |
+
# Calculate the final sequence scores
|
1044 |
+
# shape: (batch_size, beam_size)
|
1045 |
+
final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
|
1046 |
+
|
1047 |
+
# Sort the sequences based on the final scores so the best scoring
|
1048 |
+
# sequence is at index 0
|
1049 |
+
sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
|
1050 |
+
sorted_all_predictions = torch.gather(
|
1051 |
+
all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
|
1052 |
+
)
|
1053 |
+
|
1054 |
+
return sorted_all_predictions, sorted_final_scores
|
1055 |
+
|
1056 |
+
def _update_initial_state(self, state: StateType, batch_size: int):
|
1057 |
+
"""
|
1058 |
+
Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
|
1059 |
+
"""
|
1060 |
+
for key, state_tensor in state.items():
|
1061 |
+
if state_tensor is None:
|
1062 |
+
continue
|
1063 |
+
# shape: (batch_size * beam_size, *)
|
1064 |
+
_, *last_dims = state_tensor.size()
|
1065 |
+
state[key] = (
|
1066 |
+
state_tensor.unsqueeze(1)
|
1067 |
+
.expand(batch_size, self.beam_size, *last_dims)
|
1068 |
+
.reshape(batch_size * self.beam_size, *last_dims)
|
1069 |
+
)
|
1070 |
+
|
1071 |
+
def _update_state(self, state: StateType, backpointer: torch.Tensor):
|
1072 |
+
batch_size = backpointer.size()[0]
|
1073 |
+
|
1074 |
+
for key, state_tensor in state.items():
|
1075 |
+
if state_tensor is None:
|
1076 |
+
continue
|
1077 |
+
_, *last_dims = state_tensor.size()
|
1078 |
+
# shape: (batch_size, beam_size, *)
|
1079 |
+
expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
|
1080 |
+
batch_size, self.beam_size, *last_dims
|
1081 |
+
)
|
1082 |
+
# shape: (batch_size * beam_size, *)
|
1083 |
+
state[key] = (
|
1084 |
+
state_tensor.reshape(batch_size, self.beam_size, *last_dims)
|
1085 |
+
.gather(1, expanded_backpointer)
|
1086 |
+
.reshape(batch_size * self.beam_size, *last_dims)
|
1087 |
+
)
|
config_molmoe.py
CHANGED
@@ -27,11 +27,15 @@ import gin
|
|
27 |
|
28 |
#from olmo.aliases import PathOrStr
|
29 |
from .aliases import PathOrStr
|
30 |
-
from olmo.exceptions import OLMoConfigurationError
|
31 |
-
from
|
32 |
-
|
33 |
-
from
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
|
36 |
__all__ = [
|
37 |
"ActivationType",
|
|
|
27 |
|
28 |
#from olmo.aliases import PathOrStr
|
29 |
from .aliases import PathOrStr
|
30 |
+
#from olmo.exceptions import OLMoConfigurationError
|
31 |
+
from .exceptions import OLMoConfigurationError
|
32 |
+
#from olmo.util import StrEnum, resource_path
|
33 |
+
from .util import StrEnum, resource_path
|
34 |
+
|
35 |
+
#from olmo.mm_data.data_utils import build_tokenizer
|
36 |
+
from .data_utils import build_tokenizer
|
37 |
+
#from olmo.multimodal_preprocessor import MultiModalPreprocessor
|
38 |
+
from .multimodal_preprocessor import MultiModalPreprocessor
|
39 |
|
40 |
__all__ = [
|
41 |
"ActivationType",
|
constants.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
|
2 |
+
DEFAULT_IM_START_TOKEN = f"<im_start>"
|
3 |
+
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
4 |
+
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
5 |
+
IMAGE_PROMPT = "<|image|>"
|
6 |
+
|
7 |
+
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
8 |
+
|
9 |
+
|
10 |
+
VIT_STANDARD_CONFIGS = {
|
11 |
+
"dinov2-large": {
|
12 |
+
"image_emb_dim": 1024,
|
13 |
+
"image_mlp_dim": 4096,
|
14 |
+
'image_patch_size': 14,
|
15 |
+
'image_pos_patch_size': 14,
|
16 |
+
'image_num_layers': 24,
|
17 |
+
'image_num_heads': 16,
|
18 |
+
'image_num_key_value_heads': 16,
|
19 |
+
'image_head_dim': 64,
|
20 |
+
'image_mlp_activations': 'gelu',
|
21 |
+
'image_default_input_size': (224, 224),
|
22 |
+
'image_num_pos': 257,
|
23 |
+
'image_norm_eps': 1e-6,
|
24 |
+
"image_model_type": "dino"
|
25 |
+
},
|
26 |
+
"SigLIP-So400m-14-384": {
|
27 |
+
"image_emb_dim": 1152,
|
28 |
+
'image_num_layers': 27,
|
29 |
+
"image_mlp_dim": 4304,
|
30 |
+
'image_patch_size': 14,
|
31 |
+
'image_pos_patch_size': 14,
|
32 |
+
'image_num_heads': 16,
|
33 |
+
'image_num_key_value_heads': 16,
|
34 |
+
'image_head_dim': 72,
|
35 |
+
'image_mlp_activations': 'gelu',
|
36 |
+
# Although it is called "384" that seems to be an error of the author's
|
37 |
+
# part, it actually only handles 378 inputs
|
38 |
+
'image_default_input_size': (378, 378),
|
39 |
+
'image_num_pos': 729, # note not CLS token
|
40 |
+
'image_norm_eps': 1e-6,
|
41 |
+
"image_model_type": "siglip",
|
42 |
+
"resize_mode": "siglip"
|
43 |
+
},
|
44 |
+
"DFN5B-CLIP-ViT-H-14-378": {
|
45 |
+
"image_emb_dim": 1280,
|
46 |
+
'image_patch_size': 14,
|
47 |
+
'image_pos_patch_size': 14,
|
48 |
+
'image_num_layers': 32,
|
49 |
+
'image_num_heads': 16,
|
50 |
+
'image_num_key_value_heads': 16,
|
51 |
+
'image_head_dim': 80,
|
52 |
+
'image_mlp_dim': 5120,
|
53 |
+
'image_dropout_rate': 0.0,
|
54 |
+
'image_mlp_activations': 'quick_gelu',
|
55 |
+
'image_default_input_size': (378, 378),
|
56 |
+
'image_num_pos': 730,
|
57 |
+
'image_norm_eps': 1e-5,
|
58 |
+
"image_model_type": "openai",
|
59 |
+
"resize_mode": "no_aspect_ratio"
|
60 |
+
},
|
61 |
+
'ViT-L/14-336': {
|
62 |
+
'image_patch_size': 14,
|
63 |
+
'image_pos_patch_size': 14,
|
64 |
+
'image_emb_dim': 1024,
|
65 |
+
'image_num_heads': 16,
|
66 |
+
'image_num_layers': 23,
|
67 |
+
'image_head_dim': 64,
|
68 |
+
'image_mlp_dim': 4096,
|
69 |
+
'image_mlp_activations': 'quick_gelu',
|
70 |
+
'image_dropout_rate': 0.0,
|
71 |
+
'image_num_pos': 577,
|
72 |
+
'image_default_input_size': (336, 336),
|
73 |
+
'image_norm_eps': 1e-5,
|
74 |
+
'image_num_key_value_heads': 16,
|
75 |
+
"image_model_type": "openai"
|
76 |
+
},
|
77 |
+
'EVA02-L-14-336': {
|
78 |
+
'image_patch_size': 14,
|
79 |
+
'image_pos_patch_size': 14,
|
80 |
+
'image_emb_dim': 1024,
|
81 |
+
'image_num_heads': 16,
|
82 |
+
'image_num_layers': 24,
|
83 |
+
'image_head_dim': 64,
|
84 |
+
'image_mlp_dim': 2730,
|
85 |
+
'image_mlp_activations': 'silu',
|
86 |
+
'image_dropout_rate': 0.0,
|
87 |
+
'image_num_pos': 577,
|
88 |
+
'image_default_input_size': (336, 336),
|
89 |
+
'image_norm_eps': 1e-6,
|
90 |
+
'image_num_key_value_heads': 16,
|
91 |
+
"image_model_type": "eva"
|
92 |
+
},
|
93 |
+
'ViT-L/14': {
|
94 |
+
'image_patch_size': 14,
|
95 |
+
'image_pos_patch_size': 14,
|
96 |
+
'image_emb_dim': 1024,
|
97 |
+
'image_num_heads': 16,
|
98 |
+
# Note the original model has 24 layers, but we don't use the last layer
|
99 |
+
'image_num_layers': 23,
|
100 |
+
'image_head_dim': 64,
|
101 |
+
'image_mlp_dim': 4096,
|
102 |
+
'image_mlp_activations': 'quick_gelu',
|
103 |
+
'image_dropout_rate': 0.0,
|
104 |
+
'image_num_pos': 257,
|
105 |
+
'image_default_input_size': (224, 224),
|
106 |
+
'image_norm_eps': 1e-5,
|
107 |
+
'image_num_key_value_heads': 16,
|
108 |
+
"image_model_type": "openai"
|
109 |
+
},
|
110 |
+
'debug': {
|
111 |
+
'image_patch_size': 14,
|
112 |
+
'image_pos_patch_size': 14,
|
113 |
+
'image_emb_dim': 1024,
|
114 |
+
'image_num_heads': 16,
|
115 |
+
'image_num_layers': 1,
|
116 |
+
'image_head_dim': 64,
|
117 |
+
'image_mlp_dim': 1024,
|
118 |
+
'image_mlp_activations': 'quick_gelu',
|
119 |
+
'image_dropout_rate': 0.0,
|
120 |
+
'image_num_pos': 577,
|
121 |
+
'image_default_input_size': (336, 336),
|
122 |
+
'image_norm_eps': 1e-5,
|
123 |
+
'image_num_key_value_heads': 16,
|
124 |
+
"image_model_type": "openai"
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
OPEN_LLM_STANDARD_CONFIGS = {
|
129 |
+
"qwen1.5_7b": {
|
130 |
+
'vocab_size': 151936,
|
131 |
+
'hidden_size': 4096,
|
132 |
+
'intermediate_size': 11008,
|
133 |
+
'num_hidden_layers': 32,
|
134 |
+
'num_attention_heads': 32,
|
135 |
+
'num_key_value_heads': 32,
|
136 |
+
'max_sequence_length': 2048,
|
137 |
+
'max_position_embeddings': 32768,
|
138 |
+
'rope_theta': 1000000.0,
|
139 |
+
'initializer_range': 0.02,
|
140 |
+
'rms_norm_eps': 1e-6,
|
141 |
+
"qkv_bias": True,
|
142 |
+
'tie_word_embeddings': False,
|
143 |
+
'hidden_act': 'silu',
|
144 |
+
'norm_module': 'RMSNorm',
|
145 |
+
"tokenizer": "hf-Qwen/Qwen1.5-7B",
|
146 |
+
},
|
147 |
+
"qwen1.5_14b": {
|
148 |
+
'vocab_size': 152064,
|
149 |
+
'hidden_size': 5120,
|
150 |
+
'intermediate_size': 13696,
|
151 |
+
'num_hidden_layers': 40,
|
152 |
+
'num_attention_heads': 40,
|
153 |
+
'num_key_value_heads': 40,
|
154 |
+
'max_sequence_length': 2048,
|
155 |
+
'max_position_embeddings': 32768,
|
156 |
+
'rope_theta': 1000000.0,
|
157 |
+
'initializer_range': 0.02,
|
158 |
+
'rms_norm_eps': 1e-6,
|
159 |
+
"qkv_bias": True,
|
160 |
+
'tie_word_embeddings': False,
|
161 |
+
'hidden_act': 'silu',
|
162 |
+
'norm_module': 'RMSNorm',
|
163 |
+
"tokenizer": "hf-Qwen/Qwen1.5-14B",
|
164 |
+
},
|
165 |
+
"qwen1.5_32b": {
|
166 |
+
"vocab_size": 152064,
|
167 |
+
"hidden_size": 5120,
|
168 |
+
"intermediate_size": 27392,
|
169 |
+
"num_hidden_layers": 64,
|
170 |
+
"num_attention_heads": 40,
|
171 |
+
"num_key_value_heads": 8,
|
172 |
+
'max_sequence_length': 2048,
|
173 |
+
'max_position_embeddings': 32768,
|
174 |
+
"rope_theta": 1000000.0,
|
175 |
+
'initializer_range': 0.02,
|
176 |
+
"rms_norm_eps": 1e-6,
|
177 |
+
"qkv_bias": True,
|
178 |
+
"tie_word_embeddings": False,
|
179 |
+
'hidden_act': 'silu',
|
180 |
+
'norm_module': 'RMSNorm',
|
181 |
+
"tokenizer": "hf-Qwen/Qwen1.5-32B",
|
182 |
+
},
|
183 |
+
'llama_7b': {
|
184 |
+
'vocab_size': 32000,
|
185 |
+
'hidden_size': 4096,
|
186 |
+
'intermediate_size': 11008,
|
187 |
+
'num_hidden_layers': 32,
|
188 |
+
'num_attention_heads': 32,
|
189 |
+
'num_key_value_heads': 32,
|
190 |
+
'max_sequence_length': 2048,
|
191 |
+
'max_position_embeddings': 8192,
|
192 |
+
'rope_theta': 10000.0,
|
193 |
+
'initializer_range': 0.02,
|
194 |
+
'rms_norm_eps': 1e-5,
|
195 |
+
'tie_word_embeddings': False,
|
196 |
+
'hidden_act': 'silu',
|
197 |
+
'norm_module': 'RMSNorm',
|
198 |
+
"tokenizer": "llama"
|
199 |
+
},
|
200 |
+
'yi_6b': {
|
201 |
+
'vocab_size': 64000,
|
202 |
+
'hidden_size': 4096,
|
203 |
+
'intermediate_size': 11008,
|
204 |
+
'num_hidden_layers': 32,
|
205 |
+
'num_attention_heads': 32,
|
206 |
+
'num_key_value_heads': 4,
|
207 |
+
'max_sequence_length': 4096,
|
208 |
+
'max_position_embeddings': 4096,
|
209 |
+
'rope_theta': 5000000.0,
|
210 |
+
'initializer_range': 0.02,
|
211 |
+
'rms_norm_eps': 1e-5,
|
212 |
+
'tie_word_embeddings': False,
|
213 |
+
'hidden_act': 'silu',
|
214 |
+
'norm_module': 'RMSNorm',
|
215 |
+
"tokenizer": "yi"
|
216 |
+
},
|
217 |
+
'yi_9b': {
|
218 |
+
'vocab_size': 64000,
|
219 |
+
'hidden_size': 4096,
|
220 |
+
'intermediate_size': 11008,
|
221 |
+
'num_hidden_layers': 48,
|
222 |
+
'num_attention_heads': 32,
|
223 |
+
'num_key_value_heads': 4,
|
224 |
+
'max_sequence_length': 4096,
|
225 |
+
'max_position_embeddings': 4096,
|
226 |
+
'rope_theta': 10000,
|
227 |
+
'initializer_range': 0.02,
|
228 |
+
'rms_norm_eps': 1e-06,
|
229 |
+
'tie_word_embeddings': False,
|
230 |
+
'hidden_act': 'silu',
|
231 |
+
'norm_module': 'RMSNorm',
|
232 |
+
"tokenizer": "yi"
|
233 |
+
},
|
234 |
+
'yi_34b': {
|
235 |
+
'vocab_size': 64000,
|
236 |
+
'hidden_size': 7168,
|
237 |
+
'intermediate_size': 20480,
|
238 |
+
'num_hidden_layers': 60,
|
239 |
+
'num_attention_heads': 56,
|
240 |
+
'num_key_value_heads': 8,
|
241 |
+
'max_sequence_length': 4096,
|
242 |
+
'max_position_embeddings': 4096,
|
243 |
+
'rope_theta': 5000000.0,
|
244 |
+
'initializer_range': 0.02,
|
245 |
+
'rms_norm_eps': 1e-5,
|
246 |
+
'tie_word_embeddings': False,
|
247 |
+
'hidden_act': 'silu',
|
248 |
+
'norm_module': 'RMSNorm',
|
249 |
+
"tokenizer": "yi"
|
250 |
+
},
|
251 |
+
"olmo_1b": {
|
252 |
+
'vocab_size': 50304,
|
253 |
+
'hidden_size': 2048,
|
254 |
+
'intermediate_size': 8192,
|
255 |
+
'num_hidden_layers': 16,
|
256 |
+
'num_attention_heads': 16,
|
257 |
+
'num_key_value_heads': 16,
|
258 |
+
'max_sequence_length': 4096,
|
259 |
+
'max_position_embeddings': 32768,
|
260 |
+
'rope_theta': 10000.0,
|
261 |
+
'initializer_range': 0.02,
|
262 |
+
'rms_norm_eps': 1e-5,
|
263 |
+
'tie_word_embeddings': True,
|
264 |
+
'hidden_act': 'silu',
|
265 |
+
'norm_module': 'OlmoLayerNorm',
|
266 |
+
"tokenizer": "hf-allenai/OLMo-1B"
|
267 |
+
},
|
268 |
+
"olmo_7b": {
|
269 |
+
'vocab_size': 50304,
|
270 |
+
'hidden_size': 4096,
|
271 |
+
'intermediate_size': 22016//2,
|
272 |
+
'num_hidden_layers': 32,
|
273 |
+
'num_attention_heads': 32,
|
274 |
+
'num_key_value_heads': 32,
|
275 |
+
'max_sequence_length': 4096,
|
276 |
+
'max_position_embeddings': 32768,
|
277 |
+
'rope_theta': 10000.0,
|
278 |
+
'initializer_range': 0.02,
|
279 |
+
'rms_norm_eps': 1e-5,
|
280 |
+
'tie_word_embeddings': False,
|
281 |
+
'hidden_act': 'silu',
|
282 |
+
'norm_module': 'OlmoLayerNorm',
|
283 |
+
"tokenizer": "hf-allenai/OLMo-7B",
|
284 |
+
},
|
285 |
+
"olmo_1.7_7b": {
|
286 |
+
'vocab_size': 50304,
|
287 |
+
'hidden_size': 4096,
|
288 |
+
'intermediate_size': 22016//2,
|
289 |
+
'num_hidden_layers': 32,
|
290 |
+
'num_attention_heads': 32,
|
291 |
+
'num_key_value_heads': 32,
|
292 |
+
'max_sequence_length': 4096,
|
293 |
+
'max_position_embeddings': 32768,
|
294 |
+
'rope_theta': 10000.0,
|
295 |
+
'initializer_range': 0.02,
|
296 |
+
'rms_norm_eps': 1e-5,
|
297 |
+
'tie_word_embeddings': False,
|
298 |
+
'hidden_act': 'silu',
|
299 |
+
"qkv_clip": 8,
|
300 |
+
'norm_module': 'OlmoLayerNorm',
|
301 |
+
"tokenizer": "hf-allenai/OLMo-1.7-7B",
|
302 |
+
},
|
303 |
+
'mistral_7b': {
|
304 |
+
'vocab_size': 32000,
|
305 |
+
'hidden_size': 4096,
|
306 |
+
'intermediate_size': 14336,
|
307 |
+
'num_hidden_layers': 32,
|
308 |
+
'num_attention_heads': 32,
|
309 |
+
'num_key_value_heads': 8,
|
310 |
+
'max_sequence_length': 4096,
|
311 |
+
'max_position_embeddings': 32768,
|
312 |
+
'rope_theta': 10000.0,
|
313 |
+
'initializer_range': 0.02,
|
314 |
+
'rms_norm_eps': 1e-5,
|
315 |
+
'tie_word_embeddings': False,
|
316 |
+
'hidden_act': 'silu',
|
317 |
+
'norm_module': 'RMSNorm',
|
318 |
+
"tokenizer": "mistral"
|
319 |
+
},
|
320 |
+
'mistral0.3_7b': {
|
321 |
+
'vocab_size': 32768,
|
322 |
+
'hidden_size': 4096,
|
323 |
+
'intermediate_size': 14336,
|
324 |
+
'num_hidden_layers': 32,
|
325 |
+
'num_attention_heads': 32,
|
326 |
+
'num_key_value_heads': 8,
|
327 |
+
'max_sequence_length': 4096,
|
328 |
+
'max_position_embeddings': 32768,
|
329 |
+
'rope_theta': 1000000.0,
|
330 |
+
'initializer_range': 0.02,
|
331 |
+
'rms_norm_eps': 1e-5,
|
332 |
+
'tie_word_embeddings': False,
|
333 |
+
'hidden_act': 'silu',
|
334 |
+
'norm_module': 'RMSNorm',
|
335 |
+
"tokenizer": "mistral0.3"
|
336 |
+
},
|
337 |
+
"mistral0.2_22b": {
|
338 |
+
'vocab_size': 32000,
|
339 |
+
'hidden_size': 6144,
|
340 |
+
'intermediate_size': 16384,
|
341 |
+
'num_hidden_layers': 56,
|
342 |
+
'num_attention_heads': 48,
|
343 |
+
'num_key_value_heads': 8,
|
344 |
+
'max_sequence_length': 4096,
|
345 |
+
'max_position_embeddings': 32768,
|
346 |
+
'rope_theta': 1000000,
|
347 |
+
'initializer_range': 0.02,
|
348 |
+
'rms_norm_eps': 1e-5,
|
349 |
+
'tie_word_embeddings': False,
|
350 |
+
'hidden_act': 'silu',
|
351 |
+
'norm_module': 'RMSNorm',
|
352 |
+
"tokenizer": "mistral"
|
353 |
+
},
|
354 |
+
'llama_13b': {
|
355 |
+
'vocab_size': 32000,
|
356 |
+
'hidden_size': 5120,
|
357 |
+
'intermediate_size': 13824,
|
358 |
+
'num_hidden_layers': 40,
|
359 |
+
'num_attention_heads': 40,
|
360 |
+
'num_key_value_heads': 40,
|
361 |
+
'max_sequence_length': 2048,
|
362 |
+
'max_position_embeddings': 8192,
|
363 |
+
'initializer_range': 0.02,
|
364 |
+
'rms_norm_eps': 1e-5,
|
365 |
+
'tie_word_embeddings': False,
|
366 |
+
'hidden_act': 'silu',
|
367 |
+
"norm_module": 'RMSNorm',
|
368 |
+
'rope_theta': 10000.0,
|
369 |
+
"tokenizer": "llama"
|
370 |
+
},
|
371 |
+
'llama_70b': {
|
372 |
+
'vocab_size': 32000,
|
373 |
+
'hidden_size': 8192,
|
374 |
+
'intermediate_size': 28672,
|
375 |
+
'num_hidden_layers': 80,
|
376 |
+
'num_attention_heads': 64,
|
377 |
+
'num_key_value_heads': 8,
|
378 |
+
'max_sequence_length': 8192,
|
379 |
+
'max_position_embeddings': 8192,
|
380 |
+
'rope_theta': 10000.0,
|
381 |
+
'initializer_range': 0.02,
|
382 |
+
'rms_norm_eps': 1e-5,
|
383 |
+
'tie_word_embeddings': False,
|
384 |
+
'hidden_act': 'silu',
|
385 |
+
"tokenizer": "llama"
|
386 |
+
},
|
387 |
+
'llama_70bflash': {
|
388 |
+
'vocab_size': 32000,
|
389 |
+
'hidden_size': 8192,
|
390 |
+
'intermediate_size': 28672,
|
391 |
+
'num_hidden_layers': 80,
|
392 |
+
'num_attention_heads': 64,
|
393 |
+
'num_key_value_heads': 8,
|
394 |
+
'max_sequence_length': 8192,
|
395 |
+
'max_position_embeddings': 8192,
|
396 |
+
'rope_theta': 10000.0,
|
397 |
+
'initializer_range': 0.02,
|
398 |
+
'rms_norm_eps': 1e-5,
|
399 |
+
'tie_word_embeddings': False,
|
400 |
+
'scan_attention': True,
|
401 |
+
'scan_mlp': True,
|
402 |
+
'hidden_act': 'silu',
|
403 |
+
"tokenizer": "llama"
|
404 |
+
},
|
405 |
+
'llama3_8b': {
|
406 |
+
'vocab_size': 128256,
|
407 |
+
'hidden_size': 4096,
|
408 |
+
'intermediate_size': 14336,
|
409 |
+
'num_hidden_layers': 32,
|
410 |
+
'num_attention_heads': 32,
|
411 |
+
'num_key_value_heads': 8,
|
412 |
+
'max_sequence_length': 8192,
|
413 |
+
'max_position_embeddings': 8192,
|
414 |
+
'rope_theta': 500000.0,
|
415 |
+
'initializer_range': 0.02,
|
416 |
+
'rms_norm_eps': 1e-5,
|
417 |
+
'tie_word_embeddings': False,
|
418 |
+
'hidden_act': 'silu',
|
419 |
+
'norm_module': 'RMSNorm',
|
420 |
+
"tokenizer": "hf-meta-llama/Meta-Llama-3-8B",
|
421 |
+
|
422 |
+
},
|
423 |
+
'llama3_70b': {
|
424 |
+
'vocab_size': 128256,
|
425 |
+
'hidden_size': 8192,
|
426 |
+
'intermediate_size': 28672,
|
427 |
+
'num_hidden_layers': 80,
|
428 |
+
'num_attention_heads': 64,
|
429 |
+
'num_key_value_heads': 8,
|
430 |
+
'max_sequence_length': 8192,
|
431 |
+
'max_position_embeddings': 8192,
|
432 |
+
'rope_theta': 500000.0,
|
433 |
+
'initializer_range': 0.02,
|
434 |
+
'rms_norm_eps': 1e-5,
|
435 |
+
'tie_word_embeddings': False,
|
436 |
+
'hidden_act': 'silu',
|
437 |
+
'norm_module': 'RMSNorm',
|
438 |
+
"tokenizer": "hf-meta-llama/Meta-Llama-3-70B",
|
439 |
+
},
|
440 |
+
'open_llama_3b': {
|
441 |
+
'vocab_size': 32000,
|
442 |
+
'hidden_size': 3200,
|
443 |
+
'intermediate_size': 8640,
|
444 |
+
'num_hidden_layers': 26,
|
445 |
+
'num_attention_heads': 32,
|
446 |
+
'max_sequence_length': 2048,
|
447 |
+
'initializer_range': 0.02,
|
448 |
+
'rms_norm_eps': 1e-6,
|
449 |
+
'max_position_embeddings': 2048,
|
450 |
+
'num_key_value_heads': 32,
|
451 |
+
'rope_theta': 10000.0,
|
452 |
+
'tie_word_embeddings': False,
|
453 |
+
'hidden_act': 'silu',
|
454 |
+
'norm_module': 'RMSNorm',
|
455 |
+
"tokenizer": "llama"
|
456 |
+
},
|
457 |
+
'gemma_2b': {
|
458 |
+
'vocab_size': 256000,
|
459 |
+
'hidden_size': 2048,
|
460 |
+
'intermediate_size': 16384,
|
461 |
+
'num_hidden_layers': 18,
|
462 |
+
'num_attention_heads': 8,
|
463 |
+
'max_sequence_length': 8192,
|
464 |
+
'initializer_range': 0.02,
|
465 |
+
'rms_norm_eps': 1e-6,
|
466 |
+
'max_position_embeddings': 8192,
|
467 |
+
'num_key_value_heads': 1,
|
468 |
+
'rope_theta': 10000.0,
|
469 |
+
'tie_word_embeddings': True,
|
470 |
+
'normalize_input_embeds': True,
|
471 |
+
'norm_module': 'GemmaRMSNorm',
|
472 |
+
'hidden_act': 'gelu',
|
473 |
+
"tokenizer": "gemma"
|
474 |
+
},
|
475 |
+
'gemma_7b': {
|
476 |
+
'vocab_size': 256000,
|
477 |
+
'hidden_size': 3072,
|
478 |
+
'intermediate_size': 24576,
|
479 |
+
'num_hidden_layers': 28,
|
480 |
+
'num_attention_heads': 16,
|
481 |
+
'max_sequence_length': 8192,
|
482 |
+
'initializer_range': 0.02,
|
483 |
+
'rms_norm_eps': 1e-6,
|
484 |
+
'max_position_embeddings': 8192,
|
485 |
+
'num_key_value_heads': 16,
|
486 |
+
'rope_theta': 10000.0,
|
487 |
+
'tie_word_embeddings': True,
|
488 |
+
'normalize_input_embeds': True,
|
489 |
+
'norm_module': 'GemmaRMSNorm',
|
490 |
+
'hidden_act': 'gelu',
|
491 |
+
"tokenizer": "gemma"
|
492 |
+
},
|
493 |
+
'tiny_llama_1b': {
|
494 |
+
'vocab_size': 32000,
|
495 |
+
'hidden_size': 2048,
|
496 |
+
'intermediate_size': 5632,
|
497 |
+
'num_hidden_layers': 22,
|
498 |
+
'num_attention_heads': 32,
|
499 |
+
'max_sequence_length': 2048,
|
500 |
+
'initializer_range': 0.02,
|
501 |
+
'rms_norm_eps': 1e-5,
|
502 |
+
'max_position_embeddings': 2048,
|
503 |
+
'num_key_value_heads': 4,
|
504 |
+
'rope_theta': 10000.0,
|
505 |
+
'tie_word_embeddings': False,
|
506 |
+
'hidden_act': 'silu',
|
507 |
+
'norm_module': 'RMSNorm',
|
508 |
+
"tokenizer": "llama"
|
509 |
+
},
|
510 |
+
'debug': { # A small model for debugging
|
511 |
+
'vocab_size': 32000,
|
512 |
+
'hidden_size': 512,
|
513 |
+
'intermediate_size': 512,
|
514 |
+
'num_hidden_layers': 1,
|
515 |
+
'num_attention_heads': 8,
|
516 |
+
'max_sequence_length': 4096,
|
517 |
+
'initializer_range': 0.02,
|
518 |
+
'rms_norm_eps': 1e-5,
|
519 |
+
'max_position_embeddings': 4096,
|
520 |
+
'num_key_value_heads': 8,
|
521 |
+
'rope_theta': 10000.0,
|
522 |
+
'tie_word_embeddings': False,
|
523 |
+
'hidden_act': 'silu',
|
524 |
+
'norm_module': 'RMSNorm',
|
525 |
+
"tokenizer": "llama"
|
526 |
+
},
|
527 |
+
'gemma2_9b': {
|
528 |
+
'vocab_size': 256000,
|
529 |
+
'hidden_size': 3584,
|
530 |
+
'head_dim': 256,
|
531 |
+
'intermediate_size': 14336,
|
532 |
+
'num_hidden_layers': 42,
|
533 |
+
'num_attention_heads': 16,
|
534 |
+
'max_sequence_length': 8192,
|
535 |
+
"query_pre_attn_scalar": 224,
|
536 |
+
'initializer_range': 0.02,
|
537 |
+
'rms_norm_eps': 1e-6,
|
538 |
+
'max_position_embeddings': 8192,
|
539 |
+
'num_key_value_heads': 8,
|
540 |
+
'rope_theta': 10000.0,
|
541 |
+
'tie_word_embeddings': False,
|
542 |
+
'normalize_input_embeds': True,
|
543 |
+
'norm_module': 'GemmaRMSNorm',
|
544 |
+
'hidden_act': 'gelu_tanh',
|
545 |
+
"tokenizer": "hf-google/gemma-2-9b",
|
546 |
+
"attn_logit_softcapping": 50.0,
|
547 |
+
"final_logit_softcapping": 30.0,
|
548 |
+
},
|
549 |
+
'gemma2_27b': {
|
550 |
+
'vocab_size': 256000,
|
551 |
+
'hidden_size': 4608,
|
552 |
+
'head_dim': 128,
|
553 |
+
'intermediate_size': 36864,
|
554 |
+
'num_hidden_layers': 46,
|
555 |
+
'num_attention_heads': 32,
|
556 |
+
'max_sequence_length': 8192,
|
557 |
+
"query_pre_attn_scalar": 144,
|
558 |
+
'initializer_range': 0.02,
|
559 |
+
'rms_norm_eps': 1e-6,
|
560 |
+
'max_position_embeddings': 8192,
|
561 |
+
'num_key_value_heads': 16,
|
562 |
+
'rope_theta': 10000.0,
|
563 |
+
'tie_word_embeddings': False,
|
564 |
+
'normalize_input_embeds': True,
|
565 |
+
'norm_module': 'GemmaRMSNorm',
|
566 |
+
'hidden_act': 'gelu_tanh',
|
567 |
+
"tokenizer": "hf-google/gemma-2-27b",
|
568 |
+
"attn_logit_softcapping": 50.0,
|
569 |
+
"final_logit_softcapping": 30.0,
|
570 |
+
},
|
571 |
+
}
|
data_factory.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Dataset factory to load data from huggingface and others.
|
3 |
+
'''
|
4 |
+
import dataclasses
|
5 |
+
import logging
|
6 |
+
from typing import List, Optional
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
|
11 |
+
from .data_utils import add_segment_ids
|
12 |
+
from .dataset_sizes import get_dataset_size
|
13 |
+
from .tasks import get_task
|
14 |
+
from .multimodal_preprocessor import MultiModalPreprocessor
|
15 |
+
import seqio
|
16 |
+
|
17 |
+
from .torch_util import get_global_rank
|
18 |
+
|
19 |
+
log = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass
|
23 |
+
class SeqioDataset:
|
24 |
+
mixture_or_task_name: str
|
25 |
+
seq_len: int
|
26 |
+
global_batch_size: int
|
27 |
+
max_crops: int = None
|
28 |
+
is_training: bool = False
|
29 |
+
for_inference: bool = False
|
30 |
+
split: str = 'train'
|
31 |
+
shuffle: bool = True
|
32 |
+
num_epochs: int = None
|
33 |
+
drop_remainder: bool = True
|
34 |
+
seed: int = None
|
35 |
+
pack: bool = False
|
36 |
+
use_custom_packing_ops: bool = False
|
37 |
+
use_memory_cache: bool = False
|
38 |
+
shuffle_buffer_size: Optional[int] = None
|
39 |
+
different_host_mixture_seeds: bool = True
|
40 |
+
disable_autotune: bool = True
|
41 |
+
trim_output_features: bool = True
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_dict(cls, data):
|
45 |
+
return cls(**data)
|
46 |
+
|
47 |
+
def get_task_feature_lengths_dict(self, max_crops):
|
48 |
+
if self.max_crops is not None:
|
49 |
+
assert self.max_crops >= max_crops
|
50 |
+
max_crops = self.max_crops
|
51 |
+
return dict(
|
52 |
+
target_tokens=self.seq_len,
|
53 |
+
loss_masks=self.seq_len,
|
54 |
+
images=max_crops,
|
55 |
+
image_positions=max_crops,
|
56 |
+
image_input_idx=max_crops,
|
57 |
+
is_training=self.is_training
|
58 |
+
)
|
59 |
+
|
60 |
+
def build(self, preprocessor: MultiModalPreprocessor, shard_id, num_shards):
|
61 |
+
shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards)
|
62 |
+
task_feature_lengths_dict = self.get_task_feature_lengths_dict(
|
63 |
+
preprocessor.get_max_total_crops())
|
64 |
+
|
65 |
+
seed = self.seed
|
66 |
+
assert seed is not None
|
67 |
+
|
68 |
+
batch_size = self.global_batch_size // num_shards
|
69 |
+
|
70 |
+
if isinstance(self.mixture_or_task_name, (dict, list, tuple)):
|
71 |
+
if isinstance(self.mixture_or_task_name, dict):
|
72 |
+
items = self.mixture_or_task_name.items()
|
73 |
+
else:
|
74 |
+
items = self.mixture_or_task_name
|
75 |
+
task_list = []
|
76 |
+
for task, weight in items:
|
77 |
+
task = get_task(preprocessor, task, self.is_training, self.for_inference)
|
78 |
+
task_list.append((task, weight))
|
79 |
+
mixture_or_task = task_list
|
80 |
+
else:
|
81 |
+
mixture_or_task = get_task(
|
82 |
+
preprocessor, self.mixture_or_task_name, self.is_training, self.for_inference)
|
83 |
+
|
84 |
+
in_memory_shuffle = self.shuffle
|
85 |
+
if not self.drop_remainder:
|
86 |
+
# Used if we want to evaluate on an eval dataset without dropping any examples.
|
87 |
+
# To do this, we pad the dataset with dummy examples marked as invalid in their
|
88 |
+
# metadata so we can still get fixed-sized batches.
|
89 |
+
assert self.num_epochs is not None
|
90 |
+
assert not self.pack
|
91 |
+
assert not isinstance(mixture_or_task, list), "Inference datasets cannot be mixtures"
|
92 |
+
logging.info(
|
93 |
+
f"Initializing inf. dataset {mixture_or_task.name}: replica_batch_size={batch_size}"
|
94 |
+
f' seed={seed}, sharding={shard_info.index}/{shard_info.num_shards}'
|
95 |
+
)
|
96 |
+
ds = mixture_or_task.get_dataset(
|
97 |
+
sequence_length=task_feature_lengths_dict,
|
98 |
+
split=self.split,
|
99 |
+
shuffle=in_memory_shuffle,
|
100 |
+
num_epochs=self.num_epochs,
|
101 |
+
seed=seed,
|
102 |
+
try_in_mem_cache=self.use_memory_cache,
|
103 |
+
trim_output_features=self.trim_output_features
|
104 |
+
)
|
105 |
+
|
106 |
+
try:
|
107 |
+
n = len(ds)
|
108 |
+
except TypeError:
|
109 |
+
dataset_len = get_dataset_size(self.mixture_or_task_name, self.split)
|
110 |
+
logging.info(f"Setting dataset len to {dataset_len} based on DATASET_SIZES")
|
111 |
+
n = dataset_len
|
112 |
+
ds = tf.data.experimental.assert_cardinality(n)(ds)
|
113 |
+
|
114 |
+
remainder = n % self.global_batch_size
|
115 |
+
if remainder > 0:
|
116 |
+
n_to_pad = self.global_batch_size - remainder
|
117 |
+
else:
|
118 |
+
n_to_pad = 0
|
119 |
+
assert "metadata/valid" not in ds.element_spec
|
120 |
+
def add_valid(x):
|
121 |
+
x["metadata/valid"] = True
|
122 |
+
return x
|
123 |
+
def add_invalid(x):
|
124 |
+
x["metadata/valid"] = False
|
125 |
+
return x
|
126 |
+
ds = ds.map(add_valid)
|
127 |
+
if n_to_pad > 0:
|
128 |
+
to_pad = ds.take(1).map(add_invalid).cache().repeat(n_to_pad)
|
129 |
+
ds = ds.concatenate(to_pad)
|
130 |
+
|
131 |
+
# Shard after padding to ensure shards are the same length
|
132 |
+
ds = ds.shard(num_shards=num_shards, index=shard_id)
|
133 |
+
|
134 |
+
ds = preprocessor.get_post_mixing_preprocessor()(
|
135 |
+
ds, task_feature_lengths=task_feature_lengths_dict)
|
136 |
+
data_iter = ds.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
137 |
+
# Make it possible for client to get the size of the batched/sharded dataset with `len()`
|
138 |
+
new_len = (n + n_to_pad) // self.global_batch_size
|
139 |
+
data_iter = tf.data.experimental.assert_cardinality(new_len)(data_iter)
|
140 |
+
else:
|
141 |
+
if isinstance(mixture_or_task, list):
|
142 |
+
total_rate = sum(x[1] for x in mixture_or_task)
|
143 |
+
mixture_or_task = [(task, r/total_rate) for task, r in mixture_or_task]
|
144 |
+
sorted_tasks: List[seqio.Task] = sorted(mixture_or_task, key=lambda x: -x[1])
|
145 |
+
|
146 |
+
if self.different_host_mixture_seeds and shard_info:
|
147 |
+
# If each process has the same seed they will draw from the datasets in the same
|
148 |
+
# order, which can make the global batches very non-random if there are
|
149 |
+
# many processes each with a small batch size. To fix this, we give each host
|
150 |
+
# a different seed based on its rank to use when mixing
|
151 |
+
mix_seed = seed + shard_info.index*4397
|
152 |
+
else:
|
153 |
+
mix_seed = seed
|
154 |
+
|
155 |
+
logging.info(
|
156 |
+
f"Initializing mixture: replica_batch_size={batch_size} seed={seed}, "
|
157 |
+
f"mix_seed={mix_seed}, sharding={shard_info.index}/{shard_info.num_shards} rates:"
|
158 |
+
)
|
159 |
+
for task, rate in sorted_tasks:
|
160 |
+
logging.info(f"\t{task.name}: {rate:0.4f}")
|
161 |
+
|
162 |
+
datasets = []
|
163 |
+
rates = []
|
164 |
+
for task, rate in sorted_tasks:
|
165 |
+
assert rate > 0
|
166 |
+
datasets.append(task.get_dataset(
|
167 |
+
task_feature_lengths_dict,
|
168 |
+
split=self.split,
|
169 |
+
shuffle=self.shuffle,
|
170 |
+
seed=seed,
|
171 |
+
shard_info=shard_info,
|
172 |
+
num_epochs=self.num_epochs,
|
173 |
+
try_in_mem_cache=self.use_memory_cache,
|
174 |
+
trim_output_features=self.trim_output_features
|
175 |
+
))
|
176 |
+
rates.append(rate)
|
177 |
+
|
178 |
+
# If any of the sub-tasks have subsegment_ids, we need to ensure all the tasks have
|
179 |
+
# a subsegment_ids field so they can be mixed
|
180 |
+
if any("subsegment_ids" in ds.element_spec for ds in datasets):
|
181 |
+
for ix, ds in enumerate(datasets):
|
182 |
+
if "subsegment_ids" not in ds.element_spec:
|
183 |
+
datasets[ix] = add_segment_ids(ds)
|
184 |
+
|
185 |
+
ds = tf.data.Dataset.sample_from_datasets(
|
186 |
+
datasets, rates, seed=mix_seed, stop_on_empty_dataset=False)
|
187 |
+
else:
|
188 |
+
logging.info(
|
189 |
+
f"Initializing dataset {mixture_or_task.name}: replica_batch_size={batch_size}"
|
190 |
+
f' seed={seed}, sharding={shard_info.index}/{shard_info.num_shards}'
|
191 |
+
)
|
192 |
+
ds = mixture_or_task.get_dataset(
|
193 |
+
task_feature_lengths_dict,
|
194 |
+
split=self.split,
|
195 |
+
shuffle=self.shuffle,
|
196 |
+
seed=seed,
|
197 |
+
shard_info=shard_info,
|
198 |
+
num_epochs=self.num_epochs,
|
199 |
+
try_in_mem_cache=self.use_memory_cache,
|
200 |
+
trim_output_features=self.trim_output_features
|
201 |
+
)
|
202 |
+
data_iter = preprocessor.get_post_mixing_preprocessor()(
|
203 |
+
ds, task_feature_lengths=task_feature_lengths_dict)
|
204 |
+
data_iter = data_iter.batch(batch_size, drop_remainder=True, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
205 |
+
ds = ds.prefetch(2)
|
206 |
+
|
207 |
+
# Following https://github.com/google-research/big_vision/blob/b8dab6e4de3436849415f37c591399c93b1eaf39/big_vision/input_pipeline.py#L228
|
208 |
+
# These options try to stop tf datasets from eating all our RAM if we are using a
|
209 |
+
# large mixture
|
210 |
+
# This options are used by default in some google codebases
|
211 |
+
# For example: (https://github.com/google-research/big_vision/blob/b8dab6e4de3436849415f37c591399c93b1eaf39/big_vision/input_pipeline.py#L228)
|
212 |
+
# They don't seem to harm throughput and can save RAM so we use them as well
|
213 |
+
options = tf.data.Options()
|
214 |
+
options.experimental_optimization.inject_prefetch = False
|
215 |
+
options.threading.max_intra_op_parallelism = 1
|
216 |
+
if self.disable_autotune:
|
217 |
+
# Following https://www.tensorflow.org/datasets/performances
|
218 |
+
# This reduces RAM and checkpoint size by a lot
|
219 |
+
options.autotune.enabled = False
|
220 |
+
data_iter = data_iter.with_options(options)
|
221 |
+
|
222 |
+
return data_iter
|
data_utils.py
ADDED
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import dataclasses
|
3 |
+
import functools
|
4 |
+
import os
|
5 |
+
from os import environ
|
6 |
+
from typing import Mapping, Optional, Sequence, List
|
7 |
+
from absl import logging
|
8 |
+
import clu
|
9 |
+
import gin
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import seqio
|
13 |
+
from seqio import utils
|
14 |
+
from seqio.feature_converters import _check_exact_match, _check_lengths
|
15 |
+
|
16 |
+
import tensorflow as tf
|
17 |
+
from tensorflow.python.ops import control_flow_ops
|
18 |
+
from tensorflow.python.ops.image_ops_impl import _ImageDimensions, _CheckAtLeast3DImage, _assert, _is_tensor
|
19 |
+
|
20 |
+
from tensorflow.python.framework import ops
|
21 |
+
from tensorflow.python.ops import array_ops
|
22 |
+
from transformers import PreTrainedTokenizerFast
|
23 |
+
|
24 |
+
from . import seqio_tokenizer as vocab
|
25 |
+
from .constants import *
|
26 |
+
from .utils import pop_metadata
|
27 |
+
from .util import is_url
|
28 |
+
|
29 |
+
DEFAULT_EXTRA_IDS = 0
|
30 |
+
OutputFeaturesType = Mapping[str, utils.Feature]
|
31 |
+
|
32 |
+
|
33 |
+
def build_tokenizer(
|
34 |
+
tokenizer_type, has_extra_token=True,
|
35 |
+
adds_space=False,
|
36 |
+
olmo_bos_token_id=1, olmo_eos_token_id=2,
|
37 |
+
tokenizer_dir="gs://mm-olmo/tokenizer",
|
38 |
+
pad_tokenizer_to=None, cache={},
|
39 |
+
):
|
40 |
+
cache_key = (tokenizer_type, has_extra_token, adds_space, olmo_bos_token_id,
|
41 |
+
olmo_eos_token_id, pad_tokenizer_to)
|
42 |
+
if cache_key in cache:
|
43 |
+
return cache[cache_key]
|
44 |
+
|
45 |
+
if tokenizer_type == 'llama':
|
46 |
+
tok = vocab.SentencePieceVocabulary(
|
47 |
+
os.path.join(tokenizer_dir, "llama_tokenizer.model"),
|
48 |
+
extra_ids=DEFAULT_EXTRA_IDS,
|
49 |
+
reverse_extra_ids=True,
|
50 |
+
extra_tokens=EXTRA_TOKENS if has_extra_token else None,
|
51 |
+
)
|
52 |
+
elif tokenizer_type == 'yi':
|
53 |
+
tok = vocab.SentencePieceVocabulary(
|
54 |
+
os.path.join(tokenizer_dir, "yi_tokenizer.model"),
|
55 |
+
extra_ids=DEFAULT_EXTRA_IDS,
|
56 |
+
reverse_extra_ids=True,
|
57 |
+
extra_tokens=EXTRA_TOKENS if has_extra_token else None,
|
58 |
+
)
|
59 |
+
elif tokenizer_type == 'mistral':
|
60 |
+
tok = vocab.SentencePieceVocabulary(
|
61 |
+
os.path.join(tokenizer_dir, "mistral_tokenizer.model"),
|
62 |
+
extra_ids=DEFAULT_EXTRA_IDS,
|
63 |
+
reverse_extra_ids=True,
|
64 |
+
extra_tokens=EXTRA_TOKENS if has_extra_token else None,
|
65 |
+
)
|
66 |
+
|
67 |
+
elif tokenizer_type == "mistral0.3":
|
68 |
+
tok = vocab.SentencePieceVocabulary(
|
69 |
+
os.path.join(tokenizer_dir, "mistral0.3_tokenizer.model.v3"),
|
70 |
+
extra_ids=DEFAULT_EXTRA_IDS,
|
71 |
+
reverse_extra_ids=True,
|
72 |
+
extra_tokens=EXTRA_TOKENS if has_extra_token else None,
|
73 |
+
)
|
74 |
+
elif tokenizer_type == 'gemma':
|
75 |
+
tok = vocab.SentencePieceVocabulary(
|
76 |
+
os.path.join(tokenizer_dir, "gemma_tokenizer.model"),
|
77 |
+
extra_ids=DEFAULT_EXTRA_IDS,
|
78 |
+
reverse_extra_ids=True,
|
79 |
+
extra_tokens=EXTRA_TOKENS if has_extra_token else None,
|
80 |
+
)
|
81 |
+
elif tokenizer_type.startswith("hf-"):
|
82 |
+
# FIXME When using the beaker image "sanghol/mm-olmo" for hosting endpoints,
|
83 |
+
# we should set the cache_dir, otherwise FileNotFound errors will be raised
|
84 |
+
cache_dir = None if tokenizer_dir is None or is_url(tokenizer_dir) else tokenizer_dir
|
85 |
+
from transformers import AutoTokenizer
|
86 |
+
|
87 |
+
extra_tokens = list(EXTRA_TOKENS)
|
88 |
+
if pad_tokenizer_to is not None:
|
89 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_type[3:], token=environ.get("HF_ACCESS_TOKEN"), cache_dir=cache_dir)
|
90 |
+
n_extra_tokens = pad_tokenizer_to - len(tokenizer)
|
91 |
+
# This handles a case where the LLM embedding matrix is larger than the vocab size
|
92 |
+
# We need the extra tokens in `EXTRA_TOKENS` to be assigned id's higher than the embedding
|
93 |
+
# matrix size, not the vocab size, since we will concat the embedding and matrix with
|
94 |
+
# the special token embedding matrix, so we pad the vocab with additional special tokens
|
95 |
+
if n_extra_tokens > 0:
|
96 |
+
logging.info(f"Padding tokenizer with {n_extra_tokens} tokens")
|
97 |
+
extra_tokens = [f"|<EXTRA_TOKENS_{i}>|" for i in range(n_extra_tokens)] + extra_tokens
|
98 |
+
|
99 |
+
bos_token_id = None
|
100 |
+
|
101 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
102 |
+
tokenizer_type[3:], additional_special_tokens=extra_tokens,
|
103 |
+
token=environ.get("HF_ACCESS_TOKEN"),
|
104 |
+
cache_dir=cache_dir,
|
105 |
+
)
|
106 |
+
if ("qwen2" in tokenizer_type.lower()) or ("olmo" in tokenizer_type.lower()):
|
107 |
+
# These tokenizers do not have a BOS, and instead use EOS as a generic seperator token.
|
108 |
+
# In this case we will use EOS as BOS
|
109 |
+
assert tokenizer.bos_token_id is None
|
110 |
+
bos_token_id = tokenizer.eos_token_id
|
111 |
+
|
112 |
+
if pad_tokenizer_to is not None:
|
113 |
+
for ix, tok in enumerate(EXTRA_TOKENS):
|
114 |
+
ids = tokenizer.encode(tok, add_special_tokens=False)
|
115 |
+
assert ids == [pad_tokenizer_to + ix]
|
116 |
+
|
117 |
+
tok = vocab.HfTokenizerWrapper(tokenizer, bos_token_id=bos_token_id, adds_space=adds_space)
|
118 |
+
elif tokenizer_type.startswith("olmo-"):
|
119 |
+
from olmo.tokenizer import Tokenizer
|
120 |
+
assert Path(tokenizer_type[5:]).is_file()
|
121 |
+
tokenizer = Tokenizer.from_file(
|
122 |
+
tokenizer_type[5:],
|
123 |
+
eos_token_id=olmo_eos_token_id,
|
124 |
+
pad_token_id=-1,
|
125 |
+
)
|
126 |
+
tok = vocab.OLMoTokenizerWrapper(tokenizer, bos_token_id=olmo_bos_token_id, adds_space=adds_space)
|
127 |
+
else:
|
128 |
+
raise NotImplementedError(tokenizer_type)
|
129 |
+
cache[cache_key] = tok
|
130 |
+
return tok
|
131 |
+
|
132 |
+
|
133 |
+
def get_special_token_ids(tokenizer):
|
134 |
+
if isinstance(tokenizer, (vocab.HfTokenizerWrapper, vocab.OLMoTokenizerWrapper)):
|
135 |
+
ids = tokenizer.encode("".join(EXTRA_TOKENS))
|
136 |
+
if len(ids) == len(EXTRA_TOKENS) + 1:
|
137 |
+
ids = ids[1:]
|
138 |
+
elif ("gemma_tokenizer" in tokenizer._sentencepiece_model_file or
|
139 |
+
"yi_tokenizer" in tokenizer._sentencepiece_model_file
|
140 |
+
):
|
141 |
+
# Not sure why ATM, but the LLaMa tokenizer will add an extra space token
|
142 |
+
# if this string starts with a space, while the gemma one needs the leading space
|
143 |
+
ids = tokenizer.encode(" " + " ".join(EXTRA_TOKENS))
|
144 |
+
else:
|
145 |
+
ids = tokenizer.encode(" ".join(EXTRA_TOKENS))
|
146 |
+
|
147 |
+
assert len(ids) == len(EXTRA_TOKENS)
|
148 |
+
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
|
149 |
+
|
150 |
+
|
151 |
+
def _append_to_innermost_axis(
|
152 |
+
tensor: tf.Tensor, scalar: tf.Tensor,
|
153 |
+
) -> tf.Tensor:
|
154 |
+
"""Appends `scalar` to each slice in the innermost axis of `tensor`.
|
155 |
+
|
156 |
+
>>> _append_to_innermost_axis([1, 2, 3], -1)
|
157 |
+
[1, 2, 3, -1]
|
158 |
+
>>> _append_to_innermost_axis([[1, 2], [3, 4]], -1)
|
159 |
+
[[1, 2, -1], [3, 4, -1]]
|
160 |
+
>>> _append_to_innermost_axis(tf.ragged.constant([[1, 2], [3]]), -1)
|
161 |
+
[[1, 2, -1], [3, -1]]
|
162 |
+
|
163 |
+
Args:
|
164 |
+
tensor: The tensor that should have a value appended.
|
165 |
+
scalar: The value to append.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
A copy of `tensor` with `scalar` appended to each slice along
|
169 |
+
the innermost axis.
|
170 |
+
"""
|
171 |
+
if isinstance(tensor, tf.RaggedTensor):
|
172 |
+
if tensor.shape.rank > 2:
|
173 |
+
return tensor.with_values(
|
174 |
+
_append_to_innermost_axis(tensor.values, scalar)
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
return tf.concat([tensor, tf.fill([tensor.nrows(), 1], scalar)], axis=1)
|
178 |
+
else:
|
179 |
+
ndims = tf.rank(tensor)
|
180 |
+
paddings = tf.concat(
|
181 |
+
[tf.zeros((ndims - 1, 2), dtype=tf.int32), tf.constant([[0, 1]])],
|
182 |
+
axis=0,
|
183 |
+
)
|
184 |
+
return tf.pad(tensor, paddings=paddings, constant_values=scalar)
|
185 |
+
|
186 |
+
|
187 |
+
def _shift_right_by_one(tensor: tf.Tensor, bos_id: int = 0) -> tf.Tensor:
|
188 |
+
"""Shift the input tensor to the right by one position without wrapping."""
|
189 |
+
|
190 |
+
if not (tensor.dtype.is_integer or tensor.dtype.is_floating):
|
191 |
+
raise ValueError(f"Only numeric types are supported. Got: {tensor.dtype}")
|
192 |
+
# tf.roll wraps around the axis.
|
193 |
+
rolled = tf.roll(tensor, shift=1, axis=0)
|
194 |
+
|
195 |
+
# Zero out the first position by multiplying with [0, 1, 1, ..., 1].
|
196 |
+
depth = tf.shape(tensor)[0]
|
197 |
+
mask = tf.one_hot(0, depth=depth, on_value=0, off_value=1, dtype=tensor.dtype)
|
198 |
+
|
199 |
+
# Expand dims of mask to broadcast to rolled.
|
200 |
+
dim_expansion = [slice(None, None)] + [None] * (len(rolled.shape) - 1)
|
201 |
+
mask = mask[dim_expansion]
|
202 |
+
return rolled * mask + (1 - mask) * bos_id
|
203 |
+
|
204 |
+
|
205 |
+
def make_autoregressive_inputs(
|
206 |
+
targets: tf.Tensor,
|
207 |
+
sequence_id: tf.Tensor = None,
|
208 |
+
output_dtype: Optional[tf.dtypes.DType] = None,
|
209 |
+
bos_id: int = 0,
|
210 |
+
) -> tf.Tensor:
|
211 |
+
"""Generate inputs for an autoregressive model, by shifting the targets.
|
212 |
+
|
213 |
+
Modified from mesh_tensorflow.transformer.transformer.autoregressive_inputs.
|
214 |
+
|
215 |
+
For the first element of each sequence, the returned input id is 0.
|
216 |
+
|
217 |
+
For a "packed" dataset, also pass the sequence_id tensor, which aligns
|
218 |
+
with the targets tensor and contains different values for different
|
219 |
+
concatenated examples.
|
220 |
+
|
221 |
+
Example for a packed dataset:
|
222 |
+
|
223 |
+
```
|
224 |
+
targets = [3, 8, 2, 9, 2, 5, 4, 2, -1, -1]
|
225 |
+
sequence_id = [1, 1, 1, 2, 2, 3, 3, 3, 0, 0]
|
226 |
+
inputs = [1, 3, 8, 1, 9, 1, 5, 4, -1, -1]
|
227 |
+
| | |
|
228 |
+
These positions are set to 0 if sequence_id is not
|
229 |
+
None.
|
230 |
+
```
|
231 |
+
|
232 |
+
Args:
|
233 |
+
targets: a tf.int32 tensor with shape [length].
|
234 |
+
sequence_id: an optional tensor with the same shape as targets.
|
235 |
+
output_dtype: an optional output data type.
|
236 |
+
bos_id: bos id.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
a tensor with dtype tf.int32 and the same shape as targets.
|
240 |
+
"""
|
241 |
+
output_dtype = output_dtype or targets.dtype
|
242 |
+
if sequence_id is not None and not sequence_id.dtype.is_integer:
|
243 |
+
raise ValueError(
|
244 |
+
"The sequence_id should be integer-valued tensors for a packed dataset."
|
245 |
+
)
|
246 |
+
if sequence_id is not None and len(targets.shape) > 1:
|
247 |
+
raise ValueError(
|
248 |
+
"Only 1-D sequences are supported with packing. Got a "
|
249 |
+
f"packed {len(targets.shape)}-D sequence."
|
250 |
+
)
|
251 |
+
|
252 |
+
inputs = _shift_right_by_one(targets, bos_id)
|
253 |
+
if inputs.dtype != output_dtype:
|
254 |
+
inputs = tf.cast(inputs, output_dtype)
|
255 |
+
|
256 |
+
# We should have a 0 at the beginning of each sequence rather than the
|
257 |
+
# shifted EOS (e.g. 1) from the previous sequence.
|
258 |
+
if sequence_id is not None:
|
259 |
+
not_first_in_sequence = tf.equal(
|
260 |
+
sequence_id, _shift_right_by_one(sequence_id)
|
261 |
+
)
|
262 |
+
not_first_in_sequence = tf.cast(not_first_in_sequence, output_dtype)
|
263 |
+
first_ids = tf.cast((1 - not_first_in_sequence) * bos_id, output_dtype)
|
264 |
+
inputs = inputs * not_first_in_sequence + first_ids
|
265 |
+
return inputs
|
266 |
+
|
267 |
+
|
268 |
+
@tf.function
|
269 |
+
def sum_except_first_axis(tensor):
|
270 |
+
# Compute the sum along all axes except the first
|
271 |
+
axes_to_sum = tuple(range(1, len(tensor.shape)))
|
272 |
+
return tf.reduce_sum(tensor, axis=axes_to_sum)
|
273 |
+
|
274 |
+
|
275 |
+
@seqio.map_over_dataset()
|
276 |
+
def add_segment_ids(ex):
|
277 |
+
ex["subsegment_ids"] = tf.zeros_like(ex["target_tokens"], dtype=tf.int32)
|
278 |
+
return ex
|
279 |
+
|
280 |
+
|
281 |
+
def trim_and_pad_dataset(
|
282 |
+
dataset: tf.data.Dataset, feature_lengths: Mapping[str, int]
|
283 |
+
) -> tf.data.Dataset:
|
284 |
+
"""Trim and pad first dimension of features to `feature_lengths`.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
dataset: tf.data.Dataset, the dataset to trim/pad examples in.
|
288 |
+
feature_lengths: map from feature key to final length. Other features will
|
289 |
+
be returned unchanged.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Trimmed/padded tf.data.Dataset.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def _trim_and_pad(k: str, t: tf.Tensor) -> tf.Tensor:
|
296 |
+
"""Trim/pad to the first axis of `t` to be of size `length`."""
|
297 |
+
if k not in feature_lengths:
|
298 |
+
return t
|
299 |
+
if isinstance(t, tf.RaggedTensor):
|
300 |
+
t = t.to_tensor()
|
301 |
+
|
302 |
+
constant_values = -1
|
303 |
+
length_k = feature_lengths[k]
|
304 |
+
if isinstance(length_k, int):
|
305 |
+
t = t[:length_k]
|
306 |
+
pad_amt = length_k - tf.shape(t)[0]
|
307 |
+
padded_t = tf.pad(t, [(0, pad_amt)] + [(0, 0)] * (len(t.shape) - 1), constant_values=constant_values)
|
308 |
+
padded_t.set_shape([length_k] + t.shape.as_list()[1:])
|
309 |
+
return padded_t
|
310 |
+
|
311 |
+
slices = tuple((slice(0, limit) for limit in length_k))
|
312 |
+
t = t[slices]
|
313 |
+
pad_amt = tf.pad((length_k - tf.shape(t))[..., None], ((0, 0), (1, 0)), constant_values=constant_values)
|
314 |
+
padded_t = tf.pad(t, pad_amt, constant_values=constant_values)
|
315 |
+
padded_t.set_shape(length_k)
|
316 |
+
return padded_t
|
317 |
+
|
318 |
+
return dataset.map(
|
319 |
+
lambda x: {k: _trim_and_pad(k, t) for k, t in x.items()},
|
320 |
+
num_parallel_calls=tf.data.experimental.AUTOTUNE,
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
def get_3d_subsegments(segmented_suffix):
|
325 |
+
q_lens, text_lens = segmented_suffix.nested_row_lengths()
|
326 |
+
text_segments = tf.range(0, tf.shape(text_lens)[0], dtype=tf.int32)
|
327 |
+
question_repeat = tf.reshape(tf.stack([tf.ones_like(q_lens), q_lens-1], 1), [-1])
|
328 |
+
question_offset = tf.range(1, tf.shape(q_lens)[0]+1, dtype=tf.int32)*200
|
329 |
+
question_offset = tf.reshape(tf.stack([question_offset, question_offset-100], 1), [-1])
|
330 |
+
text_segments = text_segments + tf.repeat(question_offset, question_repeat)
|
331 |
+
segment_ids = tf.cast(tf.repeat(text_segments, text_lens), tf.int32)
|
332 |
+
return segment_ids
|
333 |
+
|
334 |
+
|
335 |
+
def assert_not_truncated(ds, keys, max_val):
|
336 |
+
def _check(ex):
|
337 |
+
for k in keys:
|
338 |
+
tf.assert_less(tf.shape(ex[k])[0], max_val+1,
|
339 |
+
message=f"Field {k} was unexpectedly truncated max_len={max_val}")
|
340 |
+
return ex
|
341 |
+
return ds.map(_check)
|
342 |
+
|
343 |
+
|
344 |
+
def apply_with_random_selector(x, func, num_cases):
|
345 |
+
"""Computes func(x, sel), with sel sampled from [0...num_cases-1].
|
346 |
+
Args:
|
347 |
+
x: input Tensor.
|
348 |
+
func: Python function to apply.
|
349 |
+
num_cases: Python int32, number of cases to sample sel from.
|
350 |
+
Returns:
|
351 |
+
The result of func(x, sel), where func receives the value of the
|
352 |
+
selector as a python integer, but sel is sampled dynamically.
|
353 |
+
"""
|
354 |
+
sel = tf.random.uniform([], maxval=num_cases, dtype=tf.int32)
|
355 |
+
# Pass the real x only to one of the func calls.
|
356 |
+
return control_flow_ops.merge([
|
357 |
+
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
|
358 |
+
for case in range(num_cases)])[0]
|
359 |
+
|
360 |
+
|
361 |
+
def denormalize_boxes(boxes, image_shape):
|
362 |
+
"""Converts boxes normalized by [height, width] to pixel coordinates.
|
363 |
+
Args:
|
364 |
+
boxes: a tensor whose last dimension is 4 representing the coordinates of
|
365 |
+
boxes in ymin, xmin, ymax, xmax order.
|
366 |
+
image_shape: a list of two integers, a two-element vector or a tensor such
|
367 |
+
that all but the last dimensions are `broadcastable` to `boxes`. The last
|
368 |
+
dimension is 2, which represents [height, width].
|
369 |
+
Returns:
|
370 |
+
denormalized_boxes: a tensor whose shape is the same as `boxes` representing
|
371 |
+
the denormalized boxes.
|
372 |
+
Raises:
|
373 |
+
ValueError: If the last dimension of boxes is not 4.
|
374 |
+
"""
|
375 |
+
with tf.name_scope('denormalize_boxes'):
|
376 |
+
if isinstance(image_shape, list) or isinstance(image_shape, tuple):
|
377 |
+
height, width = image_shape
|
378 |
+
height = tf.cast(height, dtype=boxes.dtype)
|
379 |
+
width = tf.cast(width, dtype=boxes.dtype)
|
380 |
+
else:
|
381 |
+
image_shape = tf.cast(image_shape, dtype=boxes.dtype)
|
382 |
+
height, width = tf.split(image_shape, 2, axis=-1)
|
383 |
+
|
384 |
+
ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=-1)
|
385 |
+
ymin = ymin * height
|
386 |
+
xmin = xmin * width
|
387 |
+
ymax = ymax * height
|
388 |
+
xmax = xmax * width
|
389 |
+
|
390 |
+
denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
|
391 |
+
return denormalized_boxes
|
392 |
+
|
393 |
+
def pad_to_bounding_box(image, offset_height, offset_width, target_height,
|
394 |
+
target_width, value=0):
|
395 |
+
|
396 |
+
return pad_to_bounding_box_internal(
|
397 |
+
image,
|
398 |
+
offset_height,
|
399 |
+
offset_width,
|
400 |
+
target_height,
|
401 |
+
target_width,
|
402 |
+
check_dims=True,
|
403 |
+
value=value)
|
404 |
+
|
405 |
+
def pad_to_bounding_box_internal(image, offset_height, offset_width,
|
406 |
+
target_height, target_width, check_dims, value):
|
407 |
+
|
408 |
+
with ops.name_scope(None, 'pad_to_bounding_box_with_one_internal', [image]):
|
409 |
+
image = ops.convert_to_tensor(image, name='image')
|
410 |
+
|
411 |
+
is_batch = True
|
412 |
+
image_shape = image.get_shape()
|
413 |
+
if image_shape.ndims == 3:
|
414 |
+
is_batch = False
|
415 |
+
image = array_ops.expand_dims(image, 0)
|
416 |
+
elif image_shape.ndims is None:
|
417 |
+
is_batch = False
|
418 |
+
image = array_ops.expand_dims(image, 0)
|
419 |
+
image.set_shape([None] * 4)
|
420 |
+
elif image_shape.ndims != 4:
|
421 |
+
raise ValueError(
|
422 |
+
'\'image\' (shape %s) must have either 3 or 4 dimensions.' %
|
423 |
+
image_shape)
|
424 |
+
|
425 |
+
batch, height, width, depth = _ImageDimensions(image, rank=4)
|
426 |
+
|
427 |
+
after_padding_width = target_width - offset_width - width
|
428 |
+
|
429 |
+
after_padding_height = target_height - offset_height - height
|
430 |
+
|
431 |
+
if check_dims:
|
432 |
+
assert_ops = _CheckAtLeast3DImage(image, require_static=False)
|
433 |
+
assert_ops += _assert(offset_height >= 0, ValueError,
|
434 |
+
'offset_height must be >= 0')
|
435 |
+
assert_ops += _assert(offset_width >= 0, ValueError,
|
436 |
+
'offset_width must be >= 0')
|
437 |
+
assert_ops += _assert(after_padding_width >= 0, ValueError,
|
438 |
+
'width must be <= target - offset')
|
439 |
+
assert_ops += _assert(after_padding_height >= 0, ValueError,
|
440 |
+
'height must be <= target - offset')
|
441 |
+
image = control_flow_ops.with_dependencies(assert_ops, image)
|
442 |
+
|
443 |
+
# Do not pad on the depth dimensions.
|
444 |
+
paddings = array_ops.reshape(
|
445 |
+
tf.stack([
|
446 |
+
0, 0, offset_height, after_padding_height, offset_width,
|
447 |
+
after_padding_width, 0, 0
|
448 |
+
]), [4, 2])
|
449 |
+
padded = array_ops.pad(image, paddings, constant_values=value)
|
450 |
+
|
451 |
+
padded_shape = [
|
452 |
+
None if _is_tensor(i) else i
|
453 |
+
for i in [batch, target_height, target_width, depth]
|
454 |
+
]
|
455 |
+
padded.set_shape(padded_shape)
|
456 |
+
|
457 |
+
if not is_batch:
|
458 |
+
padded = array_ops.squeeze(padded, axis=[0])
|
459 |
+
|
460 |
+
return padded
|
461 |
+
|
462 |
+
def resize_and_crop_boxes(boxes, image_scale, output_size, offset, paddings):
|
463 |
+
"""Resizes boxes to output size with scale and offset.
|
464 |
+
Args:
|
465 |
+
boxes: `Tensor` of shape [N, 4] representing ground truth boxes.
|
466 |
+
image_scale: 2D float `Tensor` representing scale factors that apply to
|
467 |
+
[height, width] of input image.
|
468 |
+
output_size: 2D `Tensor` or `int` representing [height, width] of target
|
469 |
+
output image size.
|
470 |
+
offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
|
471 |
+
boxes.
|
472 |
+
paddings: 2D `Tensor` representing top/left paddings.
|
473 |
+
Returns:
|
474 |
+
boxes: `Tensor` of shape [N, 4] representing the scaled boxes.
|
475 |
+
"""
|
476 |
+
# Adjusts box coordinates based on image_scale, offset and paddings.
|
477 |
+
boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
|
478 |
+
boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
|
479 |
+
boxes += tf.tile(tf.expand_dims(paddings, axis=0), [1, 2])
|
480 |
+
# Clips the boxes.
|
481 |
+
boxes = clip_boxes(boxes, output_size)
|
482 |
+
return boxes
|
483 |
+
|
484 |
+
def clip_boxes(boxes, image_shape):
|
485 |
+
"""Clips boxes to image boundaries.
|
486 |
+
Args:
|
487 |
+
boxes: a tensor whose last dimension is 4 representing the coordinates of
|
488 |
+
boxes in ymin, xmin, ymax, xmax order.
|
489 |
+
image_shape: a list of two integers, a two-element vector or a tensor such
|
490 |
+
that all but the last dimensions are `broadcastable` to `boxes`. The last
|
491 |
+
dimension is 2, which represents [height, width].
|
492 |
+
Returns:
|
493 |
+
clipped_boxes: a tensor whose shape is the same as `boxes` representing the
|
494 |
+
clipped boxes.
|
495 |
+
Raises:
|
496 |
+
ValueError: If the last dimension of boxes is not 4.
|
497 |
+
"""
|
498 |
+
if boxes.shape[-1] != 4:
|
499 |
+
raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
|
500 |
+
boxes.shape[-1]))
|
501 |
+
|
502 |
+
with tf.name_scope('clip_boxes'):
|
503 |
+
if isinstance(image_shape, list) or isinstance(image_shape, tuple):
|
504 |
+
height, width = image_shape
|
505 |
+
max_length = [height, width, height, width]
|
506 |
+
else:
|
507 |
+
image_shape = tf.cast(image_shape, dtype=boxes.dtype)
|
508 |
+
height, width = tf.unstack(image_shape, axis=-1)
|
509 |
+
max_length = tf.stack(
|
510 |
+
[height, width, height, width], axis=-1)
|
511 |
+
|
512 |
+
clipped_boxes = tf.math.maximum(tf.math.minimum(boxes, max_length), 0.0)
|
513 |
+
return clipped_boxes
|
514 |
+
|
515 |
+
|
516 |
+
def get_non_empty_box_indices(boxes):
|
517 |
+
"""Get indices for non-empty boxes."""
|
518 |
+
# Selects indices if box height or width is 0.
|
519 |
+
height = boxes[:, 2] - boxes[:, 0]
|
520 |
+
width = boxes[:, 3] - boxes[:, 1]
|
521 |
+
indices = tf.where(
|
522 |
+
tf.logical_and(tf.greater(height, 0), tf.greater(width, 0)))
|
523 |
+
return indices[:, 0]
|
524 |
+
|
525 |
+
|
526 |
+
def resize_and_pad(image, desired_output_size, masks=None, boxes=None, labels=None,
|
527 |
+
random_scale_min=0.1, random_scale_max=2.0, do_random_scale=False,
|
528 |
+
shrink_both_sides=True, boxes1=None, filter_box=True,
|
529 |
+
desired_target_size=None, random_scale_ratio=0.0,
|
530 |
+
resize_method=tf.image.ResizeMethod.BILINEAR, return_outputs=True,
|
531 |
+
pad_value=0, normalize=True):
|
532 |
+
desired_height, desired_width = desired_output_size
|
533 |
+
desired_height_f = tf.cast(desired_height, dtype=tf.float32)
|
534 |
+
desired_width_f = tf.cast(desired_width, dtype=tf.float32)
|
535 |
+
|
536 |
+
height = tf.cast(tf.shape(image)[0], tf.float32)
|
537 |
+
width = tf.cast(tf.shape(image)[1], tf.float32)
|
538 |
+
|
539 |
+
if boxes is not None:
|
540 |
+
# Converts boxes from normalized coordinates to pixel coordinates.
|
541 |
+
# Now the coordinates of boxes are w.r.t. the original image.
|
542 |
+
boxes = denormalize_boxes(boxes, [height, width])
|
543 |
+
|
544 |
+
if boxes1 is not None:
|
545 |
+
boxes1 = denormalize_boxes(boxes1, [height, width])
|
546 |
+
|
547 |
+
if do_random_scale:
|
548 |
+
random_scale_factor = tf.random.uniform([], random_scale_min, random_scale_max)
|
549 |
+
if not shrink_both_sides:
|
550 |
+
# Max random is where scale * W > W_desired
|
551 |
+
# scale * H > H_desired
|
552 |
+
rsf_max = tf.maximum(desired_width_f / width, desired_height_f / height)
|
553 |
+
random_scale_factor = tf.minimum(rsf_max, random_scale_factor)
|
554 |
+
|
555 |
+
scaled_y = tf.cast(random_scale_factor * desired_height_f, tf.int32)
|
556 |
+
scaled_x = tf.cast(random_scale_factor * desired_width_f, tf.int32)
|
557 |
+
|
558 |
+
# Recompute the accurate scale_factor using rounded scaled image size.
|
559 |
+
image_scale_y = tf.cast(scaled_y, tf.float32) / height
|
560 |
+
image_scale_x = tf.cast(scaled_x, tf.float32) / width
|
561 |
+
|
562 |
+
image_scale = tf.cond(tf.less(
|
563 |
+
tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
|
564 |
+
tf.cast(random_scale_ratio, tf.float32)),
|
565 |
+
lambda: tf.maximum(image_scale_x, image_scale_y),
|
566 |
+
lambda: tf.minimum(image_scale_x, image_scale_y))
|
567 |
+
|
568 |
+
# image_scale = tf.minimum(image_scale_x, image_scale_y)
|
569 |
+
|
570 |
+
# Conceptual captions has some REALLY WIDE images I believe
|
571 |
+
# this ensures that we won't scale any side lower than to 64
|
572 |
+
image_scale = tf.maximum(image_scale, 64.0 / tf.minimum(height, width))
|
573 |
+
|
574 |
+
# Select non-zero random offset (x, y) if scaled image is larger than
|
575 |
+
# self._output_size.
|
576 |
+
scaled_height = tf.cast(height * image_scale, tf.int32)
|
577 |
+
scaled_width = tf.cast(width * image_scale, tf.int32)
|
578 |
+
offset_y = tf.cast(scaled_height - desired_height, tf.float32)
|
579 |
+
offset_x = tf.cast(scaled_width - desired_width, tf.float32)
|
580 |
+
offset_y = tf.maximum(0.0, offset_y) * tf.random.uniform([], 0, 1)
|
581 |
+
offset_x = tf.maximum(0.0, offset_x) * tf.random.uniform([], 0, 1)
|
582 |
+
offset_y = tf.cast(offset_y, tf.int32)
|
583 |
+
offset_x = tf.cast(offset_x, tf.int32)
|
584 |
+
else:
|
585 |
+
image_scale_y = desired_height_f / height
|
586 |
+
image_scale_x = desired_width_f / width
|
587 |
+
image_scale = tf.minimum(image_scale_x, image_scale_y)
|
588 |
+
scaled_height = tf.cast(height * image_scale, tf.int32)
|
589 |
+
scaled_width = tf.cast(width * image_scale, tf.int32)
|
590 |
+
offset_y = tf.constant(0)
|
591 |
+
offset_x = tf.constant(0)
|
592 |
+
|
593 |
+
# Now resize and crop
|
594 |
+
if resize_method == 'random' and do_random_scale:
|
595 |
+
resize_methods = sorted([k for k in tf.image.ResizeMethod.__dict__.keys() if k.isupper()])
|
596 |
+
image = apply_with_random_selector(
|
597 |
+
image,
|
598 |
+
lambda x, method_idx: tf.image.resize(x, [scaled_height, scaled_width],
|
599 |
+
tf.image.ResizeMethod.__dict__[resize_methods[method_idx]],
|
600 |
+
antialias=True),
|
601 |
+
num_cases=len(resize_methods))
|
602 |
+
|
603 |
+
elif resize_method != 'random':
|
604 |
+
image = tf.image.resize(image, [scaled_height, scaled_width], method=resize_method, antialias=True)
|
605 |
+
else:
|
606 |
+
image = tf.image.resize(image, [scaled_height, scaled_width],
|
607 |
+
method=tf.image.ResizeMethod.BILINEAR, antialias=True)
|
608 |
+
|
609 |
+
image = tf.clip_by_value(image, 0.0, 1.0)
|
610 |
+
|
611 |
+
# H x W x C
|
612 |
+
image = image[offset_y:offset_y + desired_height, offset_x:offset_x + desired_width, :]
|
613 |
+
|
614 |
+
H = tf.shape(image)[0]
|
615 |
+
W = tf.shape(image)[1]
|
616 |
+
|
617 |
+
top_pad = (desired_height - H) // 2
|
618 |
+
left_pad = (desired_width - W) // 2
|
619 |
+
|
620 |
+
image_mask = pad_to_bounding_box(
|
621 |
+
tf.ones_like(image, dtype=tf.bool), top_pad, left_pad, desired_height, desired_width)[:,:,0]
|
622 |
+
|
623 |
+
image = pad_to_bounding_box(image, top_pad, left_pad, desired_height, desired_width, value=pad_value)
|
624 |
+
|
625 |
+
if isinstance(desired_height, int) and isinstance(desired_width, int):
|
626 |
+
image.set_shape([desired_height, desired_width, 3])
|
627 |
+
|
628 |
+
if masks is not None and tf.size(masks) != 0:
|
629 |
+
masks = tf.image.resize(masks, [scaled_height, scaled_width],
|
630 |
+
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
631 |
+
|
632 |
+
if len(masks.shape) == 3:
|
633 |
+
masks = masks[offset_y:offset_y + desired_height, offset_x:offset_x + desired_width]
|
634 |
+
else:
|
635 |
+
masks = masks[:, offset_y:offset_y + desired_height, offset_x:offset_x + desired_width]
|
636 |
+
|
637 |
+
masks = pad_to_bounding_box(masks, top_pad, left_pad, desired_height, desired_width)
|
638 |
+
masks = tf.image.resize(masks, desired_target_size,
|
639 |
+
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
640 |
+
|
641 |
+
indices = None
|
642 |
+
if boxes is not None:
|
643 |
+
# assert ValueError("the box need to be shift which is not tested yet.")
|
644 |
+
boxes = resize_and_crop_boxes(
|
645 |
+
boxes,
|
646 |
+
tf.stack([image_scale, image_scale]),
|
647 |
+
[desired_height, desired_width],
|
648 |
+
tf.cast(tf.stack([offset_y, offset_x]), dtype=tf.float32),
|
649 |
+
tf.cast(tf.stack([top_pad, left_pad]), dtype=tf.float32))
|
650 |
+
|
651 |
+
if filter_box:
|
652 |
+
indices = get_non_empty_box_indices(boxes)
|
653 |
+
else:
|
654 |
+
indices = tf.range(tf.shape(boxes)[0])
|
655 |
+
boxes = tf.gather(boxes, indices)
|
656 |
+
|
657 |
+
if labels is not None:
|
658 |
+
labels = tf.gather(labels, indices)
|
659 |
+
|
660 |
+
if boxes1 is not None:
|
661 |
+
boxes1 = resize_and_crop_boxes(
|
662 |
+
boxes1,
|
663 |
+
tf.stack([image_scale, image_scale]),
|
664 |
+
[desired_height, desired_width],
|
665 |
+
tf.cast(tf.stack([offset_y, offset_x]), dtype=tf.float32),
|
666 |
+
tf.cast(tf.stack([top_pad, left_pad]), dtype=tf.float32))
|
667 |
+
|
668 |
+
image_info = tf.stack([
|
669 |
+
tf.cast(top_pad, tf.float32),
|
670 |
+
tf.cast(left_pad, tf.float32),
|
671 |
+
1.0 / image_scale,
|
672 |
+
height,
|
673 |
+
width,
|
674 |
+
tf.cast(offset_y, dtype=tf.float32) / height,
|
675 |
+
tf.cast(offset_x, dtype=tf.float32) / width,
|
676 |
+
tf.cast(offset_y, dtype=tf.float32),
|
677 |
+
tf.cast(offset_x, dtype=tf.float32),
|
678 |
+
tf.cast(scaled_height, dtype=tf.float32),
|
679 |
+
tf.cast(scaled_width, dtype=tf.float32),
|
680 |
+
])
|
681 |
+
|
682 |
+
if boxes1 is not None:
|
683 |
+
outputs = (image_info, masks, boxes, labels, indices, boxes1)
|
684 |
+
else:
|
685 |
+
outputs = (image_info, masks, boxes, labels, indices)
|
686 |
+
|
687 |
+
if normalize:
|
688 |
+
image = normalize_image(image)
|
689 |
+
|
690 |
+
if return_outputs:
|
691 |
+
return image, image_mask, outputs
|
692 |
+
else:
|
693 |
+
return image, image_mask
|
694 |
+
|
695 |
+
|
696 |
+
def _remove_bars_from_frames(frames, black_bar=True, threshold=32, max_perc_to_trim=0.3):
|
697 |
+
"""
|
698 |
+
:param frames: [num_frames, height, width, 3]
|
699 |
+
:param blackbar_threshold: Pixels must be this intense for us to not trim
|
700 |
+
:param max_perc_to_prim: Will trim x% by default of the image at most in each dimension
|
701 |
+
:return:
|
702 |
+
"""
|
703 |
+
# Detect black bars####################
|
704 |
+
frames_shape = tf.shape(frames)
|
705 |
+
h, w = frames_shape[1], frames_shape[2]
|
706 |
+
if black_bar:
|
707 |
+
has_content = tf.reduce_max(frames, axis=(0, -1)) >= threshold
|
708 |
+
else:
|
709 |
+
has_content = tf.reduce_min(frames, axis=(0, -1)) <= threshold
|
710 |
+
|
711 |
+
y_frames = tf.cast(tf.reshape(tf.where(tf.reduce_any(has_content, axis=1)), [-1]), tf.int32)
|
712 |
+
nhbars = tf.shape(y_frames)[0]
|
713 |
+
y_frames = tf.cond(nhbars > 0, lambda: y_frames, lambda: tf.expand_dims(tf.cast(h // 2, tf.int32), axis=0))
|
714 |
+
|
715 |
+
y1 = tf.minimum(y_frames[0], tf.cast(tf.cast(h, tf.float32) * max_perc_to_trim, tf.int32))
|
716 |
+
y2 = tf.maximum(y_frames[-1] + 1, tf.cast(tf.cast(h, tf.float32) * (1 - max_perc_to_trim), tf.int32))
|
717 |
+
|
718 |
+
x_frames = tf.cast(tf.reshape(tf.where(tf.reduce_any(has_content, axis=0)), [-1]), tf.int32)
|
719 |
+
nvbars = tf.shape(x_frames)[0]
|
720 |
+
x_frames = tf.cond(nvbars > 0, lambda: x_frames, lambda: tf.expand_dims(tf.cast(w // 2, tf.int32), axis=0))
|
721 |
+
|
722 |
+
x1 = tf.minimum(x_frames[0], tf.cast(tf.cast(w, tf.float32) * max_perc_to_trim, tf.int32))
|
723 |
+
x2 = tf.maximum(x_frames[-1] + 1, tf.cast(tf.cast(w, tf.float32) * (1 - max_perc_to_trim), tf.int32))
|
724 |
+
|
725 |
+
frames = frames[:, y1:y2, x1:x2]
|
726 |
+
return frames
|
727 |
+
|
728 |
+
def convert_video_dtype(video,dtype):
|
729 |
+
"""
|
730 |
+
Converts tensor to dtype and scales the values.
|
731 |
+
Video equivalent of tf.convert_image_dtype: https://www.tensorflow.org/api_docs/python/tf/image/convert_image_dtype
|
732 |
+
"""
|
733 |
+
return tf.map_fn(
|
734 |
+
fn=functools.partial(
|
735 |
+
tf.image.convert_image_dtype,
|
736 |
+
dtype=dtype),
|
737 |
+
elems=video,
|
738 |
+
fn_output_signature=dtype)
|
739 |
+
|
740 |
+
|
741 |
+
def stateless_shuffle(x: tf.Tensor, seed):
|
742 |
+
if hasattr(tf.random.experimental, 'stateless_shuffle'):
|
743 |
+
return tf.random.experimental.stateless_shuffle(x, seed=seed)
|
744 |
+
else:
|
745 |
+
vals = tf.random.stateless_uniform(tf.shape(x)[:1], seed)
|
746 |
+
ixs = tf.argsort(vals)
|
747 |
+
return tf.gather(x, ixs)
|
748 |
+
|
749 |
+
|
750 |
+
def stateless_permutation(n: int, seed):
|
751 |
+
if hasattr(tf.random.experimental, 'stateless_shuffle'):
|
752 |
+
ix = tf.range(0, n, dtype=tf.int32)
|
753 |
+
return tf.random.experimental.stateless_shuffle(ix, seed=seed)
|
754 |
+
else:
|
755 |
+
vals = tf.random.stateless_uniform(n, seed)
|
756 |
+
return tf.argsort(vals)
|
757 |
+
|
758 |
+
|
759 |
+
@seqio.map_over_dataset
|
760 |
+
def _strip_metadata(example):
|
761 |
+
return pop_metadata(example)[0]
|
762 |
+
|
763 |
+
|
764 |
+
def sample_patches(mask, n_patches, stateless=False, seeds=None):
|
765 |
+
input_sample_valid = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)
|
766 |
+
input_sample_masked = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask == 0)
|
767 |
+
if stateless:
|
768 |
+
encoder_pos_ids = tf.concat([
|
769 |
+
stateless_shuffle(input_sample_valid, seeds[0]),
|
770 |
+
stateless_shuffle(input_sample_masked, seeds[1])], axis=0)[:n_patches]
|
771 |
+
else:
|
772 |
+
encoder_pos_ids = tf.concat([
|
773 |
+
tf.random.shuffle(input_sample_valid),
|
774 |
+
tf.random.shuffle(input_sample_masked)], axis=0)[:n_patches]
|
775 |
+
encoder_pos_ids = tf.reshape(encoder_pos_ids, (n_patches,))
|
776 |
+
encoder_pos_ids = tf.cast(encoder_pos_ids, tf.int32)
|
777 |
+
return encoder_pos_ids
|
778 |
+
|
779 |
+
|
780 |
+
@gin.configurable()
|
781 |
+
def normalize_image(image,
|
782 |
+
offset=(0.48145466, 0.4578275, 0.40821073),
|
783 |
+
scale=(0.26862954, 0.26130258, 0.27577711)):
|
784 |
+
"""Normalizes the image to zero mean and unit variance."""
|
785 |
+
offset = tf.constant(offset)
|
786 |
+
offset = tf.expand_dims(offset, axis=0)
|
787 |
+
offset = tf.expand_dims(offset, axis=0)
|
788 |
+
image -= tf.cast(offset, image.dtype)
|
789 |
+
|
790 |
+
scale = tf.constant(scale)
|
791 |
+
scale = tf.expand_dims(scale, axis=0)
|
792 |
+
scale = tf.expand_dims(scale, axis=0)
|
793 |
+
image /= tf.cast(scale, image.dtype)
|
794 |
+
return image
|
795 |
+
|
796 |
+
|
797 |
+
def unnormalize_image(image,
|
798 |
+
offset=(0.48145466, 0.4578275, 0.40821073),
|
799 |
+
scale=(0.26862954, 0.26130258, 0.27577711)):
|
800 |
+
"""Normalizes the image to zero mean and unit variance."""
|
801 |
+
scale = tf.cast(tf.expand_dims(tf.expand_dims(tf.constant(scale), axis=0), axis=0), image.dtype)
|
802 |
+
image *= scale
|
803 |
+
|
804 |
+
offset = tf.cast(tf.expand_dims(tf.expand_dims(tf.constant(offset), axis=0), axis=0), image.dtype)
|
805 |
+
image += offset
|
806 |
+
return image
|
807 |
+
|
808 |
+
|
809 |
+
def flatten_parts(ds: tf.data.Dataset, parts: List[str], add_index=False, dataset_size=None) -> tf.data.Dataset:
|
810 |
+
def _flatten(ex):
|
811 |
+
flat_key = {k: ex[k] for k in parts}
|
812 |
+
if add_index:
|
813 |
+
flat_key['index'] = tf.range(len(ex[parts[0]]))
|
814 |
+
|
815 |
+
flat_ds = tf.data.Dataset.from_tensor_slices(flat_key)
|
816 |
+
|
817 |
+
def _merge(_flat_ex):
|
818 |
+
for k, v in ex.items():
|
819 |
+
if k not in parts:
|
820 |
+
_flat_ex[k] = v
|
821 |
+
return _flat_ex
|
822 |
+
return flat_ds.map(_merge)
|
823 |
+
|
824 |
+
ds = ds.flat_map(_flatten)
|
825 |
+
if dataset_size is not None:
|
826 |
+
ds = tf.data.experimental.assert_cardinality(dataset_size)(ds)
|
827 |
+
return ds
|
dataset_sizes.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET_SIZES = {
|
2 |
+
("cockatoo_qa_v2", "train"): 194820,
|
3 |
+
("user_qa", "train"): 71172,
|
4 |
+
|
5 |
+
("text_vqa", "train"): 34602,
|
6 |
+
("chart_qa", "train"): 28299,
|
7 |
+
("chart_qa_prompting", "train"): 28299,
|
8 |
+
("chart_qa_weighted", "train"): 28299,
|
9 |
+
("tally_qa", "train"): 132981,
|
10 |
+
("doc_qa", "train"): 39463,
|
11 |
+
("info_qa", "train"): 23946,
|
12 |
+
("okvqa", "train"): 9009,
|
13 |
+
("gqa", "train"): 943000,
|
14 |
+
("gqa_multi", "train"): 72140,
|
15 |
+
("coco_2014_vqa", "train"): 443757, # (82783, 443757)
|
16 |
+
("coco_captioning_karpathy", "train"): 414113, # (82783, 414113)
|
17 |
+
("coco_captioning_karpathy_multi", "train"): 82783,
|
18 |
+
("coco_2014_vqa_multi", "train"): 82783,
|
19 |
+
("science_qa_img", "train"): 6218,
|
20 |
+
("ai2_diagram", "train"): 11389,
|
21 |
+
("a_okvqa_mc", "train"): 17056,
|
22 |
+
("a_okvqa_da", "train"): 17056,
|
23 |
+
("ocr_vqa", "train"): 166043,
|
24 |
+
("st_qa", "train"): 25050,
|
25 |
+
("ocr_qa", "train"): 166043,
|
26 |
+
|
27 |
+
("dv_qa", "train"): 200000,
|
28 |
+
("tabwmp_da", "train"): 23059,
|
29 |
+
("figure_qa", "train"): 100000,
|
30 |
+
("figure_qa_zero_shot", "train"): 100000,
|
31 |
+
("plot_qa", "train"): 157070,
|
32 |
+
('clocks', 'train'): 800269,
|
33 |
+
('clocks', 'validation'): 25600,
|
34 |
+
|
35 |
+
("st_qa", "test"): 4070,
|
36 |
+
('text_vqa', "test"): 5734,
|
37 |
+
('okvqa', "test"): 5046,
|
38 |
+
('chart_qa', "test"): 1250,
|
39 |
+
('doc_qa', "test"): 5188,
|
40 |
+
('info_qa', "test"): 3288,
|
41 |
+
('gqa', "test"): 95336,
|
42 |
+
('coco_captioning_karpathy', "test"): 25010,
|
43 |
+
("science_qa_img", "test"): 2017,
|
44 |
+
("ai2_diagram", "test"): 3088,
|
45 |
+
("a_okvqa_mc_eval", "test"): 6702,
|
46 |
+
("a_okvqa_da_eval", "test"): 6109,
|
47 |
+
|
48 |
+
("ai2_diagram_v2", "train"): 10950,
|
49 |
+
("ai2_diagram_v2", "validation"): 1463,
|
50 |
+
("ai2_diagram_v2", "test"): 3088,
|
51 |
+
("vqa_v2_test", "test2015"): 555187,
|
52 |
+
|
53 |
+
("ai2_diagram_v2_transparent", "train"): 10950,
|
54 |
+
("ai2_diagram_v2_transparent", "validation"): 1463,
|
55 |
+
("ai2_diagram_v2_transparent", "test"): 3088,
|
56 |
+
|
57 |
+
# splits in mix_data include both transparent + opaque boxes
|
58 |
+
("ai2_diagram_v2_mix_transparent", "train"): 15042,
|
59 |
+
("ai2_diagram_v2_mix_transparent", "validation"): 1980,
|
60 |
+
("ai2_diagram_v2_mix_transparent", "test"): 4272,
|
61 |
+
|
62 |
+
# vaia_qa
|
63 |
+
('vaia_qa', 'train'): 477052,
|
64 |
+
('vaia_qa', 'validation'): 1024,
|
65 |
+
|
66 |
+
('vaia_qa_latex_image', 'train'): 477052,
|
67 |
+
('vaia_qa_latex_image', 'validation'): 1024,
|
68 |
+
('vaia_qa_latex_image_only', 'train'): 42605,
|
69 |
+
('vaia_qa_latex_image_only', 'validation'): 1024,
|
70 |
+
('vaia_qa_latex_all_image_only', 'train'): 154266,
|
71 |
+
('vaia_qa_latex_all_image_only', 'validation'): 1024,
|
72 |
+
|
73 |
+
("vaia_qa_latex_image_math_subset_short_answer", 'train'): 198161,
|
74 |
+
("vaia_qa_latex_image_math_subset_short_answer", 'validation'): 419,
|
75 |
+
("vaia_qa_latex_image_math_subset_mc_only_short_answer", "train"): 57568,
|
76 |
+
("vaia_qa_latex_image_math_subset_mc_only_short_answer", "validation"): 118,
|
77 |
+
("vaia_qa_latex_image_math_subset_mc_only_short_answer_first", "train"): 57568,
|
78 |
+
("vaia_qa_latex_image_math_subset_mc_only_short_answer_first", "validation"): 118,
|
79 |
+
|
80 |
+
("vaia_qa_latex_image_all_image_only_short_answer", "train"): 86752,
|
81 |
+
("vaia_qa_latex_image_all_image_only_short_answer", "validation"): 92,
|
82 |
+
("vaia_qa_latex_image_all_image_only_short_answer_first", "train"): 86752,
|
83 |
+
("vaia_qa_latex_image_all_image_only_short_answer_first", "validation"): 92,
|
84 |
+
("vaia_qa_latex_image_math_subset_image_only_short_answer", "train"): 21726,
|
85 |
+
("vaia_qa_latex_image_math_subset_image_only_short_answer", "validation"): 48,
|
86 |
+
|
87 |
+
('vqa_online', 'train'): 62722,
|
88 |
+
('vqa_online', 'validation'): 1024,
|
89 |
+
('vqa_online', 'test'): 1024,
|
90 |
+
|
91 |
+
('vqa_online_gpt_longQ_longA', 'train'): 62722,
|
92 |
+
('vqa_online_gpt_longQ_longA', 'validation'): 1024,
|
93 |
+
('vqa_online_gpt_longQ_longA', 'test'): 1024,
|
94 |
+
|
95 |
+
("tally_qa", "validation"): 38589,
|
96 |
+
('text_vqa', "validation"): 5000,
|
97 |
+
('okvqa', "validation"): 5046,
|
98 |
+
('chart_qa', "validation"): 960*2,
|
99 |
+
('chart_qa_prompting_explanation', "validation"): 960*2,
|
100 |
+
('chart_qa_ex', "validation"): 960*2,
|
101 |
+
('chart_qa_human', "validation"): 960,
|
102 |
+
('chart_qa_aug', "validation"): 960,
|
103 |
+
('doc_qa', "validation"): 5349,
|
104 |
+
('info_qa', "validation"): 2801,
|
105 |
+
('coco_2014_vqa', "validation"): 214354, # 40504 images
|
106 |
+
('coco_2014_vqa_multi', "validation"): 214354,
|
107 |
+
('coco_captioning_karpathy', "validation"): 25010,
|
108 |
+
('gqa', "validation"): 132062,
|
109 |
+
("science_qa_img", "validation"): 2097,
|
110 |
+
("ai2_diagram", "validation"): 1024,
|
111 |
+
("a_okvqa_mc", "validation"): 1145,
|
112 |
+
("a_okvqa_da", "validation"): 1075,
|
113 |
+
("charxiv_descriptive", "validation"): 1000,
|
114 |
+
("charxiv_descriptive", "test"): 1320,
|
115 |
+
("charxiv_reasoning", "validation"): 1000,
|
116 |
+
("charxiv_reasoning", "test"): 1320,
|
117 |
+
("fintabnetqa", "validation"): 125,
|
118 |
+
("fintabnetqa", "test"): 250,
|
119 |
+
("vwtq", "validation"): 125,
|
120 |
+
("vwtq", "test"): 750,
|
121 |
+
("vwtq_syn", "validation"): 125,
|
122 |
+
("vwtq_syn", "test"): 250,
|
123 |
+
("vtabfact", "validation"): 125,
|
124 |
+
("vtabfact", "test"): 250,
|
125 |
+
("nutrition_fact", "validation"): 100,
|
126 |
+
("nutrition_fact", "test"): 100,
|
127 |
+
|
128 |
+
("mmmu_test", "validation"): 900,
|
129 |
+
("count_bench", "test"): 500,
|
130 |
+
("mmmu_test", "test"): 10500,
|
131 |
+
("real_world_qa_test", "test"): 765,
|
132 |
+
("real_world_qa_no_instruction", "test"): 765,
|
133 |
+
("real_world_qa_dbg", "test"): 765,
|
134 |
+
("real_world_qa_as_user_qa", "test"): 765,
|
135 |
+
|
136 |
+
("seed_bench_test", "test"): 19241,
|
137 |
+
("pope_test", "test"): 9000,
|
138 |
+
("mme_test", "test"): 2374,
|
139 |
+
("math_vista_test", "validation"): 1000,
|
140 |
+
("math_vista_demo", "validation"): 1000,
|
141 |
+
("math_vista_v2", "validation"): 1000,
|
142 |
+
|
143 |
+
("math_vista_test", "test"): 5141,
|
144 |
+
("mmbench_test", "validation"): 4329,
|
145 |
+
("mmbench_test", "test"): 6666,
|
146 |
+
("sugar_crepe_test", "test"): 15022,
|
147 |
+
("blink_test", "validation"): 1901,
|
148 |
+
("dense_caption_eval_dbg", "validation"): 1,
|
149 |
+
|
150 |
+
("refclef_unc", "train"): 17978,
|
151 |
+
("refclef_unc", "validation"): 12029,
|
152 |
+
("refcoco_unc", "train"): 16994,
|
153 |
+
("refcoco_unc", "validation"): 10834,
|
154 |
+
("refcocoplus_unc", "train"): 16992,
|
155 |
+
("refcocoplus_unc", "validation"): 10758,
|
156 |
+
("refcocog_umd", "train"): 21899,
|
157 |
+
("refcocog_umd", "validation"): 4896,
|
158 |
+
("refclef_unc", "testA"): 3449,
|
159 |
+
("refclef_unc", "testB"): 3221,
|
160 |
+
("refclef_unc", "testC"): 2664,
|
161 |
+
("refclef_unc", "testAB"): 116,
|
162 |
+
("refclef_unc", "testBC"): 86,
|
163 |
+
("refcoco_unc", "testA"): 5657,
|
164 |
+
("refcoco_unc", "testB"): 5095,
|
165 |
+
("refcocoplus_unc", "testA"): 5726,
|
166 |
+
("refcocoplus_unc", "testB"): 4889,
|
167 |
+
("refcocog_umd", "test"): 9602,
|
168 |
+
("countbench_qa_point_count", "huggingface"): 490,
|
169 |
+
('countbench_qa', 'huggingface'): 490,
|
170 |
+
|
171 |
+
('cockatoo_712k_sept6', 'train'): 712121,
|
172 |
+
('cockatoo_712k_sept6', 'validation'): 5120,
|
173 |
+
('user_qa', 'train'): 71172,
|
174 |
+
('user_qa', 'validation'): 2048,
|
175 |
+
|
176 |
+
# pointing
|
177 |
+
("pointing_test", "test"): 436,
|
178 |
+
|
179 |
+
("fast_flickr_count_qa_point_count", "train"): 36916,
|
180 |
+
("fast_flickr_count_qa_point_count", "validation"): 163,
|
181 |
+
("fast_flickr_count_qa_point_count", "test"): 540,
|
182 |
+
("fast_flickr_count_qa_pointing", "train"): 36916,
|
183 |
+
("fast_flickr_count_qa_pointing", "validation"): 163,
|
184 |
+
("fast_flickr_count_qa_pointing", "test"): 540,
|
185 |
+
('pointing', 'train'): 309216,
|
186 |
+
('point_count', 'train'): 309216,
|
187 |
+
('pointing', 'validation'): 2054,
|
188 |
+
('point_count', 'validation'): 2054,
|
189 |
+
('point_count_high_freq', 'train'): 113840,
|
190 |
+
('point_count_high_freq', 'validation'): 3969,
|
191 |
+
('pointing_high_freq', 'train'): 113840,
|
192 |
+
('pointing_high_freq', 'validation'): 3969,
|
193 |
+
('point_qa', 'train'): 27856,
|
194 |
+
('point_qa', 'validation'): 978,
|
195 |
+
("a_okvqa_da", "test"): 6109,
|
196 |
+
("a_okvqa_mc", "test"): 6702,
|
197 |
+
("user_questions_for_elo", "test"): 14851,
|
198 |
+
("user_questions_for_elo_long", "test"): 1368,
|
199 |
+
("user_questions_for_elo_9_to_12", "test"): 3000,
|
200 |
+
|
201 |
+
("sim_point_count_qa", "train"): 522611,
|
202 |
+
("sim_point_count_qa", "validation"): 800,
|
203 |
+
("sim_point_count_qa", "test"): 800,
|
204 |
+
("sim_count_qa", "train"): 522611,
|
205 |
+
("sim_count_qa", "validation"): 800,
|
206 |
+
("sim_count_qa", "test"): 800,
|
207 |
+
|
208 |
+
("scifi_charts_qa", "validation"): 1024,
|
209 |
+
("scifi_table_qa", "validation"): 1024,
|
210 |
+
("scifi_natural_qa", "validation"): 128,
|
211 |
+
("scifi_nutrition_qa", "validation"): 128,
|
212 |
+
("scifi_document_qa", "validation"): 1024,
|
213 |
+
("scifi_diagram_qa", "validation"): 1024,
|
214 |
+
("scifi_charts_qa", "train"): 233622,
|
215 |
+
("scifi_table_qa", "train"): 93036,
|
216 |
+
("scifi_document_qa", "train"): 142559,
|
217 |
+
("scifi_diagram_qa", "train"): 33102,
|
218 |
+
|
219 |
+
("scifi_charts_qa_split", "train"): 116814,
|
220 |
+
("scifi_table_qa_split", "train"): 46518,
|
221 |
+
("scifi_document_qa_split", "train"): 71282,
|
222 |
+
("scifi_diagram_qa_split", "train"): 16551,
|
223 |
+
|
224 |
+
("scifi_charts_qa_exp_split", "train"): 116814,
|
225 |
+
("scifi_table_qa_exp_split", "train"): 46518,
|
226 |
+
("scifi_document_qa_exp_split", "train"): 71282,
|
227 |
+
("scifi_diagram_qa_exp_split", "train"): 16551,
|
228 |
+
|
229 |
+
("android_control", "train"): 74714,
|
230 |
+
("android_control", "validation"): 690,
|
231 |
+
("android_control", "test"): 3897,
|
232 |
+
|
233 |
+
("synthetic_qa_v3_multi_turn", "train"): 9824,
|
234 |
+
("synthetic_qa_v3", "train"): 162855,
|
235 |
+
("synthetic_qa_v3_style_tag", "train"): 162855,
|
236 |
+
("synthetic_qa_v3_as_user_qa", "train"): 162855,
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
for (name, split), count in list(DATASET_SIZES.items()):
|
241 |
+
if name in ["chart_qa"]:
|
242 |
+
DATASET_SIZES[(name + "_scifi", split)] = count
|
243 |
+
if name in ["android_control"]:
|
244 |
+
for k in ["ll", "hl", "hl_ll", "hl_cot"]:
|
245 |
+
DATASET_SIZES[(f"{name}_{k}", split)] = count
|
246 |
+
if name in ["scifi_charts_qa" ,"scifi_table_qa", "scifi_document_qa", "scifi_diagram_qa", "scifi_datikz_qa"]:
|
247 |
+
DATASET_SIZES[(name + "_exp", split)] = count
|
248 |
+
DATASET_SIZES[(name[:-3] + "_exp", split)] = count
|
249 |
+
DATASET_SIZES[(name[:-3] + "_demo", split)] = count
|
250 |
+
if name in ["ai2_diagram_v2_mix_transparent"]:
|
251 |
+
DATASET_SIZES[("ai2_diagram_v2_mix_transparent_one_style", split)] = count
|
252 |
+
if name in ["chart_qa", "info_qa", "doc_qa", "text_vqa", "coco_2014_vqa",
|
253 |
+
"ai2_diagram_v2_mix_transparent", "countbench_qa", "chart_qa_human"]:
|
254 |
+
DATASET_SIZES[(name + "_demo", split)] = count
|
255 |
+
|
256 |
+
|
257 |
+
def get_dataset_size(name, split):
|
258 |
+
if name.endswith("_eval"):
|
259 |
+
if (name, split) in DATASET_SIZES:
|
260 |
+
return DATASET_SIZES[(name, split)]
|
261 |
+
name = name[:-len('_eval')]
|
262 |
+
return DATASET_SIZES[(name, split)]
|
exceptions.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = [
|
2 |
+
"OLMoError",
|
3 |
+
"OLMoConfigurationError",
|
4 |
+
"OLMoCliError",
|
5 |
+
"OLMoEnvironmentError",
|
6 |
+
"OLMoNetworkError",
|
7 |
+
"OLMoCheckpointError",
|
8 |
+
]
|
9 |
+
|
10 |
+
|
11 |
+
class OLMoError(Exception):
|
12 |
+
"""
|
13 |
+
Base class for all custom OLMo exceptions.
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
class OLMoConfigurationError(OLMoError):
|
18 |
+
"""
|
19 |
+
An error with a configuration file.
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
class OLMoCliError(OLMoError):
|
24 |
+
"""
|
25 |
+
An error from incorrect CLI usage.
|
26 |
+
"""
|
27 |
+
|
28 |
+
|
29 |
+
class OLMoEnvironmentError(OLMoError):
|
30 |
+
"""
|
31 |
+
An error from incorrect environment variables.
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
class OLMoNetworkError(OLMoError):
|
36 |
+
"""
|
37 |
+
An error with a network request.
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
class OLMoCheckpointError(OLMoError):
|
42 |
+
"""
|
43 |
+
An error occurred reading or writing from a checkpoint.
|
44 |
+
"""
|
45 |
+
|
46 |
+
|
47 |
+
class OLMoThreadError(Exception):
|
48 |
+
"""
|
49 |
+
Raised when a thread fails.
|
50 |
+
"""
|
iterable_dataset.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import multiprocessing
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import queue
|
7 |
+
import socket
|
8 |
+
import time
|
9 |
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
10 |
+
from multiprocessing.managers import BaseManager
|
11 |
+
from multiprocessing.shared_memory import SharedMemory
|
12 |
+
from os.path import exists
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
15 |
+
|
16 |
+
import psutil
|
17 |
+
import tensorflow as tf
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.utils.data
|
21 |
+
import clu
|
22 |
+
from clu.data.dataset_iterator import Element
|
23 |
+
|
24 |
+
|
25 |
+
from .aliases import PathOrStr
|
26 |
+
from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size, get_node_rank, \
|
27 |
+
get_local_world_size, get_local_rank, move_to_device
|
28 |
+
from .util import roundrobin, threaded_generator
|
29 |
+
from .data_factory import SeqioDataset
|
30 |
+
from .multimodal_preprocessor import MultiModalPreprocessor
|
31 |
+
from .preprocesssors import rename
|
32 |
+
import torch.distributed as dist
|
33 |
+
from . import tasks
|
34 |
+
|
35 |
+
__all__ = ["MMIterableDataset"]
|
36 |
+
|
37 |
+
log = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
def batch_fn(batch, for_inference):
|
41 |
+
if for_inference:
|
42 |
+
out = {}
|
43 |
+
for k, v in batch.items():
|
44 |
+
if k.startswith("metadata/"):
|
45 |
+
out[k] = v
|
46 |
+
else:
|
47 |
+
out[k] = torch.from_numpy(v)
|
48 |
+
return out
|
49 |
+
else:
|
50 |
+
out = {k: torch.from_numpy(v) for k, v in batch.items() if not k.startswith("metadata/")}
|
51 |
+
out["metadata"] = [{} for _ in out["input_ids"]]
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class PyTorchDatasetIterator(clu.data.dataset_iterator.TfDatasetIterator):
|
56 |
+
def __init__(self, dataset, *, checkpoint: bool, for_inference: bool):
|
57 |
+
self.for_inference = for_inference
|
58 |
+
super().__init__(dataset, checkpoint=checkpoint)
|
59 |
+
|
60 |
+
def __next__(self) -> Element:
|
61 |
+
batch = {k: v.numpy() for k, v in next(self.iterator).items()}
|
62 |
+
return batch_fn(batch, self.for_inference)
|
63 |
+
|
64 |
+
def __len__(self) -> int:
|
65 |
+
return len(self._dataset)
|
66 |
+
|
67 |
+
|
68 |
+
class MMIterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
dataset: SeqioDataset,
|
72 |
+
preprocessor: MultiModalPreprocessor,
|
73 |
+
world_size: Optional[int] = None,
|
74 |
+
rank: Optional[int] = None,
|
75 |
+
):
|
76 |
+
self.preprocessor = preprocessor
|
77 |
+
self.rank = rank if rank is not None else get_global_rank()
|
78 |
+
self.world_size = world_size if world_size is not None else get_world_size()
|
79 |
+
self.dataset_config = dataset
|
80 |
+
|
81 |
+
data_iter = dataset.build(
|
82 |
+
self.preprocessor,
|
83 |
+
self.rank,
|
84 |
+
self.world_size,
|
85 |
+
)
|
86 |
+
|
87 |
+
data_iter: tf.data.Dataset = rename(input_ids="input_tokens", labels="target_tokens")(data_iter)
|
88 |
+
self.dataset = data_iter
|
89 |
+
self.data_iter = PyTorchDatasetIterator(
|
90 |
+
data_iter, checkpoint=True, for_inference=dataset.for_inference)
|
91 |
+
|
92 |
+
def reset(self):
|
93 |
+
self.data_iter.reset()
|
94 |
+
|
95 |
+
def save(self, filename: PathOrStr):
|
96 |
+
self.data_iter.save(filename)
|
97 |
+
|
98 |
+
def restore(self, filename: PathOrStr):
|
99 |
+
self.data_iter.restore(filename)
|
100 |
+
|
101 |
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
102 |
+
return self.data_iter
|
103 |
+
|
104 |
+
|
105 |
+
def _split_batch(batch, n):
|
106 |
+
subbatches = [{} for _ in range(n)]
|
107 |
+
for k, v in batch.items():
|
108 |
+
assert len(v) % n == 0, f"n={n} but {k} has {len(v)}"
|
109 |
+
subatch_dim = len(v) // n
|
110 |
+
for i, subbatch in enumerate(subbatches):
|
111 |
+
subbatch[k] = v[i * subatch_dim:(i + 1) * subatch_dim]
|
112 |
+
return subbatches
|
113 |
+
|
114 |
+
|
115 |
+
def tf_to_torch_dtype(tf_dtype):
|
116 |
+
dtype_mapping = {
|
117 |
+
tf.float16: torch.float16,
|
118 |
+
tf.float32: torch.float32,
|
119 |
+
tf.float64: torch.float64,
|
120 |
+
tf.int8: torch.int8,
|
121 |
+
tf.uint8: torch.uint8,
|
122 |
+
tf.int16: torch.int16,
|
123 |
+
tf.int32: torch.int32,
|
124 |
+
tf.int64: torch.int64,
|
125 |
+
tf.bool: torch.bool,
|
126 |
+
}
|
127 |
+
return dtype_mapping[tf_dtype]
|
128 |
+
|
129 |
+
|
130 |
+
class PeerToPeer(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
131 |
+
"""
|
132 |
+
This dataloader runs the tf.data.Dataset on one processes per a node, and then
|
133 |
+
transfers the batch to the other processes. For 7B model about a 10% performance
|
134 |
+
despite my attempts to make it asynchronous
|
135 |
+
|
136 |
+
The advantage is that it avoids the overhead of running multiple tf.data.Dataset
|
137 |
+
in one node
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
dataset: SeqioDataset,
|
143 |
+
preprocessor: MultiModalPreprocessor,
|
144 |
+
world_size: Optional[int] = None,
|
145 |
+
rank: Optional[int] = None,
|
146 |
+
device=None
|
147 |
+
):
|
148 |
+
assert get_world_size() % get_local_world_size() == 0
|
149 |
+
self.device = device
|
150 |
+
self.device_batch_size = dataset.global_batch_size // get_world_size()
|
151 |
+
|
152 |
+
self.preprocessor = preprocessor
|
153 |
+
self.seqio_dataset = dataset
|
154 |
+
|
155 |
+
lws = get_local_world_size()
|
156 |
+
|
157 |
+
if get_local_rank() == 0:
|
158 |
+
tf_dataset = dataset.build(
|
159 |
+
self.preprocessor,
|
160 |
+
get_node_rank(),
|
161 |
+
get_world_size() // lws,
|
162 |
+
)
|
163 |
+
|
164 |
+
tf_dataset = rename(input_ids="input_tokens", labels="target_tokens")(tf_dataset)
|
165 |
+
self.dataset = tf_dataset
|
166 |
+
device_spec = {k: ((v.shape[0]//lws,) + tuple(v.shape[1:]), tf_to_torch_dtype(v.dtype))
|
167 |
+
for k, v in tf_dataset.element_spec.items()}
|
168 |
+
else:
|
169 |
+
self.dataset = None
|
170 |
+
device_spec = None
|
171 |
+
|
172 |
+
broadcast = [device_spec]
|
173 |
+
torch.distributed.broadcast_object_list(broadcast)
|
174 |
+
self.device_spec = broadcast[0]
|
175 |
+
|
176 |
+
self._node_group_ranks = ranks = [(i + get_node_rank()*lws) for i in range(lws)]
|
177 |
+
if get_local_rank() == 0:
|
178 |
+
assert get_global_rank() == self._node_group_ranks[0]
|
179 |
+
self._keys = sorted(self.device_spec)
|
180 |
+
self.multithread_pin = False
|
181 |
+
|
182 |
+
def _pin(self, it, on):
|
183 |
+
batch = next(it)
|
184 |
+
batch = {k: torch.from_numpy(v) for k, v in batch.items()}
|
185 |
+
batch = _split_batch(batch, len(self._node_group_ranks))
|
186 |
+
return [{k: v.pin_memory() for k, v in subbatch.items()} for subbatch in batch]
|
187 |
+
|
188 |
+
def _send_pinned(self, batch):
|
189 |
+
requests = []
|
190 |
+
for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1):
|
191 |
+
for k in self._keys:
|
192 |
+
batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True)
|
193 |
+
requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank))
|
194 |
+
ops = dist.batch_isend_irecv(requests)
|
195 |
+
return batch[0], ops
|
196 |
+
|
197 |
+
def _send(self, it, on):
|
198 |
+
if get_local_rank() == 0:
|
199 |
+
try:
|
200 |
+
batch = next(it)
|
201 |
+
batch = {k: torch.from_numpy(v) for k, v in batch.items()}
|
202 |
+
batch = _split_batch(batch, len(self._node_group_ranks))
|
203 |
+
except StopIteration:
|
204 |
+
# Special batch to indicate iteration is done
|
205 |
+
batch = [
|
206 |
+
{k: torch.full(sh, -10, dtype=dtype, device=self.device)
|
207 |
+
for k, (sh, dtype) in self.device_spec.items()}
|
208 |
+
for _ in range(len(self._node_group_ranks))
|
209 |
+
]
|
210 |
+
|
211 |
+
# pin_memory so the device transfer can be non_blocking
|
212 |
+
batch = [{k: v.pin_memory() for k, v in subbatch.items()}
|
213 |
+
for subbatch in batch]
|
214 |
+
|
215 |
+
requests = []
|
216 |
+
for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1):
|
217 |
+
for k in self._keys:
|
218 |
+
batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True)
|
219 |
+
requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank))
|
220 |
+
ops = dist.batch_isend_irecv(requests)
|
221 |
+
batch = batch[0]
|
222 |
+
else:
|
223 |
+
batch = {k: torch.zeros(sh, dtype=dtype, device=self.device)
|
224 |
+
for k, (sh, dtype) in self.device_spec.items()}
|
225 |
+
requests = []
|
226 |
+
for k in self._keys:
|
227 |
+
requests.append(dist.P2POp(dist.irecv, batch[k], self._node_group_ranks[0]))
|
228 |
+
ops = dist.batch_isend_irecv(requests)
|
229 |
+
return batch, ops
|
230 |
+
|
231 |
+
def __iter__(self):
|
232 |
+
on = 0
|
233 |
+
if get_local_rank() == 0:
|
234 |
+
it = iter(self.dataset.as_numpy_iterator())
|
235 |
+
else:
|
236 |
+
it = None
|
237 |
+
|
238 |
+
if get_local_rank() == 0 and self.multithread_pin:
|
239 |
+
# Try to be clever and do memory pinning in a seperate thread, in practice
|
240 |
+
# didn't seem to help much so off by default for now
|
241 |
+
# Currently does not support finite dataset
|
242 |
+
with ThreadPoolExecutor(max_workers=1) as pool:
|
243 |
+
_is_sending = self._send_pinned(self._pin(it, on))
|
244 |
+
_is_pinning = pool.submit(self._pin, it, on)
|
245 |
+
on += 1
|
246 |
+
while True:
|
247 |
+
result = _is_sending
|
248 |
+
_is_sending = self._send_pinned(_is_pinning.result())
|
249 |
+
_is_pinning = pool.submit(self._pin, it, on)
|
250 |
+
on += 1
|
251 |
+
for op in result[1]:
|
252 |
+
op.wait()
|
253 |
+
yield result[0]
|
254 |
+
else:
|
255 |
+
_in_flight = self._send(it, on)
|
256 |
+
on += 1
|
257 |
+
while True:
|
258 |
+
on += 1
|
259 |
+
next_batch = self._send(it, on) # queue up the next batch
|
260 |
+
for op in _in_flight[1]: # wait for the current batch
|
261 |
+
op.wait()
|
262 |
+
if _in_flight["input_ids"][0] != -10: # indicates no more data
|
263 |
+
return
|
264 |
+
yield _in_flight[0]
|
265 |
+
_in_flight = next_batch
|
266 |
+
|
modeling_molmoe.py
CHANGED
@@ -39,14 +39,14 @@ import einops
|
|
39 |
from transformers import PreTrainedModel
|
40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
41 |
|
42 |
-
from
|
43 |
-
from
|
44 |
BeamSearch,
|
45 |
Constraint,
|
46 |
FinalSequenceScorer,
|
47 |
Sampler
|
48 |
)
|
49 |
-
from
|
50 |
ActivationType,
|
51 |
BlockType,
|
52 |
LayerNormType,
|
@@ -56,7 +56,7 @@ from olmo.config import (
|
|
56 |
AttentionType,
|
57 |
)
|
58 |
|
59 |
-
from
|
60 |
from .config_molmoe import (
|
61 |
MolmoConfig,
|
62 |
VisionBackboneConfig
|
|
|
39 |
from transformers import PreTrainedModel
|
40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
41 |
|
42 |
+
from .aliases import PathOrStr
|
43 |
+
from .beam_search import (
|
44 |
BeamSearch,
|
45 |
Constraint,
|
46 |
FinalSequenceScorer,
|
47 |
Sampler
|
48 |
)
|
49 |
+
from .config import (
|
50 |
ActivationType,
|
51 |
BlockType,
|
52 |
LayerNormType,
|
|
|
56 |
AttentionType,
|
57 |
)
|
58 |
|
59 |
+
from .util import resource_path
|
60 |
from .config_molmoe import (
|
61 |
MolmoConfig,
|
62 |
VisionBackboneConfig
|
multimodal_preprocessor.py
ADDED
@@ -0,0 +1,1549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from collections import defaultdict
|
5 |
+
from typing import Tuple, Optional, Any, Dict, List, Union, Mapping
|
6 |
+
|
7 |
+
import einops
|
8 |
+
import seqio
|
9 |
+
import numpy as np
|
10 |
+
import tensorflow as tf
|
11 |
+
|
12 |
+
from .mm_data import seqio_tokenizer
|
13 |
+
from .data_utils import pad_to_bounding_box, \
|
14 |
+
get_3d_subsegments, _append_to_innermost_axis, resize_and_pad, \
|
15 |
+
apply_with_random_selector, get_special_token_ids, make_autoregressive_inputs, \
|
16 |
+
trim_and_pad_dataset, assert_not_truncated
|
17 |
+
from .prompts import apply_keyword_prompt, STYLE_TO_GENERAL_PROMPT, GENERAL_PROMPTS_V1
|
18 |
+
import .constants as config
|
19 |
+
|
20 |
+
|
21 |
+
def siglip_resize(src, imgsize, truncate):
|
22 |
+
"""Resize and preprocess for SigLIP ViT in the offical jax implementation"""
|
23 |
+
assert src.dtype == tf.uint8
|
24 |
+
# SigCLIP removes aspect ratio by default
|
25 |
+
resized = tf.image.resize(src, imgsize, method=tf.image.ResizeMethod.BILINEAR, antialias=False)
|
26 |
+
dtype = src.dtype
|
27 |
+
tf_dtype = tf.type_spec_from_value(src).dtype
|
28 |
+
resized = tf.cast(tf.clip_by_value(resized, tf_dtype.min, tf_dtype.max), dtype)
|
29 |
+
|
30 |
+
# Normalize between -1 and 1 without using imagenet standard mean/std
|
31 |
+
vmin=-1; vmax=1; in_min=0; in_max=255.0
|
32 |
+
in_min_t = tf.constant(in_min, tf.float32)
|
33 |
+
in_max_t = tf.constant(in_max, tf.float32)
|
34 |
+
image = tf.cast(resized, tf.float32)
|
35 |
+
image = (image - in_min_t) / (in_max_t - in_min_t)
|
36 |
+
image = vmin + image * (vmax - vmin)
|
37 |
+
if truncate:
|
38 |
+
image = image[:truncate, :truncate]
|
39 |
+
return image
|
40 |
+
|
41 |
+
|
42 |
+
def extract_bboxes(text, image_w, image_h):
|
43 |
+
points = extract_points(text, image_w, image_h)
|
44 |
+
boxes = []
|
45 |
+
for i in range(len(points)//2):
|
46 |
+
x1, y1 = points[i*2]
|
47 |
+
x2, y2 = points[i*2 + 1]
|
48 |
+
boxes.append([x1, y1, x2, y2])
|
49 |
+
return boxes
|
50 |
+
|
51 |
+
|
52 |
+
def extract_annotated_points(caption, image_w, image_h):
|
53 |
+
points = []
|
54 |
+
for match in re.finditer("<point x=\"([0-9\\.]*)\" y=\"([0-9\\.]*)\" alt=\"([^\"]*)\">", caption):
|
55 |
+
x = float(match.group(1))
|
56 |
+
y = float(match.group(2))
|
57 |
+
points.append(([[x, y]], match.group(3)))
|
58 |
+
for match in re.finditer("<points ([^<]*) alt=\"([^\"]*)\">", caption):
|
59 |
+
loc_str = match.group(1)
|
60 |
+
locations = defaultdict(dict)
|
61 |
+
if loc_str.startswith("points="):
|
62 |
+
point_grp = []
|
63 |
+
for point_match in re.finditer(r"([0-9]+\.[0-9]),? ([0-9]+\.[0-9])", loc_str):
|
64 |
+
try:
|
65 |
+
point = [float(point_match.group(i)) for i in range(1, 3)]
|
66 |
+
point_grp.append(point)
|
67 |
+
except ValueError:
|
68 |
+
pass
|
69 |
+
else:
|
70 |
+
for val in loc_str.split():
|
71 |
+
try:
|
72 |
+
key, val = val.split("=")
|
73 |
+
locations[key[1:]][key[:1]] = float(val.strip("\""))
|
74 |
+
except ValueError:
|
75 |
+
import pdb; pdb.set_trace()
|
76 |
+
logging.warning(f"Failed to parse {val} from {match.group(0)}")
|
77 |
+
point_grp = []
|
78 |
+
for key, coords in locations.items():
|
79 |
+
if sorted(coords) == ["x", "y"]:
|
80 |
+
point_grp.append([coords["x"], coords["y"]])
|
81 |
+
if point_grp:
|
82 |
+
points.append((point_grp, match.group(2)))
|
83 |
+
|
84 |
+
normalized = []
|
85 |
+
for point_grp, point_text in points:
|
86 |
+
normalized.append((
|
87 |
+
np.array(point_grp) / 100.0 * np.array([image_w, image_h]),
|
88 |
+
point_text,
|
89 |
+
))
|
90 |
+
return normalized
|
91 |
+
|
92 |
+
|
93 |
+
def extract_points(text, image_w, image_h):
|
94 |
+
all_points = []
|
95 |
+
for match in re.finditer(r"Click\(([0-9]+\.[0-9]), ?([0-9]+\.[0-9])\)", text):
|
96 |
+
try:
|
97 |
+
point = [float(match.group(i)) for i in range(1, 3)]
|
98 |
+
except ValueError:
|
99 |
+
pass
|
100 |
+
else:
|
101 |
+
point = np.array(point)
|
102 |
+
if np.max(point) > 100:
|
103 |
+
# Treat as an invalid output
|
104 |
+
continue
|
105 |
+
point /= 100.0
|
106 |
+
point = point * np.array([image_w, image_h])
|
107 |
+
all_points.append(point)
|
108 |
+
|
109 |
+
for match in re.finditer(r"\(([0-9]+\.[0-9]),? ?([0-9]+\.[0-9])\)", text):
|
110 |
+
try:
|
111 |
+
point = [float(match.group(i)) for i in range(1, 3)]
|
112 |
+
except ValueError:
|
113 |
+
pass
|
114 |
+
else:
|
115 |
+
point = np.array(point)
|
116 |
+
if np.max(point) > 100:
|
117 |
+
# Treat as an invalid output
|
118 |
+
continue
|
119 |
+
point /= 100.0
|
120 |
+
point = point * np.array([image_w, image_h])
|
121 |
+
all_points.append(point)
|
122 |
+
for match in re.finditer(r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"', text):
|
123 |
+
try:
|
124 |
+
point = [float(match.group(i)) for i in range(1, 3)]
|
125 |
+
except ValueError:
|
126 |
+
pass
|
127 |
+
else:
|
128 |
+
point = np.array(point)
|
129 |
+
if np.max(point) > 100:
|
130 |
+
# Treat as an invalid output
|
131 |
+
continue
|
132 |
+
point /= 100.0
|
133 |
+
point = point * np.array([image_w, image_h])
|
134 |
+
all_points.append(point)
|
135 |
+
for match in re.finditer(r'(?:\d+|p)\s*=\s*([0-9]{3})\s*,\s*([0-9]{3})', text):
|
136 |
+
try:
|
137 |
+
point = [int(match.group(i)) / 10.0 for i in range(1, 3)]
|
138 |
+
except ValueError:
|
139 |
+
pass
|
140 |
+
else:
|
141 |
+
point = np.array(point)
|
142 |
+
if np.max(point) > 100:
|
143 |
+
# Treat as an invalid output
|
144 |
+
continue
|
145 |
+
point /= 100.0
|
146 |
+
point = point * np.array([image_w, image_h])
|
147 |
+
all_points.append(point)
|
148 |
+
return all_points
|
149 |
+
|
150 |
+
|
151 |
+
def extract_points_from_point_count(text, image_w, image_h):
|
152 |
+
all_points = []
|
153 |
+
points = re.findall(r"(\d+\.\d+),\s*(\d+\.\d+)", text)
|
154 |
+
|
155 |
+
for match in points:
|
156 |
+
try:
|
157 |
+
point = [float(match[0]), float(match[1])]
|
158 |
+
except ValueError:
|
159 |
+
pass
|
160 |
+
else:
|
161 |
+
point = np.array(point)
|
162 |
+
if np.max(point) > 100:
|
163 |
+
# Treat as an invalid output
|
164 |
+
continue
|
165 |
+
point = point * np.array([image_w, image_h])
|
166 |
+
all_points.append(point)
|
167 |
+
return all_points
|
168 |
+
|
169 |
+
|
170 |
+
def select_tiling(h, w, patch_size, max_num_patches):
|
171 |
+
"""Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
|
172 |
+
original_size = tf.stack([h, w]) # [1, 2]
|
173 |
+
original_res = h * w
|
174 |
+
tilings = []
|
175 |
+
for i in range(1, max_num_patches+1):
|
176 |
+
for j in range(1, max_num_patches+1):
|
177 |
+
if i*j <= max_num_patches:
|
178 |
+
tilings.append((i, j))
|
179 |
+
# sort so argmin and argmax favour smaller tilings in the event of a tie
|
180 |
+
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
|
181 |
+
candidate_tilings = tf.constant(tilings, dtype=tf.int32) # [n_resolutions, 2]
|
182 |
+
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
|
183 |
+
|
184 |
+
# How much we would need to scale the image to fit exactly in each tiling
|
185 |
+
required_scale_d = tf.cast(candidate_resolutions, tf.float32) / tf.cast(original_size[None, :], tf.float32)
|
186 |
+
required_scale = tf.reduce_min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
|
187 |
+
if tf.reduce_all(required_scale < 1):
|
188 |
+
# We are forced to downscale, so try to minimize the amount of downscaling
|
189 |
+
ix = tf.argmax(required_scale)[0]
|
190 |
+
else:
|
191 |
+
# Pick the resolution that required the least upscaling so that it most closely fits the image
|
192 |
+
required_scale = tf.where(required_scale < 1.0, 10e9, required_scale)
|
193 |
+
ix = tf.argmin(required_scale)[0]
|
194 |
+
return candidate_tilings[ix]
|
195 |
+
|
196 |
+
|
197 |
+
DEMO_STYLES = [
|
198 |
+
"point_count",
|
199 |
+
"pointing",
|
200 |
+
"user_qa",
|
201 |
+
"scifi_charts_exp",
|
202 |
+
"scifi_charts_exp",
|
203 |
+
"scifi_charts_exp",
|
204 |
+
"scifi_charts_exp",
|
205 |
+
"long_caption",
|
206 |
+
"named_entity"
|
207 |
+
]
|
208 |
+
|
209 |
+
|
210 |
+
@dataclasses.dataclass
|
211 |
+
class MultiModalPreprocessor:
|
212 |
+
"""Turns text/image inputs into tensors that can be input to the model"""
|
213 |
+
tokenizer: Any
|
214 |
+
|
215 |
+
# How to prompt the model
|
216 |
+
prompt_templates: str = "none" # How to template prompts for examples
|
217 |
+
message_format: str = "none" # How to format messages
|
218 |
+
system_prompt: Optional[str] = None # How to generate system prompts
|
219 |
+
prompt_override: Optional[str] = None # Used for setting prompt manually
|
220 |
+
always_start_with_space: bool = False # Always include a leading space for the first bit of text
|
221 |
+
default_inference_len: int = 65 # Inference len for length-conditioned prompting
|
222 |
+
|
223 |
+
# How to crops/resize images
|
224 |
+
crop_mode: str = "resize"
|
225 |
+
max_crops: int = 6
|
226 |
+
overlap_margins: Tuple[int, int] = (4, 4)
|
227 |
+
do_random_scale: Optional[bool] = False
|
228 |
+
resize: str = "default"
|
229 |
+
random_scale_max: float = 1.1
|
230 |
+
random_scale_min: float = 0.9
|
231 |
+
random_scale_ratio: float = 0.5
|
232 |
+
use_col_tokens: bool = True
|
233 |
+
|
234 |
+
# Data about the ViT and connector we need when deciding the crops
|
235 |
+
base_image_input_size: Tuple[int, int] = (336, 336)
|
236 |
+
image_token_length_w: int = 12
|
237 |
+
image_token_length_h: int = 12
|
238 |
+
image_patch_size: int = 14
|
239 |
+
image_padding_mask: bool = False
|
240 |
+
|
241 |
+
# Other settings
|
242 |
+
loss_token_weighting: Optional[str] = None
|
243 |
+
unconditioned: Union[bool, float] = False # Ignore images
|
244 |
+
fix_image_input_idx: int = 2 # backwards compatibility fix
|
245 |
+
pad_to: Optional[int] = None # experimental feature
|
246 |
+
|
247 |
+
_special_tokens: Dict[str, int] = None
|
248 |
+
split_at: Optional[int] = None
|
249 |
+
|
250 |
+
def get_max_total_crops(self):
|
251 |
+
if self.crop_mode == "resize":
|
252 |
+
return 1
|
253 |
+
elif "resize" in self.crop_mode:
|
254 |
+
return 1 + self.max_crops
|
255 |
+
else:
|
256 |
+
return self.max_crops
|
257 |
+
|
258 |
+
@property
|
259 |
+
def image_num_patch(self):
|
260 |
+
h, w = self.base_image_input_size
|
261 |
+
return h//self.image_patch_size, w//self.image_patch_size
|
262 |
+
|
263 |
+
@property
|
264 |
+
def special_token_ids(self):
|
265 |
+
if self._special_tokens is None:
|
266 |
+
self._special_tokens = get_special_token_ids(self.tokenizer)
|
267 |
+
return self._special_tokens
|
268 |
+
|
269 |
+
def image_to_patches_and_tokens(self, image, is_training):
|
270 |
+
"""Preprocesses an image
|
271 |
+
|
272 |
+
Args:
|
273 |
+
image: [h, w, 3] image to preprocessing
|
274 |
+
Returns:
|
275 |
+
crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
|
276 |
+
change between images but the other dimension are fixed
|
277 |
+
tokens: (n_tokens,) tf.int32 tokens, pad tokens indicate where to insert the
|
278 |
+
patch features, might include other special tokens as well
|
279 |
+
patch_ordering: (n_crops, n_tokens_per_crop) order image features should be inserted
|
280 |
+
into the `tokens`, negative values indicates patches features to exclude
|
281 |
+
padding_mask: (n_crops, h, w) mask of what pixels are padding, can be None
|
282 |
+
"""
|
283 |
+
do_random_scale = self.do_random_scale
|
284 |
+
if do_random_scale:
|
285 |
+
do_random_scale = is_training
|
286 |
+
|
287 |
+
base_image_input_size = self.base_image_input_size
|
288 |
+
if isinstance(base_image_input_size, int):
|
289 |
+
base_image_input_size = (base_image_input_size, base_image_input_size)
|
290 |
+
|
291 |
+
image_token_length_w, image_token_length_h = self.image_token_length_w, self.image_token_length_h
|
292 |
+
base_image_input_d = self.image_patch_size
|
293 |
+
tokens_per_image = image_token_length_w * image_token_length_h
|
294 |
+
image_base_patch_w = base_image_input_size[1] // base_image_input_d
|
295 |
+
image_base_patch_h = base_image_input_size[0] // base_image_input_d
|
296 |
+
extra_image = False
|
297 |
+
patch_ordering = None
|
298 |
+
|
299 |
+
if self.resize == "default":
|
300 |
+
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
301 |
+
def _resize(_image, sz):
|
302 |
+
return resize_and_pad(
|
303 |
+
_image, sz,
|
304 |
+
do_random_scale=do_random_scale,
|
305 |
+
random_scale_max=self.random_scale_max,
|
306 |
+
random_scale_min=self.random_scale_min,
|
307 |
+
random_scale_ratio=self.random_scale_ratio,
|
308 |
+
return_outputs=False,
|
309 |
+
resize_method='random' if is_training else tf.image.ResizeMethod.BILINEAR)
|
310 |
+
elif self.resize == "stretch":
|
311 |
+
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
312 |
+
assert not do_random_scale
|
313 |
+
|
314 |
+
def _resize(_image, sz):
|
315 |
+
if not is_training:
|
316 |
+
img = tf.image.resize(_image, sz, antialias=True, method=tf.image.ResizeMethod.BILINEAR)
|
317 |
+
else:
|
318 |
+
resize_methods = sorted([k for k in tf.image.ResizeMethod.__dict__.keys() if k.isupper()])
|
319 |
+
img = apply_with_random_selector(
|
320 |
+
_image,
|
321 |
+
lambda x, method_idx: tf.image.resize(x, sz,
|
322 |
+
tf.image.ResizeMethod.__dict__[resize_methods[method_idx]],
|
323 |
+
antialias=True),
|
324 |
+
num_cases=len(resize_methods))
|
325 |
+
return img, tf.ones(tf.shape(img)[:2], dtype=tf.bool)
|
326 |
+
elif self.resize in "siglip":
|
327 |
+
assert not do_random_scale
|
328 |
+
|
329 |
+
def _resize(_image, sz):
|
330 |
+
img = siglip_resize(_image, sz, truncate=None)
|
331 |
+
return img, tf.ones(tf.shape(img)[:2], dtype=tf.bool)
|
332 |
+
else:
|
333 |
+
raise NotImplementedError(self.resize)
|
334 |
+
|
335 |
+
def _img_to_patches(_img, _img_mask, dy=1, dx=1):
|
336 |
+
_img = einops.rearrange(
|
337 |
+
_img, '(dy h dh) (dx w dw) c -> (dy dx) (h w) (dh dw c)',
|
338 |
+
dh=base_image_input_d,
|
339 |
+
dw=base_image_input_d,
|
340 |
+
dy=dy,
|
341 |
+
dx=dx,
|
342 |
+
h=image_base_patch_h,
|
343 |
+
w=image_base_patch_w
|
344 |
+
)
|
345 |
+
_img_mask = einops.rearrange(
|
346 |
+
_img_mask, '(dy h dh) (dx w dw) -> (dy dx) (h w) (dh dw)',
|
347 |
+
dh=base_image_input_d,
|
348 |
+
dw=base_image_input_d,
|
349 |
+
dy=dy,
|
350 |
+
dx=dx,
|
351 |
+
h=image_base_patch_h,
|
352 |
+
w=image_base_patch_w
|
353 |
+
)
|
354 |
+
return _img, tf.reduce_mean(tf.cast(_img_mask, tf.float32), -1)
|
355 |
+
|
356 |
+
mode = self.crop_mode
|
357 |
+
if mode == "resize":
|
358 |
+
patches, img_mask = _resize(image, base_image_input_size)
|
359 |
+
patches, img_mask = _img_to_patches(patches, img_mask)
|
360 |
+
image_layout_impatch_w = 1
|
361 |
+
image_layout_impatch_h = 1
|
362 |
+
patch_ordering = tf.range(tokens_per_image)[None, :]
|
363 |
+
|
364 |
+
elif mode in ["overlap", "overlap-and-resize-c2"]:
|
365 |
+
original_image_h = tf.shape(image, out_type=tf.int32)[0]
|
366 |
+
original_image_w = tf.shape(image, out_type=tf.int32)[1]
|
367 |
+
crop_size = base_image_input_size[0]
|
368 |
+
|
369 |
+
# Discard this many patches from the (left/top, right/bottom) of crops
|
370 |
+
left_margin, right_margin = self.overlap_margins
|
371 |
+
# left_margin, right_margin = 2, 2
|
372 |
+
assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
|
373 |
+
total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
|
374 |
+
crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
|
375 |
+
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
|
376 |
+
crop_window_size = crop_window_patches * base_image_input_d
|
377 |
+
tiling = select_tiling(original_image_h - total_margin_pixels, original_image_w - total_margin_pixels,
|
378 |
+
crop_window_size, self.max_crops)
|
379 |
+
src, img_mask = _resize(
|
380 |
+
image, [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels])
|
381 |
+
|
382 |
+
n_crops = tiling[0]*tiling[1]
|
383 |
+
patches_arr = tf.TensorArray(
|
384 |
+
tf.float32, n_crops, element_shape=[crop_size, crop_size, 3])
|
385 |
+
mask_arr = tf.TensorArray(
|
386 |
+
tf.bool, n_crops, element_shape=[crop_size, crop_size])
|
387 |
+
# We assume 2x2 pooling, but can allow padding the right/bottom with extra
|
388 |
+
# patches if the number of patches per side is not even
|
389 |
+
assert (crop_patches+1)//2 == image_token_length_h
|
390 |
+
assert (crop_patches+1)//2 == image_token_length_w
|
391 |
+
patch_ordering_arr = tf.TensorArray(
|
392 |
+
tf.int32, n_crops, element_shape=[image_token_length_h, image_token_length_w])
|
393 |
+
on = 0
|
394 |
+
on_patch = 0
|
395 |
+
for i in range(tiling[0]):
|
396 |
+
y0 = i*crop_window_size
|
397 |
+
if i == 0:
|
398 |
+
crop_y0 = 0
|
399 |
+
else:
|
400 |
+
crop_y0 = left_margin // 2
|
401 |
+
|
402 |
+
crop_h = image_base_patch_h - (right_margin + left_margin)
|
403 |
+
if i == 0:
|
404 |
+
crop_h += left_margin
|
405 |
+
if i == (tiling[0]-1):
|
406 |
+
crop_h += right_margin
|
407 |
+
for j in range(tiling[1]):
|
408 |
+
x0 = j*crop_window_size
|
409 |
+
if j == 0:
|
410 |
+
crop_x0 = 0
|
411 |
+
else:
|
412 |
+
crop_x0 = left_margin // 2
|
413 |
+
|
414 |
+
crop_w = image_base_patch_w - (right_margin + left_margin)
|
415 |
+
if j == 0:
|
416 |
+
crop_w += left_margin
|
417 |
+
if j == (tiling[1]-1):
|
418 |
+
crop_w += right_margin
|
419 |
+
|
420 |
+
pooled_w = (crop_w + 1) // 2
|
421 |
+
pooled_h = (crop_h + 1) // 2
|
422 |
+
patch_ordering_arr = patch_ordering_arr.write(
|
423 |
+
on_patch,
|
424 |
+
pad_to_bounding_box(
|
425 |
+
tf.reshape(tf.range(on, on+pooled_h*pooled_w, dtype=tf.int32), (pooled_h, pooled_w, 1)),
|
426 |
+
crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
|
427 |
+
)[:, :, 0]
|
428 |
+
)
|
429 |
+
patches_arr = patches_arr.write(on_patch, src[y0:y0+crop_size, x0:x0+crop_size])
|
430 |
+
mask_arr = mask_arr.write(on_patch, img_mask[y0:y0+crop_size, x0:x0+crop_size])
|
431 |
+
|
432 |
+
on += pooled_h*pooled_w
|
433 |
+
on_patch += 1
|
434 |
+
patches = patches_arr.stack()
|
435 |
+
patch_ordering = patch_ordering_arr.stack()
|
436 |
+
img_mask = mask_arr.stack()
|
437 |
+
|
438 |
+
image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
|
439 |
+
patches = einops.rearrange(
|
440 |
+
patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
|
441 |
+
dh=base_image_input_d,
|
442 |
+
dw=base_image_input_d,
|
443 |
+
h=image_base_patch_h,
|
444 |
+
w=image_base_patch_w
|
445 |
+
)
|
446 |
+
img_mask = einops.rearrange(
|
447 |
+
img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
|
448 |
+
dh=base_image_input_d,
|
449 |
+
dw=base_image_input_d,
|
450 |
+
h=image_base_patch_h,
|
451 |
+
w=image_base_patch_w
|
452 |
+
)
|
453 |
+
img_mask = tf.reduce_mean(tf.cast(img_mask, tf.float32), -1)
|
454 |
+
patch_ordering = tf.reshape(patch_ordering, [-1])
|
455 |
+
valid = patch_ordering >= 0
|
456 |
+
|
457 |
+
# Transpose, to get left-to-right order
|
458 |
+
patch_ordering_rh = tf.reshape(patch_ordering,
|
459 |
+
[tiling[0], tiling[1], image_token_length_h, image_token_length_w])
|
460 |
+
patch_ordering_rh = tf.transpose(patch_ordering_rh, [0, 2, 1, 3])
|
461 |
+
patch_ordering_rh = tf.reshape(patch_ordering_rh, [-1])
|
462 |
+
|
463 |
+
# The tranpose will screw up which patches are masked, project the
|
464 |
+
# new order into sparse structure of `patch_ordering` to fix this
|
465 |
+
patch_ordering = tf.tensor_scatter_nd_update(
|
466 |
+
patch_ordering,
|
467 |
+
tf.where(valid),
|
468 |
+
tf.boolean_mask(patch_ordering_rh, patch_ordering_rh >= 0),
|
469 |
+
name="patch_order_transpose_Scatter"
|
470 |
+
)
|
471 |
+
|
472 |
+
h = tiling[0]*crop_window_patches + (right_margin+left_margin)
|
473 |
+
w = tiling[1]*crop_window_patches + (right_margin+left_margin)
|
474 |
+
special_token_ids = self.special_token_ids
|
475 |
+
per_row = tf.fill(((w+1)//2,),
|
476 |
+
special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
|
477 |
+
if self.use_col_tokens:
|
478 |
+
per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
|
479 |
+
|
480 |
+
joint = tf.tile(per_row, [(h+1)//2])
|
481 |
+
joint = [
|
482 |
+
[special_token_ids[config.DEFAULT_IM_START_TOKEN]],
|
483 |
+
joint,
|
484 |
+
[special_token_ids[config.DEFAULT_IM_END_TOKEN]]
|
485 |
+
]
|
486 |
+
|
487 |
+
if "resize" in mode:
|
488 |
+
resized, resized_mask = _resize(image, base_image_input_size)
|
489 |
+
resized, resized_mask = _img_to_patches(resized, resized_mask)
|
490 |
+
if 'c2' in mode:
|
491 |
+
patches = tf.concat([resized, patches], 0)
|
492 |
+
image_mask = tf.concat([resized_mask, img_mask], 0)
|
493 |
+
else:
|
494 |
+
patches = tf.concat([patches, resized], 0)
|
495 |
+
image_mask = tf.concat([img_mask, resized_mask], 0)
|
496 |
+
|
497 |
+
if patch_ordering is not None:
|
498 |
+
if 'c2' in mode:
|
499 |
+
patch_ordering = tf.where(
|
500 |
+
patch_ordering >= 0,
|
501 |
+
patch_ordering + tokens_per_image,
|
502 |
+
-1
|
503 |
+
)
|
504 |
+
patch_ordering = tf.concat([tf.range(0, tokens_per_image), patch_ordering], 0)
|
505 |
+
else:
|
506 |
+
raise ValueError()
|
507 |
+
per_row = tf.fill((image_token_length_w,), special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
|
508 |
+
if self.use_col_tokens:
|
509 |
+
per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
|
510 |
+
extra_tokens = tf.tile(per_row, [image_token_length_h])
|
511 |
+
joint = [
|
512 |
+
[special_token_ids[config.DEFAULT_IM_START_TOKEN]],
|
513 |
+
extra_tokens,
|
514 |
+
[special_token_ids[config.DEFAULT_IM_END_TOKEN]],
|
515 |
+
] + joint
|
516 |
+
|
517 |
+
joint = tf.concat(joint, 0)
|
518 |
+
return patches, joint, patch_ordering, img_mask
|
519 |
+
|
520 |
+
elif mode in ["patchify", "patchify-and-resize", "patchify-v2", "patchify-v2-and-resize", "patchify-v2-and-resize-c2"]:
|
521 |
+
original_image_w = tf.shape(image, out_type=tf.int32)[0]
|
522 |
+
original_image_h = tf.shape(image, out_type=tf.int32)[1]
|
523 |
+
assert base_image_input_size[0] == base_image_input_size[1]
|
524 |
+
base_patch_size = base_image_input_size[0]
|
525 |
+
tiling = select_tiling(original_image_w, original_image_h, base_patch_size, self.max_crops)
|
526 |
+
|
527 |
+
patches, img_mask = _resize(
|
528 |
+
image, [tiling[0]*base_patch_size, tiling[1]*base_patch_size])
|
529 |
+
patches, img_mask = _img_to_patches(patches, img_mask, tiling[0], tiling[1])
|
530 |
+
if 'v2' in mode:
|
531 |
+
# Order patches left-to-right not crop-by-crop
|
532 |
+
patch_ordering = tf.reshape(
|
533 |
+
tf.range(tokens_per_image*tiling[0]*tiling[1]),
|
534 |
+
[tiling[0], tiling[1], image_token_length_w, image_token_length_h])
|
535 |
+
patch_ordering = tf.transpose(patch_ordering, [0, 2, 1, 3])
|
536 |
+
patch_ordering = tf.reshape(patch_ordering, (-1, tokens_per_image))
|
537 |
+
else:
|
538 |
+
patch_ordering = None
|
539 |
+
|
540 |
+
# given image size, determine the number of patch size.
|
541 |
+
image_layout_impatch_w = tiling[0]
|
542 |
+
image_layout_impatch_h = tiling[1]
|
543 |
+
|
544 |
+
if "resize" in mode:
|
545 |
+
extra_image = True
|
546 |
+
resized, resized_mask = _resize(image, base_image_input_size)
|
547 |
+
resized, resized_mask = _img_to_patches(resized, resized_mask)
|
548 |
+
if 'c2' in mode:
|
549 |
+
patches = tf.concat([resized, patches], 0)
|
550 |
+
image_mask = tf.concat([resized_mask, img_mask], 0)
|
551 |
+
else:
|
552 |
+
patches = tf.concat([patches, resized], 0)
|
553 |
+
image_mask = tf.concat([img_mask, resized_mask], 0)
|
554 |
+
|
555 |
+
if patch_ordering is not None:
|
556 |
+
if 'c2' in mode:
|
557 |
+
patch_ordering = tf.concat(
|
558 |
+
[tf.range(0, tokens_per_image)[None, :], patch_ordering+tokens_per_image], 0)
|
559 |
+
else:
|
560 |
+
n = tf.shape(patch_ordering)[0]
|
561 |
+
patch_ordering = tf.concat(patch_ordering, [tf.range(n, n+tokens_per_image)[None, :]], 0)
|
562 |
+
else:
|
563 |
+
raise NotImplementedError(mode)
|
564 |
+
|
565 |
+
special_token_ids = self.special_token_ids
|
566 |
+
|
567 |
+
per_row = tf.fill((image_token_length_w*image_layout_impatch_w,),
|
568 |
+
special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
|
569 |
+
if self.use_col_tokens:
|
570 |
+
per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
|
571 |
+
|
572 |
+
joint = tf.tile(per_row, [image_token_length_h * image_layout_impatch_h])
|
573 |
+
joint = [
|
574 |
+
[special_token_ids[config.DEFAULT_IM_START_TOKEN]],
|
575 |
+
joint,
|
576 |
+
[special_token_ids[config.DEFAULT_IM_END_TOKEN]]
|
577 |
+
]
|
578 |
+
if extra_image:
|
579 |
+
assert not self.image_padding_mask
|
580 |
+
per_row = tf.fill((image_token_length_w,), special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN],)
|
581 |
+
if self.use_col_tokens:
|
582 |
+
per_row = tf.concat([per_row, [special_token_ids[config.DEFAULT_IM_COL_TOKEN]]], 0)
|
583 |
+
extra_tokens = tf.tile(per_row, [image_token_length_h])
|
584 |
+
if 'c2' in mode:
|
585 |
+
joint = [
|
586 |
+
[special_token_ids[config.DEFAULT_IM_START_TOKEN]],
|
587 |
+
extra_tokens,
|
588 |
+
[special_token_ids[config.DEFAULT_IM_END_TOKEN]],
|
589 |
+
] + joint
|
590 |
+
else:
|
591 |
+
joint += [
|
592 |
+
[special_token_ids[config.DEFAULT_IM_START_TOKEN]],
|
593 |
+
extra_tokens,
|
594 |
+
[special_token_ids[config.DEFAULT_IM_END_TOKEN]]
|
595 |
+
]
|
596 |
+
if self.pad_to is not None:
|
597 |
+
n = [tf.shape(x)[0] for x in joint]
|
598 |
+
assert len(joint[-1]) == 1
|
599 |
+
to_pad = self.pad_to - tf.reduce_sum(tf.stack(n))
|
600 |
+
joint = tf.concat(joint[:-1] + [
|
601 |
+
tf.zeros(to_pad, dtype=tf.int32) - 1,
|
602 |
+
joint[-1]
|
603 |
+
], axis=0)
|
604 |
+
else:
|
605 |
+
joint = tf.concat(joint, 0)
|
606 |
+
return patches, tf.concat(joint, 0), patch_ordering, img_mask
|
607 |
+
|
608 |
+
def build_image_input_idx(self, input_tokens, patch_order, no_image=None):
|
609 |
+
"""Builds the index used to insert patch features into `input_tokens`"""
|
610 |
+
tokens_per_image = self.image_token_length_w * self.image_token_length_h
|
611 |
+
if no_image is not None and no_image:
|
612 |
+
return tf.zeros((0, tokens_per_image), tf.int32)
|
613 |
+
|
614 |
+
image_input_idx = input_tokens == self.special_token_ids[config.DEFAULT_IMAGE_PATCH_TOKEN]
|
615 |
+
image_input_idx = tf.experimental.numpy.nonzero(image_input_idx)[0]
|
616 |
+
image_input_idx = tf.cast(image_input_idx, tf.int32)
|
617 |
+
|
618 |
+
if patch_order is not None:
|
619 |
+
n_tokens = tf.shape(image_input_idx)[0]
|
620 |
+
# Item N should have the value of image_input_index[where(patch_order == n)] if >= 0 else -1
|
621 |
+
patch_order = tf.reshape(patch_order, [-1])
|
622 |
+
n_patches = tf.shape(patch_order)[0]
|
623 |
+
if n_tokens != n_patches:
|
624 |
+
# Most complex case where some patches are dropped
|
625 |
+
# First invert the valid tokens
|
626 |
+
valid = patch_order >= 0
|
627 |
+
sorted_patch_ixs = tf.scatter_nd(
|
628 |
+
tf.boolean_mask(patch_order, valid)[:, None],
|
629 |
+
tf.range(tf.reduce_sum(tf.cast(valid, tf.int32)), dtype=tf.int32),
|
630 |
+
[n_tokens],
|
631 |
+
name="valid_order_scatter"
|
632 |
+
)
|
633 |
+
|
634 |
+
# Project the inverted mapping into same sparse structure
|
635 |
+
tmp = tf.fill(tf.shape(patch_order), -1)
|
636 |
+
sorted_patch_ixs_ex = tf.tensor_scatter_nd_update(
|
637 |
+
tmp,
|
638 |
+
tf.where(valid),
|
639 |
+
sorted_patch_ixs,
|
640 |
+
name="order_with_padding_scatter"
|
641 |
+
)
|
642 |
+
|
643 |
+
# Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
|
644 |
+
valid = tf.cast(sorted_patch_ixs_ex >= 0, tf.int32)
|
645 |
+
image_input_idx = tf.gather(image_input_idx, sorted_patch_ixs_ex*valid)
|
646 |
+
image_input_idx = image_input_idx*valid - 100*(1 - valid)
|
647 |
+
else:
|
648 |
+
sorted_patch_ixs = tf.scatter_nd(patch_order[:, None], tf.range(n_patches), [n_patches])
|
649 |
+
image_input_idx = tf.gather(tf.reshape(image_input_idx, [-1]), sorted_patch_ixs)
|
650 |
+
image_input_idx = tf.reshape(image_input_idx, [-1, tokens_per_image])
|
651 |
+
return image_input_idx
|
652 |
+
|
653 |
+
def build_multimodel_features(self, tokens, mask, subsegments, images, is_training):
|
654 |
+
"""Builds input features by pre-processing `images` and modifying `tokens`
|
655 |
+
to include image col/pad/start/end tokens instead image placeholder tokens
|
656 |
+
"""
|
657 |
+
image_token_id = self.special_token_ids[config.IMAGE_PROMPT]
|
658 |
+
image_idx = tf.experimental.numpy.nonzero(tokens == image_token_id)[0]
|
659 |
+
if images is None or tf.shape(images)[0] == 0:
|
660 |
+
tf.debugging.assert_equal(image_idx, tf.cast(0, tf.int64),
|
661 |
+
"Image placeholders in input, but no images given!")
|
662 |
+
tokens_per_image = self.image_token_length_w * self.image_token_length_h
|
663 |
+
n_pixels = self.image_patch_size ** 2 * 3
|
664 |
+
image_num_patch = np.prod(self.image_num_patch)
|
665 |
+
crops = tf.zeros((0, image_num_patch, n_pixels), dtype=tf.float32)
|
666 |
+
image_idx = tf.zeros((0, tokens_per_image), tf.int32)
|
667 |
+
out = dict(
|
668 |
+
target_tokens=tokens,
|
669 |
+
images=crops,
|
670 |
+
image_input_idx=image_idx,
|
671 |
+
loss_masks=mask
|
672 |
+
)
|
673 |
+
if self.image_padding_mask:
|
674 |
+
out["image_masks"] = tf.zeros((0, image_num_patch), dtype=tf.float32)
|
675 |
+
if subsegments is not None:
|
676 |
+
out["subsegment_ids"] = subsegments
|
677 |
+
return out
|
678 |
+
elif tf.shape(image_idx)[0] == 0 and tf.shape(images)[0] > 0:
|
679 |
+
# As a special case, no image prompt means the images are all at the start
|
680 |
+
image_idx = tf.zeros([tf.shape(images)[0]], tf.int64) - 1
|
681 |
+
else:
|
682 |
+
tf.debugging.assert_equal(
|
683 |
+
tf.shape(images)[0], tf.shape(image_idx)[0],
|
684 |
+
message="Different number of images and image placeholders")
|
685 |
+
|
686 |
+
# Each image will produce a variable number of crops/tokens, so we aggregate things
|
687 |
+
# the results tensor arrays and the concat them
|
688 |
+
tokens_per_image = self.image_token_length_w * self.image_token_length_h
|
689 |
+
n_pixels = self.image_patch_size*self.image_patch_size*3
|
690 |
+
n_patches = self.image_num_patch[0]*self.image_num_patch[1]
|
691 |
+
|
692 |
+
n = tf.shape(images)[0]
|
693 |
+
all_crops = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
|
694 |
+
element_shape=[None, n_patches, n_pixels])
|
695 |
+
all_image_idx = tf.TensorArray(dtype=tf.int32, size=n, infer_shape=False,
|
696 |
+
element_shape=[None, tokens_per_image])
|
697 |
+
out_tokens = tf.TensorArray(dtype=tf.int32, size=n, infer_shape=False,
|
698 |
+
element_shape=[None])
|
699 |
+
out_masks = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
|
700 |
+
element_shape=[None])
|
701 |
+
if self.image_padding_mask:
|
702 |
+
all_crop_masks = tf.TensorArray(dtype=tf.float32, size=n, infer_shape=False,
|
703 |
+
element_shape=[None, None])
|
704 |
+
else:
|
705 |
+
# Dummy array to keep tensorflow's control analysis happy
|
706 |
+
all_crop_masks = tf.TensorArray(dtype=tf.float32, size=0, infer_shape=False,
|
707 |
+
element_shape=[None, None])
|
708 |
+
if subsegments is not None:
|
709 |
+
out_subsegments = tf.TensorArray(dtype=tf.int32, size=n, element_shape=[None])
|
710 |
+
else:
|
711 |
+
out_subsegments = tf.TensorArray(dtype=tf.int32, size=0, element_shape=[None])
|
712 |
+
|
713 |
+
image_idx = tf.cast(image_idx, tf.int32)
|
714 |
+
for ix in range(tf.shape(image_idx)[0]):
|
715 |
+
token_ix = image_idx[ix]
|
716 |
+
crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(images[ix], is_training)
|
717 |
+
patch_idx = self.build_image_input_idx(image_tokens, patch_ordering)
|
718 |
+
|
719 |
+
if token_ix == -1: # -1 is an image inserted at the very start
|
720 |
+
start = 0
|
721 |
+
token_ix = 0
|
722 |
+
end = 0
|
723 |
+
else:
|
724 |
+
start = 0 if ix == 0 else image_idx[ix-1] + 1
|
725 |
+
end = token_ix + 1
|
726 |
+
|
727 |
+
all_image_idx = all_image_idx.write(ix, patch_idx + token_ix)
|
728 |
+
all_crops = all_crops.write(ix, crops)
|
729 |
+
image_token_mask = tf.zeros_like(image_tokens, dtype=tf.float32)
|
730 |
+
|
731 |
+
if ix == (tf.shape(images)[0] - 1):
|
732 |
+
tokens_part = tf.concat([tokens[start:token_ix], image_tokens, tokens[end:]], 0)
|
733 |
+
mask_part = tf.concat([mask[start:token_ix], image_token_mask, mask[end:]], 0)
|
734 |
+
else:
|
735 |
+
tokens_part = tf.concat([tokens[start:token_ix], image_tokens], 0)
|
736 |
+
mask_part = tf.concat([mask[start:token_ix], image_token_mask], 0)
|
737 |
+
|
738 |
+
out_tokens = out_tokens.write(ix, tokens_part)
|
739 |
+
out_masks = out_masks.write(ix, mask_part)
|
740 |
+
if self.image_padding_mask:
|
741 |
+
all_crop_masks = all_crop_masks.write(ix, img_mask)
|
742 |
+
if subsegments is not None:
|
743 |
+
parts = tf.fill([tf.shape(image_tokens)[0]], subsegments[token_ix])
|
744 |
+
if ix == (tf.shape(images)[0] - 1):
|
745 |
+
seg = tf.concat([subsegments[start:token_ix], parts, subsegments[end:]], 0)
|
746 |
+
else:
|
747 |
+
seg = tf.concat([subsegments[start:token_ix], parts], 0)
|
748 |
+
out_subsegments = out_subsegments.write(ix, seg)
|
749 |
+
|
750 |
+
out = dict(
|
751 |
+
target_tokens=out_tokens.concat(),
|
752 |
+
images=all_crops.concat(),
|
753 |
+
image_input_idx=all_image_idx.concat(),
|
754 |
+
loss_masks=out_masks.concat()
|
755 |
+
)
|
756 |
+
if self.image_padding_mask:
|
757 |
+
out["image_masks"] = all_crop_masks.concat()
|
758 |
+
if subsegments is not None:
|
759 |
+
out["subsegment_ids"] = out_subsegments.concat()
|
760 |
+
return out
|
761 |
+
|
762 |
+
def _format_message(self, args):
|
763 |
+
message, ix = args
|
764 |
+
return self.format_message(message, ix)
|
765 |
+
|
766 |
+
def format_message(self, message, ix):
|
767 |
+
"""Applies system formatting to ith message from a sequence of messages"""
|
768 |
+
# If the image placeholder text is not preceded by space it will not get tokenized
|
769 |
+
# correctly by some tokenizers, so double check it here
|
770 |
+
assert config.IMAGE_PROMPT == "<|image|>"
|
771 |
+
tf.debugging.assert_equal(
|
772 |
+
tf.strings.regex_full_match(message, r".*[^ ]<\|image\|>.*"),
|
773 |
+
False,
|
774 |
+
message="Image token must always be preceded by a space"
|
775 |
+
)
|
776 |
+
is_user = ix % 2 == 0
|
777 |
+
if self.message_format == "none" or self.message_format is None:
|
778 |
+
pass
|
779 |
+
elif self.message_format == "role":
|
780 |
+
if is_user:
|
781 |
+
# We put the "System:" prefix here since it doesn't need a loss
|
782 |
+
message = tf.strings.join(["User: ", message, " Assistant:"])
|
783 |
+
elif self.message_format == "cleanup":
|
784 |
+
if is_user:
|
785 |
+
# We put the "System:" prefix here since it doesn't need a loss
|
786 |
+
message = tf.strings.join(
|
787 |
+
[
|
788 |
+
"[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
|
789 |
+
message,
|
790 |
+
"\n[[Assistant]]: {after}"
|
791 |
+
]
|
792 |
+
)
|
793 |
+
elif self.message_format == "mistral":
|
794 |
+
if is_user:
|
795 |
+
message = tf.strings.join(["[INST] ", message, " [/INST]"])
|
796 |
+
else:
|
797 |
+
raise NotImplementedError(self.message_format)
|
798 |
+
|
799 |
+
# For now assume a space will be used to separate the messages
|
800 |
+
if not self.tokenizer.adds_space:
|
801 |
+
if ix != 0 or self.always_start_with_space:
|
802 |
+
message = tf.strings.join([" ", message])
|
803 |
+
# Else space added automatically by the tokenizer
|
804 |
+
|
805 |
+
return message
|
806 |
+
|
807 |
+
def get_multi_message_token_input(self, conversations, text_weights=None):
|
808 |
+
"""Build inputs for a ragged tensor of conversations, where each row of the tensor,
|
809 |
+
is a different conversation"""
|
810 |
+
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(
|
811 |
+
conversations.values, re.escape(config.IMAGE_PROMPT))), False, "Segmented prompts must start with the image")
|
812 |
+
|
813 |
+
n_conversation = tf.shape(conversations)[0]
|
814 |
+
ar = tf.TensorArray(dtype=tf.int32, infer_shape=False, element_shape=[None],
|
815 |
+
size=n_conversation)
|
816 |
+
n_messages_per_conversation = conversations.row_lengths()
|
817 |
+
for ix in range(n_conversation):
|
818 |
+
ar = ar.write(ix, tf.range(n_messages_per_conversation[ix], dtype=tf.int32))
|
819 |
+
message_ix = ar.concat()
|
820 |
+
messages = tf.map_fn(
|
821 |
+
self._format_message, elems=(conversations.values, message_ix), fn_output_signature=tf.string)
|
822 |
+
messages = self.tokenizer.encode_tf(messages)
|
823 |
+
|
824 |
+
# Append EOS
|
825 |
+
is_response = message_ix % 2 == 1
|
826 |
+
is_response_int = tf.cast(is_response, tf.int32)
|
827 |
+
eos = tf.RaggedTensor.from_row_lengths(
|
828 |
+
tf.fill([tf.reduce_sum(is_response_int)], self.tokenizer.eos_token_id),
|
829 |
+
tf.cast(is_response_int, messages.row_splits.dtype)
|
830 |
+
)
|
831 |
+
messages = tf.concat([messages, eos], axis=1)
|
832 |
+
|
833 |
+
# Build mask over system responses
|
834 |
+
mask = tf.ones_like(messages) * tf.cast(tf.expand_dims(is_response, axis=1), tf.int32)
|
835 |
+
decoder_loss_weights = tf.cast(mask.values, tf.float32)
|
836 |
+
|
837 |
+
# Build subsegment ids for each conversation
|
838 |
+
tokens_per_message = tf.RaggedTensor.from_row_splits(
|
839 |
+
row_splits=conversations.row_splits,
|
840 |
+
values=messages.row_lengths()
|
841 |
+
)
|
842 |
+
token_per_conversation = tf.reduce_sum(tokens_per_message, axis=1)
|
843 |
+
subsegment_ids = tf.repeat(tf.range(n_conversation, dtype=tf.int32)+1, token_per_conversation)
|
844 |
+
|
845 |
+
image_ix = self.special_token_ids[config.IMAGE_PROMPT]
|
846 |
+
messages = tf.concat([[image_ix], messages.values], axis=0)
|
847 |
+
decoder_loss_weights = tf.concat([[0], decoder_loss_weights], axis=0)
|
848 |
+
subsegment_ids = tf.concat([[10000], subsegment_ids], axis=0)
|
849 |
+
return messages, decoder_loss_weights, subsegment_ids
|
850 |
+
|
851 |
+
def get_multi_response_token_input(self, user_prompt, text, text_weights=None):
|
852 |
+
"""Build tokens for a multi-response-per-image example"""
|
853 |
+
# FIXME this could be relaxed to just having the same prefix
|
854 |
+
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(
|
855 |
+
user_prompt, re.escape(config.IMAGE_PROMPT))), False, "Segmented prompts must start with the image")
|
856 |
+
user_prompt = self.format_message(user_prompt, 0)
|
857 |
+
vocab = self.tokenizer
|
858 |
+
prompts = vocab.encode_tf(user_prompt)
|
859 |
+
response = self.format_message(text, 1)
|
860 |
+
responses = vocab.encode_tf(response)
|
861 |
+
responses = _append_to_innermost_axis(responses, vocab.eos_token_id)
|
862 |
+
response_mask = tf.ones_like(responses, dtype=tf.float32)
|
863 |
+
if text_weights is not None:
|
864 |
+
response_mask *= text_weights
|
865 |
+
image_tokens = tf.constant([self.special_token_ids[config.IMAGE_PROMPT]])
|
866 |
+
|
867 |
+
if len(responses.shape) == 3:
|
868 |
+
# Tricky case where we have multiple questions, each of which has multiple answers
|
869 |
+
assert len(prompts.shape) == 2
|
870 |
+
|
871 |
+
# Also shift the last tokens to the response segment since that tokens will
|
872 |
+
# have multiple possible target tokens to predict
|
873 |
+
last_prompt_tokens = prompts[:, -1:]
|
874 |
+
last_prompt_tokens = tf.repeat(last_prompt_tokens, responses.row_lengths())
|
875 |
+
last_prompt_tokens = tf.RaggedTensor.from_row_splits(
|
876 |
+
values=tf.RaggedTensor.from_row_lengths(
|
877 |
+
values=last_prompt_tokens,
|
878 |
+
row_lengths=tf.ones_like(last_prompt_tokens, dtype=responses.row_splits.dtype)
|
879 |
+
),
|
880 |
+
row_splits=responses.row_splits
|
881 |
+
)
|
882 |
+
responses = tf.concat([last_prompt_tokens, responses], 2)
|
883 |
+
prompts = prompts[:, :-1]
|
884 |
+
|
885 |
+
shared_prefix = image_tokens
|
886 |
+
segmented_suffix = tf.concat([tf.expand_dims(prompts, 1), responses], 1)
|
887 |
+
targets = tf.concat([shared_prefix, segmented_suffix.values.values], 0)
|
888 |
+
|
889 |
+
segmented_mask = tf.concat([
|
890 |
+
tf.zeros_like(tf.expand_dims(prompts, 1), dtype=tf.float32),
|
891 |
+
tf.concat([
|
892 |
+
tf.zeros_like(last_prompt_tokens, dtype=tf.float32),
|
893 |
+
response_mask
|
894 |
+
], 2)
|
895 |
+
], 1).values.values
|
896 |
+
decoder_loss_weights = tf.concat(
|
897 |
+
[tf.zeros_like(shared_prefix, dtype=tf.float32), segmented_mask], 0)
|
898 |
+
|
899 |
+
text_segment_ids = get_3d_subsegments(segmented_suffix)
|
900 |
+
subsegment_ids = tf.concat([
|
901 |
+
tf.zeros_like(shared_prefix) + tf.reduce_max(text_segment_ids)+1,
|
902 |
+
text_segment_ids], 0)
|
903 |
+
subsegment_ids = tf.cast(subsegment_ids, tf.int32)
|
904 |
+
else:
|
905 |
+
if len(prompts.shape) == 1:
|
906 |
+
# One prompt for all responses, we use the last token of the prompt as the
|
907 |
+
# first token of each response segment since there will be multiple targets
|
908 |
+
# for that token, the remaining targets are part of the prefix
|
909 |
+
shared_prefix = tf.concat([image_tokens, prompts[:-1]], 0)
|
910 |
+
prompts = prompts[-1:]
|
911 |
+
prompts = tf.tile(tf.expand_dims(prompts, axis=0), [tf.shape(text)[0], 1])
|
912 |
+
else:
|
913 |
+
shared_prefix = image_tokens
|
914 |
+
|
915 |
+
# Separate prompt for each response
|
916 |
+
segmented_suffix = tf.concat([prompts, responses], 1)
|
917 |
+
segmented_mask = tf.concat([tf.zeros_like(prompts, dtype=tf.float32), response_mask], 1).values
|
918 |
+
|
919 |
+
targets = tf.concat([shared_prefix, segmented_suffix.values], 0)
|
920 |
+
decoder_loss_weights = tf.concat(
|
921 |
+
[tf.zeros_like(shared_prefix, dtype=tf.float32), segmented_mask], 0)
|
922 |
+
subsegments = tf.ragged.row_splits_to_segment_ids(segmented_suffix.row_splits) + 1
|
923 |
+
subsegment_ids = tf.concat([tf.zeros_like(shared_prefix)+10000,
|
924 |
+
tf.cast(subsegments, tf.int32)], 0)
|
925 |
+
return targets, decoder_loss_weights, subsegment_ids
|
926 |
+
|
927 |
+
def get_tokens_input(self, messages, for_inference=False, text_weights=None):
|
928 |
+
"""Gets the token input for an example, using image placeholder tokens to
|
929 |
+
indicate where images features should be inserted
|
930 |
+
|
931 |
+
inputs
|
932 |
+
messages: List or tensor users/system text messages, can have image placeholder tokens
|
933 |
+
for_inference: bool, if true truncate the messages if it is a system message
|
934 |
+
text_weights: Weights per a system message
|
935 |
+
|
936 |
+
returns
|
937 |
+
tokens: [n_tokens] tf.int32 token inputs with image placeholder tokens
|
938 |
+
loss_mask: [n_tokens] tf.float32 token weights for loss
|
939 |
+
subsegment: [n_tokens] tf.int32 or None, subsegment ids used to build more complex
|
940 |
+
attention masks if needed
|
941 |
+
"""
|
942 |
+
if isinstance(messages, tf.RaggedTensor):
|
943 |
+
assert not for_inference, "Cannot have multiple target messages for inference"
|
944 |
+
return self.get_multi_message_token_input(messages, text_weights)
|
945 |
+
elif len(tf.shape(messages[-1])) > 0:
|
946 |
+
assert not for_inference, "Cannot have multiple target messages for inference"
|
947 |
+
assert len(messages) == 2
|
948 |
+
prompt = messages[0]
|
949 |
+
response = messages[1]
|
950 |
+
return self.get_multi_response_token_input(prompt, response, text_weights)
|
951 |
+
else:
|
952 |
+
messages = tf.convert_to_tensor(messages)
|
953 |
+
if for_inference:
|
954 |
+
if tf.shape(messages) % 2 == 0:
|
955 |
+
# Remove the last message since the model should predict it
|
956 |
+
messages = messages[:-1]
|
957 |
+
|
958 |
+
# Apply system formatting
|
959 |
+
ix = tf.range(tf.shape(messages)[0])
|
960 |
+
is_response = ix % 2 == 1
|
961 |
+
messages = tf.map_fn(
|
962 |
+
self._format_message, elems=(messages, ix), fn_output_signature=tf.string)
|
963 |
+
|
964 |
+
# Tokenize
|
965 |
+
messages = self.tokenizer.encode_tf(messages)
|
966 |
+
|
967 |
+
# Add EOS to system messages
|
968 |
+
is_response_int = tf.cast(is_response, tf.int32)
|
969 |
+
eos = tf.RaggedTensor.from_row_lengths(
|
970 |
+
tf.fill([tf.reduce_sum(is_response_int)], self.tokenizer.eos_token_id),
|
971 |
+
tf.cast(is_response_int, messages.row_splits.dtype)
|
972 |
+
)
|
973 |
+
messages = tf.concat([messages, eos], axis=1)
|
974 |
+
targets = messages.values
|
975 |
+
|
976 |
+
# Build mask over system responses
|
977 |
+
mask = tf.ones_like(messages) * tf.cast(tf.expand_dims(is_response, axis=1), tf.int32)
|
978 |
+
decoder_loss_weights = tf.cast(mask.values, tf.float32)
|
979 |
+
if text_weights is not None:
|
980 |
+
decoder_loss_weights = decoder_loss_weights * text_weights
|
981 |
+
return messages.values, decoder_loss_weights, None
|
982 |
+
|
983 |
+
def preprocess(self, image, input_text, is_training=False,
|
984 |
+
seq_len=None, pad_images=1, style=None, for_inference=True):
|
985 |
+
"""Get input tensors for the given image/text data
|
986 |
+
|
987 |
+
image: [h, w, 3] numpy uint8 array of image pixels
|
988 |
+
input_text: string input text, a list of text for a multi-turn conversation or dictionary
|
989 |
+
of inputs to use to build the prompt from a template
|
990 |
+
is_training: allow training-time preprocessing (e.g., image augmentation)
|
991 |
+
seq_len: pad input tokens to `seq_len`
|
992 |
+
pad_images: pad input images to `self.get_max_total_crops()`
|
993 |
+
style: Style to use for prompt templating
|
994 |
+
"""
|
995 |
+
if image is not None and len(tf.shape(image)) == 3:
|
996 |
+
image = tf.expand_dims(image, axis=0)
|
997 |
+
|
998 |
+
messages = self.get_messages(input_text, style, is_training, for_inference=for_inference, user_prompt_seed=None, system_prompt_seed=None)
|
999 |
+
targets, loss_masks, subsegments = self.get_tokens_input(messages, for_inference=for_inference)
|
1000 |
+
batch = self.build_multimodel_features(
|
1001 |
+
targets, loss_masks, subsegments, image, is_training)
|
1002 |
+
|
1003 |
+
# Optionally padding to get constant sized arrays
|
1004 |
+
if pad_images:
|
1005 |
+
max_crops = self.get_max_total_crops() * pad_images
|
1006 |
+
image = batch["images"]
|
1007 |
+
n = max_crops - tf.shape(batch["images"])[0]
|
1008 |
+
batch["images"] = tf.pad(image, [[0, n], [0, 0], [0, 0]], constant_values=-1)
|
1009 |
+
if self.image_padding_mask:
|
1010 |
+
m = max_crops - tf.shape(batch["image_masks"])[0]
|
1011 |
+
batch["image_masks"] = tf.pad(batch["image_masks"], [[0, m], [0, 0]], constant_values=-1)
|
1012 |
+
batch["image_input_idx"] = tf.pad(batch["image_input_idx"], [[0, n], [0, 0]], constant_values=-1)
|
1013 |
+
|
1014 |
+
if seq_len is not None:
|
1015 |
+
targets = batch["target_tokens"]
|
1016 |
+
if seq_len < len(targets):
|
1017 |
+
raise ValueError("Sequence length too short")
|
1018 |
+
n = seq_len - len(targets)
|
1019 |
+
batch["target_tokens"] = tf.pad(targets, [[0, n]], constant_values=-1)
|
1020 |
+
batch["loss_masks"] = tf.pad(batch["loss_masks"], [[0, n]], constant_values=-1)
|
1021 |
+
|
1022 |
+
batch = self.get_post_mixing_preprocessor(pack=False)._convert_example(batch)
|
1023 |
+
return batch
|
1024 |
+
|
1025 |
+
def get_user_prompt(self, style, example, is_training=True, for_inference=False, seed=None):
|
1026 |
+
"""Build a list of strings of what a user might type in to the model for the given example,
|
1027 |
+
and its responses, by applying a prompt template to the fields in `example`
|
1028 |
+
|
1029 |
+
Can return multiple strings for one message for multi-response examples
|
1030 |
+
"""
|
1031 |
+
if "style" in example:
|
1032 |
+
style = example["style"]
|
1033 |
+
|
1034 |
+
if "prompt" in example:
|
1035 |
+
# Examples have a complete user prompt pre-specified, usually for eval sets
|
1036 |
+
prompt = example["prompt"]
|
1037 |
+
|
1038 |
+
elif self.prompt_templates == "none":
|
1039 |
+
# Bare-bone prompt with not templating of instructions
|
1040 |
+
if "prompt" in example:
|
1041 |
+
prompt = example["prompt"]
|
1042 |
+
elif "refexp" in example:
|
1043 |
+
prompt = example["refexp"]
|
1044 |
+
elif "question" in example and "options" in example:
|
1045 |
+
prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
|
1046 |
+
elif "question" in example:
|
1047 |
+
prompt = example["question"]
|
1048 |
+
else:
|
1049 |
+
prompt = ""
|
1050 |
+
|
1051 |
+
elif self.prompt_templates == "uber_model":
|
1052 |
+
if not isinstance(style, str):
|
1053 |
+
tf.debugging.assert_equal(tf.logical_or(
|
1054 |
+
style == "ai2_diagram_no_letter",
|
1055 |
+
style == "ai2_diagram",
|
1056 |
+
), True)
|
1057 |
+
prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
|
1058 |
+
else:
|
1059 |
+
# We template long captions and pointing since they are "demo" tasks, and use
|
1060 |
+
# plain text for everything else
|
1061 |
+
if style == "long_caption":
|
1062 |
+
prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], example, seed)
|
1063 |
+
elif style == "pointing":
|
1064 |
+
prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["pointing"], example, seed)
|
1065 |
+
elif style == "point_count":
|
1066 |
+
prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["point_count"], example, seed)
|
1067 |
+
elif "prompt" in example:
|
1068 |
+
prompt = example["prompt"]
|
1069 |
+
elif "refexp" in example:
|
1070 |
+
prompt = example["refexp"]
|
1071 |
+
elif "question" in example and "options" in example:
|
1072 |
+
prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
|
1073 |
+
elif "question" in example:
|
1074 |
+
prompt = example["question"]
|
1075 |
+
else:
|
1076 |
+
prompt = ""
|
1077 |
+
|
1078 |
+
elif self.prompt_templates == "uber_model_pointing":
|
1079 |
+
if style == "long_caption":
|
1080 |
+
long_captions = GENERAL_PROMPTS_V1["long_caption_no_pointing"]
|
1081 |
+
prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], example, seed)
|
1082 |
+
elif style == "pointing":
|
1083 |
+
prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1["pointing"], example, seed)
|
1084 |
+
elif style in [
|
1085 |
+
"scifi_charts_explanation",
|
1086 |
+
"scifi_table_explanation",
|
1087 |
+
"scifi_document_explanation",
|
1088 |
+
"scifi_diagram_explanation",
|
1089 |
+
"user_qa",
|
1090 |
+
"long_caption",
|
1091 |
+
]:
|
1092 |
+
raise NotImplementedError()
|
1093 |
+
if style == "long_caption":
|
1094 |
+
prompts = GENERAL_PROMPTS_V1["long_caption"]
|
1095 |
+
elif "prompt" in example:
|
1096 |
+
prompts = tf.expand_dims(example["prompt"], axis=0)
|
1097 |
+
else:
|
1098 |
+
prompts = tf.expand_dims(example["question"], axis=0)
|
1099 |
+
suffixes = []
|
1100 |
+
for suffix in GENERAL_PROMPTS_V1["no_pointing_suffix"]:
|
1101 |
+
if not suffix[0].isspace():
|
1102 |
+
suffix = " " + suffix
|
1103 |
+
suffixes.append(suffix)
|
1104 |
+
no_point_prompts = tf.reshape(tf.strings.join([
|
1105 |
+
tf.tile(tf.expand_dims(suffixes, 1), [1, tf.shape(prompts)[1]]),
|
1106 |
+
tf.tile(prompts, [len(suffixes), 1]),
|
1107 |
+
]), [-1])
|
1108 |
+
# prefixes = []
|
1109 |
+
# for prefix in GENERAL_PROMPTS_V1["no_pointing_prefix"]:
|
1110 |
+
# if not prefix[0].isspace():
|
1111 |
+
# prefix = prefix + " "
|
1112 |
+
# prefixes.append(prompts + prefix)
|
1113 |
+
prompt = apply_keyword_prompt(no_point_prompts, example, seed, keywords=[])
|
1114 |
+
elif "prompt" in example:
|
1115 |
+
prompt = example["prompt"]
|
1116 |
+
elif "refexp" in example:
|
1117 |
+
prompt = example["refexp"]
|
1118 |
+
elif "question" in example and "options" in example:
|
1119 |
+
prompt = tf.strings.join([example["question"], "\n", example["options"], "\n"])
|
1120 |
+
elif "question" in example:
|
1121 |
+
prompt = example["question"]
|
1122 |
+
else:
|
1123 |
+
prompt = ""
|
1124 |
+
|
1125 |
+
elif self.prompt_templates == "general_instructions_v1":
|
1126 |
+
if isinstance(style, str):
|
1127 |
+
prompt = apply_keyword_prompt(GENERAL_PROMPTS_V1[STYLE_TO_GENERAL_PROMPT[style]], example, seed)
|
1128 |
+
elif isinstance(style, list):
|
1129 |
+
# This ia bit of hack to allow apply prompts to joint caption/transcript data
|
1130 |
+
# FIXME ideally we can apply the templating to multiple styles more generally
|
1131 |
+
def _apply(_style, ix):
|
1132 |
+
tmp = dict(example)
|
1133 |
+
# prevent apply_keyword_prompt for generating multiple templates
|
1134 |
+
tmp["text"] = tmp["text"][0]
|
1135 |
+
if _style == "long_caption":
|
1136 |
+
return apply_keyword_prompt(GENERAL_PROMPTS_V1["long_caption"], tmp, seed)
|
1137 |
+
elif _style == "transcript":
|
1138 |
+
return apply_keyword_prompt(GENERAL_PROMPTS_V1["transcript"], tmp, seed)
|
1139 |
+
else:
|
1140 |
+
raise NotImplementedError(_style)
|
1141 |
+
prompt = [_apply(x, ix) for ix, x in enumerate(style)]
|
1142 |
+
else:
|
1143 |
+
raise NotImplementedError()
|
1144 |
+
|
1145 |
+
elif self.prompt_templates == "zero_shot_v1":
|
1146 |
+
assert style is not None
|
1147 |
+
if not isinstance(style, str):
|
1148 |
+
# FIXME can we handle tensor style's in a better way?
|
1149 |
+
if style == "ai2_diagram":
|
1150 |
+
prompt = "Question: {question}\nAnswer with correct answer option letter only\nOptions: {options}\nAnswer:"
|
1151 |
+
prompt = apply_keyword_prompt([prompt], example, seed)
|
1152 |
+
elif style == "ai2_diagram_no_letter":
|
1153 |
+
prompt = "Question: {question}\nAnswer with correct answer option only\nOptions: {options}\nAnswer:"
|
1154 |
+
prompt = apply_keyword_prompt([prompt], example, seed)
|
1155 |
+
else:
|
1156 |
+
prompt = ""
|
1157 |
+
tf.debugging.assert_equal(prompt != "", True)
|
1158 |
+
else:
|
1159 |
+
general_style = STYLE_TO_GENERAL_PROMPT[style]
|
1160 |
+
if general_style == "short_answer":
|
1161 |
+
prompt = apply_keyword_prompt(["Question: {question} Answer with as few words as possible. Answer:"], example, seed)
|
1162 |
+
elif general_style == "multiple_choice":
|
1163 |
+
prompt = apply_keyword_prompt(["Question: {question}\nAnswer with correct answer option letter only\nOptions: {options}\nAnswer:"], example, seed)
|
1164 |
+
elif general_style == "count_bench":
|
1165 |
+
prompt = apply_keyword_prompt(["Question: How many {object} are there?\nRespond with only a number.\nAnswer:"], example, seed)
|
1166 |
+
else:
|
1167 |
+
raise NotImplementedError(general_style)
|
1168 |
+
|
1169 |
+
elif self.prompt_templates == "zero_shot_v2":
|
1170 |
+
assert style is not None
|
1171 |
+
|
1172 |
+
if self.prompt_override:
|
1173 |
+
prompt = apply_keyword_prompt([self.prompt_override], example, seed)
|
1174 |
+
elif not isinstance(style, str):
|
1175 |
+
if style == "ai2_diagram":
|
1176 |
+
prompt = "{question} Answer with correct answer option letter only. Options: {options}"
|
1177 |
+
prompt = apply_keyword_prompt([prompt], example, seed)
|
1178 |
+
elif style == "ai2_diagram_no_letter":
|
1179 |
+
prompt = "{question} Answer with correct answer option only. Options: {options}"
|
1180 |
+
prompt = apply_keyword_prompt([prompt], example, seed)
|
1181 |
+
else:
|
1182 |
+
prompt = ""
|
1183 |
+
tf.debugging.assert_equal(prompt != "", True)
|
1184 |
+
else:
|
1185 |
+
if style in ["vqa2", "gqa", "tally_qa", "okvqa", "a_okvqa_da"]:
|
1186 |
+
prompt = "Answer with a single word. {question}"
|
1187 |
+
elif style in ["text_vqa", "doc_qa", "info_qa", "chart_qa", "st_qa", "ocr_vqa", "dv_qa", "tabwmp_da", "figure_qa", "figure_qa_zero_shot", "plot_qa"]:
|
1188 |
+
prompt = "{question}\nRespond as concisely as possible, do not output anything other than the answer."
|
1189 |
+
elif STYLE_TO_GENERAL_PROMPT[style] == "multiple_choice":
|
1190 |
+
prompt = "{question} Answer with correct answer option letter only. Options: {options}"
|
1191 |
+
elif STYLE_TO_GENERAL_PROMPT[style] == "short_answer":
|
1192 |
+
prompt = "{question} Answer with as few words as possible."
|
1193 |
+
elif style == "vtabfact":
|
1194 |
+
prompt = "{question}"
|
1195 |
+
elif style == "count_bench":
|
1196 |
+
prompt = "How many {object} are there?\nRespond with only a number."
|
1197 |
+
else:
|
1198 |
+
raise NotImplementedError(style)
|
1199 |
+
prompt = apply_keyword_prompt([prompt], example, seed)
|
1200 |
+
else:
|
1201 |
+
raise NotImplementedError(self.prompt_templates)
|
1202 |
+
|
1203 |
+
if for_inference:
|
1204 |
+
return [prompt]
|
1205 |
+
else:
|
1206 |
+
return [prompt, example["text"]]
|
1207 |
+
|
1208 |
+
def get_system_prompt(self, style, example, for_inference,
|
1209 |
+
messages, seed=None):
|
1210 |
+
if isinstance(style, str) and style == "count_bench":
|
1211 |
+
style = "ok_vqa"
|
1212 |
+
|
1213 |
+
if self.system_prompt == "style":
|
1214 |
+
if isinstance(style, str):
|
1215 |
+
prefix = style + ":"
|
1216 |
+
else:
|
1217 |
+
prefix = tf.strings.join([style, ":"])
|
1218 |
+
|
1219 |
+
elif self.system_prompt == "demo_or_style":
|
1220 |
+
if isinstance(style, str):
|
1221 |
+
if style == "android_control" or style == "demo":
|
1222 |
+
# android is a special case since I hacked in prefix in the preprocessor
|
1223 |
+
prefix = ""
|
1224 |
+
elif style in ["scifi_demo", "synthetic_qa"] or style in DEMO_STYLES:
|
1225 |
+
if style == "scifi_demo":
|
1226 |
+
p_no_prompt = 0.2
|
1227 |
+
elif style == "synthetic_qa":
|
1228 |
+
p_no_prompt = 0.25
|
1229 |
+
else:
|
1230 |
+
p_no_prompt = 0.9
|
1231 |
+
if len(tf.shape(messages)) > 1:
|
1232 |
+
n_messages = tf.shape(messages)[1]
|
1233 |
+
style = tf.tile(tf.expand_dims(style, axis=0), [n_messages])
|
1234 |
+
r = tf.random.stateless_uniform([n_messages], seed, 0, 1)
|
1235 |
+
else:
|
1236 |
+
r = tf.random.stateless_uniform((), seed, 0, 1)
|
1237 |
+
prefix = tf.where(r < p_no_prompt, "", tf.strings.join([style + ":"]))
|
1238 |
+
else:
|
1239 |
+
prefix = style + ":"
|
1240 |
+
else:
|
1241 |
+
if tf.reduce_any(style == tf.constant(DEMO_STYLES + ["scifi_demo", "android_control", "demo"])):
|
1242 |
+
prefix = ""
|
1243 |
+
else:
|
1244 |
+
prefix = tf.strings.join([style, ":"])
|
1245 |
+
|
1246 |
+
elif self.system_prompt in ["long_caption_length_hint", "style_long_caption_length_hint"]:
|
1247 |
+
if seed is not None:
|
1248 |
+
raise NotImplementedError("Determinism")
|
1249 |
+
std = 25
|
1250 |
+
use_hint = tf.logical_or(
|
1251 |
+
tf.equal(style, "long_caption"), tf.equal(style, "transcript"))
|
1252 |
+
if self.system_prompt == "style_long_caption_length_hint":
|
1253 |
+
default = tf.strings.join([style, ": "])
|
1254 |
+
else:
|
1255 |
+
default = ""
|
1256 |
+
if for_inference:
|
1257 |
+
assert len(tf.shape(use_hint)) == 0
|
1258 |
+
if self.default_inference_len and use_hint:
|
1259 |
+
prefix = tf.strings.join([style, " ", str(self.default_inference_len), ": "])
|
1260 |
+
else:
|
1261 |
+
prefix = default
|
1262 |
+
else:
|
1263 |
+
std = 25
|
1264 |
+
n = tf.strings.length(messages[-1])
|
1265 |
+
n += tf.cast(tf.random.normal(n.shape)*std, tf.int32)
|
1266 |
+
hint = tf.strings.join([style, " ", tf.strings.as_string(n//15), ": "])
|
1267 |
+
use_hint = tf.logical_and(use_hint, tf.random.uniform(tf.shape(hint)) > 0.1)
|
1268 |
+
prefix = tf.where(use_hint, hint, default)
|
1269 |
+
|
1270 |
+
elif for_inference and self.system_prompt in ["style_and_length", "style_and_length_v2"]:
|
1271 |
+
v2 = self.system_prompt == "style_and_length_v2"
|
1272 |
+
if example.get("length_cond") is not None:
|
1273 |
+
# Examples have individual length conditioning
|
1274 |
+
n = tf.strings.as_string(example["length_cond"])
|
1275 |
+
else:
|
1276 |
+
inference_len = self.default_inference_len
|
1277 |
+
n = None if inference_len is None else str(inference_len)
|
1278 |
+
logging.warning(f"eval len: {n}")
|
1279 |
+
if n is not None and tf.strings.length(n) > 0: # allow empty string to signal unconditioned
|
1280 |
+
prefix = tf.strings.join([style, " ", n, ":"])
|
1281 |
+
else:
|
1282 |
+
prefix = tf.strings.join([style, ":" if v2 else " :"])
|
1283 |
+
elif self.system_prompt in ["style_and_length", "style_and_length_v2"]:
|
1284 |
+
v2 = self.system_prompt == "style_and_length_v2"
|
1285 |
+
std = 25
|
1286 |
+
logging.info(f"style prompt std={std}, percent=10")
|
1287 |
+
if seed is not None:
|
1288 |
+
seeds = tf.random.split(seed)
|
1289 |
+
p = tf.random.stateless_uniform((), seed=seeds[0])
|
1290 |
+
else:
|
1291 |
+
p = tf.random.uniform(())
|
1292 |
+
if p > 0.10:
|
1293 |
+
n = tf.strings.length(messages[-1])
|
1294 |
+
if seed is not None:
|
1295 |
+
n += tf.cast(tf.random.stateless_normal(n.shape, seed=seeds[1])*std, tf.int32)
|
1296 |
+
else:
|
1297 |
+
n += tf.cast(tf.random.normal(n.shape)*std, tf.int32)
|
1298 |
+
n = tf.strings.as_string(n//15)
|
1299 |
+
prefix = tf.strings.join([style, " ", n, ":"])
|
1300 |
+
else:
|
1301 |
+
prefix = tf.strings.join([style, ":" if v2 else " :"])
|
1302 |
+
else:
|
1303 |
+
raise NotImplementedError(self.system_prompt)
|
1304 |
+
|
1305 |
+
return prefix
|
1306 |
+
|
1307 |
+
def preprend_system_prompt(self, style, example, for_inference, messages, seed=None):
|
1308 |
+
prefix = self.get_system_prompt(style, example, for_inference, messages, seed=seed)
|
1309 |
+
separator = tf.where(tf.logical_and(
|
1310 |
+
tf.strings.length(prefix) > 0, tf.strings.length(messages[0]) > 0), " ", "")
|
1311 |
+
with_system_prompt = tf.strings.join([prefix, separator, messages[0]])
|
1312 |
+
if isinstance(messages, list):
|
1313 |
+
messages = [with_system_prompt] + messages[1:]
|
1314 |
+
else:
|
1315 |
+
messages = tf.concat([tf.expand_dims(with_system_prompt, 0), messages[1:]], axis=0)
|
1316 |
+
return messages
|
1317 |
+
|
1318 |
+
def get_messages(self, ex, style, is_training, for_inference, user_prompt_seed, system_prompt_seed):
|
1319 |
+
if isinstance(ex, list):
|
1320 |
+
messages = ex
|
1321 |
+
elif isinstance(ex, str):
|
1322 |
+
messages = [ex]
|
1323 |
+
elif "messages" in ex:
|
1324 |
+
messages = ex["messages"]
|
1325 |
+
else:
|
1326 |
+
# Apply a prompt template
|
1327 |
+
messages = self.get_user_prompt(style, ex, is_training, for_inference=for_inference, seed=user_prompt_seed)
|
1328 |
+
|
1329 |
+
# Maybe add a system prompt. The system prompt gets concatenated with the first user input
|
1330 |
+
if self.system_prompt and self.system_prompt != "none":
|
1331 |
+
if isinstance(ex, dict):
|
1332 |
+
style = ex.get("style", style)
|
1333 |
+
|
1334 |
+
if isinstance(messages, tf.RaggedTensor):
|
1335 |
+
n = tf.shape(messages)[0]
|
1336 |
+
message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
|
1337 |
+
seeds = tf.random.split(system_prompt_seed, n)
|
1338 |
+
for i in range(n):
|
1339 |
+
message_arr = message_arr.write(i, self.preprend_system_prompt(style, None, for_inference, messages[i], seed=seeds[i]))
|
1340 |
+
messages = tf.RaggedTensor.from_row_splits(
|
1341 |
+
values=message_arr.concat(), row_splits=messages.row_splits)
|
1342 |
+
else:
|
1343 |
+
messages = self.preprend_system_prompt(style, ex, for_inference, messages, seed=system_prompt_seed)
|
1344 |
+
|
1345 |
+
return messages
|
1346 |
+
|
1347 |
+
def get_preprocessor(self, is_training, for_inference, style=None, include_metadata=None):
|
1348 |
+
"""Build a preprocessing function that can be applied ot a tf.data.Dataset"""
|
1349 |
+
vocab = self.tokenizer
|
1350 |
+
include_response = not for_inference
|
1351 |
+
if include_metadata is None:
|
1352 |
+
include_metadata = for_inference
|
1353 |
+
|
1354 |
+
@seqio.map_over_dataset(num_seeds=2)
|
1355 |
+
def to_inputs_and_targets(ex, seeds):
|
1356 |
+
if "unconditioned" in ex:
|
1357 |
+
raise NotImplementedError()
|
1358 |
+
if "image" not in ex:
|
1359 |
+
image = None
|
1360 |
+
elif ex['image'].dtype == tf.string:
|
1361 |
+
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1362 |
+
else:
|
1363 |
+
image = ex['image']
|
1364 |
+
raw_image = image
|
1365 |
+
if image is not None and len(tf.shape(image)) == 3:
|
1366 |
+
image = tf.expand_dims(image, axis=0)
|
1367 |
+
|
1368 |
+
unconditioned = self.unconditioned
|
1369 |
+
if unconditioned and isinstance(unconditioned, float):
|
1370 |
+
assert image is not None
|
1371 |
+
if is_training and tf.random.uniform((), 0, 1, dtype=tf.float32) < unconditioned:
|
1372 |
+
image = image[:0]
|
1373 |
+
elif unconditioned:
|
1374 |
+
image = None
|
1375 |
+
|
1376 |
+
messages = self.get_messages(ex, style, is_training, for_inference, seeds[0], seeds[1])
|
1377 |
+
targets, loss_masks, subsegments = self.get_tokens_input(
|
1378 |
+
messages, for_inference, ex.get("text_weights"))
|
1379 |
+
# if "scifi" in style and style.endswith("_explanation"):
|
1380 |
+
# logging.warning(f"No loss on EOS for {style}")
|
1381 |
+
# loss_masks = tf.where(targets == self.tokenizer.eos_token_id, tf.zeros_like(loss_masks), loss_masks)
|
1382 |
+
out = self.build_multimodel_features(targets, loss_masks, subsegments, image, is_training)
|
1383 |
+
|
1384 |
+
if include_metadata:
|
1385 |
+
# FIXME remove these special cases
|
1386 |
+
if "text" in ex:
|
1387 |
+
if len(ex["text"].shape) > 0:
|
1388 |
+
# FIXME can this be variable lengths after all?
|
1389 |
+
out["metadata/captions"] = tf.strings.reduce_join(
|
1390 |
+
tf.strings.regex_replace(ex['text'], "\\s+", " "),
|
1391 |
+
separator="\n"
|
1392 |
+
)
|
1393 |
+
else:
|
1394 |
+
out["metadata/captions"] = ex["text"]
|
1395 |
+
|
1396 |
+
if "image_url" in ex:
|
1397 |
+
out["metadata/image_url"] = ex["image_url"]
|
1398 |
+
elif "url" in ex:
|
1399 |
+
out["metadata/image_url"] = ex["url"]
|
1400 |
+
if "image_id" in ex:
|
1401 |
+
out["metadata/image_id"] = ex["image_id"]
|
1402 |
+
for k, v in ex.items():
|
1403 |
+
if k.startswith("metadata"):
|
1404 |
+
out[k] = v
|
1405 |
+
if raw_image is not None and "metadata/image_size" not in out:
|
1406 |
+
img_h = tf.shape(raw_image)[0]
|
1407 |
+
img_w = tf.shape(raw_image)[1]
|
1408 |
+
out["metadata/image_size"] = [img_w, img_h]
|
1409 |
+
if "metadata/image_url" not in out and raw_image is not None:
|
1410 |
+
if len(ex["image"].shape) < 4:
|
1411 |
+
# For visualizations FIXME can we make this variable length
|
1412 |
+
out["metadata/image"] = tf.io.encode_jpeg(
|
1413 |
+
tf.image.convert_image_dtype(raw_image, tf.uint8))
|
1414 |
+
return out
|
1415 |
+
return to_inputs_and_targets
|
1416 |
+
|
1417 |
+
def get_post_mixing_preprocessor(self, pack=False):
|
1418 |
+
"""Build a feature conversion function that can be applied ot a tf.data.Dataset
|
1419 |
+
|
1420 |
+
This function applies a second stage of pre-processing, but unlike `self.get_preprocessor`
|
1421 |
+
this stage can be applied after mixing tf.data.Datasets into a mixture
|
1422 |
+
"""
|
1423 |
+
return MultiModalLMFeatureConverter(
|
1424 |
+
loss_token_weighting=self.loss_token_weighting,
|
1425 |
+
bos_id=self.tokenizer.bos_token_id,
|
1426 |
+
fix_image_input_idx=self.fix_image_input_idx,
|
1427 |
+
pack=pack,
|
1428 |
+
special_tokens=list(self.special_token_ids.values()),
|
1429 |
+
)
|
1430 |
+
|
1431 |
+
|
1432 |
+
class MultiModalLMFeatureConverter:
|
1433 |
+
|
1434 |
+
def __init__(
|
1435 |
+
self, pack: bool = False, loss_token_weighting: str=None, bos_id: int = 1,
|
1436 |
+
special_tokens=None, fix_image_input_idx=2
|
1437 |
+
):
|
1438 |
+
self.pack = pack
|
1439 |
+
self.bos_id = bos_id
|
1440 |
+
self.fix_image_input_idx = fix_image_input_idx
|
1441 |
+
self.special_tokens = tf.constant(special_tokens) if special_tokens else None
|
1442 |
+
self.loss_token_weighting = loss_token_weighting
|
1443 |
+
|
1444 |
+
def _convert_example(
|
1445 |
+
self, features: Mapping[str, tf.Tensor]
|
1446 |
+
) -> Mapping[str, tf.Tensor]:
|
1447 |
+
"""Convert an LM example into an example with model features."""
|
1448 |
+
# targets_segment_id is present only for a packed dataset.
|
1449 |
+
decoder_input_tokens = make_autoregressive_inputs(
|
1450 |
+
features["target_tokens"],
|
1451 |
+
sequence_id=features.get("targets_segment_ids", None),
|
1452 |
+
bos_id=self.bos_id,
|
1453 |
+
)
|
1454 |
+
|
1455 |
+
tf.assert_equal(
|
1456 |
+
True,
|
1457 |
+
tf.reduce_all(decoder_input_tokens[-1] != self.special_tokens),
|
1458 |
+
message="An input ends with an image special token",
|
1459 |
+
)
|
1460 |
+
|
1461 |
+
image_input_idx = features["image_input_idx"]
|
1462 |
+
if self.fix_image_input_idx == 2:
|
1463 |
+
# plus one sine we have added BOS to the inputs
|
1464 |
+
image_input_idx = tf.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
1465 |
+
else:
|
1466 |
+
# Some old models trained like this, sometimes image_input_idx will go from -1 -> 0 didn't
|
1467 |
+
# effect performance but keep this code path for backwards compatiblity with those checkpoints
|
1468 |
+
image_input_idx = image_input_idx + 1
|
1469 |
+
|
1470 |
+
d = {
|
1471 |
+
"target_tokens": features["target_tokens"],
|
1472 |
+
"input_tokens": decoder_input_tokens,
|
1473 |
+
"loss_masks": features["loss_masks"],
|
1474 |
+
"images": features["images"],
|
1475 |
+
"image_input_idx": image_input_idx
|
1476 |
+
}
|
1477 |
+
if "image_masks" in features:
|
1478 |
+
d["image_masks"] = features["image_masks"]
|
1479 |
+
|
1480 |
+
has_custom_text_weight = features.get("has_custom_loss_weight", False)
|
1481 |
+
|
1482 |
+
if "subsegment_ids" in features:
|
1483 |
+
subsegment_ids = make_autoregressive_inputs(
|
1484 |
+
features["subsegment_ids"],
|
1485 |
+
sequence_id=features.get("targets_segment_ids", None),
|
1486 |
+
bos_id=features["subsegment_ids"][0],
|
1487 |
+
)
|
1488 |
+
|
1489 |
+
# Subsegment have a position based on the sum of previous positions they can attend to
|
1490 |
+
position_ids = tf.zeros_like(subsegment_ids)
|
1491 |
+
unique_segments = tf.unique(subsegment_ids)[0]
|
1492 |
+
for i in unique_segments:
|
1493 |
+
segment_position_ids = tf.cumsum(tf.cast(subsegment_ids >= i, tf.int32)) - 1
|
1494 |
+
position_ids = tf.where(subsegment_ids == i, segment_position_ids, position_ids)
|
1495 |
+
|
1496 |
+
# Apply loss weighting, this is done here so it occurs after truncation
|
1497 |
+
if has_custom_text_weight:
|
1498 |
+
pass
|
1499 |
+
elif self.loss_token_weighting in ["subsegments", "root_subsegments"]:
|
1500 |
+
n_loss_segments = tf.shape(tf.unique(tf.boolean_mask(subsegment_ids, d["loss_masks"] > 0))[0])[0]
|
1501 |
+
n_loss_segments = tf.maximum(tf.cast(n_loss_segments, tf.float32), 1)
|
1502 |
+
weight = 1/n_loss_segments if self.loss_token_weighting == "subsegments" else tf.math.rsqrt(n_loss_segments)
|
1503 |
+
d["loss_masks"] = tf.where(d["loss_masks"] > 0, d["loss_masks"]*weight, d["loss_masks"])
|
1504 |
+
elif self.loss_token_weighting is not None:
|
1505 |
+
raise NotImplementedError(self.loss_token_weighting)
|
1506 |
+
|
1507 |
+
d["subsegment_ids"] = subsegment_ids
|
1508 |
+
d["position_ids"] = position_ids
|
1509 |
+
else:
|
1510 |
+
if self.loss_token_weighting not in [None, "subsegments", "root_subsegments"] and not has_custom_text_weight:
|
1511 |
+
raise NotImplementedError(self.loss_token_weighting)
|
1512 |
+
if self.pack:
|
1513 |
+
d["decoder_segment_ids"] = features["targets_segment_ids"]
|
1514 |
+
d["decoder_positions"] = features["targets_positions"]
|
1515 |
+
|
1516 |
+
for k in features:
|
1517 |
+
if k.startswith("metadata/"):
|
1518 |
+
d[k] = features[k]
|
1519 |
+
return d
|
1520 |
+
|
1521 |
+
def _pack_or_pad(self, ds, task_feature_lengths):
|
1522 |
+
if self.pack:
|
1523 |
+
raise NotImplementedError()
|
1524 |
+
else:
|
1525 |
+
return trim_and_pad_dataset(ds, task_feature_lengths)
|
1526 |
+
|
1527 |
+
def __call__(self, ds: tf.data.Dataset, task_feature_lengths: Mapping[str, int]) -> tf.data.Dataset:
|
1528 |
+
"""Convert the dataset to be fed to a language model."""
|
1529 |
+
task_feature_lengths = dict(task_feature_lengths)
|
1530 |
+
|
1531 |
+
if "images" in ds.element_spec and "images" in task_feature_lengths:
|
1532 |
+
# Images should never be truncated
|
1533 |
+
ds = assert_not_truncated(ds, ["images", "image_input_idx"], task_feature_lengths["images"])
|
1534 |
+
|
1535 |
+
if any(x.startswith("metadata/") for x in ds.element_spec):
|
1536 |
+
# Metadata indicates the dataset is being used for inference, inference datasets
|
1537 |
+
# should not be truncated
|
1538 |
+
ds = assert_not_truncated(ds, ["target_tokens"], task_feature_lengths["target_tokens"])
|
1539 |
+
|
1540 |
+
if "image_masks" in ds.element_spec and "images" in task_feature_lengths:
|
1541 |
+
task_feature_lengths["image_masks"] = task_feature_lengths["images"]
|
1542 |
+
if "subsegment_ids" in ds.element_spec and "target_tokens" in task_feature_lengths:
|
1543 |
+
task_feature_lengths["subsegment_ids"] = task_feature_lengths["target_tokens"]
|
1544 |
+
if "loss_masks" not in task_feature_lengths and "target_tokens" in task_feature_lengths:
|
1545 |
+
task_feature_lengths["loss_masks"] = task_feature_lengths["target_tokens"]
|
1546 |
+
ds = self._pack_or_pad(ds, task_feature_lengths)
|
1547 |
+
|
1548 |
+
return ds.map(
|
1549 |
+
self._convert_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
preprocesssors.py
ADDED
@@ -0,0 +1,2472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
from functools import reduce
|
5 |
+
from typing import Mapping, Optional, Sequence
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import tensorflow as tf
|
9 |
+
import seqio
|
10 |
+
import gin
|
11 |
+
|
12 |
+
from .data_utils import flatten_parts, stateless_permutation, stateless_shuffle
|
13 |
+
from .. import config
|
14 |
+
|
15 |
+
|
16 |
+
def get_from_dict(data, keys):
|
17 |
+
"""Iterate nested dictionary"""
|
18 |
+
return reduce(dict.get, keys, data)
|
19 |
+
|
20 |
+
def get_blank_image():
|
21 |
+
image = tf.zeros([224, 224, 3], dtype=tf.uint8)
|
22 |
+
image = tf.expand_dims(image, 0)[:1]
|
23 |
+
return image
|
24 |
+
|
25 |
+
|
26 |
+
@seqio.utils.map_over_dataset
|
27 |
+
def rekey(x, key_map=None):
|
28 |
+
"""Replace the feature keys according to the mapping in `key_map`.
|
29 |
+
For example, if the dataset returns examples of the format:
|
30 |
+
{'foo': 'something', 'bar': 'something else'}
|
31 |
+
and key_map = {'boo': 'foo', 'spar': 'bar'} then this function will return
|
32 |
+
examples with the format
|
33 |
+
{'boo': 'something', 'spar': 'something else'}
|
34 |
+
If a mapping is to an empty key or None, set the new key to an empty string.
|
35 |
+
Args:
|
36 |
+
x: an example to process.
|
37 |
+
key_map: dictionary mapping new keys to original keys
|
38 |
+
Returns:
|
39 |
+
A preprocessed example with the format listed above.
|
40 |
+
"""
|
41 |
+
if key_map:
|
42 |
+
out = {}
|
43 |
+
for new_key, old_key in key_map.items():
|
44 |
+
if isinstance(old_key, list):
|
45 |
+
out[new_key] = get_from_dict(x, old_key)
|
46 |
+
else:
|
47 |
+
out[new_key] = x[old_key]
|
48 |
+
return out
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
def rename(**kwargs):
|
53 |
+
@seqio.map_over_dataset
|
54 |
+
def _fn(x):
|
55 |
+
updates = {}
|
56 |
+
for new_key, old_key in kwargs.items():
|
57 |
+
if isinstance(old_key, list):
|
58 |
+
val = x[old_key[0]]
|
59 |
+
for k in old_key[1:-1]:
|
60 |
+
val = val[k]
|
61 |
+
updates[new_key] = val.pop(old_key[-1])
|
62 |
+
else:
|
63 |
+
updates[new_key] = x.pop(old_key)
|
64 |
+
x.update(updates)
|
65 |
+
return x
|
66 |
+
return _fn
|
67 |
+
|
68 |
+
|
69 |
+
def extract_transcripts(ds):
|
70 |
+
ds = flatten_parts(ds, ["transcripts"])
|
71 |
+
def _map(ex):
|
72 |
+
return dict(
|
73 |
+
image=ex["image"],
|
74 |
+
text=ex["transcripts"],
|
75 |
+
url=ex["url"]
|
76 |
+
)
|
77 |
+
return ds.map(_map)
|
78 |
+
|
79 |
+
|
80 |
+
@seqio.map_over_dataset
|
81 |
+
def extract_caption_and_all_transcripts(ex):
|
82 |
+
transcripts = tf.random.shuffle(ex["transcripts"])[:3]
|
83 |
+
weight = 1.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
|
84 |
+
return dict(
|
85 |
+
image=ex["image"],
|
86 |
+
text=tf.concat([tf.expand_dims(ex["caption"], 0), transcripts], 0),
|
87 |
+
url=ex["url"],
|
88 |
+
text_weights=tf.pad(
|
89 |
+
tf.ones((1,), dtype=tf.float32), [[0, tf.shape(transcripts)[0]]],
|
90 |
+
constant_values=weight),
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
@seqio.map_over_dataset
|
95 |
+
def extract_all_transcripts(ex):
|
96 |
+
transcripts = tf.random.shuffle(ex["transcripts"])[:3]
|
97 |
+
weight = 3.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
|
98 |
+
return dict(
|
99 |
+
image=ex["image"],
|
100 |
+
text=transcripts,
|
101 |
+
url=ex["url"],
|
102 |
+
text_weights=tf.fill((tf.shape(transcripts)[0],), weight),
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
@seqio.map_over_dataset
|
107 |
+
def extract_transcript(ex):
|
108 |
+
transcripts = tf.random.shuffle(ex["transcripts"])
|
109 |
+
return dict(
|
110 |
+
image=ex["image"],
|
111 |
+
text=transcripts[0],
|
112 |
+
url=ex["url"],
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
@seqio.map_over_dataset
|
117 |
+
def extract_caption(ex):
|
118 |
+
caption = ex["caption"]
|
119 |
+
if len(caption.shape) > 0:
|
120 |
+
ex["text"] = caption[0]
|
121 |
+
else:
|
122 |
+
ex["text"] = caption
|
123 |
+
return ex
|
124 |
+
|
125 |
+
|
126 |
+
@seqio.map_over_dataset
|
127 |
+
def extract_joint_captions(ex):
|
128 |
+
caption = ex["caption"]
|
129 |
+
if len(caption.shape) > 0:
|
130 |
+
caption = caption[0]
|
131 |
+
_ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
|
132 |
+
_ix = _ix % tf.shape(ex["transcripts"])[0]
|
133 |
+
return dict(
|
134 |
+
image=ex["image"],
|
135 |
+
text=tf.stack([caption, ex["mistral_caption"], ex["transcripts"][_ix]], 0),
|
136 |
+
url=ex["url"]
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
@seqio.map_over_dataset(num_seeds=1)
|
141 |
+
def extract_caption_and_transcript(ex, seed):
|
142 |
+
caption = ex["caption"]
|
143 |
+
if len(caption.shape) > 0:
|
144 |
+
caption = caption[0]
|
145 |
+
_ix = tf.random.stateless_uniform((), seed, 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
|
146 |
+
return dict(
|
147 |
+
image=ex["image"],
|
148 |
+
text=tf.stack([caption, ex["transcripts"][_ix]], 0),
|
149 |
+
url=ex["url"]
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
@seqio.map_over_dataset
|
154 |
+
def caption_transcript_augmented(ex, sequence_length):
|
155 |
+
caption = ex["caption"]
|
156 |
+
if len(caption.shape) > 0:
|
157 |
+
caption = caption[0]
|
158 |
+
image = ex["image"]
|
159 |
+
properties = []
|
160 |
+
|
161 |
+
do_augmentation = sequence_length["is_training"]
|
162 |
+
# do_augmentation = False
|
163 |
+
|
164 |
+
# Keep this off, it screws up OCR
|
165 |
+
# do_hflip = (tf.random.uniform(()) > 0.2 and do_augmentation)
|
166 |
+
do_hflip = False
|
167 |
+
if do_hflip:
|
168 |
+
image = image[:, ::-1]
|
169 |
+
|
170 |
+
# Mild color jitter
|
171 |
+
do_color = (tf.random.uniform(()) > 0.5 and do_augmentation)
|
172 |
+
if do_color:
|
173 |
+
image = tf.image.random_hue(image, max_delta=0.05)
|
174 |
+
image = tf.image.random_brightness(image, max_delta=0.2)
|
175 |
+
image = tf.image.random_saturation(image, 0.7, 1.3)
|
176 |
+
image = tf.image.random_contrast(image, 0.7, 1.3)
|
177 |
+
|
178 |
+
# Mild affine transformation
|
179 |
+
do_affine = (tf.random.uniform(()) > 0.5 and do_augmentation)
|
180 |
+
if do_affine and do_augmentation:
|
181 |
+
shift_x = tf.random.uniform((), -10, 10) * 0
|
182 |
+
shift_y = tf.random.uniform((), -10, 10) * 0
|
183 |
+
shear_x = tf.random.uniform((), -2, 2)
|
184 |
+
shear_y = tf.random.uniform((), -2, 2)
|
185 |
+
rotation = tf.random.uniform((), -6, 6)
|
186 |
+
max_scale = 1.1
|
187 |
+
scale = tf.random.uniform((), 0.8, max_scale)
|
188 |
+
center = tf.cast(tf.shape(image), tf.float32)/2
|
189 |
+
|
190 |
+
image = tf.keras.ops.image.affine_transform(
|
191 |
+
image,
|
192 |
+
tf.stack(get_affine_matrix(
|
193 |
+
[center[0], center[1]],
|
194 |
+
rotation,
|
195 |
+
[shift_x, shift_y],
|
196 |
+
1/scale,
|
197 |
+
[shear_x, shear_y]
|
198 |
+
) + [0., 0.]),
|
199 |
+
interpolation='bilinear',
|
200 |
+
fill_mode='constant',
|
201 |
+
fill_value=1.,
|
202 |
+
data_format='channels_last'
|
203 |
+
)
|
204 |
+
|
205 |
+
properties = tf.stack([
|
206 |
+
("[hflip]" if do_hflip else ""),
|
207 |
+
("[color]" if do_color else ""),
|
208 |
+
("[affine]" if do_affine else "")
|
209 |
+
])
|
210 |
+
properties = tf.boolean_mask(properties, tf.strings.length(properties) > 0)
|
211 |
+
prompt = tf.strings.reduce_join(properties, separator=" ")
|
212 |
+
ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
|
213 |
+
out = dict(
|
214 |
+
image=image,
|
215 |
+
text=tf.stack([caption, ex["transcripts"][ix]], 0),
|
216 |
+
url=ex["url"],
|
217 |
+
prompt=prompt,
|
218 |
+
)
|
219 |
+
# out["metadata/unaugmented_image"] = image
|
220 |
+
return out
|
221 |
+
|
222 |
+
|
223 |
+
def extract_caption_and_transcript_hflip(ds):
|
224 |
+
|
225 |
+
# Just in case they are ordered somehow in Matt's data
|
226 |
+
@seqio.map_over_dataset
|
227 |
+
def _shuffle_transcripts(_ex):
|
228 |
+
_ex["transcripts"] = tf.random.shuffle(_ex["transcripts"])
|
229 |
+
_ex["hflip"] = tf.random.uniform((), 0, 3, dtype=tf.int32)
|
230 |
+
return _ex
|
231 |
+
|
232 |
+
ds = _shuffle_transcripts(ds)
|
233 |
+
|
234 |
+
# Build a 3x long dataset with each individual transcript so we iterate through
|
235 |
+
# each transcript
|
236 |
+
@seqio.map_over_dataset
|
237 |
+
def _with_transcript(ex, _ix):
|
238 |
+
caption = ex["caption"]
|
239 |
+
if len(caption.shape) > 0:
|
240 |
+
caption = caption[0]
|
241 |
+
hflip = ex["hflip"] == _ix
|
242 |
+
if hflip:
|
243 |
+
ex["image"] = ex["image"][:, ::-1]
|
244 |
+
style = ["long_caption_flipped", "transcript_flipped"]
|
245 |
+
else:
|
246 |
+
style = ["long_caption", "transcript"]
|
247 |
+
return dict(
|
248 |
+
image=ex["image"],
|
249 |
+
text=tf.stack([caption, ex["transcripts"][_ix]], 0),
|
250 |
+
url=ex["url"],
|
251 |
+
style=style
|
252 |
+
)
|
253 |
+
|
254 |
+
joint_ds = _with_transcript(ds, 0)
|
255 |
+
for i in range(1, 3):
|
256 |
+
joint_ds = joint_ds.concatenate(_with_transcript(ds, i))
|
257 |
+
|
258 |
+
return joint_ds
|
259 |
+
|
260 |
+
|
261 |
+
@seqio.map_over_dataset
|
262 |
+
def extract_llava(ex, sequence_length, output_features):
|
263 |
+
tf.assert_equal(tf.shape(ex['conversations']['value'])[0], 2)
|
264 |
+
prompt = ex['conversations']['value'][0]
|
265 |
+
text = ex['conversations']['value'][1]
|
266 |
+
ex.pop('conversations')
|
267 |
+
ex["text"] = text
|
268 |
+
ex["prompt"] = prompt
|
269 |
+
return ex
|
270 |
+
|
271 |
+
|
272 |
+
def extract_localized_narrative(ds):
|
273 |
+
ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
|
274 |
+
def _map(ex):
|
275 |
+
return dict(
|
276 |
+
image=ex["image"],
|
277 |
+
text=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
|
278 |
+
)
|
279 |
+
return ds.map(_map)
|
280 |
+
|
281 |
+
|
282 |
+
def float_to_text(val):
|
283 |
+
return tf.strings.as_string(tf.cast(val * 100, tf.int32))
|
284 |
+
|
285 |
+
|
286 |
+
@seqio.map_over_dataset
|
287 |
+
def extract_vqa(ex):
|
288 |
+
questions = ex["vqa"]["questions"]
|
289 |
+
answers = ex["vqa"]["answers"]
|
290 |
+
answers = tf.strings.reduce_join(answers, 1, separator="; ")
|
291 |
+
qas = tf.strings.reduce_join(tf.stack([questions, answers], 1), separator=" ")
|
292 |
+
return dict(
|
293 |
+
image=ex["image"],
|
294 |
+
text=tf.strings.reduce_join(qas, separator="\n")
|
295 |
+
)
|
296 |
+
|
297 |
+
|
298 |
+
@seqio.map_over_dataset
|
299 |
+
def coco_image_id_from_path(ex):
|
300 |
+
image_id = tf.strings.substr(ex["image/filename"], 0, tf.strings.length(ex["image/filename"])-4)
|
301 |
+
ex["image_id"] = tf.strings.to_number(image_id)
|
302 |
+
return ex
|
303 |
+
|
304 |
+
|
305 |
+
@seqio.map_over_dataset
|
306 |
+
def add_coco_url(ex):
|
307 |
+
"""Turns a COCO path into a URL, which can then be used in visualizations"""
|
308 |
+
path = ex["image/filename"]
|
309 |
+
if not tf.strings.regex_full_match(path, ".*/.*"):
|
310 |
+
prefix = tf.strings.regex_replace(path, "COCO_", "")
|
311 |
+
prefix = tf.strings.regex_replace(prefix, "_[0-9]+.jpg", "")
|
312 |
+
path = tf.strings.join([prefix, path], separator="/")
|
313 |
+
|
314 |
+
# images are hosted by the COCO website here
|
315 |
+
url = tf.strings.join(["https://s3.us-east-1.amazonaws.com/images.cocodataset.org/", path])
|
316 |
+
ex["metadata/image_url"] = url
|
317 |
+
return ex
|
318 |
+
|
319 |
+
|
320 |
+
def flatten_vqa(ds):
|
321 |
+
parts = ["questions", "answers"]
|
322 |
+
for k in ["id", "question_id"]:
|
323 |
+
if k in ds.element_spec:
|
324 |
+
parts.append(k)
|
325 |
+
return flatten_parts(ds, parts)
|
326 |
+
|
327 |
+
|
328 |
+
def format_gqa(ds, is_balanced=True, flatten=True):
|
329 |
+
if is_balanced:
|
330 |
+
ds = ds.filter(lambda x: tf.reduce_any(x["questions"]["is_balanced"]))
|
331 |
+
def _filter_qs(ex):
|
332 |
+
qs = ex["questions"]
|
333 |
+
mask = qs["is_balanced"]
|
334 |
+
qs = {k: tf.boolean_mask(v, mask) for k, v in qs.items()}
|
335 |
+
ex["questions"] = qs
|
336 |
+
return ex
|
337 |
+
ds = ds.map(_filter_qs)
|
338 |
+
|
339 |
+
if flatten:
|
340 |
+
ds = flatten_parts(ds, ["questions"])
|
341 |
+
|
342 |
+
def _rename(ex):
|
343 |
+
out = ex["questions"]
|
344 |
+
out["image"] = ex["image"]
|
345 |
+
out["image_id"] = ex["image_id"]
|
346 |
+
return out
|
347 |
+
return ds.map(_rename)
|
348 |
+
|
349 |
+
|
350 |
+
@seqio.map_over_dataset
|
351 |
+
def fix_doqa_url(x):
|
352 |
+
x["image_url"] = tf.strings.regex_replace(x["image_url"], "gs://", "")
|
353 |
+
return x
|
354 |
+
|
355 |
+
|
356 |
+
def _add_metadata(ex):
|
357 |
+
out = {}
|
358 |
+
if "id" in ex:
|
359 |
+
out["metadata/example_id"] = ex["id"]
|
360 |
+
elif "example_id" in ex:
|
361 |
+
out["metadata/example_id"] = ex["example_id"]
|
362 |
+
elif "question_id" in ex:
|
363 |
+
out["metadata/example_id"] = ex["question_id"]
|
364 |
+
if "image_url" in ex:
|
365 |
+
out["metadata/image_url"] = ex["image_url"]
|
366 |
+
for k, v in ex.items():
|
367 |
+
if k.startswith("metadata/"):
|
368 |
+
out[k] = v
|
369 |
+
return out
|
370 |
+
|
371 |
+
|
372 |
+
def image_only(ds):
|
373 |
+
return ds.filter(lambda x: x["has_image"])
|
374 |
+
|
375 |
+
|
376 |
+
def filter_difficult_direct_answer(ds):
|
377 |
+
return ds.filter(lambda x: not x["difficult_direct_answer"])
|
378 |
+
|
379 |
+
|
380 |
+
@seqio.map_over_dataset()
|
381 |
+
def format_ai2d(ex, variable_style=True):
|
382 |
+
abc = tf.constant(list("abcdefg".upper()))
|
383 |
+
out = dict(image=ex["image"])
|
384 |
+
out.update(_add_metadata(ex))
|
385 |
+
|
386 |
+
options = ex["choices"]
|
387 |
+
# >= 3 in case of none of the above like answers
|
388 |
+
n_options = tf.shape(ex["option_is_abc"])[0]
|
389 |
+
if ex["abc_label"] and tf.reduce_sum(tf.cast(ex["option_is_abc"], tf.int32)) >= (n_options - 1):
|
390 |
+
# The image labels are always upper, so use upper in the answer ptions
|
391 |
+
options = tf.where(
|
392 |
+
ex["option_is_abc"],
|
393 |
+
tf.strings.upper(options),
|
394 |
+
options
|
395 |
+
)
|
396 |
+
short_options = options
|
397 |
+
style = "ai2_diagram_no_letter"
|
398 |
+
else:
|
399 |
+
short_options = abc[:tf.shape(options)[0]]
|
400 |
+
options = tf.stack([short_options, options,], 1)
|
401 |
+
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
402 |
+
style = "ai2_diagram"
|
403 |
+
|
404 |
+
options = tf.strings.reduce_join(options, separator="\n")
|
405 |
+
out["question"] = ex["question"]
|
406 |
+
out["options"] = options
|
407 |
+
if variable_style:
|
408 |
+
out["style"] = style
|
409 |
+
if ex["answer_idx"] < 0:
|
410 |
+
out["text"] = "?"
|
411 |
+
else:
|
412 |
+
out["text"] = short_options[ex["answer_idx"]]
|
413 |
+
out["metadata/answer_idx"] = ex["answer_idx"]
|
414 |
+
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
|
415 |
+
out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
|
416 |
+
out["metadata/has_transparent_box"] = ex.get("has_transparent_box", tf.constant(False))
|
417 |
+
out["metadata/abc_label"] = ex["abc_label"]
|
418 |
+
return out
|
419 |
+
|
420 |
+
|
421 |
+
@gin.configurable()
|
422 |
+
@seqio.map_over_dataset()
|
423 |
+
def format_multiple_choice_qa(ex, option_format="abc"):
|
424 |
+
assert option_format == "abc"
|
425 |
+
abc = tf.constant(list("abcdefg".upper()))
|
426 |
+
out = dict(image=ex["image"])
|
427 |
+
out.update(_add_metadata(ex))
|
428 |
+
options = ex["choices"]
|
429 |
+
short_options = abc[:tf.shape(options)[0]]
|
430 |
+
options = tf.stack([short_options, options,], 1)
|
431 |
+
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
432 |
+
options = tf.strings.reduce_join(options, separator="\n")
|
433 |
+
out["question"] = ex["question"]
|
434 |
+
out["options"] = options
|
435 |
+
if ex["answer_idx"] < 0:
|
436 |
+
out["text"] = "?"
|
437 |
+
else:
|
438 |
+
out["text"] = short_options[ex["answer_idx"]]
|
439 |
+
out["metadata/answer_idx"] = ex["answer_idx"]
|
440 |
+
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
|
441 |
+
out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
|
442 |
+
# out["metadata/option_names"] = tf.RaggedTensor.from_row_lengths(short_options, tf.shape(short_options))
|
443 |
+
# out["metadata/option_names"] = short_options
|
444 |
+
return out
|
445 |
+
|
446 |
+
|
447 |
+
@seqio.map_over_dataset()
|
448 |
+
def output_options(ex):
|
449 |
+
ex["metadata/options"] = ex["options"]
|
450 |
+
return ex
|
451 |
+
|
452 |
+
|
453 |
+
@seqio.map_over_dataset()
|
454 |
+
def extract_tally_qa(ex):
|
455 |
+
questions = ex.pop("questions")
|
456 |
+
ex["questions"] = questions["question"]
|
457 |
+
ex["answers"] = tf.strings.as_string(questions["answer"])
|
458 |
+
ex["question_id"] = questions["question_id"]
|
459 |
+
return ex
|
460 |
+
|
461 |
+
|
462 |
+
@seqio.map_over_dataset()
|
463 |
+
def count_bench_preprocessor(ex):
|
464 |
+
return {
|
465 |
+
"image": ex["image"],
|
466 |
+
"text": tf.strings.as_string(ex["number"]),
|
467 |
+
"object": ex["noun"],
|
468 |
+
"question": tf.strings.join([
|
469 |
+
"How many ", ex["noun"], " are there?"
|
470 |
+
]),
|
471 |
+
"metadata/count": ex["number"],
|
472 |
+
}
|
473 |
+
|
474 |
+
|
475 |
+
def filter_human(ds):
|
476 |
+
return ds.filter(lambda x: x["is_human"])
|
477 |
+
|
478 |
+
|
479 |
+
def filter_aug(ds):
|
480 |
+
return ds.filter(lambda x: not x["is_human"])
|
481 |
+
|
482 |
+
|
483 |
+
@seqio.map_over_dataset()
|
484 |
+
def reweight_chartqa(ex, human, aug):
|
485 |
+
is_human = ex["metadata/is_human"]
|
486 |
+
ex["text_weights"] = human if is_human else aug
|
487 |
+
return ex
|
488 |
+
|
489 |
+
|
490 |
+
@seqio.map_over_dataset()
|
491 |
+
def chartqa_prompting(ex):
|
492 |
+
question = tf.strings.join([ex["question"], " Answer:"])
|
493 |
+
return dict(
|
494 |
+
image=ex["image"],
|
495 |
+
question=question,
|
496 |
+
answer=ex["answer"]
|
497 |
+
)
|
498 |
+
|
499 |
+
|
500 |
+
@seqio.map_over_dataset()
|
501 |
+
def chartqa_explanation(ex):
|
502 |
+
question = tf.strings.join([ex["question"], " Explanation:"])
|
503 |
+
out = {
|
504 |
+
"image": ex["image"],
|
505 |
+
"question": question,
|
506 |
+
"answer": ex["answer"],
|
507 |
+
}
|
508 |
+
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
|
509 |
+
return out
|
510 |
+
|
511 |
+
|
512 |
+
@seqio.map_over_dataset(num_seeds=1)
|
513 |
+
def _preprocess_scifi(ex, seed):
|
514 |
+
if "qa_pairs" in ex:
|
515 |
+
q = ex["qa_pairs"]
|
516 |
+
else:
|
517 |
+
q = ex["qa"]
|
518 |
+
ix = stateless_permutation(tf.shape(q["question"])[0], seed)
|
519 |
+
return dict(
|
520 |
+
image=ex["image"],
|
521 |
+
question=tf.gather(q["question"], ix),
|
522 |
+
explanation=tf.gather(q["explanation"], ix),
|
523 |
+
answer=tf.gather(q["answer"], ix),
|
524 |
+
)
|
525 |
+
|
526 |
+
@seqio.map_over_dataset
|
527 |
+
def scifi_explanation_only(ex):
|
528 |
+
return dict(
|
529 |
+
image=ex["image"],
|
530 |
+
question=ex["question"],
|
531 |
+
answer=ex["explanation"],
|
532 |
+
)
|
533 |
+
|
534 |
+
|
535 |
+
def filter_named_entity(ds):
|
536 |
+
@seqio.map_over_dataset
|
537 |
+
def _load_image(ex):
|
538 |
+
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
539 |
+
return ex
|
540 |
+
|
541 |
+
ds = _load_image(ds)
|
542 |
+
return ds.filter(lambda x: tf.reduce_min(tf.shape(x["image"])[:2]) >= 32)
|
543 |
+
|
544 |
+
|
545 |
+
@seqio.map_over_dataset()
|
546 |
+
def extract_named_entity(ex):
|
547 |
+
qs = ex["questions"]
|
548 |
+
return {
|
549 |
+
"image": ex["image"],
|
550 |
+
"metadata/image_url": ex["url"],
|
551 |
+
"metadata/entity": ex["entity"],
|
552 |
+
"questions": qs["question"],
|
553 |
+
"answers": qs["answer"],
|
554 |
+
}
|
555 |
+
|
556 |
+
@gin.configurable()
|
557 |
+
def extract_individual_vqa(ds, test=False, answer_mode="best"):
|
558 |
+
|
559 |
+
@seqio.map_over_dataset(num_seeds=1)
|
560 |
+
def _extract(ex, seed):
|
561 |
+
if "questions" in ex:
|
562 |
+
question = ex["questions"]
|
563 |
+
else:
|
564 |
+
question = ex["question"]
|
565 |
+
out = dict(
|
566 |
+
image=ex["image"],
|
567 |
+
question=question,
|
568 |
+
)
|
569 |
+
out.update(_add_metadata(ex))
|
570 |
+
out["metadata/question"] = question
|
571 |
+
if ex.get("answers") is not None:
|
572 |
+
out["metadata/references"] = tf.strings.reduce_join(ex["answers"], separator="\n")
|
573 |
+
elif ex.get("answer") is not None:
|
574 |
+
out["metadata/references"] = ex["answer"]
|
575 |
+
|
576 |
+
if not test:
|
577 |
+
if "answer" in ex:
|
578 |
+
answer = ex["answer"]
|
579 |
+
else:
|
580 |
+
answer = ex["answers"]
|
581 |
+
if answer.dtype in [tf.int32, tf.int64]:
|
582 |
+
answer = tf.strings.as_string(answer)
|
583 |
+
if len(answer.shape) == 1 and tf.shape(answer)[0] == 0:
|
584 |
+
answer = tf.expand_dims("", 0)
|
585 |
+
if len(answer.shape) == len(question.shape):
|
586 |
+
pass
|
587 |
+
# Handle questions with multiple answers
|
588 |
+
elif answer_mode == "random":
|
589 |
+
assert len(answer.shape) == 1
|
590 |
+
answer = answer[tf.random.stateless_uniform((), seed, 0, tf.shape(answer)[0], dtype=tf.int32)]
|
591 |
+
elif answer_mode == "best":
|
592 |
+
def _get_best(_answer):
|
593 |
+
vals, _, counts = tf.unique_with_counts(_answer)
|
594 |
+
count_thresh = tf.reduce_max(counts)
|
595 |
+
vals = tf.boolean_mask(vals, counts >= count_thresh)
|
596 |
+
return vals[tf.random.stateless_uniform((), seed, 0, tf.shape(vals)[0], dtype=tf.int32)]
|
597 |
+
if len(answer.shape) == 1:
|
598 |
+
answer = _get_best(answer)
|
599 |
+
elif isinstance(answer, tf.RaggedTensor):
|
600 |
+
n = tf.shape(answer)[0]
|
601 |
+
answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
|
602 |
+
for i in range(n):
|
603 |
+
answer_arr = answer_arr.write(i, _get_best(answer[i]))
|
604 |
+
answer = answer_arr.stack()
|
605 |
+
else:
|
606 |
+
answer = tf.map_fn(_get_best, answer)
|
607 |
+
elif answer_mode == "all_segments":
|
608 |
+
out["text"] = answer
|
609 |
+
elif answer_mode == "all_segments_weighted":
|
610 |
+
out["text"] = answer
|
611 |
+
out["text_weights"] = 1.0 / tf.cast(tf.shape(answer)[-1], tf.float32)
|
612 |
+
elif answer_mode == "all":
|
613 |
+
if len(answer.shape) == 1:
|
614 |
+
answer = stateless_shuffle(answer, seed)
|
615 |
+
answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
|
616 |
+
elif isinstance(answer, tf.RaggedTensor):
|
617 |
+
n = tf.shape(answer)[0]
|
618 |
+
answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
|
619 |
+
for i in range(n):
|
620 |
+
answer_arr = answer_arr.write(i, tf.strings.reduce_join(tf.random.shuffle(answer[i]), separator="\n", axis=-1))
|
621 |
+
answer = answer_arr.stack()
|
622 |
+
else:
|
623 |
+
answer = tf.map_fn(tf.random.shuffle, answer)
|
624 |
+
answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
|
625 |
+
else:
|
626 |
+
raise NotImplementedError()
|
627 |
+
out["text"] = answer
|
628 |
+
return out
|
629 |
+
return _extract(ds)
|
630 |
+
|
631 |
+
|
632 |
+
@seqio.map_over_dataset()
|
633 |
+
def extract_khan_academy(ex):
|
634 |
+
return dict(
|
635 |
+
image=ex["image"],
|
636 |
+
image_url=ex["image_url"],
|
637 |
+
prompt="Answer this question",
|
638 |
+
text=ex["gptResponse"]
|
639 |
+
)
|
640 |
+
|
641 |
+
@seqio.map_over_dataset()
|
642 |
+
def extract_vaia_qa_latex_image(ex, add_short_answer=False, set_short_answer_first=False):
|
643 |
+
if ex["has_image"]:
|
644 |
+
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
645 |
+
image = tf.expand_dims(image, 0)[:1]
|
646 |
+
else:
|
647 |
+
# image = get_blank_image() # blank image
|
648 |
+
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
649 |
+
image = tf.expand_dims(image, 0)[:0]
|
650 |
+
img_h = tf.shape(image)[1]
|
651 |
+
img_w = tf.shape(image)[2]
|
652 |
+
|
653 |
+
if add_short_answer:
|
654 |
+
if set_short_answer_first:
|
655 |
+
answer = tf.strings.join(["Answer: ", ex["short_answer"], "\n\n", ex["answer"]])
|
656 |
+
else:
|
657 |
+
answer = tf.strings.join([ex["answer"], "\n\n", "Answer: ", ex["short_answer"]])
|
658 |
+
else:
|
659 |
+
answer = ex["answer"]
|
660 |
+
out = dict(
|
661 |
+
image=image, # 4-d tensor
|
662 |
+
text=answer,
|
663 |
+
prompt=tf.strings.join([ex["latex_question"], "\n"]),
|
664 |
+
)
|
665 |
+
out["metadata/images"] = image
|
666 |
+
out.update(_add_metadata(ex))
|
667 |
+
out["metadata/batch_id"] = ex["batch_id"]
|
668 |
+
out["metadata/image_size"] = [img_w, img_h]
|
669 |
+
return out
|
670 |
+
|
671 |
+
@seqio.map_over_dataset()
|
672 |
+
def extract_vqa_online(ex):
|
673 |
+
out = dict(
|
674 |
+
image=ex["image"],
|
675 |
+
prompt=tf.strings.join([ex["question"], "\n"]),
|
676 |
+
text=ex["answer"]
|
677 |
+
)
|
678 |
+
out.update(_add_metadata(ex))
|
679 |
+
out["metadata/row_id"] = ex["row_id"]
|
680 |
+
return out
|
681 |
+
|
682 |
+
|
683 |
+
@seqio.map_over_dataset()
|
684 |
+
def extract_scifi_joint(ex):
|
685 |
+
if "qa_pairs" in ex:
|
686 |
+
q = ex["qa_pairs"]
|
687 |
+
else:
|
688 |
+
q = ex["qa"]
|
689 |
+
prompts = tf.concat([["Describe this image in detail."], q["question"]], 0)
|
690 |
+
responses = tf.concat([ex["summary"][None], q["answer"]], 0)
|
691 |
+
return dict(
|
692 |
+
image=ex["image"],
|
693 |
+
prompt=prompts,
|
694 |
+
text=responses,
|
695 |
+
)
|
696 |
+
|
697 |
+
|
698 |
+
def remove_no_qa(ds):
|
699 |
+
def _filter(ex):
|
700 |
+
if "qa_pairs" in ex:
|
701 |
+
q = ex["qa_pairs"]
|
702 |
+
else:
|
703 |
+
q = ex["qa"]
|
704 |
+
return tf.shape(q["question"])[0] > 0
|
705 |
+
return ds.filter(_filter)
|
706 |
+
|
707 |
+
|
708 |
+
@seqio.map_over_dataset()
|
709 |
+
def extract_scifi_qa_exp(ex):
|
710 |
+
return dict(
|
711 |
+
image=ex["image"],
|
712 |
+
question=ex["question"], # Array of questions
|
713 |
+
answer=tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]]),
|
714 |
+
)
|
715 |
+
|
716 |
+
|
717 |
+
@seqio.map_over_dataset(num_seeds=1)
|
718 |
+
def extract_scifi_qa_demo(ex, seed):
|
719 |
+
# if tf.random.stateless_uniform((), 0, 1) > 0.5:
|
720 |
+
answer = tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]])
|
721 |
+
# else:
|
722 |
+
# answer = ex["explanation"]
|
723 |
+
return dict(
|
724 |
+
image=ex["image"],
|
725 |
+
question=ex["question"], # Array of questions
|
726 |
+
answer=answer,
|
727 |
+
)
|
728 |
+
|
729 |
+
|
730 |
+
@seqio.map_over_dataset()
|
731 |
+
def clock_bench_preprocessor(ex):
|
732 |
+
out = dict(
|
733 |
+
image=ex["image"],
|
734 |
+
prompt="What time is being shown?",
|
735 |
+
)
|
736 |
+
for k in ["hour", "minute", "second", "answerable"]:
|
737 |
+
out[f"metadata/{k}"] = ex[k]
|
738 |
+
return out
|
739 |
+
|
740 |
+
|
741 |
+
def deg2rad(x):
|
742 |
+
return x*math.pi/180.0
|
743 |
+
|
744 |
+
|
745 |
+
def get_affine_matrix(center, angle, translate, scale, shear):
|
746 |
+
# From https://github.com/pytorch/vision/blob/f96c42fca53230057b16941b078a0a9eee06e20f/torchvision/transforms/functional.py#L1006
|
747 |
+
rot = deg2rad(angle)
|
748 |
+
sx = deg2rad(shear[0])
|
749 |
+
sy = deg2rad(shear[1])
|
750 |
+
|
751 |
+
cx, cy = center
|
752 |
+
tx, ty = translate
|
753 |
+
|
754 |
+
# RSS without scaling
|
755 |
+
a = tf.cos(rot - sy) / tf.cos(sy)
|
756 |
+
b = -tf.cos(rot - sy) * tf.tan(sx) / tf.cos(sy) - tf.sin(rot)
|
757 |
+
c = tf.sin(rot - sy) / tf.cos(sy)
|
758 |
+
d = -tf.sin(rot - sy) * tf.tan(sx) / tf.cos(sy) + tf.cos(rot)
|
759 |
+
|
760 |
+
matrix = [a, b, 0.0, c, d, 0.0]
|
761 |
+
matrix = [x * scale for x in matrix]
|
762 |
+
# Apply inverse of center translation: RSS * C^-1
|
763 |
+
matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
|
764 |
+
matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
|
765 |
+
# Apply translation and center : T * C * RSS * C^-1
|
766 |
+
matrix[2] += cx + tx
|
767 |
+
matrix[5] += cy + ty
|
768 |
+
return matrix
|
769 |
+
|
770 |
+
|
771 |
+
def quantize_point(coor, max_dim, mode="percent-precision-1"):
|
772 |
+
max_dim = tf.cast(max_dim, tf.float32)
|
773 |
+
coor = tf.cast(coor, tf.float32)
|
774 |
+
x = (coor / max_dim)
|
775 |
+
if mode == "percent-precision-1":
|
776 |
+
return tf.strings.as_string(x*100, precision=1)
|
777 |
+
elif mode == "zero_to_one":
|
778 |
+
return tf.strings.as_string(x, precision=3)
|
779 |
+
elif mode == "1k":
|
780 |
+
return tf.strings.as_string(x*1000, precision=0)
|
781 |
+
else:
|
782 |
+
raise NotImplementedError(mode)
|
783 |
+
|
784 |
+
|
785 |
+
def construct_pointing_format(label_text, alt_text, x_str, y_str):
|
786 |
+
if alt_text is None:
|
787 |
+
alt_text = label_text
|
788 |
+
np = tf.shape(x_str)[0]
|
789 |
+
if np == 0:
|
790 |
+
output = ""
|
791 |
+
elif np == 1:
|
792 |
+
output = tf.strings.join([
|
793 |
+
'<point x="', x_str[0], '" y="', y_str[0], '" alt="',
|
794 |
+
alt_text, '">', label_text, '</point>'
|
795 |
+
])
|
796 |
+
else:
|
797 |
+
ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
|
798 |
+
xs = tf.strings.join(["x", ids, '="', x_str, '"'])
|
799 |
+
ys = tf.strings.join(["y", ids, '="', y_str, '"'])
|
800 |
+
points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
|
801 |
+
output = tf.strings.join(
|
802 |
+
["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
|
803 |
+
return output
|
804 |
+
|
805 |
+
|
806 |
+
def order_points(x, y, seed, point_order):
|
807 |
+
if point_order == "natural":
|
808 |
+
return x, y
|
809 |
+
|
810 |
+
if point_order == "random":
|
811 |
+
ix = stateless_permutation(tf.shape(x)[0], seed)
|
812 |
+
elif point_order == "xy":
|
813 |
+
x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
|
814 |
+
ix = tf.argsort(x_float*100000 + y_float)
|
815 |
+
elif point_order == "yx":
|
816 |
+
x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
|
817 |
+
ix = tf.argsort(y_float*100000 + x_float)
|
818 |
+
else:
|
819 |
+
raise NotImplementedError(point_order)
|
820 |
+
return tf.gather(x, ix), tf.gather(y, ix)
|
821 |
+
|
822 |
+
|
823 |
+
@gin.configurable()
|
824 |
+
def points_to_text(x, y, w, h, seed, label=None, alt_text=None, point_mode="percent-precision-1",
|
825 |
+
point_order="xy", point_list_mode="tag"):
|
826 |
+
"""Returns a string encoding of a list of points"""
|
827 |
+
x = quantize_point(x, w, point_mode)
|
828 |
+
y = quantize_point(y, h, point_mode)
|
829 |
+
# Order the quantized points to make the order matches what was generated, this can matter
|
830 |
+
# when points have the same quantized value e.g, (10.001, 20) (10.002, 10) should be
|
831 |
+
# represented (10, 10), (10, 20), but if we sort before quantization we get (10, 20), (10, 10)
|
832 |
+
x, y = order_points(x, y, seed, point_order)
|
833 |
+
if point_list_mode == "tag":
|
834 |
+
return construct_pointing_format(label, alt_text, x, y)
|
835 |
+
elif point_list_mode == "paren":
|
836 |
+
n = tf.shape(x)[0]
|
837 |
+
return tf.strings.reduce_join(tf.strings.join([
|
838 |
+
"(", x, ", ", y, ")"
|
839 |
+
]), separator=", ")
|
840 |
+
# if n == 0:
|
841 |
+
# output = ""
|
842 |
+
# else:
|
843 |
+
# ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
|
844 |
+
# xs = tf.strings.join(["x", ids, '="', x_str, '"'])
|
845 |
+
# ys = tf.strings.join(["y", ids, '="', y_str, '"'])
|
846 |
+
# points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
|
847 |
+
# output = tf.strings.join(
|
848 |
+
# ["<points ", points, ' alt="', alt_text, '">', label_text, "</points>"])
|
849 |
+
# return output
|
850 |
+
else:
|
851 |
+
raise NotImplementedError(point_list_mode)
|
852 |
+
|
853 |
+
|
854 |
+
def points_to_answer(x, y, w, h, seed, label, is_counting, alt_text=None):
|
855 |
+
count = tf.shape(x)[0]
|
856 |
+
if is_counting:
|
857 |
+
if count == 0:
|
858 |
+
return "There are none."
|
859 |
+
else:
|
860 |
+
point_text = points_to_text(x, y, w, h, seed, label, alt_text)
|
861 |
+
return tf.strings.join([
|
862 |
+
"Counting the ", point_text,
|
863 |
+
" shows a total of ",
|
864 |
+
tf.strings.as_string(count),
|
865 |
+
"."
|
866 |
+
])
|
867 |
+
else:
|
868 |
+
if count == 0:
|
869 |
+
return "There are none."
|
870 |
+
else:
|
871 |
+
return points_to_text(x, y, w, h, seed, label, alt_text)
|
872 |
+
|
873 |
+
|
874 |
+
@seqio.map_over_dataset(num_seeds=2)
|
875 |
+
def extract_point_qa(ex, seeds, answer_type="y_major"):
|
876 |
+
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
877 |
+
img_h = tf.shape(ex["image"])[0]
|
878 |
+
img_w = tf.shape(ex["image"])[1]
|
879 |
+
|
880 |
+
questions = ex["questions"]
|
881 |
+
question = questions["question"]
|
882 |
+
n = tf.shape(question)[0]
|
883 |
+
answers = tf.TensorArray(tf.string, size=n, element_shape=())
|
884 |
+
point_text = questions["annotations"]["point_text"]
|
885 |
+
point_seeds = tf.RaggedTensor.from_row_splits(
|
886 |
+
row_splits=point_text.row_splits,
|
887 |
+
values=tf.random.split(seeds[0], num=tf.shape(point_text.values)[0])
|
888 |
+
)
|
889 |
+
for question_ix in range(n):
|
890 |
+
anno = questions["annotations"]
|
891 |
+
answer = questions["answer_with_placeholders"][question_ix]
|
892 |
+
n_anno = tf.shape(anno["point_text"][question_ix])[0]
|
893 |
+
for anno_ix in range(n_anno):
|
894 |
+
points = anno["points"][question_ix, anno_ix]
|
895 |
+
point_text = points_to_answer(
|
896 |
+
points[:, 0], points[:, 1], 100, 100,
|
897 |
+
point_seeds[question_ix, anno_ix],
|
898 |
+
anno["point_text"][question_ix, anno_ix],
|
899 |
+
False,
|
900 |
+
alt_text=anno["alt_text"][question_ix, anno_ix],
|
901 |
+
)
|
902 |
+
answer_split = tf.strings.split(answer, sep="<|POINT|>", maxsplit=1)
|
903 |
+
answer = tf.strings.join([answer_split[0], point_text, answer_split[1]])
|
904 |
+
# Make sure all placeholders where used
|
905 |
+
tf.debugging.assert_equal(tf.shape(tf.strings.split(answer, sep="<|POINT|>"))[0], 1)
|
906 |
+
answers = answers.write(question_ix, answer)
|
907 |
+
|
908 |
+
messages = tf.stack([question, answers.stack()], axis=1)
|
909 |
+
messages = tf.reshape(messages, [-1])
|
910 |
+
conversation_ids = tf.range(tf.shape(messages)[0] // 2, dtype=tf.int32)
|
911 |
+
conversation_ids = tf.repeat(conversation_ids, 2)
|
912 |
+
out = dict(
|
913 |
+
image=ex["image"],
|
914 |
+
messages=tf.RaggedTensor.from_value_rowids(messages, conversation_ids)
|
915 |
+
)
|
916 |
+
ix = stateless_permutation(tf.shape(messages)[0], seeds[1])
|
917 |
+
messages = tf.gather(messages, ix)
|
918 |
+
out.update(_add_metadata(ex))
|
919 |
+
out["metadata/image_size"] = [img_w, img_h]
|
920 |
+
return out
|
921 |
+
|
922 |
+
|
923 |
+
def select_point(mask):
|
924 |
+
bs = tf.shape(mask)[0]
|
925 |
+
valid = tf.cast(mask, tf.float32)
|
926 |
+
h, w = tf.shape(mask)[1], tf.shape(mask)[2]
|
927 |
+
ys = tf.range(h, dtype=tf.int32)
|
928 |
+
xs = tf.range(w, dtype=tf.int32)
|
929 |
+
|
930 |
+
n = tf.reduce_sum(valid, [1, 2])
|
931 |
+
cy = tf.reduce_sum(tf.cast(ys[None, :, None], tf.float32) * valid, [1, 2]) / n # [bs]
|
932 |
+
cx = tf.reduce_sum(tf.cast(xs[None, None, :], tf.float32) * valid, [1, 2]) / n # [bs]
|
933 |
+
|
934 |
+
dist_y = tf.square(tf.range(h, dtype=tf.float32)[None, :] - cy[:, None]) # [bs, h]
|
935 |
+
dist_x = tf.square(tf.range(w, dtype=tf.float32)[None, :] - cx[:, None]) # [bs, w]
|
936 |
+
dist = dist_y[:, :, None] + dist_x[:, None, :] # [batch, h, w]
|
937 |
+
dist = dist + (1 - valid) * 1e12
|
938 |
+
min_dist = tf.argmin(tf.reshape(dist, [bs, -1]), axis=-1) # [batch]
|
939 |
+
w = tf.cast(w, min_dist.dtype)
|
940 |
+
cy = tf.cast(min_dist // w, tf.float32)
|
941 |
+
cx = tf.cast(min_dist % w, tf.float32)
|
942 |
+
return cx, cy
|
943 |
+
|
944 |
+
|
945 |
+
@seqio.map_over_dataset
|
946 |
+
def refexp_pointing(ex):
|
947 |
+
img_h = tf.shape(ex["image"])[0]
|
948 |
+
img_w = tf.shape(ex["image"])[1]
|
949 |
+
objects = ex["objects"]
|
950 |
+
|
951 |
+
# Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
|
952 |
+
refexps = objects['refexp']['raw']
|
953 |
+
bbox = objects["bbox"]
|
954 |
+
mask = tf.squeeze(objects["mask"], -1)
|
955 |
+
|
956 |
+
ix = tf.range(0, tf.shape(refexps)[0], dtype=tf.int32)
|
957 |
+
ix = tf.random.shuffle(ix)
|
958 |
+
refexps = tf.gather(refexps, ix)
|
959 |
+
bbox = tf.gather(bbox, ix)
|
960 |
+
mask = tf.gather(mask, ix)
|
961 |
+
|
962 |
+
cx, cy = select_point(mask)
|
963 |
+
answers = points_to_text(img_h, img_w, cx, cy)
|
964 |
+
|
965 |
+
out = {
|
966 |
+
"image": ex["image"],
|
967 |
+
"refexp": refexps.values,
|
968 |
+
"metadata/image_size": tf.stack([img_w, img_h,]),
|
969 |
+
"text": tf.repeat(answers, refexps.row_lengths()),
|
970 |
+
}
|
971 |
+
if "image_url" in ex:
|
972 |
+
out["metadata/image_url"] = ex["image_url"]
|
973 |
+
return out
|
974 |
+
|
975 |
+
|
976 |
+
@seqio.map_over_dataset
|
977 |
+
def refexp_pointing_inf(ex):
|
978 |
+
img_h = tf.shape(ex["image"])[0]
|
979 |
+
img_w = tf.shape(ex["image"])[1]
|
980 |
+
|
981 |
+
objects = ex["objects"]
|
982 |
+
mask = tf.squeeze(objects["mask"], -1)
|
983 |
+
cx, cy = select_point(mask)
|
984 |
+
answers = points_to_text(img_h, img_w, cx, cy)
|
985 |
+
|
986 |
+
refexps = objects["refexp"]["raw"]
|
987 |
+
|
988 |
+
# We can't use `mask` directly since it is variable size, and thus it
|
989 |
+
# will break batching. Here we serialize it instead
|
990 |
+
serialized_masks = tf.map_fn(tf.io.serialize_tensor, mask, fn_output_signature=tf.string)
|
991 |
+
out = {
|
992 |
+
"image": ex["image"],
|
993 |
+
"refexp": refexps,
|
994 |
+
"metadata/bbox": objects["bbox"],
|
995 |
+
"metadata/answer": answers,
|
996 |
+
"metadata/mask": serialized_masks,
|
997 |
+
"metadata/image_size": tf.stack([img_w, img_h]),
|
998 |
+
}
|
999 |
+
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
|
1000 |
+
return out
|
1001 |
+
|
1002 |
+
@seqio.map_over_dataset
|
1003 |
+
def extract_andriod_control_inf(ex, mode):
|
1004 |
+
if mode == "ll":
|
1005 |
+
prompt = tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]])
|
1006 |
+
elif mode == "hl_ll":
|
1007 |
+
prompt = tf.strings.join([
|
1008 |
+
"high_level: ", ex["metadata/hl_instruction"],
|
1009 |
+
" low_level: ", ex["metadata/ll_instruction"]
|
1010 |
+
])
|
1011 |
+
elif mode == "hl":
|
1012 |
+
prompt = tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]])
|
1013 |
+
elif mode == "hl_cot":
|
1014 |
+
prompt = tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]])
|
1015 |
+
else:
|
1016 |
+
raise NotImplementedError()
|
1017 |
+
|
1018 |
+
out = dict(
|
1019 |
+
image=ex["image"],
|
1020 |
+
prompt=prompt,
|
1021 |
+
text=ex["metadata/target_action"]
|
1022 |
+
)
|
1023 |
+
out.update(_add_metadata(ex))
|
1024 |
+
return out
|
1025 |
+
|
1026 |
+
@seqio.map_over_dataset
|
1027 |
+
def extract_android_control(ex):
|
1028 |
+
# Each image has three tasks:
|
1029 |
+
# low level -> action
|
1030 |
+
# high+low level -> action
|
1031 |
+
# high level -> action
|
1032 |
+
# high level -> low level + action (CoT)
|
1033 |
+
out = dict(
|
1034 |
+
image=ex["image"],
|
1035 |
+
prompt=tf.stack([
|
1036 |
+
tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]]),
|
1037 |
+
tf.strings.join([
|
1038 |
+
"high_level: ", ex["metadata/hl_instruction"],
|
1039 |
+
" low_level: ", ex["metadata/ll_instruction"]
|
1040 |
+
]),
|
1041 |
+
tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]]),
|
1042 |
+
tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]]),
|
1043 |
+
]),
|
1044 |
+
text=tf.stack([
|
1045 |
+
ex["metadata/target_action"],
|
1046 |
+
ex["metadata/target_action"],
|
1047 |
+
ex["metadata/target_action"],
|
1048 |
+
tf.strings.join(["Plan: ", ex["metadata/ll_instruction"], " Action: ", ex["metadata/target_action"]]),
|
1049 |
+
])
|
1050 |
+
)
|
1051 |
+
# Only needed if visualizing
|
1052 |
+
# ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1053 |
+
# img_h = tf.shape(ex["image"])[0]
|
1054 |
+
# img_w = tf.shape(ex["image"])[1]
|
1055 |
+
# out["metadata/image_size"] = tf.stack([img_w, img_h,])
|
1056 |
+
out.update(_add_metadata(ex))
|
1057 |
+
return out
|
1058 |
+
|
1059 |
+
|
1060 |
+
@seqio.map_over_dataset(num_seeds=1)
|
1061 |
+
def refexp(ex, seed):
|
1062 |
+
img_h = tf.shape(ex["image"])[0]
|
1063 |
+
img_w = tf.shape(ex["image"])[1]
|
1064 |
+
objects = ex["objects"]
|
1065 |
+
|
1066 |
+
# Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
|
1067 |
+
refexps = objects['refexp']['raw']
|
1068 |
+
bbox = objects["bbox"]
|
1069 |
+
ix = stateless_permutation(tf.shape(refexps)[0], seed)
|
1070 |
+
refexps = tf.gather(refexps, ix)
|
1071 |
+
bbox = tf.gather(bbox, ix)
|
1072 |
+
|
1073 |
+
x2 = bbox[:, 0] + bbox[:, 2]
|
1074 |
+
y2 = bbox[:, 1] + bbox[:, 3]
|
1075 |
+
with tf.control_dependencies([
|
1076 |
+
tf.debugging.assert_equal(tf.reduce_any(x2 <= tf.cast(img_w, tf.float32)), True),
|
1077 |
+
tf.debugging.assert_equal(tf.reduce_any(y2 <= tf.cast(img_h, tf.float32)), True)
|
1078 |
+
]):
|
1079 |
+
answers = points_to_text(
|
1080 |
+
img_h, img_w,
|
1081 |
+
tf.reshape(tf.stack([bbox[:, 0], x2], 1), [-1]),
|
1082 |
+
tf.reshape(tf.stack([bbox[:, 1], y2], 1), [-1]))
|
1083 |
+
answers = tf.strings.reduce_join(tf.reshape(answers, [-1, 2]), separator=" ", axis=1)
|
1084 |
+
|
1085 |
+
out = {
|
1086 |
+
"image": ex["image"],
|
1087 |
+
"refexp": refexps.values,
|
1088 |
+
"metadata/bbox": bbox,
|
1089 |
+
"metadata/image_size": tf.stack([img_w, img_h,]),
|
1090 |
+
"text": tf.repeat(answers, refexps.row_lengths()),
|
1091 |
+
}
|
1092 |
+
|
1093 |
+
if "image_url" in ex:
|
1094 |
+
out["image_url"] = ex["image_url"]
|
1095 |
+
return out
|
1096 |
+
|
1097 |
+
|
1098 |
+
@seqio.map_over_dataset
|
1099 |
+
def refexp_inf(ex):
|
1100 |
+
img_h = tf.shape(ex["image"])[0]
|
1101 |
+
img_w = tf.shape(ex["image"])[1]
|
1102 |
+
out = {
|
1103 |
+
"image": ex["image"],
|
1104 |
+
"refexp": ex["objects"]["refexp"]["raw"],
|
1105 |
+
"metadata/bbox": ex["objects"]["bbox"],
|
1106 |
+
"metadata/image_size": tf.stack([img_w, img_h,]),
|
1107 |
+
}
|
1108 |
+
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
|
1109 |
+
return out
|
1110 |
+
|
1111 |
+
|
1112 |
+
def point_text_interleaved(*args):
|
1113 |
+
raise NotImplementedError()
|
1114 |
+
|
1115 |
+
|
1116 |
+
@seqio.map_over_dataset
|
1117 |
+
def web_pointing_preprocessor(ex):
|
1118 |
+
img_h = tf.shape(ex["image"])[0]
|
1119 |
+
img_w = tf.shape(ex["image"])[1]
|
1120 |
+
|
1121 |
+
question = point_text_interleaved(
|
1122 |
+
img_h, img_w, ex["question"], ex["question_points"]["x"], ex["question_points"]["y"])
|
1123 |
+
answer = point_text_interleaved(
|
1124 |
+
img_h, img_w, ex["answer"], ex["answer_points"]["x"], ex["answer_points"]["y"])
|
1125 |
+
answer_points = tf.stack([ex["answer_points"]["x"], ex["answer_points"]["y"]], axis=1)
|
1126 |
+
return {
|
1127 |
+
"question": question,
|
1128 |
+
"answer": answer,
|
1129 |
+
"image": ex["image"],
|
1130 |
+
"metadata/image_size": [img_w, img_h],
|
1131 |
+
"metadata/question_type": ex["question_type"],
|
1132 |
+
"metadata/answer_points": tf.io.serialize_tensor(answer_points),
|
1133 |
+
"metadata/answer": answer,
|
1134 |
+
}
|
1135 |
+
|
1136 |
+
|
1137 |
+
def filter_pointing(ds):
|
1138 |
+
return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] >= 1)
|
1139 |
+
|
1140 |
+
|
1141 |
+
def filter_qa(ds):
|
1142 |
+
return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] == 0)
|
1143 |
+
|
1144 |
+
# vaia filtering
|
1145 |
+
def filter_image_only(ds):
|
1146 |
+
return ds.filter(lambda ex: ex["has_image"])
|
1147 |
+
|
1148 |
+
def filter_mc(ds):
|
1149 |
+
return ds.filter(lambda ex: ex["is_mc"])
|
1150 |
+
|
1151 |
+
def remove_is_long(ds):
|
1152 |
+
return ds.filter(lambda ex: not ex["is_long"])
|
1153 |
+
|
1154 |
+
def remove_has_multiple_parts(ds):
|
1155 |
+
return ds.filter(lambda ex: not ex["has_multiple_parts"])
|
1156 |
+
|
1157 |
+
|
1158 |
+
def _split(ds: tf.data.Dataset, keys, n_splits=2):
|
1159 |
+
def _map(ex):
|
1160 |
+
n = tf.shape(ex[keys[0]])[0]
|
1161 |
+
if n < n_splits:
|
1162 |
+
return tf.data.Dataset.from_tensors(ex)
|
1163 |
+
else:
|
1164 |
+
# import pdb; pdb.set_trace()
|
1165 |
+
bs = n // n_splits
|
1166 |
+
remainder = n - bs*n_splits
|
1167 |
+
lens = tf.concat([
|
1168 |
+
tf.ones([remainder], dtype=tf.int32),
|
1169 |
+
tf.zeros([n_splits-remainder], dtype=tf.int32),
|
1170 |
+
], axis=0) + bs
|
1171 |
+
tf.debugging.assert_equal(tf.reduce_sum(lens), n)
|
1172 |
+
ends = tf.cumsum(lens)
|
1173 |
+
|
1174 |
+
parts = []
|
1175 |
+
for split_ix in range(n_splits):
|
1176 |
+
part_ex = dict(ex)
|
1177 |
+
e = ends[split_ix]
|
1178 |
+
s = e - lens[split_ix]
|
1179 |
+
for k in keys:
|
1180 |
+
if isinstance(k, tuple):
|
1181 |
+
assert len(k) == 2
|
1182 |
+
part_ex[k[0]][k[1]] = ex[k[0]][k[1]][s:e]
|
1183 |
+
else:
|
1184 |
+
part_ex[k] = ex[k][s:e]
|
1185 |
+
parts.append(part_ex)
|
1186 |
+
|
1187 |
+
ds = tf.data.Dataset.from_tensors(parts[0])
|
1188 |
+
for sub_ds in parts[1:]:
|
1189 |
+
sub_ds = tf.data.Dataset.from_tensors(sub_ds)
|
1190 |
+
ds = ds.concatenate(sub_ds)
|
1191 |
+
return ds
|
1192 |
+
|
1193 |
+
return ds.flat_map(_map)
|
1194 |
+
|
1195 |
+
|
1196 |
+
|
1197 |
+
def split(ds, n=2):
|
1198 |
+
# return ds
|
1199 |
+
return _split(ds, [k for k in [
|
1200 |
+
"question",
|
1201 |
+
"label",
|
1202 |
+
"text",
|
1203 |
+
"entity",
|
1204 |
+
"messages"
|
1205 |
+
] if k in ds.element_spec], n_splits=n)
|
1206 |
+
|
1207 |
+
|
1208 |
+
def split_points(ds, max_points=50):
|
1209 |
+
label = "question" if "question" in ds.element_spec else "label"
|
1210 |
+
return _split(ds, [
|
1211 |
+
"question", label, "notInImage",
|
1212 |
+
("answer_points", "x"),
|
1213 |
+
("answer_points", "y"),
|
1214 |
+
])
|
1215 |
+
|
1216 |
+
|
1217 |
+
@seqio.map_over_dataset
|
1218 |
+
def fix_count_qa(ex):
|
1219 |
+
ex["label"] = ex["label"][::2]
|
1220 |
+
tf.debugging.assert_equal(tf.shape(ex["answer_points"]["x"])[0], tf.shape(ex["label"])[0])
|
1221 |
+
return ex
|
1222 |
+
|
1223 |
+
|
1224 |
+
def filter_points(ds, max_number=40):
|
1225 |
+
|
1226 |
+
def _add_valid(ex):
|
1227 |
+
valid = (
|
1228 |
+
tf.reduce_all(ex["answer_points"]["x"] >= 0.0, axis=-1) &
|
1229 |
+
tf.reduce_all(ex["answer_points"]["x"] <= 100.0, axis=-1) &
|
1230 |
+
tf.reduce_all(ex["answer_points"]["y"] >= 0.0, axis=-1) &
|
1231 |
+
tf.reduce_all(ex["answer_points"]["y"] <= 100.0, axis=-1) &
|
1232 |
+
(ex["answer_points"]["y"].row_lengths() <= max_number)
|
1233 |
+
)
|
1234 |
+
ex["valid"] = valid
|
1235 |
+
return ex
|
1236 |
+
ds = ds.map(_add_valid)
|
1237 |
+
ds = ds.filter(lambda ex: tf.reduce_any(ex["valid"]))
|
1238 |
+
return ds
|
1239 |
+
|
1240 |
+
|
1241 |
+
# def filter_points(ds, max_number=30):
|
1242 |
+
# n_points = ds["answer_points"]["x"].row_lengths()
|
1243 |
+
# parts = tf.TensorArray(tf.int32, size=tf.shape(n_points[0]), element_shape=tf.TensorShape([None]))
|
1244 |
+
# total = 0
|
1245 |
+
# on_row = 0
|
1246 |
+
# for i in range(n_points):
|
1247 |
+
# n = n_points[i]
|
1248 |
+
# if n > max_number:
|
1249 |
+
# continue
|
1250 |
+
# if n + total > max_number:
|
1251 |
+
#
|
1252 |
+
# return ds
|
1253 |
+
|
1254 |
+
|
1255 |
+
@seqio.map_over_dataset(num_seeds=2)
|
1256 |
+
def pointing_preprocessor(ex, sequence_length, seeds, with_count=False):
|
1257 |
+
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1258 |
+
img_h = tf.shape(image)[0]
|
1259 |
+
img_w = tf.shape(image)[1]
|
1260 |
+
|
1261 |
+
ix = tf.where(ex["valid"])[:, 0]
|
1262 |
+
ix = stateless_shuffle(ix, seeds[0])
|
1263 |
+
if "label" in ex:
|
1264 |
+
question = tf.strings.lower(ex["label"])
|
1265 |
+
else:
|
1266 |
+
question = ex["question"]
|
1267 |
+
question = tf.gather(question, ix) # [n_question]
|
1268 |
+
points_x = tf.gather(ex["answer_points"]["x"], ix) # [n_question, n_points[ragged]]]
|
1269 |
+
points_y = tf.gather(ex["answer_points"]["y"], ix)
|
1270 |
+
not_in_image = tf.gather(ex["notInImage"], ix) # [n_question]
|
1271 |
+
|
1272 |
+
n = tf.shape(points_x)[0]
|
1273 |
+
point_text = tf.TensorArray(dtype=tf.string, size=n, element_shape=()) # [n_question]
|
1274 |
+
point_seeds = tf.random.split(seeds[1], n)
|
1275 |
+
for i in range(n):
|
1276 |
+
answer = points_to_answer(points_x[i], points_y[i], 100, 100, point_seeds[i], question[i], with_count)
|
1277 |
+
point_text = point_text.write(i, answer)
|
1278 |
+
return {
|
1279 |
+
"image": image,
|
1280 |
+
"metadata/image_size": [img_w, img_h],
|
1281 |
+
"entity": question,
|
1282 |
+
"question": question,
|
1283 |
+
"text": point_text.stack(),
|
1284 |
+
}
|
1285 |
+
|
1286 |
+
|
1287 |
+
@seqio.map_over_dataset
|
1288 |
+
def pointing_inf_preprocessor(ex):
|
1289 |
+
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1290 |
+
img_h = tf.shape(ex["image"])[0]
|
1291 |
+
img_w = tf.shape(ex["image"])[1]
|
1292 |
+
|
1293 |
+
question = ex["question"]
|
1294 |
+
not_in_image = tf.shape(ex["answer_points"]["x"])[0] == 0
|
1295 |
+
|
1296 |
+
# points are stored in normalized format, de-normalize here
|
1297 |
+
points_x = ex["answer_points"]["x"] * tf.cast(img_w, tf.float32) / 100.0
|
1298 |
+
points_y = ex["answer_points"]["y"] * tf.cast(img_h, tf.float32) / 100.0
|
1299 |
+
|
1300 |
+
out = dict(
|
1301 |
+
image=ex["image"],
|
1302 |
+
question=question,
|
1303 |
+
entity=question,
|
1304 |
+
)
|
1305 |
+
out.update(_add_metadata(ex))
|
1306 |
+
out["metadata/not_in_image"] = not_in_image
|
1307 |
+
# We can't use `mask` directly since it is variable size, and thus it
|
1308 |
+
# will break batching. Here we serialize it instead
|
1309 |
+
serialized_masks = tf.map_fn(tf.io.serialize_tensor, ex["masks"], fn_output_signature=tf.string)
|
1310 |
+
serialized_masks = tf.strings.reduce_join(serialized_masks, separator="|||")
|
1311 |
+
out["metadata/mask"] = serialized_masks
|
1312 |
+
out["metadata/question"] = question
|
1313 |
+
out["metadata/answer_points"] = tf.io.serialize_tensor(tf.stack([points_x, points_y], 1))
|
1314 |
+
out["metadata/image_size"] = [img_w, img_h]
|
1315 |
+
|
1316 |
+
return out
|
1317 |
+
|
1318 |
+
|
1319 |
+
@seqio.map_over_dataset(num_seeds=1)
|
1320 |
+
def count_qa_preprocessor_inf(ex, sequence_length, seed):
|
1321 |
+
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1322 |
+
img_h = tf.shape(image)[0]
|
1323 |
+
img_w = tf.shape(image)[1]
|
1324 |
+
|
1325 |
+
entity = tf.strings.substr(
|
1326 |
+
ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
|
1327 |
+
entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
|
1328 |
+
entity = tf.strings.lower(entity)
|
1329 |
+
tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
|
1330 |
+
|
1331 |
+
return {
|
1332 |
+
"image": image,
|
1333 |
+
"metadata/image_size": [img_w, img_h],
|
1334 |
+
"metadata/count": tf.strings.to_number(ex["answer"]),
|
1335 |
+
"question": ex["question"],
|
1336 |
+
"entity": entity,
|
1337 |
+
}
|
1338 |
+
|
1339 |
+
|
1340 |
+
@seqio.map_over_dataset(num_seeds=1)
|
1341 |
+
def count_qa_preprocessor(ex, sequence_length, seed, with_count=False,
|
1342 |
+
for_inference=False):
|
1343 |
+
point_answer = ex["point_answer"]
|
1344 |
+
numbers_str = tf.strings.regex_replace(point_answer, r'\.$', '')
|
1345 |
+
numbers_str = tf.strings.regex_replace(numbers_str, r'[^\d\.\s]+', '')
|
1346 |
+
numbers_str = tf.strings.strip(numbers_str)
|
1347 |
+
numbers = tf.strings.split(numbers_str)
|
1348 |
+
float_numbers = tf.strings.to_number(numbers, out_type=tf.float32)
|
1349 |
+
coordinates = tf.reshape(float_numbers, (-1, 3))
|
1350 |
+
points_x = coordinates[:, 1]
|
1351 |
+
points_y = coordinates[:, 2]
|
1352 |
+
|
1353 |
+
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1354 |
+
img_h = tf.shape(image)[0]
|
1355 |
+
img_w = tf.shape(image)[1]
|
1356 |
+
entity = tf.strings.substr(
|
1357 |
+
ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
|
1358 |
+
entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
|
1359 |
+
entity = tf.strings.lower(entity)
|
1360 |
+
tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
|
1361 |
+
count = tf.strings.to_number(ex["answer"], out_type=tf.int32)
|
1362 |
+
if for_inference:
|
1363 |
+
return {
|
1364 |
+
"image": image,
|
1365 |
+
"metadata/image_size": [img_w, img_h],
|
1366 |
+
"metadata/count": count,
|
1367 |
+
"question": ex["question"],
|
1368 |
+
"entity": entity,
|
1369 |
+
}
|
1370 |
+
else:
|
1371 |
+
tf.debugging.assert_equal(count, tf.shape(points_x)[0])
|
1372 |
+
# points are already normalized so use w=1, h=1
|
1373 |
+
answer = points_to_answer(points_x, points_y, 1, 1, seed, entity, with_count)
|
1374 |
+
return {
|
1375 |
+
"image": image,
|
1376 |
+
"metadata/image_size": [img_w, img_h],
|
1377 |
+
"metadata/count": count,
|
1378 |
+
"question": ex["question"],
|
1379 |
+
"entity": entity,
|
1380 |
+
"text": answer,
|
1381 |
+
}
|
1382 |
+
|
1383 |
+
|
1384 |
+
@gin.configurable()
|
1385 |
+
@seqio.map_over_dataset
|
1386 |
+
def cleanup_preprocessor(ex, preprocess=False):
|
1387 |
+
if preprocess:
|
1388 |
+
ex["prompt"] = tf.strings.join(
|
1389 |
+
[
|
1390 |
+
"[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
|
1391 |
+
ex["prompt"],
|
1392 |
+
"\n[[Assistant]]: {after}"
|
1393 |
+
]
|
1394 |
+
)
|
1395 |
+
return ex
|
1396 |
+
else:
|
1397 |
+
return ex
|
1398 |
+
|
1399 |
+
|
1400 |
+
@gin.configurable()
|
1401 |
+
@seqio.map_over_dataset
|
1402 |
+
def random_text_preprocessor(ex, preprocess=False):
|
1403 |
+
ex["prompt"] = "What does the text say in this image?"
|
1404 |
+
if preprocess:
|
1405 |
+
ex["prompt"] = tf.strings.join(["[[User]]: ", ex["prompt"], "\n[[Assistant]]:"])
|
1406 |
+
return ex
|
1407 |
+
else:
|
1408 |
+
return ex
|
1409 |
+
|
1410 |
+
|
1411 |
+
@seqio.map_over_dataset(num_seeds=25)
|
1412 |
+
def clock_augmentation(ex, seeds):
|
1413 |
+
seeds = list(seeds)
|
1414 |
+
image = ex["image"]
|
1415 |
+
|
1416 |
+
# Apply shear, rotation, and scale through one affine matrix
|
1417 |
+
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1418 |
+
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1419 |
+
|
1420 |
+
_call_id = [0]
|
1421 |
+
|
1422 |
+
def _rng(_minval=0, _maxval=1, shape=(), dtype=tf.float32):
|
1423 |
+
return tf.random.stateless_uniform(shape, seeds.pop(), _minval, _maxval, dtype=dtype)
|
1424 |
+
|
1425 |
+
sel = _rng(0, 1)
|
1426 |
+
if sel < 0.1:
|
1427 |
+
# Straight on
|
1428 |
+
shear_x = 0.
|
1429 |
+
shear_y = 0.
|
1430 |
+
rotation = 0.
|
1431 |
+
elif sel < 0.5:
|
1432 |
+
# Normal looking
|
1433 |
+
shear_x = _rng(-10, 10)
|
1434 |
+
shear_y = _rng(-10, 10)
|
1435 |
+
rotation = _rng(-25, 25)
|
1436 |
+
else:
|
1437 |
+
# Allowed to be very wonky
|
1438 |
+
# if tf.random.stateless_uniform((), seeds.pop(), 0, 1) > 0.8:
|
1439 |
+
# image = image[:, ::-1]
|
1440 |
+
|
1441 |
+
if _rng() > 0.5:
|
1442 |
+
shear_x = _rng( -30, 30)
|
1443 |
+
shear_y = _rng( -30, 30)
|
1444 |
+
else:
|
1445 |
+
shear_x = _rng( -10, 10)
|
1446 |
+
shear_y = _rng( -10, 10)
|
1447 |
+
rng = _rng( 0, 1)
|
1448 |
+
if rng < 0.2:
|
1449 |
+
rotation = _rng( -25, 25)
|
1450 |
+
elif rng < 0.6:
|
1451 |
+
rotation = _rng( -80, 80)
|
1452 |
+
else:
|
1453 |
+
rotation = _rng( -180, 180)
|
1454 |
+
|
1455 |
+
if _rng() > 0.5:
|
1456 |
+
scale = _rng( 0.3, 2)
|
1457 |
+
else:
|
1458 |
+
scale = _rng( 0.3, 1)
|
1459 |
+
# Pad so upscaling/rotation will not move the image out of bounds
|
1460 |
+
pad = tf.cast(tf.maximum(height, width)*0.5, tf.int32)
|
1461 |
+
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
|
1462 |
+
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1463 |
+
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1464 |
+
|
1465 |
+
image = tf.keras.ops.image.affine_transform(
|
1466 |
+
image,
|
1467 |
+
tf.stack(get_affine_matrix(
|
1468 |
+
[height/2, width/2],
|
1469 |
+
rotation,
|
1470 |
+
[0, 0],
|
1471 |
+
1/scale,
|
1472 |
+
[shear_x, shear_y]
|
1473 |
+
) + [0., 0.]),
|
1474 |
+
interpolation='bilinear',
|
1475 |
+
fill_mode='constant',
|
1476 |
+
fill_value=1.,
|
1477 |
+
data_format='channels_last'
|
1478 |
+
)
|
1479 |
+
|
1480 |
+
# Crop, otherwise it would be impossible to put the image at the corner of the image
|
1481 |
+
not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
|
1482 |
+
no_white_ix = tf.where(not_white)
|
1483 |
+
top_left = tf.reduce_min(no_white_ix, axis=0)
|
1484 |
+
bottom_right = tf.reduce_max(no_white_ix, axis=0)
|
1485 |
+
image = tf.image.crop_to_bounding_box(
|
1486 |
+
image,
|
1487 |
+
offset_height=tf.cast(top_left[0], tf.int32),
|
1488 |
+
offset_width=tf.cast(top_left[1], tf.int32),
|
1489 |
+
target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
|
1490 |
+
target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
|
1491 |
+
)
|
1492 |
+
|
1493 |
+
# Translate
|
1494 |
+
height, width = tf.shape(image)[0], tf.shape(image)[1]
|
1495 |
+
translation_seed = _rng(0, 1)
|
1496 |
+
if translation_seed < 0.2:
|
1497 |
+
h_pad = _rng(0, height//2, (2,), dtype=tf.int32)
|
1498 |
+
w_pad = _rng(0, width//2, (2,), dtype=tf.int32)
|
1499 |
+
else:
|
1500 |
+
h_pad = _rng(0, height*2, (2,), dtype=tf.int32)
|
1501 |
+
w_pad = _rng(0, width*2, (2,), dtype=tf.int32)
|
1502 |
+
image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
|
1503 |
+
constant_values=1)
|
1504 |
+
|
1505 |
+
# Random background color
|
1506 |
+
# color_rng = tf.random.stateless_uniform((4,), seeds.pop(), 0, 1)
|
1507 |
+
# random_color = color_rng[:3]
|
1508 |
+
# valid = tf.reduce_all(tf.reduce_sum(tf.abs(random_color[None, None, :] - image), -1) > 0.03)
|
1509 |
+
# if color_rng[0] < 0.2 and valid:
|
1510 |
+
# image = tf.where(tf.reduce_all(image < 0.99, axis=-1, keepdims=True),
|
1511 |
+
# image, image * 0 + random_color[None, None, :])
|
1512 |
+
|
1513 |
+
# Mild color hitter
|
1514 |
+
image = tf.image.stateless_random_hue(image, max_delta=0.05, seed=seeds.pop())
|
1515 |
+
image = tf.image.stateless_random_brightness(image, max_delta=0.15, seed=seeds.pop())
|
1516 |
+
image = tf.image.stateless_random_saturation(image, 0.8, 1.2, seed=seeds.pop())
|
1517 |
+
image = tf.image.stateless_random_contrast(image, 0.8, 1.2, seed=seeds.pop())
|
1518 |
+
|
1519 |
+
# ex["metadata/unaugmented_image"] = ex["image"]
|
1520 |
+
ex["image"] = image
|
1521 |
+
return ex
|
1522 |
+
|
1523 |
+
|
1524 |
+
@seqio.map_over_dataset
|
1525 |
+
def clocks_preprocessor(ex):
|
1526 |
+
time_format = ex["time_format"]
|
1527 |
+
shows_seconds = ex["shows_seconds"]
|
1528 |
+
hour, minute, second = [tf.cast(ex[k], tf.int32) for k in ["hour", "minute", "second"]]
|
1529 |
+
if hour == 0: # Midnight of the previous day
|
1530 |
+
am_pm = "PM"
|
1531 |
+
hour_str = 12
|
1532 |
+
hour = 24
|
1533 |
+
elif hour > 12:
|
1534 |
+
am_pm = "PM"
|
1535 |
+
hour_str = hour - 12
|
1536 |
+
else:
|
1537 |
+
hour_str = hour
|
1538 |
+
am_pm = "AM"
|
1539 |
+
hour_str = tf.strings.as_string(hour_str)
|
1540 |
+
minute_str = tf.strings.as_string(minute)
|
1541 |
+
if tf.strings.length(minute_str) == 1:
|
1542 |
+
minute_str = tf.strings.join(["0", minute_str])
|
1543 |
+
|
1544 |
+
second_str = tf.strings.as_string(second)
|
1545 |
+
if tf.strings.length(second_str) == 1:
|
1546 |
+
second_str = tf.strings.join(["0", second_str])
|
1547 |
+
|
1548 |
+
prefix = "The time shown is "
|
1549 |
+
|
1550 |
+
if time_format == "The time is not shown":
|
1551 |
+
text = "The time is not shown in the image."
|
1552 |
+
hour, minute, second = -1, -1, -1
|
1553 |
+
else:
|
1554 |
+
if not shows_seconds:
|
1555 |
+
second = -1
|
1556 |
+
if time_format == "12 hour clock (without AM/PM)" and shows_seconds:
|
1557 |
+
if hour > 12:
|
1558 |
+
hour = hour - 12
|
1559 |
+
time = tf.strings.join([hour_str, ":", minute_str, ":", second_str])
|
1560 |
+
elif time_format == "12 hour clock (with AM/PM)" and shows_seconds:
|
1561 |
+
time = tf.strings.join([hour_str, ":", minute_str, ":", second_str, " ", am_pm])
|
1562 |
+
elif time_format == "12 hour clock (with AM/PM)" and not shows_seconds:
|
1563 |
+
time = tf.strings.join([hour_str, ":", minute_str, " ", am_pm])
|
1564 |
+
elif time_format == "12 hour clock (without AM/PM)" and not shows_seconds:
|
1565 |
+
if hour > 12:
|
1566 |
+
hour = hour - 12
|
1567 |
+
time = tf.strings.join([hour_str, ":", minute_str])
|
1568 |
+
else:
|
1569 |
+
time = "" # Should never occur, but needed for tf analysis
|
1570 |
+
tf.debugging.assert_equal(tf.strings.length(time) > 0, True)
|
1571 |
+
text = tf.strings.join(["The time shown is ", time])
|
1572 |
+
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1573 |
+
image = tf.image.convert_image_dtype(image, tf.float32)[:-120] # remove the black shadow at the bottom
|
1574 |
+
return {
|
1575 |
+
"image": image,
|
1576 |
+
"prompt": "What time is being shown?",
|
1577 |
+
"text": text,
|
1578 |
+
"metadata/time_format": time_format,
|
1579 |
+
"metadata/hour": hour,
|
1580 |
+
"metadata/minute": minute,
|
1581 |
+
"metadata/text": text,
|
1582 |
+
"metadata/second": second,
|
1583 |
+
}
|
1584 |
+
|
1585 |
+
|
1586 |
+
@seqio.map_over_dataset()
|
1587 |
+
def atlas_obscura_preprocessor(ex):
|
1588 |
+
out = dict(
|
1589 |
+
image=ex["image"],
|
1590 |
+
prompt="Where was this picture taken?",
|
1591 |
+
text=tf.strings.join([
|
1592 |
+
ex["place"],
|
1593 |
+
" in ",
|
1594 |
+
ex["city"]
|
1595 |
+
])
|
1596 |
+
)
|
1597 |
+
out["metadata/image_url"] = ex["image_url"]
|
1598 |
+
out["metadata/references"] = out["text"]
|
1599 |
+
return out
|
1600 |
+
|
1601 |
+
|
1602 |
+
@seqio.map_over_dataset()
|
1603 |
+
def famous_birthdays_preprocessor(ex):
|
1604 |
+
out = dict(
|
1605 |
+
image=ex["image"],
|
1606 |
+
image_url=ex["image_url"],
|
1607 |
+
prompt="Who is this?",
|
1608 |
+
text=ex["name"]
|
1609 |
+
)
|
1610 |
+
out["metadata/references"] = out["text"]
|
1611 |
+
return out
|
1612 |
+
|
1613 |
+
|
1614 |
+
@seqio.map_over_dataset()
|
1615 |
+
def mild_color_aug_preprocessor(ex):
|
1616 |
+
if "image_url" in ex: # URL won't show the augmentations
|
1617 |
+
del ex["image_url"]
|
1618 |
+
# ex["metadata/unaugmented_image"] = ex["image"]
|
1619 |
+
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1620 |
+
ex["image"] = mild_color_aug(ex["image"])
|
1621 |
+
return ex
|
1622 |
+
|
1623 |
+
|
1624 |
+
def build_text_with_points(text, points, img_h, img_w):
|
1625 |
+
points = points_to_text(img_h, img_w, points[:, 0], points[:, 1])
|
1626 |
+
parts = tf.strings.split(text, sep="<ANS>")
|
1627 |
+
with_points = tf.strings.reduce_join(tf.reshape(tf.stack([
|
1628 |
+
parts,
|
1629 |
+
tf.pad(points, [[0, 1]], constant_values=""),
|
1630 |
+
], 1), [-1]), separator="")
|
1631 |
+
return tf.strings.split(with_points, "\n\n")
|
1632 |
+
|
1633 |
+
|
1634 |
+
@seqio.map_over_dataset()
|
1635 |
+
def synth_count_preprocessor(example):
|
1636 |
+
image_shape = tf.shape(example["image"])
|
1637 |
+
h, w = image_shape[0], image_shape[1]
|
1638 |
+
questions = build_text_with_points(example["questions"], example["question_points"], h, w)
|
1639 |
+
answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
|
1640 |
+
keep_q = tf.strings.regex_full_match(questions, "How many.*")
|
1641 |
+
keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
|
1642 |
+
keep = tf.logical_and(keep_q, keep_ans)
|
1643 |
+
questions = tf.boolean_mask(questions, keep)
|
1644 |
+
answers = tf.boolean_mask(answers, keep)
|
1645 |
+
ix = tf.range(0, tf.shape(answers)[0], dtype=tf.int32)
|
1646 |
+
ix = tf.random.shuffle(ix)
|
1647 |
+
return dict(
|
1648 |
+
image=example["image"],
|
1649 |
+
prompt=tf.gather(questions, ix),
|
1650 |
+
text=tf.gather(answers, ix),
|
1651 |
+
)
|
1652 |
+
|
1653 |
+
|
1654 |
+
def synth_count_inf_preprocessor(ds):
|
1655 |
+
|
1656 |
+
@seqio.map_over_dataset(num_seeds=1)
|
1657 |
+
def get_two(example, seed):
|
1658 |
+
image_shape = tf.shape(example["image"])
|
1659 |
+
h, w = image_shape[0], image_shape[1]
|
1660 |
+
questions = build_text_with_points(example["questions"], example["question_points"], h, w)
|
1661 |
+
answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
|
1662 |
+
keep_q = tf.strings.regex_full_match(questions, "How many.*")
|
1663 |
+
keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
|
1664 |
+
keep = tf.logical_and(keep_q, keep_ans)
|
1665 |
+
questions = tf.boolean_mask(questions, keep)
|
1666 |
+
answers = tf.boolean_mask(answers, keep)
|
1667 |
+
|
1668 |
+
ix = stateless_permutation(tf.shape(answers)[0], seed)[:2]
|
1669 |
+
return {
|
1670 |
+
"image": example["image"],
|
1671 |
+
"prompt": tf.gather(questions, ix),
|
1672 |
+
"metadata/references": tf.gather(answers, ix),
|
1673 |
+
}
|
1674 |
+
|
1675 |
+
ds = get_two(ds)
|
1676 |
+
return flatten_parts(ds, ["prompt", "metadata/references"])
|
1677 |
+
|
1678 |
+
|
1679 |
+
def mild_color_aug(image):
|
1680 |
+
image = tf.image.random_hue(image, max_delta=0.05)
|
1681 |
+
image = tf.image.random_brightness(image, max_delta=0.15)
|
1682 |
+
image = tf.image.random_saturation(image, 0.7, 1.3)
|
1683 |
+
image = tf.image.random_contrast(image, 0.8, 1.2)
|
1684 |
+
return image
|
1685 |
+
|
1686 |
+
|
1687 |
+
@seqio.map_over_dataset()
|
1688 |
+
def name_entity_augmentation(ex, p_high_color=0.7):
|
1689 |
+
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
1690 |
+
image = ex["image"]
|
1691 |
+
image = tf.image.convert_image_dtype(image, tf.float32)
|
1692 |
+
|
1693 |
+
# Horizontal flip
|
1694 |
+
if tf.random.uniform((), 0, 1) > 0.85:
|
1695 |
+
image = image[:, ::-1]
|
1696 |
+
|
1697 |
+
# Random crop
|
1698 |
+
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1699 |
+
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1700 |
+
crop_rng = tf.random.uniform((), 0, 1)
|
1701 |
+
if crop_rng < 0.2:
|
1702 |
+
pass
|
1703 |
+
else:
|
1704 |
+
if crop_rng < 0.4:
|
1705 |
+
h_crop = height * 0.15
|
1706 |
+
w_crop = width * 0.15
|
1707 |
+
else:
|
1708 |
+
h_crop = height * 0.4
|
1709 |
+
w_crop = width * 0.4
|
1710 |
+
crop_h = tf.cast(tf.random.uniform((2,), 0, h_crop/2), tf.int32)
|
1711 |
+
crop_w = tf.cast(tf.random.uniform((2,), 0, w_crop/2), tf.int32)
|
1712 |
+
image = image[crop_h[0]:-crop_h[1]-1, crop_w[0]:-crop_w[1]-1]
|
1713 |
+
height = tf.cast(tf.shape(image)[0], tf.float32)
|
1714 |
+
width = tf.cast(tf.shape(image)[1], tf.float32)
|
1715 |
+
|
1716 |
+
if tf.random.uniform(()) > p_high_color:
|
1717 |
+
image = tf.image.random_hue(image, max_delta=0.05)
|
1718 |
+
image = tf.image.random_brightness(image, max_delta=0.15)
|
1719 |
+
image = tf.image.random_saturation(image, 0.7, 1.3)
|
1720 |
+
image = tf.image.random_contrast(image, 0.8, 1.2)
|
1721 |
+
else:
|
1722 |
+
image = tf.image.random_hue(image, max_delta=0.1)
|
1723 |
+
image = tf.image.random_brightness(image, max_delta=0.3)
|
1724 |
+
image = tf.image.random_saturation(image, 0.0, 2.0)
|
1725 |
+
image = tf.image.random_contrast(image, 0.2, 1.5)
|
1726 |
+
|
1727 |
+
# Apply shear, rotation, and scale through one affine matrix
|
1728 |
+
sel = tf.random.uniform((), 0, 1)
|
1729 |
+
if sel < 0.1:
|
1730 |
+
pass
|
1731 |
+
else:
|
1732 |
+
if sel < 0.15: # Scale only
|
1733 |
+
shear_x = 0
|
1734 |
+
shear_y = 0
|
1735 |
+
rotation = 0
|
1736 |
+
if sel < 0.7: # Mild
|
1737 |
+
shear_x = tf.random.uniform((), -2, 2)
|
1738 |
+
shear_y = tf.random.uniform((), -2, 2)
|
1739 |
+
rotation = tf.random.uniform((), -5, 5)
|
1740 |
+
else: # Severe
|
1741 |
+
shear_x = tf.random.uniform((), -10, 10)
|
1742 |
+
shear_y = tf.random.uniform((), -10, 10)
|
1743 |
+
rotation = tf.random.uniform((), -20, 20)
|
1744 |
+
|
1745 |
+
max_scale = 1.2
|
1746 |
+
scale = tf.random.uniform((), 0.4, max_scale)
|
1747 |
+
|
1748 |
+
# Pad so upscaling/rotation will not move the image out of bounds
|
1749 |
+
pad = tf.cast(tf.maximum(height, width)*0.2, tf.int32)
|
1750 |
+
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
|
1751 |
+
|
1752 |
+
image = tf.keras.ops.image.affine_transform(
|
1753 |
+
image,
|
1754 |
+
tf.stack(get_affine_matrix(
|
1755 |
+
[height/2, width/2],
|
1756 |
+
rotation,
|
1757 |
+
[0, 0],
|
1758 |
+
1/scale,
|
1759 |
+
[shear_x, shear_y]
|
1760 |
+
) + [0., 0.]),
|
1761 |
+
interpolation='bilinear',
|
1762 |
+
fill_mode='constant',
|
1763 |
+
fill_value=1.,
|
1764 |
+
data_format='channels_last'
|
1765 |
+
)
|
1766 |
+
|
1767 |
+
# Crop, otherwise it would be impossible to put the image at the corner of the image
|
1768 |
+
not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
|
1769 |
+
no_white_ix = tf.where(not_white)
|
1770 |
+
top_left = tf.reduce_min(no_white_ix, axis=0)
|
1771 |
+
bottom_right = tf.reduce_max(no_white_ix, axis=0)
|
1772 |
+
|
1773 |
+
# Very low chance center crop will get nothing but white space, we just skip
|
1774 |
+
if (
|
1775 |
+
(bottom_right[0] - top_left[0]) > 1 and (bottom_right[1] - top_left[1]) > 1
|
1776 |
+
):
|
1777 |
+
image = tf.image.crop_to_bounding_box(
|
1778 |
+
image,
|
1779 |
+
offset_height=tf.cast(top_left[0], tf.int32),
|
1780 |
+
offset_width=tf.cast(top_left[1], tf.int32),
|
1781 |
+
target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
|
1782 |
+
target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
|
1783 |
+
)
|
1784 |
+
|
1785 |
+
# Translate
|
1786 |
+
height, width = tf.shape(image)[0], tf.shape(image)[1]
|
1787 |
+
if tf.random.uniform((), 0, 1) < 0.1:
|
1788 |
+
h_pad = tf.zeros((2,), dtype=tf.int32)
|
1789 |
+
w_pad = tf.zeros((2,), dtype=tf.int32)
|
1790 |
+
elif tf.random.uniform((), 0, 1) < 0.8:
|
1791 |
+
h_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
|
1792 |
+
w_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
|
1793 |
+
else:
|
1794 |
+
pad = tf.cast(tf.maximum(height, width), tf.int32)
|
1795 |
+
h_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
|
1796 |
+
w_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
|
1797 |
+
image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
|
1798 |
+
constant_values=1)
|
1799 |
+
|
1800 |
+
if "image_url" in ex: # URL won't show the augmentations
|
1801 |
+
del ex["image_url"]
|
1802 |
+
# ex["metadata/unaugmented_image"] = ex["image"]
|
1803 |
+
ex["image"] = image
|
1804 |
+
return ex
|
1805 |
+
|
1806 |
+
|
1807 |
+
@seqio.map_over_dataset()
|
1808 |
+
def wiki_art_preprocessor(ex):
|
1809 |
+
out = dict(
|
1810 |
+
image=ex["image"],
|
1811 |
+
prompt="What is this?",
|
1812 |
+
text=ex["question"]
|
1813 |
+
)
|
1814 |
+
out["metadata/title"] = ex["title"]
|
1815 |
+
out["metadata/gt"] = ex["question"]
|
1816 |
+
out["metadata/artist"] = ex["artist"]
|
1817 |
+
out["metadata/painting_url"] = ex["painting_url"]
|
1818 |
+
# if "metadata/unaugmented_image" in ex:
|
1819 |
+
# out["metadata/unaugmented_image"] = ex["metadata/unaugmented_image"]
|
1820 |
+
return out
|
1821 |
+
|
1822 |
+
@seqio.map_over_dataset()
|
1823 |
+
def oscar_preprocessor(ex):
|
1824 |
+
out = dict(
|
1825 |
+
image=ex["image"],
|
1826 |
+
prompt=ex["question"]
|
1827 |
+
)
|
1828 |
+
out.update(_add_metadata(ex))
|
1829 |
+
out["metadata/question"] = ex["question"]
|
1830 |
+
out["metadata/answer"] = ex["answer"]
|
1831 |
+
out["metadata/category"] = ex["category"]
|
1832 |
+
return out
|
1833 |
+
|
1834 |
+
|
1835 |
+
@seqio.map_over_dataset()
|
1836 |
+
def tulu_preprocessor(ex):
|
1837 |
+
return {
|
1838 |
+
"messages": ex["messages"]["content"],
|
1839 |
+
}
|
1840 |
+
# logging.info("Debugging tulue")
|
1841 |
+
# return {"messages": ex["messages"]["content"], "text_weights": 1e-6}
|
1842 |
+
|
1843 |
+
|
1844 |
+
WIKI_DATA_QUESTION = "What is this? Respond with just a proper name."
|
1845 |
+
|
1846 |
+
|
1847 |
+
@seqio.map_over_dataset()
|
1848 |
+
def extract_wiki_data(ex):
|
1849 |
+
return dict(
|
1850 |
+
image=ex["image"],
|
1851 |
+
image_url=ex["image_url"],
|
1852 |
+
prompt=[
|
1853 |
+
WIKI_DATA_QUESTION,
|
1854 |
+
"What is this? Respond with the proper name of the main focus of the image and a few details about it."
|
1855 |
+
],
|
1856 |
+
text=[
|
1857 |
+
tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", "")),
|
1858 |
+
ex["gptResponse"],
|
1859 |
+
]
|
1860 |
+
)
|
1861 |
+
|
1862 |
+
|
1863 |
+
@seqio.map_over_dataset()
|
1864 |
+
def extract_wiki_data_name(ex):
|
1865 |
+
target = tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", ""))
|
1866 |
+
out = dict(
|
1867 |
+
image=ex["image"],
|
1868 |
+
image_url=ex["image_url"],
|
1869 |
+
prompt=WIKI_DATA_QUESTION,
|
1870 |
+
text=target,
|
1871 |
+
)
|
1872 |
+
out["metadata/references"] = target
|
1873 |
+
return out
|
1874 |
+
|
1875 |
+
|
1876 |
+
@seqio.map_over_dataset()
|
1877 |
+
def extract_wiki_data_describe(ex):
|
1878 |
+
out = dict(
|
1879 |
+
image=ex["image"],
|
1880 |
+
image_url=ex["image_url"],
|
1881 |
+
prompt="What is this? Respond with the proper name of the main focus of the image and a few details about it.",
|
1882 |
+
)
|
1883 |
+
out["metadata/references"] = ex["gptResponse"]
|
1884 |
+
return out
|
1885 |
+
|
1886 |
+
|
1887 |
+
@gin.configurable()
|
1888 |
+
def format_multiple_style_qa(ds, types=['multiple_choice', 'short_answer'], styles=['ai2_diagram', 'vqa2'], default_style='vqa2',
|
1889 |
+
strip_instruction=False):
|
1890 |
+
def _extract(ex):
|
1891 |
+
prompt = ex["question"]
|
1892 |
+
out = dict(image=ex["image"])
|
1893 |
+
out.update(_add_metadata(ex))
|
1894 |
+
|
1895 |
+
out["text"] = ex["answer"]
|
1896 |
+
out["metadata/references"] = ex["answer"]
|
1897 |
+
|
1898 |
+
if ex["metadata/question_type"] == 'multiple_choice':
|
1899 |
+
style = styles[0]
|
1900 |
+
else:
|
1901 |
+
style = styles[1]
|
1902 |
+
if strip_instruction:
|
1903 |
+
if ex["metadata/question_type"] == "multiple_choice":
|
1904 |
+
# parts = tf.strings.split(prompt, "\n")
|
1905 |
+
# parts 1 is blank and part -1 is the instruction
|
1906 |
+
# prompt = tf.strings.reduce_join(tf.concat([parts[:1], parts[2:-1]], 0), separator="\n")
|
1907 |
+
prompt = prompt
|
1908 |
+
else:
|
1909 |
+
prompt = tf.strings.split(prompt, "\n")[0]
|
1910 |
+
|
1911 |
+
out["style"] = style
|
1912 |
+
out["prompt"] = prompt
|
1913 |
+
return out
|
1914 |
+
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
1915 |
+
return ds
|
1916 |
+
|
1917 |
+
|
1918 |
+
@gin.configurable()
|
1919 |
+
def extract_mmmu(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
|
1920 |
+
assert option_format == "abc"
|
1921 |
+
keys_tensor = tf.constant(types, dtype=tf.string)
|
1922 |
+
values_tensor = tf.constant(styles, dtype=tf.string)
|
1923 |
+
table = tf.lookup.StaticHashTable(
|
1924 |
+
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
|
1925 |
+
default_value=tf.constant(default_style, dtype=tf.string),
|
1926 |
+
)
|
1927 |
+
def _extract(ex):
|
1928 |
+
out = dict(image=tf.expand_dims(ex["image_1"], 0))
|
1929 |
+
out.update(_add_metadata(ex))
|
1930 |
+
style = table.lookup(ex["metadata/question_type"])
|
1931 |
+
out["style"] = style
|
1932 |
+
out["text"] = ex["answer"]
|
1933 |
+
out["metadata/references"] = ex["answer"]
|
1934 |
+
|
1935 |
+
if style == styles[0]:
|
1936 |
+
abc = tf.constant(list("abcdefghi".upper()))
|
1937 |
+
options = ex["options"]
|
1938 |
+
num_options = tf.shape(options)[0]
|
1939 |
+
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
|
1940 |
+
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
|
1941 |
+
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
|
1942 |
+
|
1943 |
+
short_options = abc[:num_options]
|
1944 |
+
options = tf.stack([short_options, options,], 1)
|
1945 |
+
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
1946 |
+
options = tf.strings.reduce_join(options, separator="\n")
|
1947 |
+
out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
|
1948 |
+
if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
|
1949 |
+
# Following LLaVa, don't use any images if there are multiple images paths
|
1950 |
+
# I think the rationale is that this means the image are answer-options
|
1951 |
+
out["image"] = out["image"][:0]
|
1952 |
+
else:
|
1953 |
+
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
|
1954 |
+
out["prompt"] = ex["question"]
|
1955 |
+
out["image"] = out["image"][:0]
|
1956 |
+
return out
|
1957 |
+
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
1958 |
+
return ds
|
1959 |
+
|
1960 |
+
@gin.configurable()
|
1961 |
+
def extract_mmmu_cot(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
|
1962 |
+
assert option_format == "abc"
|
1963 |
+
keys_tensor = tf.constant(types, dtype=tf.string)
|
1964 |
+
values_tensor = tf.constant(styles, dtype=tf.string)
|
1965 |
+
table = tf.lookup.StaticHashTable(
|
1966 |
+
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
|
1967 |
+
default_value=tf.constant(default_style, dtype=tf.string),
|
1968 |
+
)
|
1969 |
+
def _extract(ex):
|
1970 |
+
# out = dict(image=tf.expand_dims(ex["image_with_question"], 0))
|
1971 |
+
out = dict(image=tf.expand_dims(ex["image_1"], 0))
|
1972 |
+
out.update(_add_metadata(ex))
|
1973 |
+
style = table.lookup(ex["metadata/question_type"])
|
1974 |
+
# out["style"] = style
|
1975 |
+
out["text"] = ex["answer"]
|
1976 |
+
out["metadata/question"] = ex["question"]
|
1977 |
+
out["metadata/references"] = ex["answer"]
|
1978 |
+
|
1979 |
+
if style == styles[0]:
|
1980 |
+
abc = tf.constant(list("abcdefghi".upper()))
|
1981 |
+
options = ex["options"]
|
1982 |
+
num_options = tf.shape(options)[0]
|
1983 |
+
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
|
1984 |
+
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
|
1985 |
+
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
|
1986 |
+
|
1987 |
+
short_options = abc[:num_options]
|
1988 |
+
options = tf.stack([short_options, options,], 1)
|
1989 |
+
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
|
1990 |
+
options = tf.strings.reduce_join(options, separator="\n")
|
1991 |
+
out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
|
1992 |
+
# out["prompt"] = ex["question"]
|
1993 |
+
if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, "<img='(.*?)'>"), tf.int32)) > 1:
|
1994 |
+
# Following LLaVa, don't use any images if there are multiple images paths
|
1995 |
+
# I think the rationale is that this means the image are answer-options
|
1996 |
+
out["image"] = out["image"][:0]
|
1997 |
+
else:
|
1998 |
+
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
|
1999 |
+
out["prompt"] = ex["question"]
|
2000 |
+
# out["image"] = out["image"][:0]
|
2001 |
+
return out
|
2002 |
+
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
2003 |
+
return ds
|
2004 |
+
|
2005 |
+
|
2006 |
+
@seqio.map_over_dataset
|
2007 |
+
def reformat_math_vista(ex):
|
2008 |
+
query = ex["query"]
|
2009 |
+
query = tf.strings.split(query, sep="Question:")[-1]
|
2010 |
+
query = tf.strings.strip(tf.strings.split(query, sep="Hint:")[0])
|
2011 |
+
ex["query"] = query
|
2012 |
+
return ex
|
2013 |
+
|
2014 |
+
|
2015 |
+
@seqio.map_over_dataset
|
2016 |
+
def extract_math_vista(ex, styles=['ai2_diagram', 'vqa2']):
|
2017 |
+
out = dict(image=ex["image"])
|
2018 |
+
out.update(_add_metadata(ex))
|
2019 |
+
|
2020 |
+
is_mc = ex["metadata/question_type"] == 'multi_choice'
|
2021 |
+
if is_mc:
|
2022 |
+
style = styles[0]
|
2023 |
+
abc = tf.constant(list("abcdefghi".upper()))
|
2024 |
+
options = ex["choices"]
|
2025 |
+
num_options = tf.shape(options)[0]
|
2026 |
+
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
|
2027 |
+
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
|
2028 |
+
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
|
2029 |
+
|
2030 |
+
if ex["metadata/split"] != "test":
|
2031 |
+
short_options = abc[:num_options]
|
2032 |
+
answer_short_option = tf.boolean_mask(short_options, options == ex["answer"])[0]
|
2033 |
+
out["text"] = answer_short_option
|
2034 |
+
else:
|
2035 |
+
out["text"] = ex["answer"]
|
2036 |
+
else:
|
2037 |
+
style = styles[1]
|
2038 |
+
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
|
2039 |
+
out["text"] = ex["answer"]
|
2040 |
+
out["style"] = style
|
2041 |
+
out["prompt"] = ex["query"]
|
2042 |
+
out["metadata/query"] = ex["query"]
|
2043 |
+
out["metadata/references"] = ex["answer"]
|
2044 |
+
return out
|
2045 |
+
|
2046 |
+
|
2047 |
+
NO_POINT_PREFIX = [
|
2048 |
+
"No pointing: ",
|
2049 |
+
"No pointing: ",
|
2050 |
+
"no pointing:\n",
|
2051 |
+
"No pointing:\n",
|
2052 |
+
"Not pointing:\n",
|
2053 |
+
"No Points: ",
|
2054 |
+
"No Points: ",
|
2055 |
+
"NO POINTING\n",
|
2056 |
+
"No pontiing\n",
|
2057 |
+
"No Points:\n ",
|
2058 |
+
"No pointing\n",
|
2059 |
+
"Do not point. ",
|
2060 |
+
"Refrain from pointing. ",
|
2061 |
+
"Avoid generating points . ",
|
2062 |
+
"For this question, do not use points. ",
|
2063 |
+
"Refrain from using points:\n",
|
2064 |
+
"Don't include points in your response. ",
|
2065 |
+
"Don't point. ",
|
2066 |
+
"Don't use points. ",
|
2067 |
+
"Please don't use points.\n\n",
|
2068 |
+
"Please don't use points.\n\n",
|
2069 |
+
"Respond without using points. ",
|
2070 |
+
"Respond without pointing:\n",
|
2071 |
+
"Do not generate ponits: ",
|
2072 |
+
"Do not point. ",
|
2073 |
+
"Do not point\n",
|
2074 |
+
"no pointing\n\n",
|
2075 |
+
"Answer without points: ",
|
2076 |
+
"Answer this question without pointing: ",
|
2077 |
+
"Answer without poiints. ",
|
2078 |
+
"answer without points: ",
|
2079 |
+
"answer with text only, do not points\n"
|
2080 |
+
]
|
2081 |
+
assert all(x[-1].isspace() for x in NO_POINT_PREFIX)
|
2082 |
+
NO_POINT_PREFIX_TF = tf.constant(NO_POINT_PREFIX)
|
2083 |
+
|
2084 |
+
|
2085 |
+
def prefix_how_many(messages, seed):
|
2086 |
+
question = messages[0]
|
2087 |
+
if tf.strings.regex_full_match(tf.strings.lower(question), "how many.*"):
|
2088 |
+
ix = tf.random.stateless_uniform((), seed, 0, len(NO_POINT_PREFIX), tf.int32)
|
2089 |
+
question = tf.strings.join([NO_POINT_PREFIX_TF[ix], question])
|
2090 |
+
return tf.concat([tf.expand_dims(question, 0), messages[1:]], axis=0)
|
2091 |
+
else:
|
2092 |
+
return messages
|
2093 |
+
|
2094 |
+
|
2095 |
+
@seqio.map_over_dataset(num_seeds=1)
|
2096 |
+
def prefix_how_many_messages(ex, seed):
|
2097 |
+
messages = ex["messages"]
|
2098 |
+
n = tf.shape(messages)[0]
|
2099 |
+
seeds = tf.random.split(seed, n)
|
2100 |
+
message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
|
2101 |
+
for i in range(n):
|
2102 |
+
message_arr = message_arr.write(i, prefix_how_many(messages[i], seeds[i]))
|
2103 |
+
ex["messages"] = tf.RaggedTensor.from_row_splits(
|
2104 |
+
values=message_arr.concat(), row_splits=messages.row_splits)
|
2105 |
+
return ex
|
2106 |
+
|
2107 |
+
|
2108 |
+
def filter_single_turn(ds):
|
2109 |
+
@seqio.map_over_dataset
|
2110 |
+
def _filter(ex):
|
2111 |
+
multi_turn = ex["messages"].row_lengths() > 2
|
2112 |
+
ex["messages"] = tf.ragged.boolean_mask(ex["messages"], multi_turn)
|
2113 |
+
return ex
|
2114 |
+
|
2115 |
+
ds = _filter(ds)
|
2116 |
+
ds = ds.filter(lambda x: tf.shape(x["messages"])[0] > 0)
|
2117 |
+
return ds
|
2118 |
+
|
2119 |
+
|
2120 |
+
@seqio.map_over_dataset(num_seeds=1)
|
2121 |
+
def extract_cockatoo_qa_v2(ex, seed):
|
2122 |
+
messages = tf.RaggedTensor.from_value_rowids(ex["messages"], ex["conversation_ids"])
|
2123 |
+
ix = stateless_permutation(tf.shape(messages)[0], seed)
|
2124 |
+
messages = tf.gather(messages, ix)
|
2125 |
+
out = dict(
|
2126 |
+
image=ex["image"],
|
2127 |
+
messages=messages
|
2128 |
+
)
|
2129 |
+
out.update(_add_metadata(ex))
|
2130 |
+
return out
|
2131 |
+
|
2132 |
+
|
2133 |
+
def format_mmbench(ds):
|
2134 |
+
|
2135 |
+
def _trim(ex):
|
2136 |
+
num_passes = tf.shape(ex["id"])[0]
|
2137 |
+
ex["choices"] = ex["choices"][:num_passes, :num_passes]
|
2138 |
+
ex["answer"] = ex["answer"][:num_passes]
|
2139 |
+
return ex
|
2140 |
+
|
2141 |
+
ds = ds.map(_trim)
|
2142 |
+
ds = flatten_parts(ds, ["id", "query", "choices", "answer"])
|
2143 |
+
|
2144 |
+
def _extract(ex):
|
2145 |
+
out = dict(image=ex["image"])
|
2146 |
+
out.update(_add_metadata(ex))
|
2147 |
+
out["prompt"] = ex["query"]
|
2148 |
+
out["text"] = ex["answer"]
|
2149 |
+
options = ex["choices"]
|
2150 |
+
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
|
2151 |
+
out["metadata/options"] = tf.strings.reduce_join(options, separator="|||")
|
2152 |
+
out["metadata/question"] = ex["question"]
|
2153 |
+
out["metadata/references"] = ex["answer"]
|
2154 |
+
return out
|
2155 |
+
|
2156 |
+
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
2157 |
+
return ds
|
2158 |
+
|
2159 |
+
|
2160 |
+
@seqio.map_over_dataset
|
2161 |
+
def extract_lvis(ex, class_name_file="gs://oe-training-chrisc/cockatoo/data/lvis_class_names.json"):
|
2162 |
+
with tf.io.gfile.GFile(class_name_file) as f:
|
2163 |
+
class_names = json.load(f)
|
2164 |
+
class_names_arr = [None]*len(class_names)
|
2165 |
+
for k, v in class_names.items():
|
2166 |
+
class_names_arr[int(k)] = v
|
2167 |
+
assert all(x is not None for x in class_names_arr)
|
2168 |
+
class_names_arr = tf.constant(class_names_arr)
|
2169 |
+
|
2170 |
+
return dict(
|
2171 |
+
image=ex["image"],
|
2172 |
+
bbox=ex["objects"]["bbox"],
|
2173 |
+
label=tf.gather(class_names_arr, ex["objects"]["label"]),
|
2174 |
+
)
|
2175 |
+
|
2176 |
+
|
2177 |
+
def extract_open_images_boxes(ds):
|
2178 |
+
# ds = ds.filter(lambda ex: tf.logical_or(
|
2179 |
+
# tf.shape(ex["cap/cap_caption"])[0] > 0,
|
2180 |
+
# tf.shape(ex["detection/bbox"])[0] > 0
|
2181 |
+
# ))
|
2182 |
+
ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
|
2183 |
+
|
2184 |
+
@seqio.map_over_dataset
|
2185 |
+
def _map(ex):
|
2186 |
+
bbox = tf.reshape(ex["detection/bbox"], (-1, 4))
|
2187 |
+
bbox = tf.stack([
|
2188 |
+
bbox[:, 2],
|
2189 |
+
bbox[:, 0],
|
2190 |
+
bbox[:, 3],
|
2191 |
+
bbox[:, 1]
|
2192 |
+
], 1)
|
2193 |
+
return dict(
|
2194 |
+
image=tf.image.decode_jpeg(ex["image"]),
|
2195 |
+
bbox=bbox,
|
2196 |
+
label=ex["detection/label"],
|
2197 |
+
caption=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
|
2198 |
+
)
|
2199 |
+
|
2200 |
+
return _map(ds)
|
2201 |
+
|
2202 |
+
|
2203 |
+
@seqio.map_over_dataset
|
2204 |
+
def region_captions_to_dense(ex):
|
2205 |
+
if "captions" in ex:
|
2206 |
+
captions = ex["captions"]["text"]
|
2207 |
+
boxes = ex["captions"]["bbox"]
|
2208 |
+
else:
|
2209 |
+
captions = ex["label"]
|
2210 |
+
boxes = ex["bbox"]
|
2211 |
+
|
2212 |
+
|
2213 |
+
sh = tf.cast(tf.shape(ex["image"])[:2], tf.float32)
|
2214 |
+
# image_h, image_w = sh[0], sh[1]
|
2215 |
+
w = boxes[:, 2] - boxes[:, 0]
|
2216 |
+
h = boxes[:, 3] - boxes[:, 1]
|
2217 |
+
|
2218 |
+
cx = tf.cast(boxes[:, 0] + w/2, tf.float32)
|
2219 |
+
cy = tf.cast(boxes[:, 1] + h/2, tf.float32)
|
2220 |
+
# w = w / image_w
|
2221 |
+
# h = h / image_h
|
2222 |
+
coor = tf.strings.reduce_join(
|
2223 |
+
float_to_text(tf.stack([cx, cy, w, h], 1)), separator=",", axis=1)
|
2224 |
+
|
2225 |
+
area = w*h
|
2226 |
+
if tf.random.uniform(()) < 0.5:
|
2227 |
+
coor_text = "before"
|
2228 |
+
captions = tf.strings.join([coor, captions], separator=": ")
|
2229 |
+
else:
|
2230 |
+
coor_text = "after"
|
2231 |
+
captions = tf.strings.join([captions, coor], separator=": ")
|
2232 |
+
|
2233 |
+
ix = tf.random.uniform((), 0, 6, tf.int32)
|
2234 |
+
center = boxes
|
2235 |
+
if ix == 0:
|
2236 |
+
order_text = "left"
|
2237 |
+
sort_by = boxes[:, 0]
|
2238 |
+
elif ix == 1:
|
2239 |
+
order_text = "right"
|
2240 |
+
sort_by = -boxes[:, 2]
|
2241 |
+
elif ix == 2:
|
2242 |
+
order_text = "top"
|
2243 |
+
sort_by = boxes[:, 1]
|
2244 |
+
elif ix == 3:
|
2245 |
+
order_text = "bottom"
|
2246 |
+
sort_by = -boxes[:, 3]
|
2247 |
+
elif ix == 4:
|
2248 |
+
order_text = "largest"
|
2249 |
+
sort_by = area
|
2250 |
+
else:
|
2251 |
+
order_text = "smallest"
|
2252 |
+
sort_by = -area
|
2253 |
+
ixs = tf.argsort(sort_by)
|
2254 |
+
captions = tf.gather(captions, ixs)
|
2255 |
+
text = tf.strings.join([
|
2256 |
+
order_text,
|
2257 |
+
coor_text,
|
2258 |
+
tf.strings.reduce_join(captions, separator="\n")
|
2259 |
+
], separator="; ")
|
2260 |
+
|
2261 |
+
if "caption" in ex:
|
2262 |
+
if tf.random.uniform(()) > 0.5:
|
2263 |
+
text = tf.strings.join([text, "\ncaption: ", ex["caption"]])
|
2264 |
+
else:
|
2265 |
+
text = tf.strings.join(["caption: ", ex["caption"], "\n", text])
|
2266 |
+
|
2267 |
+
return dict(
|
2268 |
+
image=ex["image"],
|
2269 |
+
text=text
|
2270 |
+
)
|
2271 |
+
|
2272 |
+
|
2273 |
+
@seqio.map_over_dataset()
|
2274 |
+
def join_captions(ex):
|
2275 |
+
text = tf.random.shuffle(ex['text'])
|
2276 |
+
ex["text"] = tf.strings.reduce_join(text, separator="\n")
|
2277 |
+
return ex
|
2278 |
+
|
2279 |
+
|
2280 |
+
@seqio.map_over_dataset(num_seeds=1)
|
2281 |
+
def extract_figureqa(ex, seed):
|
2282 |
+
questions = ex["questions"]
|
2283 |
+
n = stateless_permutation(tf.shape(questions["question"])[0], seed)
|
2284 |
+
return dict(
|
2285 |
+
image=ex["image"],
|
2286 |
+
questions=tf.gather(questions["question"], n),
|
2287 |
+
question_id=tf.gather(questions["question_id"], n),
|
2288 |
+
answer=tf.gather(tf.strings.as_string(questions["answer"]), n)
|
2289 |
+
)
|
2290 |
+
|
2291 |
+
|
2292 |
+
@seqio.map_over_dataset
|
2293 |
+
def convert_figureqa_answer(ex):
|
2294 |
+
keys_tensor = tf.constant(["0", "1"])
|
2295 |
+
values_tensor = tf.constant(["no", "yes"])
|
2296 |
+
table = tf.lookup.StaticHashTable(
|
2297 |
+
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
|
2298 |
+
default_value=tf.constant("nan", dtype=tf.string),
|
2299 |
+
)
|
2300 |
+
answer = table.lookup(ex["answer"])
|
2301 |
+
ex["answer"] = answer
|
2302 |
+
return ex
|
2303 |
+
|
2304 |
+
|
2305 |
+
@seqio.map_over_dataset()
|
2306 |
+
def build_question_with_hint(ex):
|
2307 |
+
hint = ex["hint"]
|
2308 |
+
if tf.strings.length(hint) > 0:
|
2309 |
+
ex["question"] = tf.strings.join([hint, ex["question"]], separator="\n")
|
2310 |
+
return ex
|
2311 |
+
|
2312 |
+
@seqio.map_over_dataset()
|
2313 |
+
def build_question_with_context(ex):
|
2314 |
+
context = ex["context"]
|
2315 |
+
if tf.strings.length(context) > 0:
|
2316 |
+
ex["question"] = tf.strings.join([context, ex["question"]], separator="\n")
|
2317 |
+
return ex
|
2318 |
+
|
2319 |
+
|
2320 |
+
def max_words(ds, max_words):
|
2321 |
+
return ds.filter(lambda x: x["n_words"] <= max_words)
|
2322 |
+
|
2323 |
+
|
2324 |
+
@seqio.map_over_dataset
|
2325 |
+
def format_pdfa_eng_wds(example):
|
2326 |
+
return dict(
|
2327 |
+
image=example["image"],
|
2328 |
+
text=tf.strings.reduce_join(example["lines"]["text"], separator="\n"),
|
2329 |
+
)
|
2330 |
+
|
2331 |
+
|
2332 |
+
@gin.configurable()
|
2333 |
+
def accuracy_conditioned_joint(ds, sequence_length, is_eval=False, eval_quality=17,
|
2334 |
+
transcript_quality=None):
|
2335 |
+
# v2: Transcripts no longer get a quality score
|
2336 |
+
is_training = sequence_length.get('is_training', True)
|
2337 |
+
if not is_training:
|
2338 |
+
if is_eval:
|
2339 |
+
prompt = f"quality {eval_quality}:"
|
2340 |
+
else:
|
2341 |
+
prompt = f"quality 17:"
|
2342 |
+
|
2343 |
+
@seqio.map_over_dataset
|
2344 |
+
def _with_prompt(ex):
|
2345 |
+
out = dict(
|
2346 |
+
image=ex["image"],
|
2347 |
+
url=ex["url"],
|
2348 |
+
prompt=prompt,
|
2349 |
+
)
|
2350 |
+
if "text" in ex:
|
2351 |
+
out["text"] = ex["text"]
|
2352 |
+
elif "caption" in ex:
|
2353 |
+
out["text"] = ex["caption"]
|
2354 |
+
return out
|
2355 |
+
return _with_prompt(ds)
|
2356 |
+
|
2357 |
+
elif is_eval:
|
2358 |
+
raise ValueError("is_eval=True and is_training=False")
|
2359 |
+
|
2360 |
+
# each transcript
|
2361 |
+
@seqio.map_over_dataset
|
2362 |
+
def _with_transcript(ex):
|
2363 |
+
if tf.shape(ex["edited_captions"]["caption"])[0] > 0:
|
2364 |
+
edited_caption = ex["edited_captions"]["caption"][0]
|
2365 |
+
n = ex["edited_captions"]["n_edits"][0]
|
2366 |
+
else:
|
2367 |
+
edited_caption = ""
|
2368 |
+
n = 0
|
2369 |
+
text = [
|
2370 |
+
ex["caption"],
|
2371 |
+
ex["transcripts"][tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)],
|
2372 |
+
edited_caption
|
2373 |
+
]
|
2374 |
+
edit_quality = 17 - n
|
2375 |
+
prompt = [
|
2376 |
+
"quality 17:",
|
2377 |
+
"" if transcript_quality is None else f"quality: {edit_quality}:",
|
2378 |
+
tf.strings.join(["quality ", tf.strings.as_string(edit_quality), ":"])
|
2379 |
+
]
|
2380 |
+
return dict(
|
2381 |
+
image=ex["image"],
|
2382 |
+
text=tf.stack(text, 0),
|
2383 |
+
url=ex["url"],
|
2384 |
+
prompt=tf.stack(prompt, 0),
|
2385 |
+
style=["long_caption", "transcript", "long_caption"]
|
2386 |
+
)
|
2387 |
+
return _with_transcript(ds)
|
2388 |
+
|
2389 |
+
|
2390 |
+
def select_dense_caption_sample(ds, samples=200):
|
2391 |
+
def compute_hash(string: str) -> str:
|
2392 |
+
return hashlib.sha256(string.encode("utf-8")).hexdigest()
|
2393 |
+
|
2394 |
+
with tf.io.gfile.GFile("gs://oe-training-chrisc/cockatoo/data/dense-caption-eval-v0-final-data.json") as f:
|
2395 |
+
data = json.load(f)
|
2396 |
+
for ex in data:
|
2397 |
+
ex["image_id"] = compute_hash(ex["image"])
|
2398 |
+
data.sort(key=lambda x: x["image_id"])
|
2399 |
+
np.random.RandomState(12312).shuffle(data)
|
2400 |
+
keep = tf.constant([x["image"] for x in data[:samples]])
|
2401 |
+
|
2402 |
+
def _keep(ex):
|
2403 |
+
return tf.reduce_any(ex["url"] == keep)
|
2404 |
+
ds = ds.filter(_keep)
|
2405 |
+
ds = tf.data.experimental.assert_cardinality(samples)(ds)
|
2406 |
+
return ds
|
2407 |
+
|
2408 |
+
@seqio.map_over_dataset()
|
2409 |
+
def charxiv_preprocessor(ex):
|
2410 |
+
question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4", "reasoning_q"]
|
2411 |
+
answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4", "reasoning_a"]
|
2412 |
+
|
2413 |
+
questions = [ex[name] for name in question_names]
|
2414 |
+
answers = [ex[name] for name in answer_names]
|
2415 |
+
|
2416 |
+
return dict(
|
2417 |
+
image=ex["image"],
|
2418 |
+
question=tf.stack(questions, 0),
|
2419 |
+
answer=tf.stack(answers, 0)
|
2420 |
+
)
|
2421 |
+
|
2422 |
+
@seqio.map_over_dataset()
|
2423 |
+
def charxiv_descriptive_preprocessor(ex):
|
2424 |
+
question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4"]
|
2425 |
+
answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4"]
|
2426 |
+
|
2427 |
+
questions = [ex[name] for name in question_names]
|
2428 |
+
answers = [ex[name] for name in answer_names]
|
2429 |
+
|
2430 |
+
return dict(
|
2431 |
+
image=ex["image"],
|
2432 |
+
question=tf.stack(questions, 0),
|
2433 |
+
answer=tf.stack(answers, 0)
|
2434 |
+
)
|
2435 |
+
|
2436 |
+
@seqio.map_over_dataset()
|
2437 |
+
def charxiv_reasoning_preprocessor(ex):
|
2438 |
+
return dict(
|
2439 |
+
image=ex["image"],
|
2440 |
+
question=ex["reasoning_q"],
|
2441 |
+
answer=ex["reasoning_a"]
|
2442 |
+
)
|
2443 |
+
|
2444 |
+
@seqio.map_over_dataset()
|
2445 |
+
def tablevqa_preprocessor(ex):
|
2446 |
+
return dict(
|
2447 |
+
image=ex["image"],
|
2448 |
+
question=ex["question"],
|
2449 |
+
answer=ex["gt"]
|
2450 |
+
)
|
2451 |
+
|
2452 |
+
@seqio.map_over_dataset()
|
2453 |
+
def vtabfact_preprocessor(ex):
|
2454 |
+
return dict(
|
2455 |
+
image=ex["image"],
|
2456 |
+
question=tf.strings.join([ex["question"], "Answer with yes or no."], separator="\n"),
|
2457 |
+
answer=ex["gt"]
|
2458 |
+
)
|
2459 |
+
|
2460 |
+
@seqio.map_over_dataset()
|
2461 |
+
def nutrition_fact_preprocessor(ex):
|
2462 |
+
question_names = ["descriptive_q", "reasoning_q"]
|
2463 |
+
answer_names = ["descriptive_a", "reasoning_a"]
|
2464 |
+
|
2465 |
+
questions = [ex[name] for name in question_names]
|
2466 |
+
answers = [ex[name] for name in answer_names]
|
2467 |
+
|
2468 |
+
return dict(
|
2469 |
+
image=ex["image"],
|
2470 |
+
question=tf.stack(questions, 0),
|
2471 |
+
answer=tf.stack(answers, 0)
|
2472 |
+
)
|
prompts.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
|
5 |
+
IMAGE_PROMPT = "<|image|>"
|
6 |
+
|
7 |
+
|
8 |
+
GENERAL_PROMPTS_V1 = {
|
9 |
+
"short_answer": [
|
10 |
+
"Answer this question very briefly\n{question}",
|
11 |
+
"{question} Answer with a few words",
|
12 |
+
"{question} Response very briefly",
|
13 |
+
"{question} Answer directly without any details, explanation, or elaboration",
|
14 |
+
"I have a question about this image, please answer it very briefly: {question}",
|
15 |
+
"Question: {question} Short Answer:",
|
16 |
+
"Question: {question}\nShort Answer:",
|
17 |
+
'{question}\nAnswer the question as briefly as possible.',
|
18 |
+
'Answer very briefly:\n{question}',
|
19 |
+
'The question "{question}" can be answered using the image. A short answer is',
|
20 |
+
"{question} Based on the image, respond to this question with a short answer:",
|
21 |
+
"{question} Short answer:",
|
22 |
+
"{question} A short answer to the question is",
|
23 |
+
"Give a short, matter-of-fact answer to this question: {question}",
|
24 |
+
"Give me a simple, direct answer to this question, do not elaborate or explain your answer:\n{question}"
|
25 |
+
],
|
26 |
+
"short_caption": [
|
27 |
+
'Caption the image with 1 or two sentences',
|
28 |
+
'Write a very short description of this image.',
|
29 |
+
'Briefly describe the image.',
|
30 |
+
'Look and this image, and then summarize it in a sentence or two.',
|
31 |
+
'Write a brief caption describing the image',
|
32 |
+
'Brief Caption:'
|
33 |
+
'A short image caption:',
|
34 |
+
'A short image description',
|
35 |
+
'Briefly describe the content of the image.',
|
36 |
+
'Can you give me one sentence summary of the picture?',
|
37 |
+
'How would you describe this image in a sentence or two?',
|
38 |
+
],
|
39 |
+
"long_caption": [
|
40 |
+
'Describe this image.',
|
41 |
+
'Describe this image',
|
42 |
+
'describe the image',
|
43 |
+
'Write a long description of this image.',
|
44 |
+
'caption the picture',
|
45 |
+
'Caption',
|
46 |
+
'caption',
|
47 |
+
'Construct a long caption for this image',
|
48 |
+
'Generate a caption',
|
49 |
+
'Create a detailed caption',
|
50 |
+
'Write a long caption',
|
51 |
+
'Describe this image in detail',
|
52 |
+
'Describe this',
|
53 |
+
'describe this',
|
54 |
+
'Caption this',
|
55 |
+
'What can be seen in this image?',
|
56 |
+
'What do you see in the image?',
|
57 |
+
'Look at this photo carefully and then tell me about it in detail',
|
58 |
+
'Write a long description of this image',
|
59 |
+
'Tell me about this picture.',
|
60 |
+
'Write a paragraph about this image.',
|
61 |
+
'Look at this image carefully and then describe it in detail',
|
62 |
+
'Generate a long caption about this image.'
|
63 |
+
],
|
64 |
+
"long_caption_no_pointing": [
|
65 |
+
'Describe this image in detail, but without any pointing.',
|
66 |
+
'Write a long description of this image, do not produce any points.',
|
67 |
+
'Tell me about this picture, use plain text only.',
|
68 |
+
'Generate a plain text description of this caption',
|
69 |
+
"What is in this image?\nNo pointing\nGive lots of detail"
|
70 |
+
"Write a long caption.\nDo not use image coordinates\nOutput a full paragraph"
|
71 |
+
],
|
72 |
+
"transcript": [
|
73 |
+
'Describe this image as if you are a person speaking',
|
74 |
+
'Imagine you are a person talking about this image. Generate a transcript of what you would say.',
|
75 |
+
"Generate an audio transcript of a person describing this image",
|
76 |
+
"Create a transcript of a human describing this image out load",
|
77 |
+
"Describe this in this style of a human talking",
|
78 |
+
],
|
79 |
+
"refexp": [
|
80 |
+
'What region does \"{refexp}\" refer to?',
|
81 |
+
],
|
82 |
+
"count_bench": [
|
83 |
+
'How many {object} are there?',
|
84 |
+
],
|
85 |
+
"refexp_pointing": [
|
86 |
+
'Where is the \"{refexp}\"?',
|
87 |
+
'Point to {refexp}',
|
88 |
+
'point at {refexp}',
|
89 |
+
'Find the {refexp}.',
|
90 |
+
'Which object in the image does \"{refexp}\" refer to?',
|
91 |
+
'Locate the object \"{refexp}\" refers to.',
|
92 |
+
'Point to the object that best matches the expression:\n{refexp}\n',
|
93 |
+
'What object could be described as: {refexp}.\nPoint:',
|
94 |
+
'Referring Expression: {refexp}.\nPoint:',
|
95 |
+
'Expression: {refexp}\nPoint to the refexp',
|
96 |
+
'Task: Point to the object that best matches the expression.\nExpression: {refexp}\nPoint:',
|
97 |
+
'Instruction: Locate the object that matches the expression by returning a point.\nReferring Expression: {refexp}\n',
|
98 |
+
'Help me find an object in this image by pointing to the {refexp}',
|
99 |
+
'What point of the image might the expression \'{refexp}\' refer to?',
|
100 |
+
],
|
101 |
+
"plain": ["{question}"],
|
102 |
+
"multiple_choice": [
|
103 |
+
"{question}\n{options}\nReturn only the letter of the best answer option",
|
104 |
+
"Answer this question by naming one of the provided options:\n{question}\n{options}",
|
105 |
+
"{question}\n{options}\nWhat option best answers the question?",
|
106 |
+
"{question}\n{options}\nReturn the best answer option",
|
107 |
+
"Look at the options, then return the letter of the option that best answers the question.\nQuesiton: {question}\nOptions: {options}",
|
108 |
+
"{question}? Select an answer option from:\n{options}",
|
109 |
+
"{question}\nSelect an answer option from:\n{options}\n\n",
|
110 |
+
"Question: {question}? Options: {options} Answer:",
|
111 |
+
"Answer the question by selecting an answer options\nQuestion: {question}\nOptions: {options}",
|
112 |
+
"{question}?\n{options}\nReturn only the letter of the correct answer",
|
113 |
+
"Help me answer this question: \"{question}\", by stating which of the following options is correct\n{options}."
|
114 |
+
],
|
115 |
+
"binary": ["{question}\nAnswer with 'yes' or 'no'"],
|
116 |
+
"pointing": [
|
117 |
+
"Point to {entity}\nPlease say 'This isn't in the image.' if it is not in the image.",
|
118 |
+
"Point to all occurrences of \"{entity}\"",
|
119 |
+
"Point to any {entity} in the image",
|
120 |
+
"Point to any {entity} in the image.",
|
121 |
+
"Point: Where are the {entity}",
|
122 |
+
"Show me where the {entity} are",
|
123 |
+
"Can you show me where the {entity} are?",
|
124 |
+
"Show me where the {entity} are",
|
125 |
+
"Show me where a {entity} is",
|
126 |
+
"Show me where a {entity} is.",
|
127 |
+
"If there are any {entity} in the image? Show me where they are.",
|
128 |
+
"Where are the {entity}?",
|
129 |
+
"Generate a list of points showing where the {entity} are.",
|
130 |
+
"Find the \"{entity}\".",
|
131 |
+
"Find a \"{entity}\".",
|
132 |
+
"Locate all {entity}.",
|
133 |
+
"Locate an {entity}.",
|
134 |
+
"Locate a {entity}.",
|
135 |
+
"Locate every {entity}.",
|
136 |
+
"Locate {entity}.",
|
137 |
+
"Locate the {entity}.",
|
138 |
+
"Object: {entity}\nInstruction: Point to the object.",
|
139 |
+
"find {entity}",
|
140 |
+
"find {entity}.",
|
141 |
+
"Point to every {entity}",
|
142 |
+
"find any {entity} in the picture",
|
143 |
+
"Find the {entity}",
|
144 |
+
"Find any {entity}",
|
145 |
+
"Point to a {entity}",
|
146 |
+
"Point to an {entity}",
|
147 |
+
"Look for {entity} in the image and show me where they are.",
|
148 |
+
"Help me find an object in the image by pointing to them.\nObject: {entity}.",
|
149 |
+
"I am looking for {entity}, where can they be found in the image?",
|
150 |
+
"Can you see any {entity} in the image? Point to them.",
|
151 |
+
"Point out each {entity} in the image.",
|
152 |
+
"Point out every {entity} in the image.",
|
153 |
+
"Point to the {entity} in the image.",
|
154 |
+
"Locate each {entity} in the image.",
|
155 |
+
"Can you point out all {entity} in this image?",
|
156 |
+
"Please find {entity} and show me where they are.",
|
157 |
+
"If there are any {entity} present, indicate their positions.",
|
158 |
+
"If there is a {entity} present, indicate its positions.",
|
159 |
+
"show me all visible {entity}",
|
160 |
+
],
|
161 |
+
"point_count": [
|
162 |
+
"How many {entity} are there?",
|
163 |
+
"How many {entity}?",
|
164 |
+
"How many {entity}.",
|
165 |
+
"how many {entity}.",
|
166 |
+
"how many {entity}?",
|
167 |
+
"How many {entity} are there in the image?",
|
168 |
+
"Tell me how many {entity} there are",
|
169 |
+
"Tell me how many {entity} there are and point to them.",
|
170 |
+
"how many {entity}",
|
171 |
+
"Tell me where each {entity} is.",
|
172 |
+
"Tell me how many {entity} are in the image",
|
173 |
+
"count {entity}",
|
174 |
+
"count every {entity}",
|
175 |
+
"count each {entity}",
|
176 |
+
"count {entity}.",
|
177 |
+
"Count the {entity}.",
|
178 |
+
"How many {entity} do you see?",
|
179 |
+
"How many {entity} are visible?",
|
180 |
+
"Count all the {entity}",
|
181 |
+
"how mmny {entity}?",
|
182 |
+
"Count every {entity} in the picture.",
|
183 |
+
"Count all the {entity}",
|
184 |
+
"Count each {entity}",
|
185 |
+
"Point to and count the {entity} in the picture.",
|
186 |
+
"Point and count {entity}",
|
187 |
+
"Point to every {entity}",
|
188 |
+
"Locate the {entity} and count them",
|
189 |
+
"Locate every {entity} and count them",
|
190 |
+
"Find all the {entity}. How many are there?",
|
191 |
+
"Find each {entity}. How many are there?",
|
192 |
+
"Point at {entity} and then tell me the count.",
|
193 |
+
"What is the total number of {entity} in the image?",
|
194 |
+
"In all the picture, how many {entity} are there?",
|
195 |
+
"Point at the {entity} and then count them.",
|
196 |
+
"Point to all the visible {entity} output the total count.",
|
197 |
+
"Point to all the {entity} visible and output the total count. \nPlease say 'This isn't in the image.' if it is not in the image.",
|
198 |
+
"Point to all occurrences of \"{entity}\" and output the total count.",
|
199 |
+
"Show me where the {entity} are and output the total count.",
|
200 |
+
"Where are the {entity}? How many are there?",
|
201 |
+
"Generate list of points showing where the {entity} are and output the total count.",
|
202 |
+
"Object: {entity}\nInstruction: Point to the object and output the total count.",
|
203 |
+
"find any {entity} in the picture and output the total count.",
|
204 |
+
"Can you see any {entity} in the image? Point to them and output the total count.",
|
205 |
+
"Can you point out all {entity} in this image? How many are there?",
|
206 |
+
"If there are any {entity} present, indicate their positions and output the total count.",
|
207 |
+
"How many {entity} are there in the image? Point to them and output the total count.",
|
208 |
+
"How many {entity} are there in the image?",
|
209 |
+
"Give me the count of {entity} in the image.",
|
210 |
+
"How many {entity} are visible in the image?",
|
211 |
+
"How many {entity} are there?",
|
212 |
+
"In the image, how many {entity} are there?",
|
213 |
+
"Can you count the number of {entity} in the image?",
|
214 |
+
"Can you count every {entity} in the picture?",
|
215 |
+
"Can you see any {entity} in the image? How many are there?",
|
216 |
+
"Are there any {entity} in the image? How many are there?",
|
217 |
+
"If you see any {entity} in the image, give me the count. Otherwise, say 'This isn't in the image.'",
|
218 |
+
"Object: {entity}\nInstruction: How many are there?",
|
219 |
+
],
|
220 |
+
|
221 |
+
# vaia
|
222 |
+
"detailed_solution": [
|
223 |
+
"Answer the question providing a step by step solution and answer in the end.\n"
|
224 |
+
"Provide a step-by-step solution to the question, ending with your final answer.\n",
|
225 |
+
"Please provide a step-by-step solution to the question shown in the image.\n",
|
226 |
+
"Give a detailed explanation for the question, concluding with your final answer.\n",
|
227 |
+
"Solve the problem presented in the question with a thorough explanation. Give me your final answer at the end.\n",
|
228 |
+
"Please analyze the question and provide a complete solution, finishing with your final answer.\n",
|
229 |
+
"Work through the problem, offering detailed reasoning before stating your final answer.\n",
|
230 |
+
"Interpret the question and guide me through the solution, concluding with your answer.\n",
|
231 |
+
"Review the question and deliver a well-explained solution, making sure to include your final answer.\n",
|
232 |
+
"Examine the question: provide a detailed explanation followed by your final answer.\n"
|
233 |
+
],
|
234 |
+
|
235 |
+
# vaia first answer with short_answer
|
236 |
+
"detailed_solution_answer_first": [
|
237 |
+
"Answer the question directly, then provide a step-by-step solution.\n",
|
238 |
+
"Please provide the answer first, followed by a step-by-step solution to the question shown in the image.\n",
|
239 |
+
"Give the final answer first, then provide a detailed explanation for the question.\n",
|
240 |
+
"Provide the final answer, then solve the problem presented in the question with a thorough explanation.\n",
|
241 |
+
"First, give the final answer, then analyze the question and provide a complete solution.\n",
|
242 |
+
"State the final answer first, then work through the problem, offering detailed reasoning.\n",
|
243 |
+
"Provide the final answer, then interpret the question and guide me through the solution.\n",
|
244 |
+
"Give the final answer first, then review the question and deliver a well-explained solution.\n",
|
245 |
+
"First, provide the final answer, then examine the question and give a detailed explanation.\n"
|
246 |
+
],
|
247 |
+
|
248 |
+
# vqa_online
|
249 |
+
"detailed_answer": [
|
250 |
+
"Answer the question providing a step-by-step explanation and answer in the end.\n",
|
251 |
+
"Provide a step-by-step explanation to the question, ending with your final answer.\n",
|
252 |
+
"Please provide a step-by-step explanation to the question shown in the image.\n",
|
253 |
+
"Give a detailed explanation for the question, concluding with your final answer.\n",
|
254 |
+
"Address the problem presented in the question with a thorough explanation. Give me your final answer at the end.\n",
|
255 |
+
"Please analyze the question and provide a complete explanation, finishing with your final answer.\n",
|
256 |
+
"Work through the problem, offering detailed reasoning before stating your final answer.\n",
|
257 |
+
"Interpret the question and guide me through the explanation, concluding with your answer.\n",
|
258 |
+
"Review the question and deliver a well-explained answer, making sure to include your final answer.\n",
|
259 |
+
"Examine the question: provide a detailed explanation followed by your final answer.\n"
|
260 |
+
],
|
261 |
+
}
|
262 |
+
|
263 |
+
GENERAL_PROMPTS_V1["pointing_tag"] = [txt + " Make the alt text and the inside of the tag the target label." for txt in GENERAL_PROMPTS_V1["pointing"]]
|
264 |
+
|
265 |
+
STYLE_TO_GENERAL_PROMPT = {
|
266 |
+
"vqa2": "short_answer",
|
267 |
+
"coco_captioning": "short_caption",
|
268 |
+
"gqa": "short_answer",
|
269 |
+
"ocr_vqa": "short_answer",
|
270 |
+
"tally_qa": "short_answer",
|
271 |
+
"text_vqa": "short_answer",
|
272 |
+
"okvqa": "short_answer",
|
273 |
+
"chart_qa": "short_answer",
|
274 |
+
"doc_qa": "short_answer",
|
275 |
+
"info_qa": "short_answer",
|
276 |
+
"science_qa": "multiple_choice",
|
277 |
+
"ai2_diagram": "multiple_choice",
|
278 |
+
"a_okvqa_mc": "multiple_choice",
|
279 |
+
"a_okvqa_da": "short_answer",
|
280 |
+
"long_caption": "long_caption",
|
281 |
+
"web_pointing": "plain",
|
282 |
+
"count_bench": "count_bench",
|
283 |
+
"refexp": "refexp",
|
284 |
+
"refexp_pointing": "refexp_pointing",
|
285 |
+
"vtabfact": "binary",
|
286 |
+
"vwtq": "short_answer",
|
287 |
+
"vwtq_syn": "short_answer",
|
288 |
+
"fintabnetqa": "short_answer",
|
289 |
+
"scifi_charts": "short_answer",
|
290 |
+
"scifi_charts_qa": "short_answer",
|
291 |
+
"charxiv_descriptive": "short_answer",
|
292 |
+
"charxiv_reasoning": "short_answer",
|
293 |
+
"pointing": "pointing",
|
294 |
+
"pointing_tag": "pointing_tag",
|
295 |
+
"point_count": "point_count",
|
296 |
+
"plain": "plain",
|
297 |
+
}
|
298 |
+
|
299 |
+
|
300 |
+
# def maybe_format_options(example, option_style="basic"):
|
301 |
+
# abc = tf.constant(list("abcdefg".upper()))
|
302 |
+
# if option_style == "random-v1":
|
303 |
+
# letter_option_sep = [": ", ". ", ")"]
|
304 |
+
# option_sep = ["\n", "\n", "\n", " ", ". ", ".\n", "; ", ", "]
|
305 |
+
# option_sep = tf.constant(option_sep)[tf.random.uniform((), 0, len(option_sep), tf.int32)]
|
306 |
+
# elif option_style == "basic":
|
307 |
+
# letter_option_sep = ": "
|
308 |
+
# option_sep = "\n"
|
309 |
+
# else:
|
310 |
+
# raise NotImplementedError(option_style)
|
311 |
+
#
|
312 |
+
# options = example["options"]
|
313 |
+
# short_options = abc[:tf.shape(options)[0]]
|
314 |
+
# sep = tf.constant(letter_option_sep)[tf.random.uniform((), 0, len(letter_option_sep), tf.int32)]
|
315 |
+
#
|
316 |
+
# options = tf.stack([short_options, options,], 1)
|
317 |
+
#
|
318 |
+
# options = tf.strings.reduce_join(options, axis=-1, separator=sep)
|
319 |
+
#
|
320 |
+
# options = tf.strings.reduce_join(options, separator=option_sep)
|
321 |
+
# example["options"] = options
|
322 |
+
# tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
|
323 |
+
# example["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
|
324 |
+
#
|
325 |
+
# if "answer_idx" in example:
|
326 |
+
# if example["answer_idx"] < 0:
|
327 |
+
# example["text"] = "?"
|
328 |
+
# else:
|
329 |
+
# example["text"] = short_options[example["answer_idx"]]
|
330 |
+
# example["metadata/answer_idx"] = example["answer_idx"]
|
331 |
+
# return example
|
332 |
+
|
333 |
+
|
334 |
+
def apply_keyword_prompt(prompts, example, seed=None, weights=None, keywords=None):
|
335 |
+
if isinstance(prompts, list):
|
336 |
+
assert keywords is None
|
337 |
+
all_keywords = [sorted(re.findall("{([^{}]+)}", x)) for x in prompts]
|
338 |
+
keywords = all_keywords[0]
|
339 |
+
assert len(keywords) == len(set(keywords)), f"Repeated keywords in {keywords}"
|
340 |
+
assert all(keywords == x for x in all_keywords), f"Inconsistent keywords in prompts {all_keywords}"
|
341 |
+
assert not any("{" not in word[1:-1] and "}" in word[1:-1] for word in keywords)
|
342 |
+
|
343 |
+
for k in keywords:
|
344 |
+
assert k in example, f"Example missing expected field {k}, example={example}"
|
345 |
+
prompts = tf.constant(prompts)
|
346 |
+
|
347 |
+
multiple = False
|
348 |
+
if "text" in example and len(example["text"].shape) > 0:
|
349 |
+
multiple = True
|
350 |
+
|
351 |
+
if weights is not None:
|
352 |
+
weights = tf.expand_dims(tf.math.log(weights), 0)
|
353 |
+
|
354 |
+
if seed is None:
|
355 |
+
raise ValueError()
|
356 |
+
|
357 |
+
if not multiple:
|
358 |
+
if weights is None:
|
359 |
+
prompt = prompts[tf.random.stateless_uniform((), seed, 0, len(prompts), dtype=tf.int32)]
|
360 |
+
else:
|
361 |
+
prompt = prompts[tf.random.stateless_categorical(weights, 1, seed, 0, len(prompts), dtype=tf.int32)][0, 0]
|
362 |
+
for keyword in keywords:
|
363 |
+
# We use split not regex_replace because regex_replace has issues with
|
364 |
+
# value strings with backslashes
|
365 |
+
res = tf.strings.split(prompt, "{"+keyword+"}", maxsplit=2)
|
366 |
+
prompt = tf.strings.join([res[0], example[keyword], res[1]])
|
367 |
+
return prompt
|
368 |
+
else:
|
369 |
+
n_prompts = tf.shape(example["text"])[0]
|
370 |
+
if weights is None:
|
371 |
+
ix = tf.random.stateless_uniform(
|
372 |
+
(n_prompts,), seed, 0, tf.shape(prompts)[0], dtype=tf.int32)
|
373 |
+
else:
|
374 |
+
ix = tf.random.stateless_categorical(
|
375 |
+
weights, tf.shape(prompts)[0], seed, 0, len(prompts), dtype=tf.int32)[0]
|
376 |
+
prompt = tf.gather(prompts, ix)
|
377 |
+
out = tf.TensorArray(dtype=tf.string, size=n_prompts, element_shape=())
|
378 |
+
for i in range(n_prompts):
|
379 |
+
modified = prompt[i]
|
380 |
+
for keyword in keywords:
|
381 |
+
res = tf.strings.split(modified, "{"+keyword+"}", maxsplit=2)
|
382 |
+
modified = tf.strings.join([res[0], example[keyword][i], res[1]])
|
383 |
+
out = out.write(i, modified)
|
384 |
+
return out.stack()
|
385 |
+
|
seqio_tokenizer.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The SeqIO Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Vocabularies."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
import dataclasses
|
19 |
+
import functools
|
20 |
+
import hashlib
|
21 |
+
import threading
|
22 |
+
from typing import Any, ClassVar, Dict, Iterable, Optional, Sequence, Union, List, Tuple
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
from absl import logging
|
26 |
+
import tensorflow.compat.v2 as tf
|
27 |
+
|
28 |
+
from sentencepiece import sentencepiece_model_pb2
|
29 |
+
import sentencepiece as sentencepiece_processor
|
30 |
+
|
31 |
+
PAD_ID = -1 # -1 for llama tokenizer
|
32 |
+
|
33 |
+
|
34 |
+
class Vocabulary(metaclass=abc.ABCMeta):
|
35 |
+
"""Abstract class for all vocabularies.
|
36 |
+
|
37 |
+
Subclasses must implement methods for converting between strings and tokens
|
38 |
+
both in pure python (`_encode`/`_decode`) and in TensorFlow
|
39 |
+
(`_encode_tf`/`_decode_tf`).
|
40 |
+
|
41 |
+
Subclasses are responsible for reserving PAD_ID=0 as well as optionally
|
42 |
+
reserving EOS_ID and UNK_ID
|
43 |
+
|
44 |
+
`_base_vocab_size` should account for PAD, EOS, and UNK but not `extra_ids`.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, extra_ids: int = 0):
|
48 |
+
"""Vocabulary constructor.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
extra_ids: The number of extra IDs to reserve.
|
52 |
+
"""
|
53 |
+
self._extra_ids = extra_ids or 0
|
54 |
+
|
55 |
+
@property
|
56 |
+
def bos_token_id(self) -> Optional[int]:
|
57 |
+
raise NotImplementedError("need to implement bos_id")
|
58 |
+
|
59 |
+
@property
|
60 |
+
@abc.abstractmethod
|
61 |
+
def eos_token_id(self) -> Optional[int]:
|
62 |
+
raise NotImplementedError("need to implement eos_id")
|
63 |
+
|
64 |
+
@property
|
65 |
+
def pad_id(self) -> int:
|
66 |
+
return PAD_ID
|
67 |
+
|
68 |
+
@property
|
69 |
+
@abc.abstractmethod
|
70 |
+
def unk_id(self) -> Optional[int]:
|
71 |
+
raise NotImplementedError("need to implement unk_id")
|
72 |
+
|
73 |
+
@property
|
74 |
+
def extra_ids(self) -> int:
|
75 |
+
return self._extra_ids
|
76 |
+
|
77 |
+
@property
|
78 |
+
def vocab_size(self) -> int:
|
79 |
+
"""Vocabulary size, including extra ids."""
|
80 |
+
return self._base_vocab_size + self.extra_ids
|
81 |
+
|
82 |
+
@property
|
83 |
+
@abc.abstractmethod
|
84 |
+
def _base_vocab_size(self) -> int:
|
85 |
+
"""Vocabulary size, excluding extra ids but including PAD/EOS/UNK."""
|
86 |
+
# TODO(fjord): add a check that pad_id and unk_id (if present)
|
87 |
+
# are less than _base_vocab_size.
|
88 |
+
raise NotImplementedError
|
89 |
+
|
90 |
+
@abc.abstractmethod
|
91 |
+
def _encode(self, s: str) -> Sequence[int]:
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]:
|
95 |
+
"""Tokenizes string to an int sequence, without adding EOS."""
|
96 |
+
return self._encode(s)
|
97 |
+
|
98 |
+
@abc.abstractmethod
|
99 |
+
def _decode(self, ids):
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
def decode(self, ids: Iterable[int], truncate_at_eos=True):
|
103 |
+
"""Detokenizes int32 iterable to a string, up through first EOS."""
|
104 |
+
clean_ids = list(ids)
|
105 |
+
|
106 |
+
if self.unk_id is not None:
|
107 |
+
vocab_size = self._base_vocab_size
|
108 |
+
clean_ids = [self.unk_id if i >= vocab_size else i for i in clean_ids]
|
109 |
+
|
110 |
+
if truncate_at_eos and (self.eos_token_id is not None and self.eos_token_id in clean_ids):
|
111 |
+
clean_ids = clean_ids[: clean_ids.index(self.eos_token_id) + 1]
|
112 |
+
|
113 |
+
return self._decode(clean_ids)
|
114 |
+
|
115 |
+
@abc.abstractmethod
|
116 |
+
def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
|
117 |
+
raise NotImplementedError
|
118 |
+
|
119 |
+
def encode_tf(self, s: tf.Tensor) -> tf.Tensor:
|
120 |
+
"""Tokenizes string Scalar to an int32 Tensor, without adding EOS."""
|
121 |
+
return self._encode_tf(s)
|
122 |
+
|
123 |
+
@abc.abstractmethod
|
124 |
+
def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
|
125 |
+
raise NotImplementedError
|
126 |
+
|
127 |
+
def decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
|
128 |
+
"""Detokenizes int32 batched Tensor through first EOS."""
|
129 |
+
clean_ids = ids
|
130 |
+
|
131 |
+
if self.unk_id is not None:
|
132 |
+
base_vocab_size = self._base_vocab_size
|
133 |
+
clean_ids = tf.where(
|
134 |
+
tf.less(clean_ids, base_vocab_size), clean_ids, self.unk_id
|
135 |
+
)
|
136 |
+
|
137 |
+
if self.eos_id is not None:
|
138 |
+
# Replace everything after the first eos_id with pad_id.
|
139 |
+
after_eos = tf.cumsum(
|
140 |
+
tf.cast(tf.equal(clean_ids, self.eos_id), tf.int32),
|
141 |
+
exclusive=True,
|
142 |
+
axis=-1,
|
143 |
+
)
|
144 |
+
clean_ids = tf.where(tf.cast(after_eos, tf.bool), self.pad_id, clean_ids)
|
145 |
+
|
146 |
+
return self._decode_tf(clean_ids)
|
147 |
+
|
148 |
+
|
149 |
+
class PassThroughVocabulary(Vocabulary):
|
150 |
+
"""Vocabulary that passes through inputs unchanged."""
|
151 |
+
|
152 |
+
def __init__(self, size: int, eos_id: Optional[Any] = None):
|
153 |
+
"""PassThroughVocabulary constructor.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
size: the full size of the vocabulary.
|
157 |
+
eos_id: the end-of-sequence token.
|
158 |
+
"""
|
159 |
+
self._size = size
|
160 |
+
self._eos_id = eos_id
|
161 |
+
super().__init__()
|
162 |
+
|
163 |
+
@property
|
164 |
+
def _base_vocab_size(self):
|
165 |
+
return self._size
|
166 |
+
|
167 |
+
def _encode(self, s: Sequence[Any]) -> Sequence[Any]:
|
168 |
+
return s
|
169 |
+
|
170 |
+
def _decode(self, ids: Sequence[Any]) -> Sequence[Any]:
|
171 |
+
return ids
|
172 |
+
|
173 |
+
def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
|
174 |
+
return s
|
175 |
+
|
176 |
+
def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
|
177 |
+
return ids
|
178 |
+
|
179 |
+
@property
|
180 |
+
def eos_id(self) -> Optional[Any]:
|
181 |
+
return self._eos_id
|
182 |
+
|
183 |
+
@property
|
184 |
+
def unk_id(self) -> Optional[Any]:
|
185 |
+
return None
|
186 |
+
|
187 |
+
def __eq__(self, other):
|
188 |
+
if not isinstance(other, PassThroughVocabulary):
|
189 |
+
return False
|
190 |
+
return self._size == other._size and self.eos_id == other.eos_id
|
191 |
+
|
192 |
+
def __str__(self) -> str:
|
193 |
+
return f"PassThroughVocabulary(size={self._size}, eos_id={self.eos_id})"
|
194 |
+
|
195 |
+
|
196 |
+
class UnigramVocabulary(Vocabulary):
|
197 |
+
"""Vocabulary that does table-lookup of unigrams."""
|
198 |
+
|
199 |
+
def __init__(self, unigrams: Sequence[str]):
|
200 |
+
"""UnigramVocabulary constructor.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
unigrams: the collection of in-vocabulary tokens. This collection should
|
204 |
+
not include PAD or UNK, which are automatically assigned ids and managed
|
205 |
+
as possible decode tokens.
|
206 |
+
"""
|
207 |
+
|
208 |
+
super().__init__()
|
209 |
+
unigrams_as_list = list(unigrams)
|
210 |
+
self._unigram_by_id = ["PAD"] + unigrams_as_list + ["UNK"]
|
211 |
+
self._id_by_unigram = {u: i for i, u in enumerate(self._unigram_by_id)}
|
212 |
+
initializer = tf.lookup.KeyValueTensorInitializer(
|
213 |
+
keys=tf.constant(["PAD"] + unigrams_as_list),
|
214 |
+
# One extra value because the leading 0 corresponds to PAD
|
215 |
+
values=tf.constant(range(len(unigrams) + 1), dtype=tf.int64),
|
216 |
+
)
|
217 |
+
self._id_by_unigram_tf = tf.lookup.StaticVocabularyTable(
|
218 |
+
initializer, num_oov_buckets=1
|
219 |
+
)
|
220 |
+
self._unigram_by_id_tf = tf.constant(self._unigram_by_id)
|
221 |
+
|
222 |
+
def _encode(self, s: str) -> Sequence[int]:
|
223 |
+
return [self._id_by_unigram.get(s, self.unk_id)]
|
224 |
+
|
225 |
+
def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
|
226 |
+
tf_ids = self._id_by_unigram_tf.lookup(s)
|
227 |
+
return tf.expand_dims(tf.dtypes.cast(tf_ids, tf.int32), -1)
|
228 |
+
|
229 |
+
def _decode(self, ids: Sequence[int]) -> str:
|
230 |
+
return " ".join(self._unigram_by_id[id] for id in ids)
|
231 |
+
|
232 |
+
def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
|
233 |
+
return self._unigram_by_id_tf[ids[0]]
|
234 |
+
|
235 |
+
@property
|
236 |
+
def _base_vocab_size(self):
|
237 |
+
return len(self._unigram_by_id)
|
238 |
+
|
239 |
+
@property
|
240 |
+
def eos_id(self):
|
241 |
+
return None
|
242 |
+
|
243 |
+
@property
|
244 |
+
def unk_id(self):
|
245 |
+
return self._base_vocab_size - 1
|
246 |
+
|
247 |
+
|
248 |
+
class SentencePieceVocabulary(Vocabulary):
|
249 |
+
"""Wrapper for nlp/sentencepiece encoder.
|
250 |
+
|
251 |
+
Assumes the model was built using flags to reserve ID=0 for padding, ID=1 for
|
252 |
+
EOS, and ID=2 for UNK.
|
253 |
+
|
254 |
+
If using extra ids, you can represent them in string-form as `<extra_id_0>`,
|
255 |
+
`<extra_id_1>`, etc. They will be indexed starting from the end of the
|
256 |
+
vocabulary to match how the masking preprocessors are set up.
|
257 |
+
|
258 |
+
IMPORTANT NOTE: these placeholders only work properly when they are used at
|
259 |
+
word starts (e.g., "I like peanut butter and <extra_id_0> sandwiches." or
|
260 |
+
"I like peanut butter and <extra_id_0>ly sandwiches" are both okay, but
|
261 |
+
"I like peanut butter and jel<extra_id_0> sandwiches" is not.).
|
262 |
+
"""
|
263 |
+
|
264 |
+
@dataclasses.dataclass
|
265 |
+
class _ModelContext:
|
266 |
+
tokenizer: sentencepiece_processor.SentencePieceProcessor
|
267 |
+
sp_model: bytes
|
268 |
+
|
269 |
+
_load_model_lock: ClassVar[threading.Lock] = threading.Lock()
|
270 |
+
|
271 |
+
def __init__(
|
272 |
+
self,
|
273 |
+
sentencepiece_model_file: str,
|
274 |
+
extra_ids: int = 0,
|
275 |
+
normalizer_spec_overrides: Optional[
|
276 |
+
sentencepiece_model_pb2.NormalizerSpec
|
277 |
+
] = None,
|
278 |
+
reverse_extra_ids: bool = False,
|
279 |
+
extra_tokens: Tuple[str] = None,
|
280 |
+
hack_to_t5_start_tokens: bool = False,
|
281 |
+
):
|
282 |
+
"""Create a SentencePieceVocabulary.
|
283 |
+
|
284 |
+
Optionally, specify a number of extra ids to add to the end of the
|
285 |
+
vocabulary for use as sentinels.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
sentencepiece_model_file: path of the sentence piece model.
|
289 |
+
extra_ids: number of extra ids to include.
|
290 |
+
normalizer_spec_overrides: If not None, this proto will be merged into the
|
291 |
+
model's normalizer and denormalizer specs. Thus, any options set on this
|
292 |
+
object will override the values of those options in the loaded model.
|
293 |
+
reverse_extra_ids: if True, extra_ids are numbered in descending order, so
|
294 |
+
the first extra_id has the highest number. This is done for
|
295 |
+
compatibility with span_corruption mask generation in T5.
|
296 |
+
"""
|
297 |
+
self._sentencepiece_model_file = sentencepiece_model_file
|
298 |
+
self._normalizer_spec_overrides = normalizer_spec_overrides
|
299 |
+
self._reverse_extra_ids = reverse_extra_ids
|
300 |
+
self._model: Optional[SentencePieceVocabulary._ModelContext] = None
|
301 |
+
self._extra_tokens = extra_tokens
|
302 |
+
self._hack_to_t5_start_tokens = hack_to_t5_start_tokens
|
303 |
+
super().__init__(extra_ids=extra_ids)
|
304 |
+
|
305 |
+
def __getstate__(self):
|
306 |
+
state = self.__dict__.copy()
|
307 |
+
# Gin config makes a deep copy of the keyword arguments of configurables.
|
308 |
+
# When a SentencePieceVocabulary vocabulary is used as a keyword argument
|
309 |
+
# in a Gin configurable, it must be picklable. We therefore remove
|
310 |
+
# _model; will be initialized lazily as needed.
|
311 |
+
del state["_model"]
|
312 |
+
return state
|
313 |
+
|
314 |
+
def __setstate__(self, state):
|
315 |
+
self.__dict__.update(state)
|
316 |
+
self._model = None
|
317 |
+
|
318 |
+
def load_model(self) -> None:
|
319 |
+
_ = self._model_context()
|
320 |
+
|
321 |
+
def _model_context(
|
322 |
+
self,
|
323 |
+
) -> _ModelContext:
|
324 |
+
"""Loads model if not yet loaded and returns the model context.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
The model context as a tuple of (tokenizer, sp_model).
|
328 |
+
"""
|
329 |
+
if self._model:
|
330 |
+
return self._model
|
331 |
+
|
332 |
+
normalizer_spec_overrides_serialized = (
|
333 |
+
self._normalizer_spec_overrides.SerializeToString(deterministic=True)
|
334 |
+
if self._normalizer_spec_overrides
|
335 |
+
else None
|
336 |
+
)
|
337 |
+
|
338 |
+
self._model = self._load_model(
|
339 |
+
self._sentencepiece_model_file,
|
340 |
+
self._extra_ids,
|
341 |
+
normalizer_spec_overrides_serialized,
|
342 |
+
self._reverse_extra_ids,
|
343 |
+
extra_tokens=self._extra_tokens,
|
344 |
+
hack_to_t5_start_tokens=self._hack_to_t5_start_tokens,
|
345 |
+
)
|
346 |
+
return self._model
|
347 |
+
|
348 |
+
@classmethod
|
349 |
+
@functools.lru_cache(maxsize=None)
|
350 |
+
def _load_model(
|
351 |
+
cls,
|
352 |
+
sentencepiece_model_file: str,
|
353 |
+
extra_ids: int,
|
354 |
+
normalizer_spec_overrides_serialized: Optional[bytes] = None,
|
355 |
+
reverse_extra_ids: bool = True,
|
356 |
+
extra_tokens: Tuple[str] = None,
|
357 |
+
hack_to_t5_start_tokens=False,
|
358 |
+
) -> _ModelContext:
|
359 |
+
"""Load SPM, Python tokenizer, and cache results to the class definition."""
|
360 |
+
# SentencePieceProcessor::LoadFromSerializedProto is not thread-safe.
|
361 |
+
# Without a lock, users may randomly see SIGSEGV on
|
362 |
+
# sentencepiece::ModelInterface::pad_piece when using the vocabulary in
|
363 |
+
# SeqIO preprocessors.
|
364 |
+
with cls._load_model_lock:
|
365 |
+
# Handle cases where SP can't load the file, but gfile can.
|
366 |
+
with tf.io.gfile.GFile(sentencepiece_model_file, "rb") as f:
|
367 |
+
sp_model = f.read()
|
368 |
+
model = sentencepiece_model_pb2.ModelProto.FromString(sp_model)
|
369 |
+
|
370 |
+
if hack_to_t5_start_tokens:
|
371 |
+
# PAD token would still be 0 same as BOS for consistency as previous!
|
372 |
+
unk = model.pieces[0]
|
373 |
+
bos = model.pieces[1]
|
374 |
+
eos = model.pieces[2]
|
375 |
+
model.pieces.remove(unk)
|
376 |
+
model.pieces.remove(bos)
|
377 |
+
model.pieces.remove(eos)
|
378 |
+
model.pieces.insert(0, bos) # BOS is token 0
|
379 |
+
model.pieces.insert(1, eos) # EOS is token 1
|
380 |
+
model.pieces.insert(2, unk) # UNK is token 2
|
381 |
+
|
382 |
+
# Add placeholder strings for extra IDs.
|
383 |
+
if extra_ids:
|
384 |
+
# By default, we them in reverse order to match span corruption.
|
385 |
+
if reverse_extra_ids:
|
386 |
+
extra_id_tokens = reversed(range(extra_ids))
|
387 |
+
else:
|
388 |
+
extra_id_tokens = range(extra_ids)
|
389 |
+
|
390 |
+
for i in extra_id_tokens:
|
391 |
+
model.pieces.add(
|
392 |
+
piece=f"▁<extra_id_{i}>",
|
393 |
+
score=0.0,
|
394 |
+
type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED,
|
395 |
+
)
|
396 |
+
|
397 |
+
if extra_tokens:
|
398 |
+
for s in extra_tokens:
|
399 |
+
model.pieces.add(
|
400 |
+
piece=f"▁"+s,
|
401 |
+
score=0.0,
|
402 |
+
type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED,
|
403 |
+
)
|
404 |
+
|
405 |
+
if normalizer_spec_overrides_serialized is not None:
|
406 |
+
normalizer_spec_overrides = (
|
407 |
+
sentencepiece_model_pb2.NormalizerSpec.FromString(
|
408 |
+
normalizer_spec_overrides_serialized
|
409 |
+
)
|
410 |
+
)
|
411 |
+
|
412 |
+
model.normalizer_spec.MergeFrom(normalizer_spec_overrides)
|
413 |
+
model.denormalizer_spec.MergeFrom(normalizer_spec_overrides)
|
414 |
+
sp_model = model.SerializeToString()
|
415 |
+
# Load Python tokenizer and ensure the EOS and PAD IDs are correct.
|
416 |
+
tokenizer = sentencepiece_processor.SentencePieceProcessor()
|
417 |
+
tokenizer.LoadFromSerializedProto(sp_model)
|
418 |
+
if tokenizer.pad_id() != PAD_ID:
|
419 |
+
logging.warning(
|
420 |
+
(
|
421 |
+
"T5 library uses PAD_ID=%s, which is different from the "
|
422 |
+
"sentencepiece vocabulary, which defines pad_id=%s"
|
423 |
+
),
|
424 |
+
PAD_ID,
|
425 |
+
tokenizer.pad_id(),
|
426 |
+
)
|
427 |
+
|
428 |
+
return cls._ModelContext(tokenizer=tokenizer, sp_model=sp_model)
|
429 |
+
|
430 |
+
@property
|
431 |
+
def num_extra_tokens(self):
|
432 |
+
if self._extra_tokens:
|
433 |
+
return len(self._extra_tokens)
|
434 |
+
return 0
|
435 |
+
|
436 |
+
@property
|
437 |
+
def bos_id(self) -> Optional[int]:
|
438 |
+
return self.tokenizer.bos_id()
|
439 |
+
|
440 |
+
@property
|
441 |
+
def bos_token_id(self) -> Optional[int]:
|
442 |
+
return self.tokenizer.bos_id()
|
443 |
+
|
444 |
+
@property
|
445 |
+
def eos_token_id(self) -> Optional[int]:
|
446 |
+
return self.tokenizer.eos_id()
|
447 |
+
|
448 |
+
@property
|
449 |
+
def eos_id(self) -> Optional[int]:
|
450 |
+
return self.tokenizer.eos_id()
|
451 |
+
|
452 |
+
@property
|
453 |
+
def unk_id(self) -> Optional[int]:
|
454 |
+
return self.tokenizer.unk_id()
|
455 |
+
|
456 |
+
@property
|
457 |
+
def sp_model(self) -> Optional[bytes]:
|
458 |
+
"""Retrieve the SPM."""
|
459 |
+
return self._model_context().sp_model
|
460 |
+
|
461 |
+
@property
|
462 |
+
def sentencepiece_model_file(self) -> str:
|
463 |
+
return self._sentencepiece_model_file
|
464 |
+
|
465 |
+
@property
|
466 |
+
def tokenizer(self) -> sentencepiece_processor.SentencePieceProcessor:
|
467 |
+
"""Returns the Python tokenizer."""
|
468 |
+
return self._model_context().tokenizer
|
469 |
+
|
470 |
+
@property
|
471 |
+
def tf_tokenizer(self):
|
472 |
+
"""Instantiate and return a TF tokenizer."""
|
473 |
+
import tensorflow_text as tf_text # import here to keep the dependency optional
|
474 |
+
return tf_text.SentencepieceTokenizer(model=self.sp_model)
|
475 |
+
|
476 |
+
@property
|
477 |
+
def vocab_size(self):
|
478 |
+
return self._base_vocab_size
|
479 |
+
|
480 |
+
@property
|
481 |
+
def _base_vocab_size(self):
|
482 |
+
"""Number of ids (including 0=PAD, 1=EOS, and 2=UNK).
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
an integer, the vocabulary size
|
486 |
+
"""
|
487 |
+
return self.tokenizer.GetPieceSize()
|
488 |
+
|
489 |
+
def _encode(self, s):
|
490 |
+
"""Encode a python string as a list of integers.
|
491 |
+
|
492 |
+
Args:
|
493 |
+
s: a string
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
a list of integers (not terminated by EOS)
|
497 |
+
"""
|
498 |
+
return self.tokenizer.EncodeAsIds(s)
|
499 |
+
|
500 |
+
def _decode(self, ids):
|
501 |
+
"""Decode a list of integers to a python string.
|
502 |
+
|
503 |
+
Args:
|
504 |
+
ids: a list of integers (not terminated by EOS)
|
505 |
+
|
506 |
+
Returns:
|
507 |
+
a string
|
508 |
+
"""
|
509 |
+
# convert all the extra ids (sentinels) to UNK=2
|
510 |
+
unk_id = self.tokenizer.unk_id()
|
511 |
+
piece_size = self.tokenizer.GetPieceSize()
|
512 |
+
ids = [unk_id if i >= piece_size else int(i) for i in ids]
|
513 |
+
return self.tokenizer.DecodeIds(ids)
|
514 |
+
|
515 |
+
def _encode_tf(self, s):
|
516 |
+
"""Encode a tf.Scalar string to a tf.Tensor.
|
517 |
+
|
518 |
+
This will be necessary for on-the-fly tokenization.
|
519 |
+
|
520 |
+
Args:
|
521 |
+
s: a tf.Scalar with dtype tf.string
|
522 |
+
|
523 |
+
Returns:
|
524 |
+
a 1d tf.Tensor with dtype tf.int32
|
525 |
+
"""
|
526 |
+
return self.tf_tokenizer.tokenize(s)
|
527 |
+
|
528 |
+
def _decode_tf(self, ids):
|
529 |
+
"""Decode in TensorFlow.
|
530 |
+
|
531 |
+
Args:
|
532 |
+
ids: a 1d or 2d tf.Tensor with dtype tf.int32
|
533 |
+
|
534 |
+
Returns:
|
535 |
+
a 1d or 2d tf.Tensor with dtype tf.string
|
536 |
+
"""
|
537 |
+
return self.tf_tokenizer.detokenize(ids)
|
538 |
+
|
539 |
+
def __eq__(self, other):
|
540 |
+
if not isinstance(other, SentencePieceVocabulary):
|
541 |
+
return False
|
542 |
+
try:
|
543 |
+
their_md5 = hashlib.md5(other.sp_model).hexdigest()
|
544 |
+
# If other has no sp_model attribute, we can't test for equality
|
545 |
+
except AttributeError:
|
546 |
+
return False
|
547 |
+
if self.sp_model is None:
|
548 |
+
return False
|
549 |
+
our_md5 = hashlib.md5(self.sp_model).hexdigest()
|
550 |
+
return our_md5 == their_md5
|
551 |
+
|
552 |
+
def __str__(self) -> str:
|
553 |
+
return (
|
554 |
+
f"SentencePieceVocabulary(file={self.sentencepiece_model_file}, "
|
555 |
+
f"extra_ids={self._extra_ids}, "
|
556 |
+
f"spm_md5={hashlib.md5(self.sp_model).hexdigest()})"
|
557 |
+
)
|
558 |
+
|
559 |
+
@property
|
560 |
+
def adds_space(self):
|
561 |
+
return True
|
562 |
+
|
563 |
+
|
564 |
+
class HfTokenizerWrapper:
|
565 |
+
def __init__(self, tokenizer, bos_token_id=None, adds_space=False):
|
566 |
+
"""
|
567 |
+
tokenizer: Tokenizer to wrap
|
568 |
+
bos_token_id: BOS token id to use if not `tokenizer.bos_token_id`
|
569 |
+
adds_space: If concatenating interdependently tokenized pieces of text, will the tokens
|
570 |
+
already including a seerating space?
|
571 |
+
"""
|
572 |
+
self.adds_space = adds_space
|
573 |
+
self.tokenizer = tokenizer
|
574 |
+
if bos_token_id is None:
|
575 |
+
self.bos_token_id = tokenizer.bos_token_id
|
576 |
+
else:
|
577 |
+
self.bos_token_id = bos_token_id
|
578 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
579 |
+
self.pad_id = -1
|
580 |
+
|
581 |
+
def encode(self, x: str):
|
582 |
+
return self.tokenizer.encode(x, add_special_tokens=False)
|
583 |
+
|
584 |
+
def decode(self, x: List[int], truncate_at_eos=True):
|
585 |
+
x = [int(t) for t in x]
|
586 |
+
|
587 |
+
if self.eos_token_id == self.bos_token_id and (len(x) > 0 and x[0] == self.eos_token_id):
|
588 |
+
# Assume an EOS at the start is functioning as BOS
|
589 |
+
x = x[1:]
|
590 |
+
|
591 |
+
if truncate_at_eos:
|
592 |
+
# Follow seqio and automatically cut off at EOS
|
593 |
+
try:
|
594 |
+
eos_ix = x.index(self.eos_token_id)
|
595 |
+
x = x[:eos_ix]
|
596 |
+
except ValueError:
|
597 |
+
pass
|
598 |
+
return self.tokenizer.decode(x, skip_special_tokens=True)
|
599 |
+
|
600 |
+
|
601 |
+
def vocab_size(self):
|
602 |
+
return len(self.tokenizer)
|
603 |
+
|
604 |
+
def encode_tf(self, x):
|
605 |
+
if isinstance(x, str) or len(x.shape) == 0:
|
606 |
+
def _enc(_data):
|
607 |
+
_data = _data.item() if isinstance(_data, np.ndarray) else _data
|
608 |
+
return self.tokenizer.encode(_data.decode("utf-8"), add_special_tokens=False, return_tensors="np")[0].astype(np.int32)
|
609 |
+
return tf.ensure_shape(tf.numpy_function(_enc, [x], tf.int32, stateful=False), [None])
|
610 |
+
|
611 |
+
flattened = tf.reshape(x, [-1])
|
612 |
+
|
613 |
+
def _enc(_data):
|
614 |
+
tokens = [self.tokenizer.encode(x.decode("utf-8"), add_special_tokens=False, return_tensors="np")[0].astype(np.int32)
|
615 |
+
for x in _data]
|
616 |
+
if len(tokens) == 0:
|
617 |
+
return np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32)
|
618 |
+
else:
|
619 |
+
return np.concatenate(tokens, 0), np.array([len(x) for x in tokens]).astype(np.int32)
|
620 |
+
if not (isinstance(x, str) or x.dtype == tf.string):
|
621 |
+
raise ValueError("Input be a string or a string numpy array")
|
622 |
+
text, lens = tf.numpy_function(_enc, [flattened], (tf.int32, tf.int32), stateful=False)
|
623 |
+
lens = tf.ensure_shape(lens, [None])
|
624 |
+
text = tf.ensure_shape(text, [None])
|
625 |
+
if len(x.shape) == 2:
|
626 |
+
n = x.shape[1]
|
627 |
+
assert n is not None
|
628 |
+
return tf.RaggedTensor.from_nested_row_lengths(
|
629 |
+
text,
|
630 |
+
[tf.ones(tf.shape(x)[0], dtype=lens.dtype)*n, lens]
|
631 |
+
)
|
632 |
+
else:
|
633 |
+
return tf.RaggedTensor.from_row_lengths(text, lens)
|
634 |
+
|
635 |
+
|
636 |
+
class OLMoTokenizerWrapper(HfTokenizerWrapper):
|
637 |
+
|
638 |
+
def encode(self, x: str):
|
639 |
+
return self.tokenizer.encode(x, add_special_tokens=False)
|
640 |
+
|
641 |
+
def encode_tf(self, x):
|
642 |
+
if isinstance(x, str) or len(x.shape) == 0:
|
643 |
+
def _enc(_data):
|
644 |
+
return np.asarray(self.tokenizer.encode(_data.numpy().decode("utf-8"), add_special_tokens=False), dtype=np.int32)
|
645 |
+
out = tf.py_function(_enc, (x,), tf.int32)
|
646 |
+
return tf.ensure_shape(out, [None])
|
647 |
+
else:
|
648 |
+
def _enc(_data):
|
649 |
+
tokens = [self.tokenizer.encode(x.decode("utf-8"), add_special_tokens=False)
|
650 |
+
for x in _data.numpy()]
|
651 |
+
if len(tokens) == 0:
|
652 |
+
return np.zeros((0,), dtype=np.int32), np.zeros((0,), dtype=np.int32)
|
653 |
+
else:
|
654 |
+
return np.concatenate(tokens, 0), np.array([len(x) for x in tokens])
|
655 |
+
text, lens = tf.py_function(_enc, (x,), (tf.int32, tf.int32))
|
656 |
+
lens = tf.ensure_shape(lens, [None])
|
657 |
+
text = tf.ensure_shape(text, [None])
|
658 |
+
return tf.RaggedTensor.from_row_lengths(text, lens)
|
659 |
+
|
tasks.py
ADDED
@@ -0,0 +1,2548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Module that can be imported to register all tasks
|
2 |
+
import dataclasses
|
3 |
+
import functools
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from collections import OrderedDict
|
7 |
+
from typing import List, Dict, Any
|
8 |
+
|
9 |
+
import seqio
|
10 |
+
from seqio import dataset_providers
|
11 |
+
import tensorflow_datasets as tfds
|
12 |
+
|
13 |
+
from .data_utils import _strip_metadata, build_tokenizer
|
14 |
+
from .preprocesssors import *
|
15 |
+
from .preprocesssors import _preprocess_scifi
|
16 |
+
|
17 |
+
|
18 |
+
@dataclasses.dataclass
|
19 |
+
class TaskSpec:
|
20 |
+
name: str
|
21 |
+
source: seqio.DataSourceInterface
|
22 |
+
preprocessors: List
|
23 |
+
style: str
|
24 |
+
inference_preprocessors: List = None
|
25 |
+
inference_only: bool = False
|
26 |
+
decode_image: bool = False
|
27 |
+
shuffle_after: Optional[int] = None
|
28 |
+
ignore_errors: bool = False
|
29 |
+
|
30 |
+
|
31 |
+
MULTITASK_TFDS_DATA_DIR = "/weka/oe-training-default/mm-olmo/tensorflow_datasets"
|
32 |
+
|
33 |
+
TASKS: Dict[str, TaskSpec] = {}
|
34 |
+
|
35 |
+
|
36 |
+
def add_task(
|
37 |
+
name,
|
38 |
+
source: seqio.DataSourceInterface,
|
39 |
+
preprocessors: List,
|
40 |
+
style: str,
|
41 |
+
inf_preprocessor=None,
|
42 |
+
inf_only=False,
|
43 |
+
decode_image=False,
|
44 |
+
shuffle_after=None,
|
45 |
+
ignore_errors=False
|
46 |
+
):
|
47 |
+
TASKS[name] = TaskSpec(
|
48 |
+
name, source, preprocessors, style, inf_preprocessor, inf_only, decode_image,
|
49 |
+
shuffle_after, ignore_errors)
|
50 |
+
|
51 |
+
|
52 |
+
@seqio.map_over_dataset
|
53 |
+
def add_image_size(ex):
|
54 |
+
if ex["image"].dtype == tf.string:
|
55 |
+
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
|
56 |
+
img_h = tf.shape(ex["image"])[0]
|
57 |
+
img_w = tf.shape(ex["image"])[1]
|
58 |
+
ex["metadata/image_size"] = [img_w, img_h]
|
59 |
+
|
60 |
+
|
61 |
+
@dataclasses.dataclass
|
62 |
+
class TaskDatasetBuilder:
|
63 |
+
"""tf.data.Dataset builder for task after shuffling, sharding, and initial model pre-processing
|
64 |
+
have been applied"""
|
65 |
+
# This class is a simplified and customized version of seqio.Task
|
66 |
+
#
|
67 |
+
# The main differences are:
|
68 |
+
# 1: Does not prefetch by default, which wastes a small amount of RAM if we are using the
|
69 |
+
# dataset in a mixture which can just have its own top-level prefetch
|
70 |
+
# 2: Reduce threshold for memory caching which is way too high for image datasets by default
|
71 |
+
# 3: Can customize when shuffling occurs to help minimizes RAM usage, in general shuffling
|
72 |
+
# should happen before building image crops and tokenization so the shuffle and
|
73 |
+
# dataset checkpoint take less memory
|
74 |
+
# 4: Don't decoding images until after shuffling for the same reason
|
75 |
+
# 5: Support splitting with tfds.map_split so we never have to fall back to example sharding
|
76 |
+
# not default at the moment since its not well tested
|
77 |
+
# 6: Removes caching/output feature spec stuff from seqio that we don't need
|
78 |
+
|
79 |
+
name: str
|
80 |
+
source: Any
|
81 |
+
preprocessors: List
|
82 |
+
keep_metadata: bool
|
83 |
+
shuffle_after: int
|
84 |
+
sharding: str = "tfds_split"
|
85 |
+
decode_image: bool = False
|
86 |
+
ignore_errors: bool = False
|
87 |
+
|
88 |
+
def get_dataset(
|
89 |
+
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
|
90 |
+
sequence_length: Optional[Mapping[str, int]] = None,
|
91 |
+
split: str = tfds.Split.TRAIN,
|
92 |
+
shuffle: bool = True,
|
93 |
+
shuffle_buffer_size: Optional[int] = 1000,
|
94 |
+
seed: Optional[int] = None,
|
95 |
+
shard_info: Optional[seqio.ShardInfo] = None,
|
96 |
+
num_epochs: Optional[int] = 1,
|
97 |
+
try_in_mem_cache: bool = True,
|
98 |
+
trim_output_features: bool=True
|
99 |
+
) -> tf.data.Dataset:
|
100 |
+
source = self.source
|
101 |
+
|
102 |
+
if self.sharding == "seqio":
|
103 |
+
if source.supports_arbitrary_sharding:
|
104 |
+
shard_data_source = True
|
105 |
+
elif shard_info:
|
106 |
+
# Whether we should shard at source or on the examples from the source.
|
107 |
+
shard_data_source = (
|
108 |
+
len(source.list_shards(split=split)) >= shard_info.num_shards
|
109 |
+
)
|
110 |
+
logging.info(
|
111 |
+
"Sharding at the %s: %d of %d",
|
112 |
+
"data source" if shard_data_source else "examples",
|
113 |
+
shard_info.index + 1,
|
114 |
+
shard_info.num_shards,
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
# Call get_dataset on the source without a shard_info.
|
118 |
+
shard_data_source = True
|
119 |
+
shard_info = None
|
120 |
+
|
121 |
+
if "image" in source.tfds_dataset.info.features:
|
122 |
+
if not self.decode_image:
|
123 |
+
source.tfds_dataset._decoders = dict(image=tfds.decode.SkipDecoding())
|
124 |
+
|
125 |
+
if shard_data_source:
|
126 |
+
ds = source.get_dataset(
|
127 |
+
split=split, shuffle=shuffle, seed=seed, shard_info=shard_info)
|
128 |
+
else:
|
129 |
+
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
|
130 |
+
ds = ds.shard(shard_info.num_shards, shard_info.index)
|
131 |
+
elif self.sharding == "tfds_split":
|
132 |
+
# Shard with `tfds.even_splits`, which is seems to be recommended for mult-host training
|
133 |
+
# https://github.com/tensorflow/datasets/blob/master/docs/splits.md#tfdseven_splits--multi-host-training
|
134 |
+
assert isinstance(self.source, seqio.TfdsDataSource)
|
135 |
+
loader: seqio.LazyTfdsLoader = self.source.tfds_dataset
|
136 |
+
dataset, data_dir = loader.get_split_params(split)
|
137 |
+
shard_split = loader._map_split(split)
|
138 |
+
if shard_info and shard_info.num_shards > 1:
|
139 |
+
shard_split = tfds.even_splits(shard_split, n=shard_info.num_shards, drop_remainder=False)[shard_info.index]
|
140 |
+
else:
|
141 |
+
shard_split = shard_split
|
142 |
+
read_config = loader.read_config
|
143 |
+
read_config.shuffle_seed = seed
|
144 |
+
read_config.skip_prefetch = True
|
145 |
+
read_config.input_context = None
|
146 |
+
# Don't decode images until after shuffling to save RAM
|
147 |
+
if "image" in loader.info.features:
|
148 |
+
decoders = dict(image=tfds.decode.SkipDecoding())
|
149 |
+
else:
|
150 |
+
decoders = None
|
151 |
+
ds = tfds.load(
|
152 |
+
dataset,
|
153 |
+
split=shard_split,
|
154 |
+
data_dir=data_dir,
|
155 |
+
shuffle_files=shuffle,
|
156 |
+
download=True,
|
157 |
+
try_gcs=True,
|
158 |
+
read_config=read_config,
|
159 |
+
decoders=decoders
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
raise NotImplementedError(self.sharding)
|
163 |
+
|
164 |
+
num_shards = shard_info.num_shards if shard_info else 1
|
165 |
+
if try_in_mem_cache and (
|
166 |
+
source.num_input_examples(split)
|
167 |
+
and source.num_input_examples(split)
|
168 |
+
< 10000 * num_shards
|
169 |
+
):
|
170 |
+
logging.info(f"Automatically caching small dataset in memory: {self.name}:{split}")
|
171 |
+
ds = ds.cache()
|
172 |
+
|
173 |
+
# We repeat before calling any (potentially) stochastic
|
174 |
+
# preprocessors in order to take new samples each epoch.
|
175 |
+
if num_epochs != 1:
|
176 |
+
ds = ds.repeat(num_epochs)
|
177 |
+
|
178 |
+
preprocessors = [
|
179 |
+
seqio.add_kwargs_to_transform(
|
180 |
+
_fn,
|
181 |
+
sequence_length=sequence_length,
|
182 |
+
output_features=None,
|
183 |
+
) for _fn in self.preprocessors
|
184 |
+
]
|
185 |
+
|
186 |
+
with seqio.utils.map_seed_manager(seed):
|
187 |
+
for fn in preprocessors[:self.shuffle_after]:
|
188 |
+
ds = fn(ds)
|
189 |
+
|
190 |
+
# Strip metadata before shuffling if possible so its doesn't waste space
|
191 |
+
if not self.keep_metadata:
|
192 |
+
ds = _strip_metadata(ds)
|
193 |
+
|
194 |
+
if shuffle:
|
195 |
+
if shuffle_buffer_size is None:
|
196 |
+
raise ValueError("Shuffle is true, but shuffle_buffer_size is None")
|
197 |
+
else:
|
198 |
+
ds = ds.shuffle(shuffle_buffer_size, seed=seed)
|
199 |
+
|
200 |
+
for fn in preprocessors[self.shuffle_after:]:
|
201 |
+
ds = fn(ds)
|
202 |
+
|
203 |
+
if self.ignore_errors:
|
204 |
+
ds = ds.ignore_errors(log_warning=True)
|
205 |
+
|
206 |
+
if trim_output_features:
|
207 |
+
ds = seqio.trim_dataset(ds, sequence_length, sequence_length)
|
208 |
+
|
209 |
+
return ds
|
210 |
+
|
211 |
+
|
212 |
+
def get_task(preprocessor, name, is_training, for_inference,
|
213 |
+
include_metadata=None, style_override=None) -> TaskDatasetBuilder:
|
214 |
+
"""Get a builder for task `name` that is pre-processed by `preprocessor`"""
|
215 |
+
|
216 |
+
task_spec = TASKS[name]
|
217 |
+
if for_inference is None:
|
218 |
+
for_inference = task_spec.inference_only
|
219 |
+
elif task_spec.inference_only and not for_inference:
|
220 |
+
raise ValueError(f"Inference=only task {task_spec.name} can only be used in inference mode")
|
221 |
+
|
222 |
+
if include_metadata is None:
|
223 |
+
include_metadata = for_inference
|
224 |
+
|
225 |
+
if preprocessor is not None:
|
226 |
+
style = style_override if style_override else task_spec.style
|
227 |
+
preprocessor = preprocessor.get_preprocessor(
|
228 |
+
is_training, for_inference, style, include_metadata)
|
229 |
+
preprocessor = [preprocessor]
|
230 |
+
else:
|
231 |
+
preprocessor = []
|
232 |
+
task_preprocessors = task_spec.preprocessors
|
233 |
+
if for_inference and task_spec.inference_preprocessors is not None:
|
234 |
+
task_preprocessors = task_spec.inference_preprocessors
|
235 |
+
if isinstance(task_spec.source, seqio.TfdsDataSource):
|
236 |
+
from seqio.utils import _TFDS_DATA_DIR_OVERRIDE
|
237 |
+
if _TFDS_DATA_DIR_OVERRIDE:
|
238 |
+
# Stop annoying override warnings flooding the log
|
239 |
+
task_spec.source.tfds_dataset._data_dir = None
|
240 |
+
|
241 |
+
return TaskDatasetBuilder(
|
242 |
+
task_spec.name,
|
243 |
+
task_spec.source,
|
244 |
+
task_preprocessors + preprocessor,
|
245 |
+
keep_metadata=include_metadata,
|
246 |
+
shuffle_after=(task_spec.shuffle_after if task_spec.shuffle_after
|
247 |
+
else len(task_spec.preprocessors)),
|
248 |
+
sharding="seqio",
|
249 |
+
decode_image=task_spec.decode_image,
|
250 |
+
ignore_errors=task_spec.ignore_errors,
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
add_task(
|
255 |
+
"coco_caption_2017",
|
256 |
+
source=seqio.TfdsDataSource(
|
257 |
+
tfds_name="coco_all:1.0.1",
|
258 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
259 |
+
),
|
260 |
+
preprocessors=[
|
261 |
+
functools.partial(rekey, key_map={
|
262 |
+
"image/filename": ["image/filename"],
|
263 |
+
"image": ["image"],
|
264 |
+
"text": ["captions", "text"]
|
265 |
+
}),
|
266 |
+
functools.partial(flatten_parts, parts=["text"]),
|
267 |
+
],
|
268 |
+
inf_preprocessor=[
|
269 |
+
functools.partial(rekey, key_map={
|
270 |
+
"image/filename": ["image/filename"],
|
271 |
+
"image": ["image"],
|
272 |
+
"text": ["captions", "text"]
|
273 |
+
})
|
274 |
+
],
|
275 |
+
style="coco_captioning",
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
add_task(
|
280 |
+
"coco_captioning_karpathy",
|
281 |
+
source=seqio.TfdsDataSource(
|
282 |
+
tfds_name="coco_captioning_karpathy:1.0.2",
|
283 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
284 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
285 |
+
),
|
286 |
+
preprocessors=[
|
287 |
+
rename(text="captions"),
|
288 |
+
functools.partial(flatten_parts, parts=["text"]),
|
289 |
+
],
|
290 |
+
inf_preprocessor=[add_coco_url],
|
291 |
+
style="coco_captioning",
|
292 |
+
)
|
293 |
+
|
294 |
+
|
295 |
+
add_task(
|
296 |
+
"synth_counting",
|
297 |
+
source=seqio.TfdsDataSource(
|
298 |
+
tfds_name="synth_counting:0.0.3",
|
299 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
300 |
+
splits={"train": "train[5120:]", "validation": "train[:5120]"}
|
301 |
+
),
|
302 |
+
preprocessors=[synth_count_preprocessor],
|
303 |
+
inf_preprocessor=[synth_count_inf_preprocessor],
|
304 |
+
style="synth_counting",
|
305 |
+
)
|
306 |
+
|
307 |
+
|
308 |
+
add_task(
|
309 |
+
"khan_academy",
|
310 |
+
source=seqio.TfdsDataSource(
|
311 |
+
tfds_name="khan_academy:1.0.0",
|
312 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
313 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]"}
|
314 |
+
),
|
315 |
+
preprocessors=[extract_khan_academy],
|
316 |
+
style="khan_academy",
|
317 |
+
)
|
318 |
+
|
319 |
+
for name, src in [
|
320 |
+
("vaia_qa_latex_image_math_subset", seqio.TfdsDataSource(
|
321 |
+
tfds_name=f"vaia_qa_latex_image_short_answer:0.1.2",
|
322 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
323 |
+
splits={"train": "train", "validation": "validation"}
|
324 |
+
)),
|
325 |
+
("vaia_qa_latex_image_all", seqio.TfdsDataSource(
|
326 |
+
tfds_name=f"vaia_qa_latex_image_short_answer:0.1.3",
|
327 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
328 |
+
splits={"train": "train", "validation": "validation"}
|
329 |
+
)),
|
330 |
+
]:
|
331 |
+
add_task(
|
332 |
+
f"{name}_short_answer",
|
333 |
+
source=src,
|
334 |
+
preprocessors=[
|
335 |
+
remove_is_long,
|
336 |
+
remove_has_multiple_parts,
|
337 |
+
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
|
338 |
+
],
|
339 |
+
style="vaia_qa",
|
340 |
+
)
|
341 |
+
add_task(
|
342 |
+
f"{name}_short_answer_first",
|
343 |
+
source=src,
|
344 |
+
preprocessors=[
|
345 |
+
remove_is_long,
|
346 |
+
remove_has_multiple_parts,
|
347 |
+
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
|
348 |
+
],
|
349 |
+
style="vaia_qa_short_answer_first",
|
350 |
+
)
|
351 |
+
add_task(
|
352 |
+
f"{name}_mc_only_short_answer",
|
353 |
+
source=src,
|
354 |
+
preprocessors=[
|
355 |
+
remove_is_long,
|
356 |
+
remove_has_multiple_parts,
|
357 |
+
filter_mc,
|
358 |
+
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
|
359 |
+
],
|
360 |
+
style="vaia_qa_short_answer",
|
361 |
+
)
|
362 |
+
add_task(
|
363 |
+
f"{name}_mc_only_short_answer_first",
|
364 |
+
source=src,
|
365 |
+
preprocessors=[
|
366 |
+
remove_is_long,
|
367 |
+
remove_has_multiple_parts,
|
368 |
+
filter_mc,
|
369 |
+
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
|
370 |
+
],
|
371 |
+
style="vaia_qa_short_answer_first",
|
372 |
+
)
|
373 |
+
add_task(
|
374 |
+
f"{name}_image_only_short_answer",
|
375 |
+
source=src,
|
376 |
+
preprocessors=[
|
377 |
+
image_only,
|
378 |
+
remove_is_long,
|
379 |
+
remove_has_multiple_parts,
|
380 |
+
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
|
381 |
+
],
|
382 |
+
style="vaia_qa_short_answer",
|
383 |
+
)
|
384 |
+
add_task(
|
385 |
+
f"{name}_image_only_short_answer_first",
|
386 |
+
source=src,
|
387 |
+
preprocessors=[
|
388 |
+
image_only,
|
389 |
+
remove_is_long,
|
390 |
+
remove_has_multiple_parts,
|
391 |
+
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
|
392 |
+
],
|
393 |
+
style="vaia_qa_short_answer_first",
|
394 |
+
)
|
395 |
+
|
396 |
+
add_task(
|
397 |
+
"vqa_online",
|
398 |
+
source=seqio.TfdsDataSource(
|
399 |
+
tfds_name="vqa_online:1.0.1",
|
400 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
401 |
+
splits={"train": "train", "validation": "validation", "test": "validation"}
|
402 |
+
),
|
403 |
+
preprocessors=[
|
404 |
+
build_question_with_context,
|
405 |
+
extract_vqa_online,
|
406 |
+
],
|
407 |
+
style="vqa_online",
|
408 |
+
)
|
409 |
+
|
410 |
+
add_task(
|
411 |
+
"vqa_online_gpt_longQ_longA",
|
412 |
+
source=seqio.TfdsDataSource(
|
413 |
+
tfds_name="vqa_online_gpt_parsed:1.1.0",
|
414 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
415 |
+
splits={"train": "train", "validation": "validation", "test": "validation"}
|
416 |
+
),
|
417 |
+
preprocessors=[
|
418 |
+
rename(question="question_long", answer="answer_long"),
|
419 |
+
extract_vqa_online,
|
420 |
+
],
|
421 |
+
style="vqa_online",
|
422 |
+
)
|
423 |
+
|
424 |
+
|
425 |
+
add_task(
|
426 |
+
"famous_birthdays",
|
427 |
+
source=seqio.TfdsDataSource(
|
428 |
+
tfds_name="famous_birth_days:1.0.0",
|
429 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
430 |
+
splits={"train": "train[5120:]", "validation": "train[:5120]"}
|
431 |
+
),
|
432 |
+
preprocessors=[
|
433 |
+
famous_birthdays_preprocessor,
|
434 |
+
functools.partial(name_entity_augmentation, p_high_color=0.0),
|
435 |
+
],
|
436 |
+
style="famous_birthdays",
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
add_task(
|
441 |
+
"wiki_art",
|
442 |
+
source=seqio.TfdsDataSource(
|
443 |
+
tfds_name="wiki_art:1.0.0",
|
444 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
445 |
+
splits={"train": "train[5120:]", "validation": "train[:5120]"}
|
446 |
+
),
|
447 |
+
preprocessors=[name_entity_augmentation, wiki_art_preprocessor],
|
448 |
+
style="wiki_art",
|
449 |
+
)
|
450 |
+
|
451 |
+
add_task(
|
452 |
+
"wiki_art_no_aug",
|
453 |
+
source=seqio.TfdsDataSource(
|
454 |
+
tfds_name="wiki_art:1.0.0",
|
455 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
456 |
+
splits={"train": "train[5120:]", "validation": "train[:5120]"}
|
457 |
+
),
|
458 |
+
preprocessors=[wiki_art_preprocessor],
|
459 |
+
style="wiki_art",
|
460 |
+
)
|
461 |
+
|
462 |
+
add_task(
|
463 |
+
"atlas_obscura",
|
464 |
+
source=seqio.TfdsDataSource(
|
465 |
+
tfds_name="atlas_obscura:1.0.0",
|
466 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
467 |
+
splits={"train": "train[5120:]", "validation": "train[:5120]"}
|
468 |
+
),
|
469 |
+
preprocessors=[
|
470 |
+
atlas_obscura_preprocessor,
|
471 |
+
mild_color_aug_preprocessor
|
472 |
+
],
|
473 |
+
style="atlas_obscura",
|
474 |
+
)
|
475 |
+
|
476 |
+
|
477 |
+
add_task(
|
478 |
+
"clocks",
|
479 |
+
source=seqio.TfdsDataSource(
|
480 |
+
tfds_name="clocks:1.0.1",
|
481 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
482 |
+
),
|
483 |
+
preprocessors=[
|
484 |
+
clocks_preprocessor,
|
485 |
+
clock_augmentation
|
486 |
+
],
|
487 |
+
style="clocks",
|
488 |
+
shuffle_after=0
|
489 |
+
)
|
490 |
+
|
491 |
+
|
492 |
+
add_task(
|
493 |
+
"count_bench",
|
494 |
+
source=seqio.TfdsDataSource(
|
495 |
+
tfds_name="count_bench:1.0.0",
|
496 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
497 |
+
),
|
498 |
+
preprocessors=[
|
499 |
+
count_bench_preprocessor,
|
500 |
+
],
|
501 |
+
style="count_bench",
|
502 |
+
)
|
503 |
+
|
504 |
+
|
505 |
+
add_task(
|
506 |
+
"tulu_v2_sft",
|
507 |
+
source=seqio.TfdsDataSource(
|
508 |
+
tfds_name="allenai__tulu_v2_sft_mixture:1.0.0",
|
509 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
510 |
+
),
|
511 |
+
preprocessors=[tulu_preprocessor],
|
512 |
+
style="tulu_v2",
|
513 |
+
)
|
514 |
+
|
515 |
+
|
516 |
+
# Pointing / Point+Count datasets
|
517 |
+
for is_count in [True, False]:
|
518 |
+
if is_count:
|
519 |
+
task = "point_count"
|
520 |
+
else:
|
521 |
+
task = "pointing"
|
522 |
+
add_task(
|
523 |
+
task,
|
524 |
+
source=seqio.TfdsDataSource(
|
525 |
+
tfds_name="pointing:1.0.1",
|
526 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
527 |
+
splits={"train": "train", "validation": "validation"}
|
528 |
+
),
|
529 |
+
preprocessors=[
|
530 |
+
filter_points,
|
531 |
+
functools.partial(pointing_preprocessor, with_count=is_count),
|
532 |
+
split
|
533 |
+
],
|
534 |
+
style=task,
|
535 |
+
)
|
536 |
+
add_task(
|
537 |
+
task + "_eval", # pointing validation set
|
538 |
+
source=seqio.TfdsDataSource(
|
539 |
+
tfds_name="pointing:1.0.2",
|
540 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
541 |
+
),
|
542 |
+
preprocessors=[
|
543 |
+
filter_points,
|
544 |
+
functools.partial(pointing_preprocessor, with_count=is_count),
|
545 |
+
split
|
546 |
+
],
|
547 |
+
style=task,
|
548 |
+
)
|
549 |
+
add_task(
|
550 |
+
task + "_high_freq",
|
551 |
+
source=seqio.TfdsDataSource(
|
552 |
+
tfds_name="count_qa:0.0.2",
|
553 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
554 |
+
splits=dict(
|
555 |
+
train="train[2048:]",
|
556 |
+
validation="train[:2048]"
|
557 |
+
)
|
558 |
+
),
|
559 |
+
preprocessors=[
|
560 |
+
filter_points,
|
561 |
+
fix_count_qa, # Fix a tfrecord bug TODO fix the underlying records
|
562 |
+
functools.partial(pointing_preprocessor, with_count=is_count),
|
563 |
+
split,
|
564 |
+
],
|
565 |
+
style=task,
|
566 |
+
)
|
567 |
+
add_task(
|
568 |
+
"fast_flickr_count_qa_" + task,
|
569 |
+
source=seqio.TfdsDataSource(
|
570 |
+
tfds_name="fast_flickr_count_qa:1.0.4",
|
571 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
572 |
+
),
|
573 |
+
preprocessors=[
|
574 |
+
functools.partial(count_qa_preprocessor, with_count=is_count),
|
575 |
+
],
|
576 |
+
inf_preprocessor=[
|
577 |
+
functools.partial(count_qa_preprocessor, with_count=is_count, for_inference=True),
|
578 |
+
],
|
579 |
+
style=task,
|
580 |
+
)
|
581 |
+
|
582 |
+
|
583 |
+
add_task(
|
584 |
+
"countbench_qa",
|
585 |
+
source=seqio.TfdsDataSource(
|
586 |
+
tfds_name="countbench_qa:1.2.0",
|
587 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
588 |
+
),
|
589 |
+
inf_only=True,
|
590 |
+
preprocessors=[
|
591 |
+
count_qa_preprocessor_inf,
|
592 |
+
],
|
593 |
+
style="point_count",
|
594 |
+
)
|
595 |
+
|
596 |
+
|
597 |
+
add_task(
|
598 |
+
f"pointing_test", # pointing set with segmentation ground truths
|
599 |
+
source=seqio.TfdsDataSource(
|
600 |
+
tfds_name="pointing:1.0.3",
|
601 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
602 |
+
),
|
603 |
+
preprocessors=[
|
604 |
+
pointing_inf_preprocessor
|
605 |
+
],
|
606 |
+
style=task,
|
607 |
+
inf_only=True,
|
608 |
+
)
|
609 |
+
|
610 |
+
|
611 |
+
add_task(
|
612 |
+
"point_qa",
|
613 |
+
source=seqio.TfdsDataSource(
|
614 |
+
tfds_name="point_qa:0.0.5",
|
615 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
616 |
+
splits=dict(
|
617 |
+
train="train[512:]",
|
618 |
+
validation="train[:512]"
|
619 |
+
)
|
620 |
+
),
|
621 |
+
preprocessors=[extract_point_qa, split],
|
622 |
+
style="point_qa",
|
623 |
+
)
|
624 |
+
|
625 |
+
add_task(
|
626 |
+
"clocks_no_aug",
|
627 |
+
source=seqio.TfdsDataSource(
|
628 |
+
tfds_name="clocks:1.0.1",
|
629 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
630 |
+
),
|
631 |
+
preprocessors=[
|
632 |
+
clocks_preprocessor
|
633 |
+
],
|
634 |
+
style="clocks",
|
635 |
+
)
|
636 |
+
|
637 |
+
|
638 |
+
add_task(
|
639 |
+
"clock_bench",
|
640 |
+
source=seqio.TfdsDataSource(
|
641 |
+
tfds_name="clock_bench:1.0.0",
|
642 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
643 |
+
),
|
644 |
+
preprocessors=[
|
645 |
+
clock_bench_preprocessor
|
646 |
+
],
|
647 |
+
inf_only=True,
|
648 |
+
style="clocks",
|
649 |
+
)
|
650 |
+
|
651 |
+
add_task(
|
652 |
+
"wiki_data",
|
653 |
+
source=seqio.TfdsDataSource(
|
654 |
+
tfds_name="cockatoo_wiki:1.0.0",
|
655 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
656 |
+
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
|
657 |
+
),
|
658 |
+
preprocessors=[extract_wiki_data],
|
659 |
+
style="wiki_data",
|
660 |
+
)
|
661 |
+
|
662 |
+
|
663 |
+
add_task(
|
664 |
+
"wiki_data_name",
|
665 |
+
source=seqio.TfdsDataSource(
|
666 |
+
tfds_name="cockatoo_wiki:1.0.0",
|
667 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
668 |
+
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
|
669 |
+
),
|
670 |
+
preprocessors=[
|
671 |
+
extract_wiki_data_name,
|
672 |
+
mild_color_aug_preprocessor
|
673 |
+
],
|
674 |
+
style="wiki_data",
|
675 |
+
)
|
676 |
+
|
677 |
+
add_task(
|
678 |
+
"wiki_data_describe",
|
679 |
+
source=seqio.TfdsDataSource(
|
680 |
+
tfds_name="cockatoo_wiki:1.0.0",
|
681 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
682 |
+
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
|
683 |
+
),
|
684 |
+
preprocessors=[extract_wiki_data_describe],
|
685 |
+
inf_only=True,
|
686 |
+
style="wiki_data",
|
687 |
+
)
|
688 |
+
|
689 |
+
add_task(
|
690 |
+
"wiki_data_describe",
|
691 |
+
source=seqio.TfdsDataSource(
|
692 |
+
tfds_name="cockatoo_wiki:1.0.0",
|
693 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
694 |
+
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
|
695 |
+
),
|
696 |
+
preprocessors=[extract_wiki_data_describe],
|
697 |
+
inf_only=True,
|
698 |
+
style="wiki_data",
|
699 |
+
)
|
700 |
+
|
701 |
+
|
702 |
+
for name, src in [
|
703 |
+
("scifi_charts", seqio.TfdsDataSource(
|
704 |
+
tfds_name="sci_fi_charts:1.0.6",
|
705 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
706 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]"}
|
707 |
+
)),
|
708 |
+
("scifi_table", seqio.TfdsDataSource(
|
709 |
+
tfds_name="sci_fi_table:1.0.3",
|
710 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
711 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]"}
|
712 |
+
)),
|
713 |
+
("scifi_document", seqio.TfdsDataSource(
|
714 |
+
tfds_name="sci_fi_document:1.0.3",
|
715 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
716 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]"}
|
717 |
+
)),
|
718 |
+
("scifi_diagram", seqio.TfdsDataSource(
|
719 |
+
tfds_name="sci_fi_diagram:1.0.0",
|
720 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
721 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]"}
|
722 |
+
)),
|
723 |
+
("scifi_natural", seqio.TfdsDataSource(
|
724 |
+
tfds_name="sci_fi_natural:1.0.1",
|
725 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
726 |
+
splits={"train": "train[128:]", "validation": "train[:128]"}
|
727 |
+
)),
|
728 |
+
("scifi_nutrition", seqio.TfdsDataSource(
|
729 |
+
tfds_name="sci_fi_nutrition:1.0.0",
|
730 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
731 |
+
splits={"train": "train[128:]", "validation": "train[:128]"}
|
732 |
+
))
|
733 |
+
]:
|
734 |
+
add_task(
|
735 |
+
name + "_qa",
|
736 |
+
source=src,
|
737 |
+
preprocessors=[
|
738 |
+
remove_no_qa,
|
739 |
+
_preprocess_scifi,
|
740 |
+
extract_individual_vqa,
|
741 |
+
],
|
742 |
+
inf_preprocessor=[
|
743 |
+
remove_no_qa, _preprocess_scifi,
|
744 |
+
functools.partial(flatten_parts, parts=["question", "answer"]),
|
745 |
+
extract_individual_vqa,
|
746 |
+
],
|
747 |
+
style=name,
|
748 |
+
)
|
749 |
+
add_task(
|
750 |
+
name + "_qa_split",
|
751 |
+
source=src,
|
752 |
+
preprocessors=[
|
753 |
+
remove_no_qa,
|
754 |
+
_preprocess_scifi,
|
755 |
+
extract_individual_vqa,
|
756 |
+
split
|
757 |
+
],
|
758 |
+
inf_preprocessor=[
|
759 |
+
remove_no_qa, _preprocess_scifi,
|
760 |
+
functools.partial(flatten_parts, parts=["question", "answer"]),
|
761 |
+
extract_individual_vqa,
|
762 |
+
],
|
763 |
+
style=name,
|
764 |
+
)
|
765 |
+
add_task(
|
766 |
+
name + "_qa_exp",
|
767 |
+
source=src,
|
768 |
+
preprocessors=[
|
769 |
+
remove_no_qa,
|
770 |
+
_preprocess_scifi,
|
771 |
+
extract_scifi_qa_exp,
|
772 |
+
extract_individual_vqa,
|
773 |
+
],
|
774 |
+
inf_preprocessor=[
|
775 |
+
remove_no_qa, _preprocess_scifi,
|
776 |
+
extract_scifi_qa_exp,
|
777 |
+
functools.partial(flatten_parts, parts=["question", "answer"]),
|
778 |
+
extract_individual_vqa,
|
779 |
+
],
|
780 |
+
style=name + "_qa_exp",
|
781 |
+
)
|
782 |
+
add_task(
|
783 |
+
name + "_qa_exp_split",
|
784 |
+
source=src,
|
785 |
+
preprocessors=[
|
786 |
+
remove_no_qa,
|
787 |
+
_preprocess_scifi,
|
788 |
+
extract_scifi_qa_exp,
|
789 |
+
extract_individual_vqa,
|
790 |
+
split,
|
791 |
+
],
|
792 |
+
inf_preprocessor=[
|
793 |
+
remove_no_qa, _preprocess_scifi,
|
794 |
+
extract_scifi_qa_exp,
|
795 |
+
functools.partial(flatten_parts, parts=["question", "answer"]),
|
796 |
+
extract_individual_vqa,
|
797 |
+
],
|
798 |
+
style=name + "_qa_exp",
|
799 |
+
)
|
800 |
+
add_task(
|
801 |
+
name + "_exp",
|
802 |
+
source=src,
|
803 |
+
preprocessors=[
|
804 |
+
remove_no_qa,
|
805 |
+
_preprocess_scifi,
|
806 |
+
scifi_explanation_only,
|
807 |
+
extract_individual_vqa,
|
808 |
+
split
|
809 |
+
],
|
810 |
+
style=name + "_exp"
|
811 |
+
)
|
812 |
+
add_task(
|
813 |
+
name + "_demo",
|
814 |
+
source=src,
|
815 |
+
preprocessors=[
|
816 |
+
remove_no_qa,
|
817 |
+
_preprocess_scifi,
|
818 |
+
extract_scifi_qa_demo,
|
819 |
+
extract_individual_vqa,
|
820 |
+
split
|
821 |
+
],
|
822 |
+
style="scifi_demo"
|
823 |
+
)
|
824 |
+
|
825 |
+
|
826 |
+
add_task(
|
827 |
+
"chart_qa_scifi",
|
828 |
+
source=seqio.TfdsDataSource(
|
829 |
+
tfds_name="chart_qa:1.0.2",
|
830 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
831 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
832 |
+
),
|
833 |
+
preprocessors=[
|
834 |
+
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
|
835 |
+
extract_individual_vqa,
|
836 |
+
],
|
837 |
+
style="scifi_charts_qa_exp",
|
838 |
+
)
|
839 |
+
|
840 |
+
|
841 |
+
add_task(
|
842 |
+
"chart_qa_prompting",
|
843 |
+
source=seqio.TfdsDataSource(
|
844 |
+
tfds_name="chart_qa:1.0.2",
|
845 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
846 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
847 |
+
),
|
848 |
+
preprocessors=[
|
849 |
+
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
|
850 |
+
chartqa_prompting,
|
851 |
+
extract_individual_vqa,
|
852 |
+
],
|
853 |
+
style="chart_qa",
|
854 |
+
)
|
855 |
+
|
856 |
+
|
857 |
+
add_task(
|
858 |
+
"chart_qa_prompting_explanation",
|
859 |
+
source=seqio.TfdsDataSource(
|
860 |
+
tfds_name="chart_qa:1.0.2",
|
861 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
862 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
863 |
+
),
|
864 |
+
preprocessors=[
|
865 |
+
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
|
866 |
+
chartqa_explanation,
|
867 |
+
extract_individual_vqa,
|
868 |
+
],
|
869 |
+
style="chart_qa",
|
870 |
+
)
|
871 |
+
|
872 |
+
|
873 |
+
|
874 |
+
add_task(
|
875 |
+
"coco_captioning_karpathy_multi",
|
876 |
+
source=seqio.TfdsDataSource(
|
877 |
+
tfds_name="coco_captioning_karpathy:1.0.2",
|
878 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
879 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
880 |
+
),
|
881 |
+
preprocessors=[rename(text="captions")],
|
882 |
+
inf_preprocessor=[add_coco_url],
|
883 |
+
style="coco_captioning",
|
884 |
+
)
|
885 |
+
|
886 |
+
|
887 |
+
add_task(
|
888 |
+
"coco_caption_2017_grouped",
|
889 |
+
source=seqio.TfdsDataSource(
|
890 |
+
tfds_name="coco_all:1.0.1",
|
891 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
892 |
+
),
|
893 |
+
preprocessors=[
|
894 |
+
functools.partial(
|
895 |
+
rekey, key_map={
|
896 |
+
"image/filename": ["image/filename"],
|
897 |
+
"image": ["image"],
|
898 |
+
"text": ["captions", "text"]
|
899 |
+
}),
|
900 |
+
join_captions
|
901 |
+
],
|
902 |
+
style="coco_captioning_multiple",
|
903 |
+
)
|
904 |
+
|
905 |
+
|
906 |
+
add_task(
|
907 |
+
"llava_pretrain",
|
908 |
+
source=seqio.TfdsDataSource(
|
909 |
+
tfds_name="llava_pretrain:1.0.0",
|
910 |
+
tfds_data_dir="gs://mm-olmo-datasets/",
|
911 |
+
splits=dict(
|
912 |
+
train="train[4096:]",
|
913 |
+
validation="train[:4096]"
|
914 |
+
)
|
915 |
+
),
|
916 |
+
preprocessors=[extract_llava],
|
917 |
+
style="web_caption"
|
918 |
+
)
|
919 |
+
|
920 |
+
|
921 |
+
add_task(
|
922 |
+
"rohun_images",
|
923 |
+
source=seqio.TfdsDataSource(
|
924 |
+
tfds_name="rohun_images:1.0.0",
|
925 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
926 |
+
),
|
927 |
+
preprocessors=[],
|
928 |
+
style="long_caption",
|
929 |
+
inf_only=True
|
930 |
+
)
|
931 |
+
|
932 |
+
|
933 |
+
add_task(
|
934 |
+
"dense_caption_eval",
|
935 |
+
source=seqio.TfdsDataSource(
|
936 |
+
tfds_name="dense_captioning_eval:1.0.0",
|
937 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
938 |
+
splits=dict(validation="train")
|
939 |
+
),
|
940 |
+
preprocessors=[],
|
941 |
+
style="long_caption",
|
942 |
+
inf_only=True
|
943 |
+
)
|
944 |
+
|
945 |
+
|
946 |
+
add_task(
|
947 |
+
"dense_caption_eval_dbg",
|
948 |
+
source=seqio.TfdsDataSource(
|
949 |
+
tfds_name="dense_captioning_eval:1.0.0",
|
950 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
951 |
+
splits=dict(validation="train")
|
952 |
+
),
|
953 |
+
preprocessors=[
|
954 |
+
lambda ds: ds.filter(lambda x: x["url"] == "https://explore-multimodal-datasets.s3.us-west-2.amazonaws.com/eval-set/v0/eval-set/a211be07e2c9c722ef75093026a608856bd07ad935ebdedea6f2944b1f2d2b0e.jpg")
|
955 |
+
],
|
956 |
+
style="long_caption",
|
957 |
+
inf_only=True
|
958 |
+
)
|
959 |
+
|
960 |
+
|
961 |
+
add_task(
|
962 |
+
"dense_caption_sample",
|
963 |
+
source=seqio.TfdsDataSource(
|
964 |
+
tfds_name="dense_captioning_eval:1.0.0",
|
965 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
966 |
+
splits=dict(
|
967 |
+
validation="train"
|
968 |
+
)
|
969 |
+
),
|
970 |
+
preprocessors=[select_dense_caption_sample],
|
971 |
+
style="long_caption",
|
972 |
+
)
|
973 |
+
|
974 |
+
|
975 |
+
add_task(
|
976 |
+
"cockatoo_1per_caption_287k",
|
977 |
+
source=seqio.TfdsDataSource(
|
978 |
+
tfds_name="cockatoo_1per_caption_287k:1.0.5",
|
979 |
+
tfds_data_dir="gs://mm-olmo-data/",
|
980 |
+
splits=dict(
|
981 |
+
train="train[5120:]",
|
982 |
+
validation="train[:5120]"
|
983 |
+
)
|
984 |
+
),
|
985 |
+
preprocessors=[
|
986 |
+
rename(text="caption"),
|
987 |
+
],
|
988 |
+
style="long_caption"
|
989 |
+
)
|
990 |
+
|
991 |
+
|
992 |
+
def _filter_large_ratio(ds):
|
993 |
+
return ds.filter(
|
994 |
+
lambda x: tf.shape(x["image"])[0] > tf.shape(x["image"])[1]*2
|
995 |
+
)
|
996 |
+
|
997 |
+
|
998 |
+
add_task(
|
999 |
+
f"cockatoo_dbg",
|
1000 |
+
source= seqio.TfdsDataSource(
|
1001 |
+
tfds_name="cockatoo_476k:1.0.5",
|
1002 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1003 |
+
splits=dict(
|
1004 |
+
train="train[5120:]",
|
1005 |
+
validation="train[:5120]"
|
1006 |
+
)
|
1007 |
+
)
|
1008 |
+
,
|
1009 |
+
preprocessors=[
|
1010 |
+
_filter_large_ratio,
|
1011 |
+
extract_caption_and_transcript
|
1012 |
+
],
|
1013 |
+
style=["long_caption", "transcript"]
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
|
1017 |
+
for name, src in [
|
1018 |
+
("712k_sept6", seqio.TfdsDataSource(
|
1019 |
+
tfds_name="cockatoo_712k_sept6:1.0.5",
|
1020 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1021 |
+
splits=dict(
|
1022 |
+
train="train[5120:]",
|
1023 |
+
validation="train[:5120]"
|
1024 |
+
)
|
1025 |
+
)),
|
1026 |
+
("476k", seqio.TfdsDataSource(
|
1027 |
+
tfds_name="cockatoo_476k:1.0.5",
|
1028 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1029 |
+
splits=dict(
|
1030 |
+
train="train[5120:]",
|
1031 |
+
validation="train[:5120]"
|
1032 |
+
)
|
1033 |
+
)),
|
1034 |
+
("476k_gpt_captions", seqio.TfdsDataSource(
|
1035 |
+
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
|
1036 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1037 |
+
splits=dict(
|
1038 |
+
train="train[5120:]",
|
1039 |
+
validation="train[:5120]"
|
1040 |
+
)
|
1041 |
+
)),
|
1042 |
+
("100k_of_476k_gpt_captions", seqio.TfdsDataSource(
|
1043 |
+
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
|
1044 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1045 |
+
splits=dict(
|
1046 |
+
train="train[5120:105120]",
|
1047 |
+
validation="train[:5120]"
|
1048 |
+
)
|
1049 |
+
)),
|
1050 |
+
("200k_of_476k_gpt_captions", seqio.TfdsDataSource(
|
1051 |
+
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
|
1052 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1053 |
+
splits=dict(
|
1054 |
+
train="train[5120:205120]",
|
1055 |
+
validation="train[:5120]"
|
1056 |
+
)
|
1057 |
+
)),
|
1058 |
+
("300k_of_476k_gpt_captions", seqio.TfdsDataSource(
|
1059 |
+
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
|
1060 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1061 |
+
splits=dict(
|
1062 |
+
train="train[5120:305120]",
|
1063 |
+
validation="train[:5120]"
|
1064 |
+
)
|
1065 |
+
)),
|
1066 |
+
("400k_of_476k_gpt_captions", seqio.TfdsDataSource(
|
1067 |
+
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
|
1068 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1069 |
+
splits=dict(
|
1070 |
+
train="train[5120:405120]",
|
1071 |
+
validation="train[:5120]"
|
1072 |
+
)
|
1073 |
+
)),
|
1074 |
+
("400k_of_476k", seqio.TfdsDataSource(
|
1075 |
+
tfds_name="cockatoo_476k:1.0.5",
|
1076 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1077 |
+
splits=dict(
|
1078 |
+
train="train[5120:405120]",
|
1079 |
+
validation="train[:5120]"
|
1080 |
+
)
|
1081 |
+
)),
|
1082 |
+
("300k_of_476k", seqio.TfdsDataSource(
|
1083 |
+
tfds_name="cockatoo_476k:1.0.5",
|
1084 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1085 |
+
splits=dict(
|
1086 |
+
train="train[5120:305120]",
|
1087 |
+
validation="train[:5120]"
|
1088 |
+
)
|
1089 |
+
)),
|
1090 |
+
("200k_of_476k", seqio.TfdsDataSource(
|
1091 |
+
tfds_name="cockatoo_476k:1.0.5",
|
1092 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1093 |
+
splits=dict(
|
1094 |
+
train="train[5120:205120]",
|
1095 |
+
validation="train[:5120]"
|
1096 |
+
)
|
1097 |
+
)),
|
1098 |
+
("100k_of_476k", seqio.TfdsDataSource(
|
1099 |
+
tfds_name="cockatoo_476k:1.0.5",
|
1100 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1101 |
+
splits=dict(
|
1102 |
+
train="train[5120:105120]",
|
1103 |
+
validation="train[:5120]"
|
1104 |
+
)
|
1105 |
+
)),
|
1106 |
+
("276k", seqio.TfdsDataSource(
|
1107 |
+
tfds_name="cockatoo:1.0.5",
|
1108 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1109 |
+
splits=dict(
|
1110 |
+
train="train[5120:]",
|
1111 |
+
validation="train[:5120]"
|
1112 |
+
)
|
1113 |
+
)),
|
1114 |
+
("180k", seqio.TfdsDataSource(
|
1115 |
+
tfds_name="cockatoo:1.0.3",
|
1116 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1117 |
+
splits=dict(
|
1118 |
+
train="train[4096:]",
|
1119 |
+
validation="train[:4096]"
|
1120 |
+
)
|
1121 |
+
)),
|
1122 |
+
("84k_claude_captions", seqio.TfdsDataSource(
|
1123 |
+
tfds_name="cockatoo_84k_claude_captions:1.0.0",
|
1124 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1125 |
+
splits=dict(
|
1126 |
+
train="train[1000:]",
|
1127 |
+
validation="train[:1000]"
|
1128 |
+
)
|
1129 |
+
)),
|
1130 |
+
]:
|
1131 |
+
add_task(
|
1132 |
+
f"cockatoo_{name}",
|
1133 |
+
source=src,
|
1134 |
+
preprocessors=[extract_caption],
|
1135 |
+
style="long_caption"
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
add_task(
|
1139 |
+
f"cockatoo_and_transcript_{name}",
|
1140 |
+
source=src,
|
1141 |
+
preprocessors=[extract_caption_and_transcript],
|
1142 |
+
style=["long_caption", "transcript"]
|
1143 |
+
)
|
1144 |
+
|
1145 |
+
add_task(
|
1146 |
+
f"cockatoo_and_transcript_stratified_{name}",
|
1147 |
+
source=src,
|
1148 |
+
preprocessors=[
|
1149 |
+
extract_caption_and_transcript,
|
1150 |
+
# put this here to hack seqio into repeating the dataset after
|
1151 |
+
# `extract_caption_and_transcript` which will properly stratify the transcripts
|
1152 |
+
seqio.CacheDatasetPlaceholder(),
|
1153 |
+
],
|
1154 |
+
style=["long_caption", "transcript"]
|
1155 |
+
)
|
1156 |
+
add_task(
|
1157 |
+
f"cockatoo_and_all_transcripts_{name}",
|
1158 |
+
source=src,
|
1159 |
+
preprocessors=[extract_caption_and_all_transcripts],
|
1160 |
+
style=["long_caption", "transcript", "transcript", "transcript"]
|
1161 |
+
)
|
1162 |
+
|
1163 |
+
add_task(
|
1164 |
+
f"cockatoo_all_transcripts_{name}",
|
1165 |
+
source=src,
|
1166 |
+
preprocessors=[extract_all_transcripts],
|
1167 |
+
style="transcript"
|
1168 |
+
)
|
1169 |
+
add_task(
|
1170 |
+
f"cockatoo_transcripts_{name}",
|
1171 |
+
source=src,
|
1172 |
+
preprocessors=[extract_transcript],
|
1173 |
+
style="transcript"
|
1174 |
+
)
|
1175 |
+
|
1176 |
+
|
1177 |
+
TFRECORD_IMAGE_TEXT_FEATURES = {
|
1178 |
+
'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
|
1179 |
+
'text':tf.io.FixedLenFeature(shape=(), dtype=tf.string),
|
1180 |
+
}
|
1181 |
+
|
1182 |
+
|
1183 |
+
add_task(
|
1184 |
+
"laion400m",
|
1185 |
+
source=seqio.TFExampleDataSource(
|
1186 |
+
split_to_filepattern={
|
1187 |
+
"train": os.path.join("gs://unified-io-2-us-east/", "pretrain-datasets", "laion400m", "1.0.0", "laion400m-train*"),
|
1188 |
+
},
|
1189 |
+
feature_description=TFRECORD_IMAGE_TEXT_FEATURES,
|
1190 |
+
),
|
1191 |
+
preprocessors=[
|
1192 |
+
functools.partial(rekey, key_map={
|
1193 |
+
"image": ["image"],
|
1194 |
+
"text": ["text"]
|
1195 |
+
}),
|
1196 |
+
],
|
1197 |
+
style="laion",
|
1198 |
+
)
|
1199 |
+
|
1200 |
+
|
1201 |
+
add_task(
|
1202 |
+
"laion_2B",
|
1203 |
+
source=seqio.TFExampleDataSource(
|
1204 |
+
split_to_filepattern={
|
1205 |
+
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "laion2b_en", "1.0.0", "laion2b_en-train*"),
|
1206 |
+
},
|
1207 |
+
feature_description=TFRECORD_IMAGE_TEXT_FEATURES,
|
1208 |
+
),
|
1209 |
+
preprocessors=[
|
1210 |
+
functools.partial(rekey, key_map={
|
1211 |
+
"image": ["image"],
|
1212 |
+
"text": ["text"]
|
1213 |
+
}),
|
1214 |
+
],
|
1215 |
+
style="laion",
|
1216 |
+
)
|
1217 |
+
|
1218 |
+
|
1219 |
+
add_task(
|
1220 |
+
"region_caption_vg",
|
1221 |
+
source=seqio.TfdsDataSource(
|
1222 |
+
tfds_name="vg:1.0.1",
|
1223 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1224 |
+
),
|
1225 |
+
preprocessors=[region_captions_to_dense],
|
1226 |
+
style="region_captions",
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
|
1230 |
+
add_task(
|
1231 |
+
"pdfa_eng_wds",
|
1232 |
+
source=seqio.TfdsDataSource(
|
1233 |
+
tfds_name="pdfa_eng_wds:1.0.0",
|
1234 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1235 |
+
),
|
1236 |
+
preprocessors=[
|
1237 |
+
functools.partial(max_words, max_words=400),
|
1238 |
+
format_pdfa_eng_wds
|
1239 |
+
],
|
1240 |
+
style="pdfa_eng_wds",
|
1241 |
+
)
|
1242 |
+
|
1243 |
+
|
1244 |
+
add_task(
|
1245 |
+
"idl_words",
|
1246 |
+
source=seqio.TfdsDataSource(
|
1247 |
+
tfds_name="idl_words:1.0.0",
|
1248 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1249 |
+
),
|
1250 |
+
preprocessors=[],
|
1251 |
+
style="idl_words",
|
1252 |
+
)
|
1253 |
+
|
1254 |
+
|
1255 |
+
|
1256 |
+
open_image_v6_keys_to_features = {
|
1257 |
+
'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
|
1258 |
+
'image_id': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
|
1259 |
+
'detection/label':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
|
1260 |
+
'detection/bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
|
1261 |
+
'detection/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
|
1262 |
+
'vrd/sub_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
|
1263 |
+
'vrd/obj_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
|
1264 |
+
'vrd/sub_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
|
1265 |
+
'vrd/obj_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
|
1266 |
+
'vrd/relation': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
|
1267 |
+
'vrd/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
|
1268 |
+
'cap/cap_caption': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
|
1269 |
+
'cap/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
|
1270 |
+
'seg/masks': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
|
1271 |
+
'seg/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
|
1272 |
+
'seg/label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
|
1273 |
+
'seg/bbox': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
|
1274 |
+
}
|
1275 |
+
|
1276 |
+
|
1277 |
+
add_task(
|
1278 |
+
"localized_narratives_v6",
|
1279 |
+
source=seqio.TFExampleDataSource(
|
1280 |
+
split_to_filepattern={
|
1281 |
+
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"),
|
1282 |
+
},
|
1283 |
+
feature_description=open_image_v6_keys_to_features,
|
1284 |
+
),
|
1285 |
+
preprocessors=[extract_localized_narrative],
|
1286 |
+
style="localized_narratives",
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
|
1290 |
+
add_task(
|
1291 |
+
"lvis_objects",
|
1292 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
1293 |
+
source=seqio.TfdsDataSource(
|
1294 |
+
tfds_name="lvis:1.2.0",
|
1295 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1296 |
+
),
|
1297 |
+
preprocessors=[
|
1298 |
+
extract_lvis,
|
1299 |
+
region_captions_to_dense,
|
1300 |
+
],
|
1301 |
+
style="lvis_objects",
|
1302 |
+
)
|
1303 |
+
|
1304 |
+
|
1305 |
+
add_task(
|
1306 |
+
"open_images_with_objects",
|
1307 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
1308 |
+
source=seqio.TFExampleDataSource(
|
1309 |
+
split_to_filepattern={
|
1310 |
+
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"),
|
1311 |
+
},
|
1312 |
+
feature_description=open_image_v6_keys_to_features,
|
1313 |
+
),
|
1314 |
+
preprocessors=[
|
1315 |
+
extract_open_images_boxes,
|
1316 |
+
region_captions_to_dense,
|
1317 |
+
],
|
1318 |
+
style="visual_narratives_with_objects",
|
1319 |
+
)
|
1320 |
+
|
1321 |
+
|
1322 |
+
add_task(
|
1323 |
+
"cockatoo_with_acc_476k_gpt_captions",
|
1324 |
+
source=seqio.TfdsDataSource(
|
1325 |
+
tfds_name="cockatoo_with_acc_476k_gpt_captions:1.0.0",
|
1326 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1327 |
+
splits=dict(
|
1328 |
+
train="train[5120:]",
|
1329 |
+
validation="train[:5120]"
|
1330 |
+
)
|
1331 |
+
),
|
1332 |
+
preprocessors=[accuracy_conditioned_joint],
|
1333 |
+
inf_preprocessor=[functools.partial(accuracy_conditioned_joint, is_eval=True)],
|
1334 |
+
style=None
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
|
1338 |
+
add_task(
|
1339 |
+
"dense_caption_eval_with_acc",
|
1340 |
+
source=seqio.TfdsDataSource(
|
1341 |
+
tfds_name="dense_captioning_eval:1.0.0",
|
1342 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1343 |
+
splits=dict(validation="train")
|
1344 |
+
),
|
1345 |
+
preprocessors=[functools.partial(accuracy_conditioned_joint, is_eval=True)],
|
1346 |
+
style="long_caption",
|
1347 |
+
inf_only=True
|
1348 |
+
)
|
1349 |
+
|
1350 |
+
# ************************
|
1351 |
+
# VQA Datasets
|
1352 |
+
# ************************
|
1353 |
+
|
1354 |
+
add_task(
|
1355 |
+
"science_qa_img",
|
1356 |
+
source=seqio.TfdsDataSource(
|
1357 |
+
tfds_name="science_qa:1.0.0",
|
1358 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1359 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1360 |
+
),
|
1361 |
+
preprocessors=[
|
1362 |
+
image_only,
|
1363 |
+
rename(answer_idx="answer"),
|
1364 |
+
build_question_with_hint,
|
1365 |
+
format_multiple_choice_qa
|
1366 |
+
],
|
1367 |
+
style="science_qa",
|
1368 |
+
)
|
1369 |
+
|
1370 |
+
|
1371 |
+
add_task(
|
1372 |
+
"tabwmp_da",
|
1373 |
+
source=seqio.TfdsDataSource(
|
1374 |
+
tfds_name="tab_mwp:1.0.0",
|
1375 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1376 |
+
splits={"train": "train", "validation": "dev", "test": "test"}
|
1377 |
+
),
|
1378 |
+
preprocessors=[
|
1379 |
+
rename(text="answer")
|
1380 |
+
],
|
1381 |
+
style="tabwmp_da",
|
1382 |
+
)
|
1383 |
+
|
1384 |
+
|
1385 |
+
add_task(
|
1386 |
+
"figure_qa",
|
1387 |
+
source=seqio.TfdsDataSource(
|
1388 |
+
tfds_name="figure_qa:1.0.2",
|
1389 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1390 |
+
splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"}
|
1391 |
+
),
|
1392 |
+
preprocessors=[extract_figureqa, extract_individual_vqa],
|
1393 |
+
style="figure_qa",
|
1394 |
+
)
|
1395 |
+
|
1396 |
+
add_task(
|
1397 |
+
"figure_qa_zero_shot",
|
1398 |
+
source=seqio.TfdsDataSource(
|
1399 |
+
tfds_name="figure_qa:1.0.2",
|
1400 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1401 |
+
splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"}
|
1402 |
+
),
|
1403 |
+
preprocessors=[extract_figureqa, convert_figureqa_answer, extract_individual_vqa],
|
1404 |
+
style="figure_qa",
|
1405 |
+
)
|
1406 |
+
|
1407 |
+
|
1408 |
+
add_task(
|
1409 |
+
"plot_qa",
|
1410 |
+
source=seqio.TfdsDataSource(
|
1411 |
+
tfds_name="plot_qa:1.0.0",
|
1412 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1413 |
+
),
|
1414 |
+
preprocessors=[extract_figureqa, extract_individual_vqa],
|
1415 |
+
inf_preprocessor=[
|
1416 |
+
extract_figureqa,
|
1417 |
+
functools.partial(flatten_parts, parts=["questions", "answer", "question_id"]),
|
1418 |
+
extract_individual_vqa
|
1419 |
+
],
|
1420 |
+
style="plot_qa",
|
1421 |
+
)
|
1422 |
+
|
1423 |
+
|
1424 |
+
add_task(
|
1425 |
+
"ai2_diagram",
|
1426 |
+
source=seqio.TfdsDataSource(
|
1427 |
+
tfds_name="ai2_diagram:1.0.2",
|
1428 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1429 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
|
1430 |
+
),
|
1431 |
+
preprocessors=[
|
1432 |
+
rename(choices="answer_texts", answer_idx="correct_answer"),
|
1433 |
+
format_multiple_choice_qa
|
1434 |
+
],
|
1435 |
+
style="ai2_diagram",
|
1436 |
+
)
|
1437 |
+
|
1438 |
+
|
1439 |
+
add_task(
|
1440 |
+
"ai2_diagram_v2",
|
1441 |
+
source=seqio.TfdsDataSource(
|
1442 |
+
tfds_name="ai2_diagram_v2:1.0.1",
|
1443 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1444 |
+
),
|
1445 |
+
preprocessors=[
|
1446 |
+
rename(choices="answer_texts", answer_idx="correct_answer"),
|
1447 |
+
format_ai2d
|
1448 |
+
],
|
1449 |
+
style="ai2_diagram",
|
1450 |
+
)
|
1451 |
+
|
1452 |
+
|
1453 |
+
add_task(
|
1454 |
+
"ai2_diagram_v2_transparent",
|
1455 |
+
source=seqio.TfdsDataSource(
|
1456 |
+
tfds_name="ai2_diagram_v2_transparent:1.0.5",
|
1457 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1458 |
+
),
|
1459 |
+
preprocessors=[
|
1460 |
+
rename(choices="answer_texts", answer_idx="correct_answer"),
|
1461 |
+
format_ai2d
|
1462 |
+
],
|
1463 |
+
style="ai2_diagram",
|
1464 |
+
)
|
1465 |
+
|
1466 |
+
# ai2_diagram_v2 mixed with addiitonal abc label questions with transparent box.
|
1467 |
+
# Shares the same image split as ai2_diagram_v2.
|
1468 |
+
add_task(
|
1469 |
+
"ai2_diagram_v2_mix_transparent",
|
1470 |
+
source=seqio.TfdsDataSource(
|
1471 |
+
tfds_name="ai2_diagram_v2_mix_transparent:1.0.6",
|
1472 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1473 |
+
splits={
|
1474 |
+
"train": "train_mix",
|
1475 |
+
"validation": "validation_mix",
|
1476 |
+
"test": "test_mix", # test should only use either transparent or opaque
|
1477 |
+
# "test": "test_opaque",
|
1478 |
+
}
|
1479 |
+
),
|
1480 |
+
preprocessors=[
|
1481 |
+
rename(choices="answer_texts", answer_idx="correct_answer"),
|
1482 |
+
format_ai2d
|
1483 |
+
],
|
1484 |
+
style="ai2_diagram",
|
1485 |
+
)
|
1486 |
+
|
1487 |
+
add_task(
|
1488 |
+
"ai2_diagram_v2_mix_transparent_one_style",
|
1489 |
+
source=seqio.TfdsDataSource(
|
1490 |
+
tfds_name="ai2_diagram_v2_mix_transparent:1.0.6",
|
1491 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1492 |
+
splits={
|
1493 |
+
"train": "train_mix",
|
1494 |
+
"validation": "validation_mix",
|
1495 |
+
"test": "test_mix", # test should only use either transparent or opaque
|
1496 |
+
# "test": "test_opaque",
|
1497 |
+
}
|
1498 |
+
),
|
1499 |
+
preprocessors=[
|
1500 |
+
rename(choices="answer_texts", answer_idx="correct_answer"),
|
1501 |
+
functools.partial(format_ai2d, variable_style=False),
|
1502 |
+
],
|
1503 |
+
style="ai2_diagram",
|
1504 |
+
)
|
1505 |
+
|
1506 |
+
|
1507 |
+
for src, test_sets in [
|
1508 |
+
["refclef_unc", ["testA", "testB", "testC", "testAB", "testBC"]],
|
1509 |
+
["refcoco_unc", ["testA", "testB"]],
|
1510 |
+
["refcocoplus_unc", ["testA", "testB"]],
|
1511 |
+
["refcocog_umd", ["test"]],
|
1512 |
+
]:
|
1513 |
+
if "coco" in src:
|
1514 |
+
add_url = [add_coco_url]
|
1515 |
+
else:
|
1516 |
+
add_url = []
|
1517 |
+
splits = {x: x for x in test_sets}
|
1518 |
+
splits.update({"train": "train", "validation": "val"})
|
1519 |
+
add_task(
|
1520 |
+
src,
|
1521 |
+
source=seqio.TfdsDataSource(
|
1522 |
+
tfds_name=f"{src}:1.0.2",
|
1523 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1524 |
+
splits=splits
|
1525 |
+
),
|
1526 |
+
preprocessors=[refexp],
|
1527 |
+
inf_preprocessor=add_url + [
|
1528 |
+
refexp_inf,
|
1529 |
+
# Flatten objects
|
1530 |
+
functools.partial(flatten_parts, parts=["refexp", "metadata/bbox"]),
|
1531 |
+
# Flatten expressions
|
1532 |
+
functools.partial(flatten_parts, parts=["refexp"])
|
1533 |
+
],
|
1534 |
+
style="refexp",
|
1535 |
+
decode_image=True,
|
1536 |
+
)
|
1537 |
+
add_task(
|
1538 |
+
src + "_pointing",
|
1539 |
+
source=seqio.TfdsDataSource(
|
1540 |
+
tfds_name=f"{src}:1.0.2",
|
1541 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1542 |
+
splits=splits
|
1543 |
+
),
|
1544 |
+
preprocessors=[refexp_pointing],
|
1545 |
+
inf_preprocessor=add_url + [
|
1546 |
+
refexp_pointing_inf,
|
1547 |
+
functools.partial(flatten_parts, parts=["refexp", "metadata/bbox", "metadata/mask", "metadata/answer"]),
|
1548 |
+
functools.partial(flatten_parts, parts=["refexp"])
|
1549 |
+
],
|
1550 |
+
decode_image=True,
|
1551 |
+
style="refexp_pointing",
|
1552 |
+
)
|
1553 |
+
|
1554 |
+
|
1555 |
+
# FIXME
|
1556 |
+
add_task(
|
1557 |
+
"ai2_diagram_test",
|
1558 |
+
source=seqio.TfdsDataSource(
|
1559 |
+
tfds_name="ai2_diagram:1.0.2",
|
1560 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1561 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
|
1562 |
+
),
|
1563 |
+
preprocessors=[
|
1564 |
+
rename(choices="answer_texts", answer_idx="correct_answer"),
|
1565 |
+
format_multiple_choice_qa
|
1566 |
+
],
|
1567 |
+
style="ai2_diagram",
|
1568 |
+
)
|
1569 |
+
|
1570 |
+
|
1571 |
+
add_task(
|
1572 |
+
"gqa",
|
1573 |
+
source=seqio.TfdsDataSource(
|
1574 |
+
tfds_name="gqa:1.0.1",
|
1575 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1576 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1577 |
+
),
|
1578 |
+
preprocessors=[
|
1579 |
+
functools.partial(format_gqa, is_balanced=True),
|
1580 |
+
extract_individual_vqa,
|
1581 |
+
],
|
1582 |
+
inf_preprocessor=[
|
1583 |
+
functools.partial(format_gqa, is_balanced=True),
|
1584 |
+
extract_individual_vqa,
|
1585 |
+
],
|
1586 |
+
style="gqa",
|
1587 |
+
)
|
1588 |
+
|
1589 |
+
|
1590 |
+
add_task(
|
1591 |
+
"gqa_multi",
|
1592 |
+
source=seqio.TfdsDataSource(
|
1593 |
+
tfds_name="gqa:1.0.1",
|
1594 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1595 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1596 |
+
),
|
1597 |
+
preprocessors=[
|
1598 |
+
functools.partial(format_gqa, is_balanced=True, flatten=False),
|
1599 |
+
extract_individual_vqa,
|
1600 |
+
],
|
1601 |
+
inf_preprocessor=[
|
1602 |
+
functools.partial(format_gqa, is_balanced=True, flatten=False),
|
1603 |
+
extract_individual_vqa,
|
1604 |
+
],
|
1605 |
+
style="gqa",
|
1606 |
+
)
|
1607 |
+
|
1608 |
+
|
1609 |
+
add_task(
|
1610 |
+
"text_vqa",
|
1611 |
+
source=seqio.TfdsDataSource(
|
1612 |
+
tfds_name="text_vqa:1.0.3",
|
1613 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1614 |
+
),
|
1615 |
+
preprocessors=[
|
1616 |
+
functools.partial(
|
1617 |
+
rekey, key_map={
|
1618 |
+
"image": ["image"],
|
1619 |
+
"questions": ["question"],
|
1620 |
+
"answers": ["answers"],
|
1621 |
+
"id": ["question_id"]
|
1622 |
+
}),
|
1623 |
+
extract_individual_vqa,
|
1624 |
+
],
|
1625 |
+
style="text_vqa",
|
1626 |
+
)
|
1627 |
+
|
1628 |
+
|
1629 |
+
add_task(
|
1630 |
+
"okvqa",
|
1631 |
+
source=seqio.TfdsDataSource(
|
1632 |
+
tfds_name="ok_vqa:1.0.2",
|
1633 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1634 |
+
),
|
1635 |
+
preprocessors=[
|
1636 |
+
rename(example_id="question_id"),
|
1637 |
+
add_coco_url,
|
1638 |
+
extract_individual_vqa,
|
1639 |
+
],
|
1640 |
+
style="okvqa",
|
1641 |
+
)
|
1642 |
+
|
1643 |
+
add_task(
|
1644 |
+
"a_okvqa_da",
|
1645 |
+
source=seqio.TfdsDataSource(
|
1646 |
+
tfds_name="a_ok_vqa:1.0.2",
|
1647 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1648 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1649 |
+
),
|
1650 |
+
preprocessors=[
|
1651 |
+
rename(**{
|
1652 |
+
"example_id": "question_id",
|
1653 |
+
"answers": "direct_answers",
|
1654 |
+
"metadata/difficult_direct_answer": "difficult_direct_answer"
|
1655 |
+
}),
|
1656 |
+
extract_individual_vqa,
|
1657 |
+
],
|
1658 |
+
inf_preprocessor=[
|
1659 |
+
filter_difficult_direct_answer,
|
1660 |
+
rename(**{
|
1661 |
+
"example_id": "question_id",
|
1662 |
+
"answers": "direct_answers",
|
1663 |
+
"metadata/difficult_direct_answer": "difficult_direct_answer"
|
1664 |
+
}),
|
1665 |
+
add_coco_url,
|
1666 |
+
extract_individual_vqa,
|
1667 |
+
],
|
1668 |
+
style="a_okvqa_da",
|
1669 |
+
)
|
1670 |
+
|
1671 |
+
|
1672 |
+
add_task(
|
1673 |
+
"a_okvqa_mc",
|
1674 |
+
source=seqio.TfdsDataSource(
|
1675 |
+
tfds_name="a_ok_vqa:1.0.2",
|
1676 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1677 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1678 |
+
),
|
1679 |
+
preprocessors=[
|
1680 |
+
rename(**{
|
1681 |
+
"example_id": "question_id",
|
1682 |
+
"metadata/difficult_direct_answer": "difficult_direct_answer",
|
1683 |
+
"answer_idx": "correct_choice_idx"
|
1684 |
+
}),
|
1685 |
+
add_coco_url,
|
1686 |
+
format_multiple_choice_qa,
|
1687 |
+
],
|
1688 |
+
style="a_okvqa_mc",
|
1689 |
+
)
|
1690 |
+
|
1691 |
+
|
1692 |
+
add_task(
|
1693 |
+
"dv_qa",
|
1694 |
+
source=seqio.TfdsDataSource(
|
1695 |
+
tfds_name="dv_qa:1.0.0",
|
1696 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1697 |
+
splits={"train": "train", "validation": "val_easy"}
|
1698 |
+
),
|
1699 |
+
preprocessors=[
|
1700 |
+
extract_figureqa,
|
1701 |
+
extract_individual_vqa,
|
1702 |
+
],
|
1703 |
+
inf_preprocessor=[
|
1704 |
+
extract_figureqa,
|
1705 |
+
flatten_vqa,
|
1706 |
+
extract_individual_vqa
|
1707 |
+
],
|
1708 |
+
style="dv_qa",
|
1709 |
+
)
|
1710 |
+
|
1711 |
+
|
1712 |
+
@seqio.map_over_dataset
|
1713 |
+
def add_image_question_example_id(ex):
|
1714 |
+
key = tf.strings.join([ex["question"], "\n\n", ex["image"]])
|
1715 |
+
ex["metadata/example_id"] = tf.strings.to_hash_bucket(key, 2**30)
|
1716 |
+
return ex
|
1717 |
+
|
1718 |
+
|
1719 |
+
add_task(
|
1720 |
+
"chart_qa",
|
1721 |
+
source=seqio.TfdsDataSource(
|
1722 |
+
tfds_name="chart_qa:1.0.2",
|
1723 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1724 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1725 |
+
),
|
1726 |
+
preprocessors=[
|
1727 |
+
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
|
1728 |
+
add_image_question_example_id,
|
1729 |
+
extract_individual_vqa,
|
1730 |
+
],
|
1731 |
+
style="chart_qa",
|
1732 |
+
)
|
1733 |
+
|
1734 |
+
|
1735 |
+
add_task(
|
1736 |
+
"chart_qa_ex",
|
1737 |
+
source=seqio.TfdsDataSource(
|
1738 |
+
tfds_name="chart_qa:1.0.2",
|
1739 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1740 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1741 |
+
),
|
1742 |
+
preprocessors=[
|
1743 |
+
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
|
1744 |
+
extract_individual_vqa,
|
1745 |
+
],
|
1746 |
+
style="scifi_charts_qa_exp",
|
1747 |
+
)
|
1748 |
+
|
1749 |
+
|
1750 |
+
add_task(
|
1751 |
+
"chart_qa_weighted",
|
1752 |
+
source=seqio.TfdsDataSource(
|
1753 |
+
tfds_name="chart_qa:1.0.2",
|
1754 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1755 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1756 |
+
),
|
1757 |
+
preprocessors=[
|
1758 |
+
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
|
1759 |
+
extract_individual_vqa,
|
1760 |
+
functools.partial(reweight_chartqa, human=2*20901/(20901+7398), aug=2*7398/(20901+7398)),
|
1761 |
+
],
|
1762 |
+
style="chart_qa",
|
1763 |
+
)
|
1764 |
+
|
1765 |
+
|
1766 |
+
add_task(
|
1767 |
+
"chart_qa_human",
|
1768 |
+
source=seqio.TfdsDataSource(
|
1769 |
+
tfds_name="chart_qa:1.0.2",
|
1770 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1771 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1772 |
+
),
|
1773 |
+
preprocessors=[
|
1774 |
+
rename(question="query", answer="label"),
|
1775 |
+
add_image_question_example_id,
|
1776 |
+
filter_human,
|
1777 |
+
extract_individual_vqa,
|
1778 |
+
],
|
1779 |
+
style="chart_qa",
|
1780 |
+
)
|
1781 |
+
|
1782 |
+
|
1783 |
+
add_task(
|
1784 |
+
"chart_qa_aug",
|
1785 |
+
source=seqio.TfdsDataSource(
|
1786 |
+
tfds_name="chart_qa:1.0.2",
|
1787 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1788 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1789 |
+
),
|
1790 |
+
preprocessors=[
|
1791 |
+
rename(question="query", answer="label"),
|
1792 |
+
filter_aug,
|
1793 |
+
extract_individual_vqa,
|
1794 |
+
],
|
1795 |
+
style="chart_qa",
|
1796 |
+
)
|
1797 |
+
|
1798 |
+
|
1799 |
+
add_task(
|
1800 |
+
"doc_qa",
|
1801 |
+
source=seqio.TfdsDataSource(
|
1802 |
+
tfds_name="doc_qa:1.0.1",
|
1803 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1804 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1805 |
+
),
|
1806 |
+
preprocessors=[fix_doqa_url, extract_individual_vqa],
|
1807 |
+
style="doc_qa",
|
1808 |
+
)
|
1809 |
+
|
1810 |
+
|
1811 |
+
add_task(
|
1812 |
+
"ocr_qa",
|
1813 |
+
source=seqio.TfdsDataSource(
|
1814 |
+
tfds_name="ocr_vqa:1.0.0",
|
1815 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1816 |
+
),
|
1817 |
+
preprocessors=[extract_individual_vqa],
|
1818 |
+
inf_preprocessor=[flatten_vqa, extract_individual_vqa],
|
1819 |
+
style="ocr_vqa",
|
1820 |
+
)
|
1821 |
+
|
1822 |
+
|
1823 |
+
add_task(
|
1824 |
+
"st_qa",
|
1825 |
+
source=seqio.TfdsDataSource(
|
1826 |
+
tfds_name="st_vqa:1.0.2",
|
1827 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1828 |
+
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
|
1829 |
+
),
|
1830 |
+
preprocessors=[extract_individual_vqa],
|
1831 |
+
inf_preprocessor=[extract_individual_vqa],
|
1832 |
+
style="st_qa",
|
1833 |
+
)
|
1834 |
+
|
1835 |
+
|
1836 |
+
add_task(
|
1837 |
+
"tally_qa",
|
1838 |
+
source=seqio.TfdsDataSource(
|
1839 |
+
tfds_name="tally_qa:1.0.2",
|
1840 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1841 |
+
splits={"train": "train", "validation": "test"}
|
1842 |
+
),
|
1843 |
+
preprocessors=[
|
1844 |
+
extract_tally_qa,
|
1845 |
+
extract_individual_vqa
|
1846 |
+
],
|
1847 |
+
inf_preprocessor=[
|
1848 |
+
extract_tally_qa,
|
1849 |
+
flatten_vqa,
|
1850 |
+
extract_individual_vqa
|
1851 |
+
],
|
1852 |
+
style="tally_qa",
|
1853 |
+
)
|
1854 |
+
|
1855 |
+
|
1856 |
+
add_task(
|
1857 |
+
"info_qa",
|
1858 |
+
source=seqio.TfdsDataSource(
|
1859 |
+
tfds_name="info_qa:1.0.0",
|
1860 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1861 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1862 |
+
),
|
1863 |
+
preprocessors=[extract_individual_vqa],
|
1864 |
+
style="info_qa",
|
1865 |
+
)
|
1866 |
+
|
1867 |
+
add_task(
|
1868 |
+
"android_control",
|
1869 |
+
source=seqio.TfdsDataSource(
|
1870 |
+
tfds_name="android_control:2.0.0",
|
1871 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1872 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1873 |
+
),
|
1874 |
+
preprocessors=[extract_android_control],
|
1875 |
+
style="android_control",
|
1876 |
+
)
|
1877 |
+
|
1878 |
+
for mode in ["ll", "hl", "hl_ll", "hl_cot"]:
|
1879 |
+
add_task(
|
1880 |
+
f"android_control_{mode}",
|
1881 |
+
source=seqio.TfdsDataSource(
|
1882 |
+
tfds_name="android_control:2.0.0",
|
1883 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1884 |
+
splits={"train": "train", "validation": "val", "test": "test"}
|
1885 |
+
),
|
1886 |
+
preprocessors=[functools.partial(extract_andriod_control_inf, mode=mode)],
|
1887 |
+
style="android_control",
|
1888 |
+
)
|
1889 |
+
|
1890 |
+
|
1891 |
+
map_coco_vqa = functools.partial(rekey, key_map={
|
1892 |
+
"image": ["image"],
|
1893 |
+
"questions": ["vqa", "questions"],
|
1894 |
+
"answers": ["vqa", "answers"],
|
1895 |
+
"id": ["vqa", "id"],
|
1896 |
+
"metadata/image_url": ["metadata/image_url"],
|
1897 |
+
})
|
1898 |
+
|
1899 |
+
|
1900 |
+
add_task(
|
1901 |
+
"coco_2017_vqa",
|
1902 |
+
source=seqio.TfdsDataSource(
|
1903 |
+
tfds_name="coco_all:1.0.1",
|
1904 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1905 |
+
),
|
1906 |
+
preprocessors=[
|
1907 |
+
add_coco_url,
|
1908 |
+
map_coco_vqa,
|
1909 |
+
flatten_vqa,
|
1910 |
+
extract_individual_vqa
|
1911 |
+
],
|
1912 |
+
style="vqa2",
|
1913 |
+
)
|
1914 |
+
|
1915 |
+
|
1916 |
+
add_task(
|
1917 |
+
"cockatoo_qa",
|
1918 |
+
source=seqio.TfdsDataSource(
|
1919 |
+
tfds_name="cockatoo_qa:1.0.0",
|
1920 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1921 |
+
splits=dict(
|
1922 |
+
train="train[5120:]",
|
1923 |
+
validation="train[:5120]"
|
1924 |
+
)
|
1925 |
+
),
|
1926 |
+
preprocessors=[rename(text="answer")],
|
1927 |
+
style=None,
|
1928 |
+
)
|
1929 |
+
|
1930 |
+
|
1931 |
+
add_task(
|
1932 |
+
"synthetic_qa_v3",
|
1933 |
+
source=seqio.TfdsDataSource(
|
1934 |
+
tfds_name="synthetic_qa_v3:0.0.4",
|
1935 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1936 |
+
splits=dict(
|
1937 |
+
train="train[2048:]",
|
1938 |
+
validation="train[:2048]"
|
1939 |
+
)
|
1940 |
+
),
|
1941 |
+
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
|
1942 |
+
style="synthetic_qa",
|
1943 |
+
)
|
1944 |
+
|
1945 |
+
|
1946 |
+
add_task(
|
1947 |
+
"synthetic_qa_v3_style_tag",
|
1948 |
+
source=seqio.TfdsDataSource(
|
1949 |
+
tfds_name="synthetic_qa_v3:0.0.4",
|
1950 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1951 |
+
splits=dict(
|
1952 |
+
train="train[2048:]",
|
1953 |
+
validation="train[:2048]"
|
1954 |
+
)
|
1955 |
+
),
|
1956 |
+
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
|
1957 |
+
style="llm_qa",
|
1958 |
+
)
|
1959 |
+
|
1960 |
+
|
1961 |
+
add_task(
|
1962 |
+
"synthetic_qa_v3_as_user_qa",
|
1963 |
+
source=seqio.TfdsDataSource(
|
1964 |
+
tfds_name="synthetic_qa_v3:0.0.4",
|
1965 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1966 |
+
splits=dict(
|
1967 |
+
train="train[2048:]",
|
1968 |
+
validation="train[:2048]"
|
1969 |
+
)
|
1970 |
+
),
|
1971 |
+
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
|
1972 |
+
style="user_qa",
|
1973 |
+
)
|
1974 |
+
|
1975 |
+
|
1976 |
+
add_task(
|
1977 |
+
"synthetic_qa_v3_multi_turn",
|
1978 |
+
source=seqio.TfdsDataSource(
|
1979 |
+
tfds_name="synthetic_qa_v3:0.0.4",
|
1980 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1981 |
+
splits=dict(
|
1982 |
+
train="train[2048:]",
|
1983 |
+
validation="train[:2048]"
|
1984 |
+
)
|
1985 |
+
),
|
1986 |
+
preprocessors=[extract_cockatoo_qa_v2, filter_single_turn, prefix_how_many_messages],
|
1987 |
+
style="synthetic_qa",
|
1988 |
+
)
|
1989 |
+
|
1990 |
+
|
1991 |
+
NE_SHARDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
|
1992 |
+
|
1993 |
+
for i in NE_SHARDS:
|
1994 |
+
add_task(
|
1995 |
+
f"named_entity{i}",
|
1996 |
+
source=seqio.TfdsDataSource(
|
1997 |
+
tfds_name=f"named_entities_qa_{i}_of_18:1.0.0",
|
1998 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
1999 |
+
splits=dict(
|
2000 |
+
train="train[1024:]",
|
2001 |
+
validation="train[:1024]"
|
2002 |
+
)
|
2003 |
+
),
|
2004 |
+
preprocessors=[filter_named_entity, extract_named_entity, extract_individual_vqa],
|
2005 |
+
inf_preprocessor=[
|
2006 |
+
filter_named_entity,
|
2007 |
+
extract_named_entity,
|
2008 |
+
flatten_vqa,
|
2009 |
+
extract_individual_vqa
|
2010 |
+
],
|
2011 |
+
style="named_entity",
|
2012 |
+
ignore_errors=True
|
2013 |
+
)
|
2014 |
+
|
2015 |
+
|
2016 |
+
add_task(
|
2017 |
+
"user_qa",
|
2018 |
+
source=seqio.TfdsDataSource(
|
2019 |
+
tfds_name="user_qa:0.0.1",
|
2020 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2021 |
+
splits=dict(
|
2022 |
+
train="train[2048:]",
|
2023 |
+
validation="train[:2048]"
|
2024 |
+
)
|
2025 |
+
),
|
2026 |
+
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
|
2027 |
+
style="user_qa",
|
2028 |
+
)
|
2029 |
+
|
2030 |
+
add_task(
|
2031 |
+
"user_questions_for_elo",
|
2032 |
+
source=seqio.TfdsDataSource(
|
2033 |
+
tfds_name="user_questions_for_elo:0.0.3",
|
2034 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2035 |
+
),
|
2036 |
+
preprocessors=[functools.partial(extract_individual_vqa, test=True)],
|
2037 |
+
inf_only=True,
|
2038 |
+
style="demo",
|
2039 |
+
)
|
2040 |
+
|
2041 |
+
|
2042 |
+
def _filter_by_id(ds, prediction_file, max_seq_len):
|
2043 |
+
with open(prediction_file) as f:
|
2044 |
+
predictions = json.load(f)
|
2045 |
+
is_long = []
|
2046 |
+
lens = []
|
2047 |
+
tokenizer = build_tokenizer("hf-Qwen/Qwen2-7B")
|
2048 |
+
for pred in predictions:
|
2049 |
+
n_tokens = len(tokenizer.encode(pred["prediction"]))
|
2050 |
+
lens.append(n_tokens)
|
2051 |
+
if n_tokens >= max_seq_len:
|
2052 |
+
is_long.append(pred["example_id"])
|
2053 |
+
is_long = tf.constant(is_long)
|
2054 |
+
logging.info(f"Filtering for {len(is_long)} ids")
|
2055 |
+
return ds.filter(lambda ex: tf.reduce_any(ex["example_id"] == is_long))
|
2056 |
+
|
2057 |
+
|
2058 |
+
|
2059 |
+
add_task(
|
2060 |
+
"user_questions_for_elo",
|
2061 |
+
source=seqio.TfdsDataSource(
|
2062 |
+
tfds_name="user_questions_for_elo:0.0.3",
|
2063 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2064 |
+
),
|
2065 |
+
preprocessors=[functools.partial(extract_individual_vqa, test=True)],
|
2066 |
+
inf_only=True,
|
2067 |
+
style="demo",
|
2068 |
+
)
|
2069 |
+
|
2070 |
+
|
2071 |
+
add_task(
|
2072 |
+
"user_questions_for_elo_long",
|
2073 |
+
source=seqio.TfdsDataSource(
|
2074 |
+
tfds_name="user_questions_for_elo:0.0.3",
|
2075 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2076 |
+
),
|
2077 |
+
preprocessors=[
|
2078 |
+
functools.partial(_filter_by_id, prediction_file="/weka/oe-training-default/chrisc/cockatoo/models/uber-model-v11/70b-335-30k-3.2-resume8k-noopt/predictions-ck20000-user_questions_for_elo-test/predictions.json", max_seq_len=230),
|
2079 |
+
functools.partial(extract_individual_vqa, test=True)
|
2080 |
+
],
|
2081 |
+
inf_only=True,
|
2082 |
+
style="demo",
|
2083 |
+
)
|
2084 |
+
|
2085 |
+
|
2086 |
+
add_task(
|
2087 |
+
"coco_2014_vqa",
|
2088 |
+
source=seqio.TfdsDataSource(
|
2089 |
+
tfds_name="coco_2014_all:1.0.1",
|
2090 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2091 |
+
),
|
2092 |
+
preprocessors=[
|
2093 |
+
add_coco_url,
|
2094 |
+
map_coco_vqa,
|
2095 |
+
flatten_vqa,
|
2096 |
+
extract_individual_vqa
|
2097 |
+
],
|
2098 |
+
inf_preprocessor=[
|
2099 |
+
add_coco_url,
|
2100 |
+
map_coco_vqa,
|
2101 |
+
flatten_vqa,
|
2102 |
+
extract_individual_vqa
|
2103 |
+
],
|
2104 |
+
style="vqa2",
|
2105 |
+
)
|
2106 |
+
|
2107 |
+
|
2108 |
+
add_task(
|
2109 |
+
"coco_2014_vqa_multi",
|
2110 |
+
source=seqio.TfdsDataSource(
|
2111 |
+
tfds_name="coco_2014_all:1.0.1",
|
2112 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2113 |
+
),
|
2114 |
+
preprocessors=[
|
2115 |
+
add_coco_url,
|
2116 |
+
map_coco_vqa,
|
2117 |
+
extract_individual_vqa
|
2118 |
+
],
|
2119 |
+
inf_preprocessor=[
|
2120 |
+
add_coco_url,
|
2121 |
+
map_coco_vqa,
|
2122 |
+
flatten_vqa,
|
2123 |
+
extract_individual_vqa
|
2124 |
+
],
|
2125 |
+
style="vqa2",
|
2126 |
+
)
|
2127 |
+
|
2128 |
+
|
2129 |
+
add_task(
|
2130 |
+
"coco_2017_vqa_multi",
|
2131 |
+
source=seqio.TfdsDataSource(
|
2132 |
+
tfds_name="coco_all:1.0.1",
|
2133 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2134 |
+
),
|
2135 |
+
preprocessors=[
|
2136 |
+
add_coco_url,
|
2137 |
+
map_coco_vqa,
|
2138 |
+
extract_individual_vqa
|
2139 |
+
],
|
2140 |
+
inf_preprocessor=[
|
2141 |
+
add_coco_url,
|
2142 |
+
map_coco_vqa,
|
2143 |
+
flatten_vqa,
|
2144 |
+
extract_individual_vqa
|
2145 |
+
],
|
2146 |
+
style="vqa2",
|
2147 |
+
)
|
2148 |
+
|
2149 |
+
|
2150 |
+
add_task(
|
2151 |
+
"vqa_v2_test",
|
2152 |
+
source=seqio.TfdsDataSource(
|
2153 |
+
tfds_name="coco_test_all:1.0.1",
|
2154 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2155 |
+
),
|
2156 |
+
preprocessors=[
|
2157 |
+
functools.partial(rekey, key_map={
|
2158 |
+
"image": ["image"],
|
2159 |
+
"questions": ["vqa", "questions"],
|
2160 |
+
"answers": ["vqa", "answers"],
|
2161 |
+
"id": ["vqa", "id"],
|
2162 |
+
}),
|
2163 |
+
flatten_vqa,
|
2164 |
+
functools.partial(extract_individual_vqa, test=True)
|
2165 |
+
],
|
2166 |
+
style="vqa2",
|
2167 |
+
inf_only=True
|
2168 |
+
)
|
2169 |
+
|
2170 |
+
# ************************
|
2171 |
+
# Eval-only Datasets
|
2172 |
+
# ************************
|
2173 |
+
|
2174 |
+
add_task(
|
2175 |
+
"seed_bench_test",
|
2176 |
+
source=seqio.TfdsDataSource(
|
2177 |
+
tfds_name="seed_bench:1.0.0",
|
2178 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2179 |
+
),
|
2180 |
+
preprocessors=[
|
2181 |
+
format_multiple_choice_qa,
|
2182 |
+
],
|
2183 |
+
style="a_okvqa_mc",
|
2184 |
+
inf_only=True
|
2185 |
+
)
|
2186 |
+
|
2187 |
+
|
2188 |
+
add_task(
|
2189 |
+
"pope_test",
|
2190 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2191 |
+
source=seqio.TfdsDataSource(
|
2192 |
+
tfds_name="pope:1.0.0",
|
2193 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2194 |
+
),
|
2195 |
+
preprocessors=[
|
2196 |
+
add_coco_url,
|
2197 |
+
extract_individual_vqa
|
2198 |
+
],
|
2199 |
+
style="vqa2",
|
2200 |
+
inf_only=True
|
2201 |
+
)
|
2202 |
+
|
2203 |
+
|
2204 |
+
MME_SOURCE = seqio.TfdsDataSource(
|
2205 |
+
tfds_name="mme:1.0.0",
|
2206 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2207 |
+
)
|
2208 |
+
|
2209 |
+
|
2210 |
+
add_task(
|
2211 |
+
"mme_test",
|
2212 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2213 |
+
source=MME_SOURCE,
|
2214 |
+
preprocessors=[
|
2215 |
+
functools.partial(flatten_parts, parts=["questions", "answers"]),
|
2216 |
+
rename(question="questions", answer="answers"),
|
2217 |
+
extract_individual_vqa,
|
2218 |
+
],
|
2219 |
+
style="vqa2",
|
2220 |
+
inf_only=True
|
2221 |
+
)
|
2222 |
+
|
2223 |
+
add_task(
|
2224 |
+
"real_world_qa_test",
|
2225 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2226 |
+
source=seqio.TfdsDataSource(
|
2227 |
+
tfds_name="real_world_qa:1.0.0",
|
2228 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2229 |
+
),
|
2230 |
+
preprocessors=[
|
2231 |
+
functools.partial(
|
2232 |
+
format_multiple_style_qa,
|
2233 |
+
types=['multiple_choice', 'short_answer'],
|
2234 |
+
styles=['a_okvqa_mc', 'vqa2'],
|
2235 |
+
default_style="a_okvqa_mc",
|
2236 |
+
),
|
2237 |
+
],
|
2238 |
+
style=None,
|
2239 |
+
inf_only=True
|
2240 |
+
)
|
2241 |
+
|
2242 |
+
add_task(
|
2243 |
+
"real_world_qa_no_instruction",
|
2244 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2245 |
+
source=seqio.TfdsDataSource(
|
2246 |
+
tfds_name="real_world_qa:1.0.0",
|
2247 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2248 |
+
),
|
2249 |
+
preprocessors=[
|
2250 |
+
functools.partial(
|
2251 |
+
functools.partial(format_multiple_style_qa, strip_instruction=True),
|
2252 |
+
types=['multiple_choice', 'short_answer'],
|
2253 |
+
styles=['a_okvqa_mc', 'vqa2'],
|
2254 |
+
default_style="a_okvqa_mc",
|
2255 |
+
),
|
2256 |
+
],
|
2257 |
+
style=None,
|
2258 |
+
inf_only=True
|
2259 |
+
)
|
2260 |
+
|
2261 |
+
add_task(
|
2262 |
+
"real_world_qa_dbg",
|
2263 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2264 |
+
source=seqio.TfdsDataSource(
|
2265 |
+
tfds_name="real_world_qa:1.0.0",
|
2266 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2267 |
+
),
|
2268 |
+
preprocessors=[
|
2269 |
+
functools.partial(
|
2270 |
+
format_multiple_style_qa,
|
2271 |
+
types=['multiple_choice', 'short_answer'],
|
2272 |
+
styles=['user_qa', 'user_qa'],
|
2273 |
+
default_style="user_qa",
|
2274 |
+
),
|
2275 |
+
],
|
2276 |
+
style=None,
|
2277 |
+
inf_only=True
|
2278 |
+
)
|
2279 |
+
|
2280 |
+
|
2281 |
+
add_task(
|
2282 |
+
"mmmu",
|
2283 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2284 |
+
source=seqio.TfdsDataSource(
|
2285 |
+
tfds_name="mmmu:1.0.0",
|
2286 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2287 |
+
splits={"train": "dev"},
|
2288 |
+
),
|
2289 |
+
preprocessors=[
|
2290 |
+
rename(img_type="metadata/img_type"),
|
2291 |
+
functools.partial(
|
2292 |
+
extract_mmmu,
|
2293 |
+
types=['multiple-choice', 'open'],
|
2294 |
+
styles=['a_okvqa_mc', 'vqa2'],
|
2295 |
+
default_style="a_okvqa_mc",
|
2296 |
+
),
|
2297 |
+
],
|
2298 |
+
style=None,
|
2299 |
+
)
|
2300 |
+
|
2301 |
+
|
2302 |
+
add_task(
|
2303 |
+
"mmmu_test",
|
2304 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2305 |
+
source=seqio.TfdsDataSource(
|
2306 |
+
tfds_name="mmmu:1.0.0",
|
2307 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2308 |
+
splits={"validation": "validation", "test": "test"},
|
2309 |
+
),
|
2310 |
+
preprocessors=[
|
2311 |
+
rename(img_type="metadata/img_type"),
|
2312 |
+
extract_mmmu,
|
2313 |
+
],
|
2314 |
+
style=None,
|
2315 |
+
inf_only=True
|
2316 |
+
)
|
2317 |
+
|
2318 |
+
for style in ["vaia_qa", "vaia_qa_short_answer_first", "vqa_online", ]:
|
2319 |
+
add_task(
|
2320 |
+
f"mmmu_test_{style}",
|
2321 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2322 |
+
source=seqio.TfdsDataSource(
|
2323 |
+
tfds_name="mmmu:1.0.0",
|
2324 |
+
# tfds_name="mmmu_khan_academy:1.0.1",
|
2325 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2326 |
+
splits={"validation": "validation", "test": "test", "dev": "dev"},
|
2327 |
+
),
|
2328 |
+
preprocessors=[
|
2329 |
+
rename(img_type="metadata/img_type"),
|
2330 |
+
extract_mmmu_cot,
|
2331 |
+
],
|
2332 |
+
style=style,
|
2333 |
+
inf_only=True
|
2334 |
+
)
|
2335 |
+
|
2336 |
+
|
2337 |
+
add_task(
|
2338 |
+
"math_vista_test",
|
2339 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2340 |
+
source=seqio.TfdsDataSource(
|
2341 |
+
tfds_name="math_vista:1.0.0",
|
2342 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2343 |
+
splits={"validation": "testmini", "test": "test"},
|
2344 |
+
),
|
2345 |
+
preprocessors=[
|
2346 |
+
functools.partial(rekey, key_map={
|
2347 |
+
"id": ["id"],
|
2348 |
+
"query": ["query"],
|
2349 |
+
"image": ["image"],
|
2350 |
+
"choices": ["choices"],
|
2351 |
+
"answer": ["answer"],
|
2352 |
+
"metadata/question_type": ["question_type"],
|
2353 |
+
"metadata/answer_type": ["answer_type"],
|
2354 |
+
"metadata/precision": ["precision"],
|
2355 |
+
"metadata/split": ["metadata/split"],
|
2356 |
+
}),
|
2357 |
+
functools.partial(extract_math_vista, styles=['a_okvqa_mc', 'vqa2']),
|
2358 |
+
],
|
2359 |
+
style=None,
|
2360 |
+
inf_only=True
|
2361 |
+
)
|
2362 |
+
|
2363 |
+
|
2364 |
+
add_task(
|
2365 |
+
"math_vista_v2",
|
2366 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2367 |
+
source=seqio.TfdsDataSource(
|
2368 |
+
tfds_name="math_vista:1.0.0",
|
2369 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2370 |
+
splits={"validation": "testmini", "test": "test"},
|
2371 |
+
),
|
2372 |
+
preprocessors=[
|
2373 |
+
functools.partial(rekey, key_map={
|
2374 |
+
"id": ["id"],
|
2375 |
+
"query": ["query"],
|
2376 |
+
"image": ["image"],
|
2377 |
+
"choices": ["choices"],
|
2378 |
+
"answer": ["answer"],
|
2379 |
+
"metadata/question_type": ["question_type"],
|
2380 |
+
"metadata/answer_type": ["answer_type"],
|
2381 |
+
"metadata/precision": ["precision"],
|
2382 |
+
"metadata/split": ["metadata/split"],
|
2383 |
+
}),
|
2384 |
+
reformat_math_vista,
|
2385 |
+
functools.partial(
|
2386 |
+
extract_math_vista,
|
2387 |
+
styles=['a_okvqa_mc', 'vqa2'],
|
2388 |
+
),
|
2389 |
+
],
|
2390 |
+
style=None,
|
2391 |
+
inf_only=True
|
2392 |
+
)
|
2393 |
+
|
2394 |
+
|
2395 |
+
MM_BENCH_SRC = seqio.TfdsDataSource(
|
2396 |
+
tfds_name="mmbench:1.0.0",
|
2397 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2398 |
+
splits={"validation": "dev", "test": "test"},
|
2399 |
+
)
|
2400 |
+
|
2401 |
+
add_task(
|
2402 |
+
"mmbench_test",
|
2403 |
+
source=MM_BENCH_SRC,
|
2404 |
+
preprocessors=[format_mmbench],
|
2405 |
+
style="a_okvqa_mc",
|
2406 |
+
inf_only=True
|
2407 |
+
)
|
2408 |
+
|
2409 |
+
|
2410 |
+
add_task(
|
2411 |
+
"sugar_crepe_test",
|
2412 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2413 |
+
source=seqio.TfdsDataSource(
|
2414 |
+
tfds_name="sugar_crepe:1.0.0",
|
2415 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2416 |
+
),
|
2417 |
+
preprocessors=[
|
2418 |
+
add_coco_url,
|
2419 |
+
functools.partial(flatten_parts, parts=["choices", "answer_idx", "metadata/answer_type"]),
|
2420 |
+
format_multiple_choice_qa,
|
2421 |
+
],
|
2422 |
+
style="a_okvqa_mc",
|
2423 |
+
inf_only=True
|
2424 |
+
)
|
2425 |
+
|
2426 |
+
|
2427 |
+
add_task(
|
2428 |
+
"blink_test",
|
2429 |
+
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
|
2430 |
+
source=seqio.TfdsDataSource(
|
2431 |
+
tfds_name="blink:1.0.0",
|
2432 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2433 |
+
),
|
2434 |
+
preprocessors=[
|
2435 |
+
functools.partial(rekey, key_map={
|
2436 |
+
"id": ["id"],
|
2437 |
+
"question": ["prompt"],
|
2438 |
+
"image": ["image_concat"],
|
2439 |
+
"choices": ["choices"],
|
2440 |
+
"answer_idx": ["answer_idx"],
|
2441 |
+
"metadata/subtask": ["metadata/subtask"],
|
2442 |
+
"metadata/question": ["question"],
|
2443 |
+
}),
|
2444 |
+
format_multiple_choice_qa,
|
2445 |
+
output_options,
|
2446 |
+
],
|
2447 |
+
style="a_okvqa_mc",
|
2448 |
+
inf_only=True
|
2449 |
+
)
|
2450 |
+
|
2451 |
+
add_task(
|
2452 |
+
"oscarbench_qa",
|
2453 |
+
source=seqio.TfdsDataSource(
|
2454 |
+
tfds_name="oscarbench_qa:1.0.0",
|
2455 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2456 |
+
splits={"validation": "val"}
|
2457 |
+
),
|
2458 |
+
preprocessors=[oscar_preprocessor],
|
2459 |
+
style="oscarbench_qa"
|
2460 |
+
|
2461 |
+
)
|
2462 |
+
|
2463 |
+
add_task(
|
2464 |
+
"charxiv",
|
2465 |
+
source=seqio.TfdsDataSource(
|
2466 |
+
tfds_name="charxiv:1.0.0",
|
2467 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2468 |
+
splits={"validation": "validation", "test": "test"}
|
2469 |
+
),
|
2470 |
+
preprocessors=[charxiv_preprocessor, extract_individual_vqa],
|
2471 |
+
inf_preprocessor=[
|
2472 |
+
charxiv_preprocessor,
|
2473 |
+
functools.partial(flatten_parts, parts=["question", "answer"]),
|
2474 |
+
extract_individual_vqa,
|
2475 |
+
],
|
2476 |
+
style="charxiv",
|
2477 |
+
)
|
2478 |
+
|
2479 |
+
add_task(
|
2480 |
+
"charxiv_descriptive",
|
2481 |
+
source=seqio.TfdsDataSource(
|
2482 |
+
tfds_name="charxiv:1.0.0",
|
2483 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2484 |
+
splits={"validation": "validation", "test": "test"}
|
2485 |
+
),
|
2486 |
+
preprocessors=[charxiv_descriptive_preprocessor, extract_individual_vqa],
|
2487 |
+
inf_preprocessor=[
|
2488 |
+
charxiv_descriptive_preprocessor,
|
2489 |
+
functools.partial(flatten_parts, parts=["question", "answer"]),
|
2490 |
+
extract_individual_vqa,
|
2491 |
+
],
|
2492 |
+
style="charxiv_descriptive",
|
2493 |
+
)
|
2494 |
+
|
2495 |
+
add_task(
|
2496 |
+
"charxiv_reasoning",
|
2497 |
+
source=seqio.TfdsDataSource(
|
2498 |
+
tfds_name="charxiv:1.0.0",
|
2499 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2500 |
+
splits={"validation": "validation", "test": "test"}
|
2501 |
+
),
|
2502 |
+
preprocessors=[charxiv_reasoning_preprocessor, extract_individual_vqa],
|
2503 |
+
style="charxiv_reasoning",
|
2504 |
+
)
|
2505 |
+
|
2506 |
+
for tablevqa_name in ["fintabnetqa", "vwtq", "vwtq_syn"]:
|
2507 |
+
add_task(
|
2508 |
+
tablevqa_name,
|
2509 |
+
source=seqio.TfdsDataSource(
|
2510 |
+
tfds_name=f"{tablevqa_name}:1.0.0",
|
2511 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2512 |
+
splits={"validation": "test[:125]", "test": "test"}
|
2513 |
+
),
|
2514 |
+
preprocessors=[tablevqa_preprocessor, extract_individual_vqa],
|
2515 |
+
style=tablevqa_name,
|
2516 |
+
)
|
2517 |
+
|
2518 |
+
add_task(
|
2519 |
+
"vtabfact",
|
2520 |
+
source=seqio.TfdsDataSource(
|
2521 |
+
tfds_name="vtabfact:1.0.0",
|
2522 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2523 |
+
splits={"validation": "test[:125]", "test": "test"}
|
2524 |
+
),
|
2525 |
+
preprocessors=[vtabfact_preprocessor, extract_individual_vqa],
|
2526 |
+
style="vtabfact",
|
2527 |
+
)
|
2528 |
+
|
2529 |
+
add_task(
|
2530 |
+
"nutrition_fact",
|
2531 |
+
source=seqio.TfdsDataSource(
|
2532 |
+
tfds_name="nutrition_fact:1.0.0",
|
2533 |
+
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
|
2534 |
+
splits={"validation": "test", "test": "test"}
|
2535 |
+
),
|
2536 |
+
preprocessors=[nutrition_fact_preprocessor, extract_individual_vqa],
|
2537 |
+
inf_preprocessor=[
|
2538 |
+
nutrition_fact_preprocessor,
|
2539 |
+
functools.partial(flatten_parts, parts=["question", "answer"]),
|
2540 |
+
extract_individual_vqa,
|
2541 |
+
],
|
2542 |
+
style="nutrition_fact",
|
2543 |
+
inf_only=True
|
2544 |
+
)
|
2545 |
+
|
2546 |
+
for k in ["chart_qa", "info_qa", "doc_qa", "text_vqa", "coco_2014_vqa",
|
2547 |
+
"ai2_diagram_v2_mix_transparent", "chart_qa_human"]:
|
2548 |
+
TASKS[k + "_demo"] = dataclasses.replace(TASKS[k], style="demo")
|
torch_util.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from typing import Optional, TypeVar, List, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
T = TypeVar("T")
|
10 |
+
|
11 |
+
|
12 |
+
log = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def seed_all(seed: int):
|
16 |
+
"""Seed all rng objects."""
|
17 |
+
import random
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
if seed < 0 or seed > 2**32 - 1:
|
22 |
+
raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]")
|
23 |
+
random.seed(seed)
|
24 |
+
np.random.seed(seed)
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
# torch.manual_seed may call manual_seed_all but calling it again here
|
27 |
+
# to make sure it gets called at least once
|
28 |
+
torch.cuda.manual_seed_all(seed)
|
29 |
+
|
30 |
+
|
31 |
+
def is_distributed() -> bool:
|
32 |
+
return dist.is_available() and dist.is_initialized()
|
33 |
+
|
34 |
+
|
35 |
+
def get_node_rank() -> int:
|
36 |
+
return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size())
|
37 |
+
|
38 |
+
|
39 |
+
def get_world_size() -> int:
|
40 |
+
if is_distributed():
|
41 |
+
return dist.get_world_size()
|
42 |
+
else:
|
43 |
+
return 1
|
44 |
+
|
45 |
+
|
46 |
+
def get_local_world_size() -> int:
|
47 |
+
return int(os.environ.get("LOCAL_WORLD_SIZE") or 1)
|
48 |
+
|
49 |
+
|
50 |
+
def get_global_rank() -> int:
|
51 |
+
if is_distributed():
|
52 |
+
return int(os.environ.get("RANK") or dist.get_rank())
|
53 |
+
else:
|
54 |
+
return 0
|
55 |
+
|
56 |
+
|
57 |
+
def get_local_rank() -> int:
|
58 |
+
return int(os.environ.get("LOCAL_RANK") or 0)
|
59 |
+
|
60 |
+
|
61 |
+
def get_fs_local_rank() -> int:
|
62 |
+
"""Get the local rank per filesystem, meaning that, regardless of the number of nodes,
|
63 |
+
if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
|
64 |
+
but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
|
65 |
+
"""
|
66 |
+
if os.environ.get("OLMO_SHARED_FS"):
|
67 |
+
return int(os.environ.get("FS_LOCAL_RANK") or get_global_rank())
|
68 |
+
else:
|
69 |
+
return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
|
70 |
+
|
71 |
+
|
72 |
+
def move_to_device(o: T, device: torch.device) -> T:
|
73 |
+
if isinstance(o, torch.Tensor):
|
74 |
+
return o.to(device) # type: ignore[return-value]
|
75 |
+
elif isinstance(o, dict):
|
76 |
+
return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
|
77 |
+
elif isinstance(o, list):
|
78 |
+
return [move_to_device(x, device) for x in o] # type: ignore[return-value]
|
79 |
+
elif isinstance(o, tuple):
|
80 |
+
return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
|
81 |
+
else:
|
82 |
+
return o
|
83 |
+
|
84 |
+
|
85 |
+
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
|
86 |
+
"""
|
87 |
+
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
|
88 |
+
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
|
89 |
+
"""
|
90 |
+
if check_neg_inf:
|
91 |
+
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
|
92 |
+
if check_pos_inf:
|
93 |
+
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
|
94 |
+
|
95 |
+
|
96 |
+
def get_default_device() -> torch.device:
|
97 |
+
if torch.cuda.is_available() and torch.cuda.is_initialized():
|
98 |
+
return torch.device("cuda")
|
99 |
+
else:
|
100 |
+
return torch.device("cpu")
|
101 |
+
|
102 |
+
|
103 |
+
def barrier() -> None:
|
104 |
+
if is_distributed():
|
105 |
+
dist.barrier()
|
106 |
+
|
107 |
+
|
108 |
+
def peak_gpu_memory(reset: bool = False) -> Optional[float]:
|
109 |
+
"""
|
110 |
+
Get the peak GPU memory usage in MB across all ranks.
|
111 |
+
Only rank 0 will get the final result.
|
112 |
+
"""
|
113 |
+
if not torch.cuda.is_available():
|
114 |
+
return None
|
115 |
+
|
116 |
+
device = torch.device("cuda")
|
117 |
+
peak_mb = torch.cuda.max_memory_allocated(device) / 1000000
|
118 |
+
if is_distributed():
|
119 |
+
peak_mb_tensor = torch.tensor(peak_mb, device=device)
|
120 |
+
dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX)
|
121 |
+
peak_mb = peak_mb_tensor.item()
|
122 |
+
|
123 |
+
if reset:
|
124 |
+
# Reset peak stats.
|
125 |
+
torch.cuda.reset_max_memory_allocated(device)
|
126 |
+
|
127 |
+
return peak_mb
|
128 |
+
|
129 |
+
|
130 |
+
V = TypeVar("V", bool, int, float)
|
131 |
+
|
132 |
+
|
133 |
+
def synchronize_value(value: V, device: torch.device) -> V:
|
134 |
+
if dist.is_available() and dist.is_initialized():
|
135 |
+
value_tensor = torch.tensor(value, device=device)
|
136 |
+
dist.broadcast(value_tensor, 0)
|
137 |
+
return value_tensor.item() # type: ignore
|
138 |
+
else:
|
139 |
+
return value
|
140 |
+
|
141 |
+
|
142 |
+
def synchronize_flag(flag: bool, device: torch.device) -> bool:
|
143 |
+
return synchronize_value(flag, device)
|
144 |
+
|
145 |
+
|
146 |
+
def gc_cuda():
|
147 |
+
gc.collect()
|
148 |
+
if torch.cuda.is_available():
|
149 |
+
torch.cuda.empty_cache()
|
150 |
+
|
151 |
+
|
152 |
+
def listinstr(lst, s, delimiter=None):
|
153 |
+
assert isinstance(lst, list)
|
154 |
+
for item in lst:
|
155 |
+
if delimiter:
|
156 |
+
if all(x in s for x in item.split(delimiter)):
|
157 |
+
return True
|
158 |
+
else:
|
159 |
+
if item in s:
|
160 |
+
return True
|
161 |
+
return False
|
162 |
+
|
163 |
+
|
164 |
+
def freeze_module(module: torch.nn.Module, exclude_params: Optional[List[str]] = None):
|
165 |
+
for name, param in module.named_parameters():
|
166 |
+
if exclude_params is not None and listinstr(exclude_params, name):
|
167 |
+
continue
|
168 |
+
param.requires_grad = False
|
169 |
+
|
170 |
+
|
171 |
+
def freeze_parameters_by_name(model: torch.nn.Module, freeze_names: Tuple[str]):
|
172 |
+
for name in freeze_names:
|
173 |
+
try:
|
174 |
+
module_or_param = model.get_submodule(name)
|
175 |
+
except:
|
176 |
+
try:
|
177 |
+
module_or_param = model.get_parameter(name)
|
178 |
+
except:
|
179 |
+
log.warning(f"Could not find module or parameter with name {name}")
|
180 |
+
if isinstance(module_or_param, torch.nn.Module):
|
181 |
+
freeze_module(module_or_param)
|
182 |
+
else:
|
183 |
+
module_or_param.requires_grad = False
|
util.py
CHANGED
@@ -33,7 +33,7 @@ from .exceptions import (
|
|
33 |
OLMoNetworkError,
|
34 |
OLMoThreadError,
|
35 |
)
|
36 |
-
from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
|
37 |
|
38 |
try:
|
39 |
from functools import cache
|
|
|
33 |
OLMoNetworkError,
|
34 |
OLMoThreadError,
|
35 |
)
|
36 |
+
# from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
|
37 |
|
38 |
try:
|
39 |
from functools import cache
|
utils.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import hashlib
|
3 |
+
import sys
|
4 |
+
import typing
|
5 |
+
import warnings
|
6 |
+
import socket
|
7 |
+
from typing import Optional, Any, Dict
|
8 |
+
import os
|
9 |
+
import logging
|
10 |
+
import absl.flags
|
11 |
+
from flax.traverse_util import flatten_dict
|
12 |
+
|
13 |
+
from ml_collections import ConfigDict, config_flags
|
14 |
+
from ml_collections.config_dict import placeholder
|
15 |
+
from mlxu import function_args_to_config
|
16 |
+
|
17 |
+
_log_extra_fields: Dict[str, Any] = {}
|
18 |
+
|
19 |
+
|
20 |
+
def is_float_printable(x):
|
21 |
+
try:
|
22 |
+
f"{x:0.2f}"
|
23 |
+
return True
|
24 |
+
except (ValueError, TypeError):
|
25 |
+
return False
|
26 |
+
|
27 |
+
|
28 |
+
def compute_hash(string: str) -> str:
|
29 |
+
"""Computes the hash of a string."""
|
30 |
+
return hashlib.sha256(string.encode("utf-8")).hexdigest()
|
31 |
+
|
32 |
+
|
33 |
+
def pop_metadata(data):
|
34 |
+
meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")}
|
35 |
+
return data, meta
|
36 |
+
|
37 |
+
|
38 |
+
def setup_logging():
|
39 |
+
handler: logging.Handler
|
40 |
+
handler = logging.StreamHandler(sys.stdout)
|
41 |
+
formatter = logging.Formatter(
|
42 |
+
"[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s",
|
43 |
+
datefmt="%H:%M:%S"
|
44 |
+
)
|
45 |
+
handler.setFormatter(formatter)
|
46 |
+
logging.basicConfig(handlers=[handler], level=logging.INFO)
|
47 |
+
|
48 |
+
logging.captureWarnings(True)
|
49 |
+
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
50 |
+
|
51 |
+
|
52 |
+
def get_maybe_optional_type(field_type):
|
53 |
+
if type(None) in typing.get_args(field_type):
|
54 |
+
# Handle optional type
|
55 |
+
args = [x for x in typing.get_args(field_type) if x != type(None)]
|
56 |
+
assert len(args) == 1
|
57 |
+
field_type = args[0]
|
58 |
+
return field_type
|
59 |
+
|
60 |
+
|
61 |
+
def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict:
|
62 |
+
"""Build a `ConfigDict` matching the possibly nested dataclass
|
63 |
+
|
64 |
+
dataclass: A dataclass instance or a dataclass type, if an instance defaults
|
65 |
+
will be set to the values in the class, if a class defaults will be
|
66 |
+
set to the field defaults, or None if the field is required
|
67 |
+
defaults_to_none: Make all defaults None
|
68 |
+
"""
|
69 |
+
out = {}
|
70 |
+
fields = dataclasses.fields(dataclass)
|
71 |
+
for field in fields:
|
72 |
+
if not field.init:
|
73 |
+
continue
|
74 |
+
|
75 |
+
if defaults_to_none:
|
76 |
+
default = None
|
77 |
+
elif hasattr(dataclass, field.name):
|
78 |
+
default = getattr(dataclass, field.name)
|
79 |
+
elif field.default is dataclasses.MISSING:
|
80 |
+
default = None
|
81 |
+
else:
|
82 |
+
default = field.default
|
83 |
+
|
84 |
+
field_type = get_maybe_optional_type(field.type)
|
85 |
+
|
86 |
+
if hasattr(field_type, "__dataclass_fields__"):
|
87 |
+
if not defaults_to_none and default is None:
|
88 |
+
pass
|
89 |
+
else:
|
90 |
+
out[field.name] = config_from_dataclass(
|
91 |
+
default or field.type, defaults_to_none=defaults_to_none)
|
92 |
+
else:
|
93 |
+
if default is None:
|
94 |
+
assert not field_type == typing.Any
|
95 |
+
origin = getattr(field_type, "__origin__", None)
|
96 |
+
if origin is not None:
|
97 |
+
field_type = origin
|
98 |
+
out[field.name] = placeholder(field_type)
|
99 |
+
else:
|
100 |
+
out[field.name] = default
|
101 |
+
return ConfigDict(out)
|
102 |
+
|
103 |
+
|
104 |
+
def dataclass_with_none(cls):
|
105 |
+
"""Build an instance of possibly nested dataclass `cls` with all attributes None"""
|
106 |
+
fields = dataclasses.fields(cls)
|
107 |
+
args = {}
|
108 |
+
for field in fields:
|
109 |
+
if not field.init:
|
110 |
+
pass
|
111 |
+
elif dataclasses.is_dataclass(field.type):
|
112 |
+
args[field.name] = dataclass_with_none(field.type)
|
113 |
+
else:
|
114 |
+
args[field.name] = None
|
115 |
+
return cls(**args)
|
116 |
+
|
117 |
+
|
118 |
+
def dataclass_from_config(cls, config: Dict):
|
119 |
+
"""Build an instance of `cls` with attributes from `config``"""
|
120 |
+
fields = dataclasses.fields(cls)
|
121 |
+
args = set(x.name for x in fields)
|
122 |
+
for k in config.keys():
|
123 |
+
if k not in args:
|
124 |
+
raise ValueError(f"Config has unknown arg {k} fr {cls}")
|
125 |
+
args = {}
|
126 |
+
for field in fields:
|
127 |
+
if not field.init:
|
128 |
+
continue
|
129 |
+
|
130 |
+
field_type = get_maybe_optional_type(field.type)
|
131 |
+
if hasattr(field_type, "__dataclass_fields__"):
|
132 |
+
if config.get(field.name) is None:
|
133 |
+
args[field.name] = None
|
134 |
+
elif hasattr(field_type, "from_dict"):
|
135 |
+
src = config[field.name]
|
136 |
+
if isinstance(src, ConfigDict):
|
137 |
+
src = src.to_dict()
|
138 |
+
args[field.name] = field_type.from_dict(src)
|
139 |
+
else:
|
140 |
+
args[field.name] = dataclass_from_config(field_type, config[field.name])
|
141 |
+
elif field.name in config:
|
142 |
+
if isinstance(config[field.name], ConfigDict):
|
143 |
+
args[field.name] = config[field.name].to_dict()
|
144 |
+
else:
|
145 |
+
args[field.name] = config[field.name]
|
146 |
+
return cls(**args)
|
147 |
+
|
148 |
+
|
149 |
+
def update_dataclass(obj, updates):
|
150 |
+
"""Sets attributes in `obj` to match non-None fields in `updates`"""
|
151 |
+
fields = dataclasses.fields(obj)
|
152 |
+
for field in fields:
|
153 |
+
if not field.init:
|
154 |
+
continue
|
155 |
+
update = updates.get(field.name)
|
156 |
+
if update is None:
|
157 |
+
continue
|
158 |
+
current_value = getattr(obj, field.name)
|
159 |
+
if dataclasses.is_dataclass(current_value):
|
160 |
+
update_dataclass(current_value, update)
|
161 |
+
else:
|
162 |
+
if isinstance(update, (ConfigDict, dict)):
|
163 |
+
assert all(x is None for x in flatten_dict(update).values())
|
164 |
+
else:
|
165 |
+
setattr(obj, field.name, update)
|
166 |
+
|
167 |
+
|
168 |
+
def log_metrics_to_console(prefix: str, metrics: Dict[str, float]):
|
169 |
+
# Stolen from the OLMo codebase
|
170 |
+
def format_value(value: float) -> str:
|
171 |
+
if isinstance(value, str):
|
172 |
+
return value
|
173 |
+
if value < 0.0001:
|
174 |
+
return str(value) # scientific notation
|
175 |
+
elif value > 1000:
|
176 |
+
return f"{int(value):,d}"
|
177 |
+
elif value > 100:
|
178 |
+
return f"{value:.1f}"
|
179 |
+
elif value > 10:
|
180 |
+
return f"{value:.2f}"
|
181 |
+
elif value > 1:
|
182 |
+
return f"{value:.3f}"
|
183 |
+
else:
|
184 |
+
return f"{value:.4f}"
|
185 |
+
|
186 |
+
logging.info(
|
187 |
+
f"{prefix}\n"
|
188 |
+
+ "\n".join(
|
189 |
+
[
|
190 |
+
f" {name}={format_value(value)}"
|
191 |
+
for name, value in metrics.items()
|
192 |
+
if not name.startswith("optim/") # there's too many optimizer metrics
|
193 |
+
]
|
194 |
+
)
|
195 |
+
)
|