import hashlib
import json
import math
from functools import reduce
from typing import Mapping, Optional, Sequence
import numpy as np
import tensorflow as tf
import seqio
import gin
from .data_utils import flatten_parts, stateless_permutation, stateless_shuffle
from .. import config
def get_from_dict(data, keys):
"""Iterate nested dictionary"""
return reduce(dict.get, keys, data)
def get_blank_image():
image = tf.zeros([224, 224, 3], dtype=tf.uint8)
image = tf.expand_dims(image, 0)[:1]
return image
@seqio.utils.map_over_dataset
def rekey(x, key_map=None):
"""Replace the feature keys according to the mapping in `key_map`.
For example, if the dataset returns examples of the format:
{'foo': 'something', 'bar': 'something else'}
and key_map = {'boo': 'foo', 'spar': 'bar'} then this function will return
examples with the format
{'boo': 'something', 'spar': 'something else'}
If a mapping is to an empty key or None, set the new key to an empty string.
Args:
x: an example to process.
key_map: dictionary mapping new keys to original keys
Returns:
A preprocessed example with the format listed above.
"""
if key_map:
out = {}
for new_key, old_key in key_map.items():
if isinstance(old_key, list):
out[new_key] = get_from_dict(x, old_key)
else:
out[new_key] = x[old_key]
return out
return x
def rename(**kwargs):
@seqio.map_over_dataset
def _fn(x):
updates = {}
for new_key, old_key in kwargs.items():
if isinstance(old_key, list):
val = x[old_key[0]]
for k in old_key[1:-1]:
val = val[k]
updates[new_key] = val.pop(old_key[-1])
else:
updates[new_key] = x.pop(old_key)
x.update(updates)
return x
return _fn
def extract_transcripts(ds):
ds = flatten_parts(ds, ["transcripts"])
def _map(ex):
return dict(
image=ex["image"],
text=ex["transcripts"],
url=ex["url"]
)
return ds.map(_map)
@seqio.map_over_dataset
def extract_caption_and_all_transcripts(ex):
transcripts = tf.random.shuffle(ex["transcripts"])[:3]
weight = 1.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
return dict(
image=ex["image"],
text=tf.concat([tf.expand_dims(ex["caption"], 0), transcripts], 0),
url=ex["url"],
text_weights=tf.pad(
tf.ones((1,), dtype=tf.float32), [[0, tf.shape(transcripts)[0]]],
constant_values=weight),
)
@seqio.map_over_dataset
def extract_all_transcripts(ex):
transcripts = tf.random.shuffle(ex["transcripts"])[:3]
weight = 3.0 / tf.cast(tf.shape(transcripts)[0], tf.float32)
return dict(
image=ex["image"],
text=transcripts,
url=ex["url"],
text_weights=tf.fill((tf.shape(transcripts)[0],), weight),
)
@seqio.map_over_dataset
def extract_transcript(ex):
transcripts = tf.random.shuffle(ex["transcripts"])
return dict(
image=ex["image"],
text=transcripts[0],
url=ex["url"],
)
@seqio.map_over_dataset
def extract_caption(ex):
caption = ex["caption"]
if len(caption.shape) > 0:
ex["text"] = caption[0]
else:
ex["text"] = caption
return ex
@seqio.map_over_dataset
def extract_joint_captions(ex):
caption = ex["caption"]
if len(caption.shape) > 0:
caption = caption[0]
_ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
_ix = _ix % tf.shape(ex["transcripts"])[0]
return dict(
image=ex["image"],
text=tf.stack([caption, ex["mistral_caption"], ex["transcripts"][_ix]], 0),
url=ex["url"]
)
@seqio.map_over_dataset(num_seeds=1)
def extract_caption_and_transcript(ex, seed):
caption = ex["caption"]
if len(caption.shape) > 0:
caption = caption[0]
_ix = tf.random.stateless_uniform((), seed, 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
return dict(
image=ex["image"],
text=tf.stack([caption, ex["transcripts"][_ix]], 0),
url=ex["url"]
)
@seqio.map_over_dataset
def caption_transcript_augmented(ex, sequence_length):
caption = ex["caption"]
if len(caption.shape) > 0:
caption = caption[0]
image = ex["image"]
properties = []
do_augmentation = sequence_length["is_training"]
# do_augmentation = False
# Keep this off, it screws up OCR
# do_hflip = (tf.random.uniform(()) > 0.2 and do_augmentation)
do_hflip = False
if do_hflip:
image = image[:, ::-1]
# Mild color jitter
do_color = (tf.random.uniform(()) > 0.5 and do_augmentation)
if do_color:
image = tf.image.random_hue(image, max_delta=0.05)
image = tf.image.random_brightness(image, max_delta=0.2)
image = tf.image.random_saturation(image, 0.7, 1.3)
image = tf.image.random_contrast(image, 0.7, 1.3)
# Mild affine transformation
do_affine = (tf.random.uniform(()) > 0.5 and do_augmentation)
if do_affine and do_augmentation:
shift_x = tf.random.uniform((), -10, 10) * 0
shift_y = tf.random.uniform((), -10, 10) * 0
shear_x = tf.random.uniform((), -2, 2)
shear_y = tf.random.uniform((), -2, 2)
rotation = tf.random.uniform((), -6, 6)
max_scale = 1.1
scale = tf.random.uniform((), 0.8, max_scale)
center = tf.cast(tf.shape(image), tf.float32)/2
image = tf.keras.ops.image.affine_transform(
image,
tf.stack(get_affine_matrix(
[center[0], center[1]],
rotation,
[shift_x, shift_y],
1/scale,
[shear_x, shear_y]
) + [0., 0.]),
interpolation='bilinear',
fill_mode='constant',
fill_value=1.,
data_format='channels_last'
)
properties = tf.stack([
("[hflip]" if do_hflip else ""),
("[color]" if do_color else ""),
("[affine]" if do_affine else "")
])
properties = tf.boolean_mask(properties, tf.strings.length(properties) > 0)
prompt = tf.strings.reduce_join(properties, separator=" ")
ix = tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)
out = dict(
image=image,
text=tf.stack([caption, ex["transcripts"][ix]], 0),
url=ex["url"],
prompt=prompt,
)
# out["metadata/unaugmented_image"] = image
return out
def extract_caption_and_transcript_hflip(ds):
# Just in case they are ordered somehow in Matt's data
@seqio.map_over_dataset
def _shuffle_transcripts(_ex):
_ex["transcripts"] = tf.random.shuffle(_ex["transcripts"])
_ex["hflip"] = tf.random.uniform((), 0, 3, dtype=tf.int32)
return _ex
ds = _shuffle_transcripts(ds)
# Build a 3x long dataset with each individual transcript so we iterate through
# each transcript
@seqio.map_over_dataset
def _with_transcript(ex, _ix):
caption = ex["caption"]
if len(caption.shape) > 0:
caption = caption[0]
hflip = ex["hflip"] == _ix
if hflip:
ex["image"] = ex["image"][:, ::-1]
style = ["long_caption_flipped", "transcript_flipped"]
else:
style = ["long_caption", "transcript"]
return dict(
image=ex["image"],
text=tf.stack([caption, ex["transcripts"][_ix]], 0),
url=ex["url"],
style=style
)
joint_ds = _with_transcript(ds, 0)
for i in range(1, 3):
joint_ds = joint_ds.concatenate(_with_transcript(ds, i))
return joint_ds
@seqio.map_over_dataset
def extract_llava(ex, sequence_length, output_features):
tf.assert_equal(tf.shape(ex['conversations']['value'])[0], 2)
prompt = ex['conversations']['value'][0]
text = ex['conversations']['value'][1]
ex.pop('conversations')
ex["text"] = text
ex["prompt"] = prompt
return ex
def extract_localized_narrative(ds):
ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
def _map(ex):
return dict(
image=ex["image"],
text=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
)
return ds.map(_map)
def float_to_text(val):
return tf.strings.as_string(tf.cast(val * 100, tf.int32))
@seqio.map_over_dataset
def extract_vqa(ex):
questions = ex["vqa"]["questions"]
answers = ex["vqa"]["answers"]
answers = tf.strings.reduce_join(answers, 1, separator="; ")
qas = tf.strings.reduce_join(tf.stack([questions, answers], 1), separator=" ")
return dict(
image=ex["image"],
text=tf.strings.reduce_join(qas, separator="\n")
)
@seqio.map_over_dataset
def coco_image_id_from_path(ex):
image_id = tf.strings.substr(ex["image/filename"], 0, tf.strings.length(ex["image/filename"])-4)
ex["image_id"] = tf.strings.to_number(image_id)
return ex
@seqio.map_over_dataset
def add_coco_url(ex):
"""Turns a COCO path into a URL, which can then be used in visualizations"""
path = ex["image/filename"]
if not tf.strings.regex_full_match(path, ".*/.*"):
prefix = tf.strings.regex_replace(path, "COCO_", "")
prefix = tf.strings.regex_replace(prefix, "_[0-9]+.jpg", "")
path = tf.strings.join([prefix, path], separator="/")
# images are hosted by the COCO website here
url = tf.strings.join(["https://s3.us-east-1.amazonaws.com/images.cocodataset.org/", path])
ex["metadata/image_url"] = url
return ex
def flatten_vqa(ds):
parts = ["questions", "answers"]
for k in ["id", "question_id"]:
if k in ds.element_spec:
parts.append(k)
return flatten_parts(ds, parts)
def format_gqa(ds, is_balanced=True, flatten=True):
if is_balanced:
ds = ds.filter(lambda x: tf.reduce_any(x["questions"]["is_balanced"]))
def _filter_qs(ex):
qs = ex["questions"]
mask = qs["is_balanced"]
qs = {k: tf.boolean_mask(v, mask) for k, v in qs.items()}
ex["questions"] = qs
return ex
ds = ds.map(_filter_qs)
if flatten:
ds = flatten_parts(ds, ["questions"])
def _rename(ex):
out = ex["questions"]
out["image"] = ex["image"]
out["image_id"] = ex["image_id"]
return out
return ds.map(_rename)
@seqio.map_over_dataset
def fix_doqa_url(x):
x["image_url"] = tf.strings.regex_replace(x["image_url"], "gs://", "")
return x
def _add_metadata(ex):
out = {}
if "id" in ex:
out["metadata/example_id"] = ex["id"]
elif "example_id" in ex:
out["metadata/example_id"] = ex["example_id"]
elif "question_id" in ex:
out["metadata/example_id"] = ex["question_id"]
if "image_url" in ex:
out["metadata/image_url"] = ex["image_url"]
for k, v in ex.items():
if k.startswith("metadata/"):
out[k] = v
return out
def image_only(ds):
return ds.filter(lambda x: x["has_image"])
def filter_difficult_direct_answer(ds):
return ds.filter(lambda x: not x["difficult_direct_answer"])
@seqio.map_over_dataset()
def format_ai2d(ex, variable_style=True):
abc = tf.constant(list("abcdefg".upper()))
out = dict(image=ex["image"])
out.update(_add_metadata(ex))
options = ex["choices"]
# >= 3 in case of none of the above like answers
n_options = tf.shape(ex["option_is_abc"])[0]
if ex["abc_label"] and tf.reduce_sum(tf.cast(ex["option_is_abc"], tf.int32)) >= (n_options - 1):
# The image labels are always upper, so use upper in the answer ptions
options = tf.where(
ex["option_is_abc"],
tf.strings.upper(options),
options
)
short_options = options
style = "ai2_diagram_no_letter"
else:
short_options = abc[:tf.shape(options)[0]]
options = tf.stack([short_options, options,], 1)
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
style = "ai2_diagram"
options = tf.strings.reduce_join(options, separator="\n")
out["question"] = ex["question"]
out["options"] = options
if variable_style:
out["style"] = style
if ex["answer_idx"] < 0:
out["text"] = "?"
else:
out["text"] = short_options[ex["answer_idx"]]
out["metadata/answer_idx"] = ex["answer_idx"]
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
out["metadata/has_transparent_box"] = ex.get("has_transparent_box", tf.constant(False))
out["metadata/abc_label"] = ex["abc_label"]
return out
@gin.configurable()
@seqio.map_over_dataset()
def format_multiple_choice_qa(ex, option_format="abc"):
assert option_format == "abc"
abc = tf.constant(list("abcdefg".upper()))
out = dict(image=ex["image"])
out.update(_add_metadata(ex))
options = ex["choices"]
short_options = abc[:tf.shape(options)[0]]
options = tf.stack([short_options, options,], 1)
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
options = tf.strings.reduce_join(options, separator="\n")
out["question"] = ex["question"]
out["options"] = options
if ex["answer_idx"] < 0:
out["text"] = "?"
else:
out["text"] = short_options[ex["answer_idx"]]
out["metadata/answer_idx"] = ex["answer_idx"]
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
out["metadata/option_names"] = tf.strings.reduce_join(short_options, separator="|||")
# out["metadata/option_names"] = tf.RaggedTensor.from_row_lengths(short_options, tf.shape(short_options))
# out["metadata/option_names"] = short_options
return out
@seqio.map_over_dataset()
def output_options(ex):
ex["metadata/options"] = ex["options"]
return ex
@seqio.map_over_dataset()
def extract_tally_qa(ex):
questions = ex.pop("questions")
ex["questions"] = questions["question"]
ex["answers"] = tf.strings.as_string(questions["answer"])
ex["question_id"] = questions["question_id"]
return ex
@seqio.map_over_dataset()
def count_bench_preprocessor(ex):
return {
"image": ex["image"],
"text": tf.strings.as_string(ex["number"]),
"object": ex["noun"],
"question": tf.strings.join([
"How many ", ex["noun"], " are there?"
]),
"metadata/count": ex["number"],
}
def filter_human(ds):
return ds.filter(lambda x: x["is_human"])
def filter_aug(ds):
return ds.filter(lambda x: not x["is_human"])
@seqio.map_over_dataset()
def reweight_chartqa(ex, human, aug):
is_human = ex["metadata/is_human"]
ex["text_weights"] = human if is_human else aug
return ex
@seqio.map_over_dataset()
def chartqa_prompting(ex):
question = tf.strings.join([ex["question"], " Answer:"])
return dict(
image=ex["image"],
question=question,
answer=ex["answer"]
)
@seqio.map_over_dataset()
def chartqa_explanation(ex):
question = tf.strings.join([ex["question"], " Explanation:"])
out = {
"image": ex["image"],
"question": question,
"answer": ex["answer"],
}
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
return out
@seqio.map_over_dataset(num_seeds=1)
def _preprocess_scifi(ex, seed):
if "qa_pairs" in ex:
q = ex["qa_pairs"]
else:
q = ex["qa"]
ix = stateless_permutation(tf.shape(q["question"])[0], seed)
return dict(
image=ex["image"],
question=tf.gather(q["question"], ix),
explanation=tf.gather(q["explanation"], ix),
answer=tf.gather(q["answer"], ix),
)
@seqio.map_over_dataset
def scifi_explanation_only(ex):
return dict(
image=ex["image"],
question=ex["question"],
answer=ex["explanation"],
)
def filter_named_entity(ds):
@seqio.map_over_dataset
def _load_image(ex):
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
return ex
ds = _load_image(ds)
return ds.filter(lambda x: tf.reduce_min(tf.shape(x["image"])[:2]) >= 32)
@seqio.map_over_dataset()
def extract_named_entity(ex):
qs = ex["questions"]
return {
"image": ex["image"],
"metadata/image_url": ex["url"],
"metadata/entity": ex["entity"],
"questions": qs["question"],
"answers": qs["answer"],
}
@gin.configurable()
def extract_individual_vqa(ds, test=False, answer_mode="best"):
@seqio.map_over_dataset(num_seeds=1)
def _extract(ex, seed):
if "questions" in ex:
question = ex["questions"]
else:
question = ex["question"]
out = dict(
image=ex["image"],
question=question,
)
out.update(_add_metadata(ex))
out["metadata/question"] = question
if ex.get("answers") is not None:
out["metadata/references"] = tf.strings.reduce_join(ex["answers"], separator="\n")
elif ex.get("answer") is not None:
out["metadata/references"] = ex["answer"]
if not test:
if "answer" in ex:
answer = ex["answer"]
else:
answer = ex["answers"]
if answer.dtype in [tf.int32, tf.int64]:
answer = tf.strings.as_string(answer)
if len(answer.shape) == 1 and tf.shape(answer)[0] == 0:
answer = tf.expand_dims("", 0)
if len(answer.shape) == len(question.shape):
pass
# Handle questions with multiple answers
elif answer_mode == "random":
assert len(answer.shape) == 1
answer = answer[tf.random.stateless_uniform((), seed, 0, tf.shape(answer)[0], dtype=tf.int32)]
elif answer_mode == "best":
def _get_best(_answer):
vals, _, counts = tf.unique_with_counts(_answer)
count_thresh = tf.reduce_max(counts)
vals = tf.boolean_mask(vals, counts >= count_thresh)
return vals[tf.random.stateless_uniform((), seed, 0, tf.shape(vals)[0], dtype=tf.int32)]
if len(answer.shape) == 1:
answer = _get_best(answer)
elif isinstance(answer, tf.RaggedTensor):
n = tf.shape(answer)[0]
answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
for i in range(n):
answer_arr = answer_arr.write(i, _get_best(answer[i]))
answer = answer_arr.stack()
else:
answer = tf.map_fn(_get_best, answer)
elif answer_mode == "all_segments":
out["text"] = answer
elif answer_mode == "all_segments_weighted":
out["text"] = answer
out["text_weights"] = 1.0 / tf.cast(tf.shape(answer)[-1], tf.float32)
elif answer_mode == "all":
if len(answer.shape) == 1:
answer = stateless_shuffle(answer, seed)
answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
elif isinstance(answer, tf.RaggedTensor):
n = tf.shape(answer)[0]
answer_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=())
for i in range(n):
answer_arr = answer_arr.write(i, tf.strings.reduce_join(tf.random.shuffle(answer[i]), separator="\n", axis=-1))
answer = answer_arr.stack()
else:
answer = tf.map_fn(tf.random.shuffle, answer)
answer = tf.strings.reduce_join(answer, separator="\n", axis=-1)
else:
raise NotImplementedError()
out["text"] = answer
return out
return _extract(ds)
@seqio.map_over_dataset()
def extract_khan_academy(ex):
return dict(
image=ex["image"],
image_url=ex["image_url"],
prompt="Answer this question",
text=ex["gptResponse"]
)
@seqio.map_over_dataset()
def extract_vaia_qa_latex_image(ex, add_short_answer=False, set_short_answer_first=False):
if ex["has_image"]:
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
image = tf.expand_dims(image, 0)[:1]
else:
# image = get_blank_image() # blank image
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
image = tf.expand_dims(image, 0)[:0]
img_h = tf.shape(image)[1]
img_w = tf.shape(image)[2]
if add_short_answer:
if set_short_answer_first:
answer = tf.strings.join(["Answer: ", ex["short_answer"], "\n\n", ex["answer"]])
else:
answer = tf.strings.join([ex["answer"], "\n\n", "Answer: ", ex["short_answer"]])
else:
answer = ex["answer"]
out = dict(
image=image, # 4-d tensor
text=answer,
prompt=tf.strings.join([ex["latex_question"], "\n"]),
)
out["metadata/images"] = image
out.update(_add_metadata(ex))
out["metadata/batch_id"] = ex["batch_id"]
out["metadata/image_size"] = [img_w, img_h]
return out
@seqio.map_over_dataset()
def extract_vqa_online(ex):
out = dict(
image=ex["image"],
prompt=tf.strings.join([ex["question"], "\n"]),
text=ex["answer"]
)
out.update(_add_metadata(ex))
out["metadata/row_id"] = ex["row_id"]
return out
@seqio.map_over_dataset()
def extract_scifi_joint(ex):
if "qa_pairs" in ex:
q = ex["qa_pairs"]
else:
q = ex["qa"]
prompts = tf.concat([["Describe this image in detail."], q["question"]], 0)
responses = tf.concat([ex["summary"][None], q["answer"]], 0)
return dict(
image=ex["image"],
prompt=prompts,
text=responses,
)
def remove_no_qa(ds):
def _filter(ex):
if "qa_pairs" in ex:
q = ex["qa_pairs"]
else:
q = ex["qa"]
return tf.shape(q["question"])[0] > 0
return ds.filter(_filter)
@seqio.map_over_dataset()
def extract_scifi_qa_exp(ex):
return dict(
image=ex["image"],
question=ex["question"], # Array of questions
answer=tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]]),
)
@seqio.map_over_dataset(num_seeds=1)
def extract_scifi_qa_demo(ex, seed):
# if tf.random.stateless_uniform((), 0, 1) > 0.5:
answer = tf.strings.join([ex["explanation"], " Answer: ", ex["answer"]])
# else:
# answer = ex["explanation"]
return dict(
image=ex["image"],
question=ex["question"], # Array of questions
answer=answer,
)
@seqio.map_over_dataset()
def clock_bench_preprocessor(ex):
out = dict(
image=ex["image"],
prompt="What time is being shown?",
)
for k in ["hour", "minute", "second", "answerable"]:
out[f"metadata/{k}"] = ex[k]
return out
def deg2rad(x):
return x*math.pi/180.0
def get_affine_matrix(center, angle, translate, scale, shear):
# From https://github.com/pytorch/vision/blob/f96c42fca53230057b16941b078a0a9eee06e20f/torchvision/transforms/functional.py#L1006
rot = deg2rad(angle)
sx = deg2rad(shear[0])
sy = deg2rad(shear[1])
cx, cy = center
tx, ty = translate
# RSS without scaling
a = tf.cos(rot - sy) / tf.cos(sy)
b = -tf.cos(rot - sy) * tf.tan(sx) / tf.cos(sy) - tf.sin(rot)
c = tf.sin(rot - sy) / tf.cos(sy)
d = -tf.sin(rot - sy) * tf.tan(sx) / tf.cos(sy) + tf.cos(rot)
matrix = [a, b, 0.0, c, d, 0.0]
matrix = [x * scale for x in matrix]
# Apply inverse of center translation: RSS * C^-1
matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
# Apply translation and center : T * C * RSS * C^-1
matrix[2] += cx + tx
matrix[5] += cy + ty
return matrix
def quantize_point(coor, max_dim, mode="percent-precision-1"):
max_dim = tf.cast(max_dim, tf.float32)
coor = tf.cast(coor, tf.float32)
x = (coor / max_dim)
if mode == "percent-precision-1":
return tf.strings.as_string(x*100, precision=1)
elif mode == "zero_to_one":
return tf.strings.as_string(x, precision=3)
elif mode == "1k":
return tf.strings.as_string(x*1000, precision=0)
else:
raise NotImplementedError(mode)
def construct_pointing_format(label_text, alt_text, x_str, y_str):
if alt_text is None:
alt_text = label_text
np = tf.shape(x_str)[0]
if np == 0:
output = ""
elif np == 1:
output = tf.strings.join([
'', label_text, ''
])
else:
ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
xs = tf.strings.join(["x", ids, '="', x_str, '"'])
ys = tf.strings.join(["y", ids, '="', y_str, '"'])
points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
output = tf.strings.join(
["', label_text, ""])
return output
def order_points(x, y, seed, point_order):
if point_order == "natural":
return x, y
if point_order == "random":
ix = stateless_permutation(tf.shape(x)[0], seed)
elif point_order == "xy":
x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
ix = tf.argsort(x_float*100000 + y_float)
elif point_order == "yx":
x_float, y_float = tf.strings.to_number(x), tf.strings.to_number(y)
ix = tf.argsort(y_float*100000 + x_float)
else:
raise NotImplementedError(point_order)
return tf.gather(x, ix), tf.gather(y, ix)
@gin.configurable()
def points_to_text(x, y, w, h, seed, label=None, alt_text=None, point_mode="percent-precision-1",
point_order="xy", point_list_mode="tag"):
"""Returns a string encoding of a list of points"""
x = quantize_point(x, w, point_mode)
y = quantize_point(y, h, point_mode)
# Order the quantized points to make the order matches what was generated, this can matter
# when points have the same quantized value e.g, (10.001, 20) (10.002, 10) should be
# represented (10, 10), (10, 20), but if we sort before quantization we get (10, 20), (10, 10)
x, y = order_points(x, y, seed, point_order)
if point_list_mode == "tag":
return construct_pointing_format(label, alt_text, x, y)
elif point_list_mode == "paren":
n = tf.shape(x)[0]
return tf.strings.reduce_join(tf.strings.join([
"(", x, ", ", y, ")"
]), separator=", ")
# if n == 0:
# output = ""
# else:
# ids = tf.strings.as_string(tf.range(1, np + 1, dtype=tf.int32))
# xs = tf.strings.join(["x", ids, '="', x_str, '"'])
# ys = tf.strings.join(["y", ids, '="', y_str, '"'])
# points = tf.strings.reduce_join(tf.reshape(tf.stack([xs, ys], 1), [-1]), separator=' ', axis=-1)
# output = tf.strings.join(
# ["', label_text, ""])
# return output
else:
raise NotImplementedError(point_list_mode)
def points_to_answer(x, y, w, h, seed, label, is_counting, alt_text=None):
count = tf.shape(x)[0]
if is_counting:
if count == 0:
return "There are none."
else:
point_text = points_to_text(x, y, w, h, seed, label, alt_text)
return tf.strings.join([
"Counting the ", point_text,
" shows a total of ",
tf.strings.as_string(count),
"."
])
else:
if count == 0:
return "There are none."
else:
return points_to_text(x, y, w, h, seed, label, alt_text)
@seqio.map_over_dataset(num_seeds=2)
def extract_point_qa(ex, seeds, answer_type="y_major"):
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
questions = ex["questions"]
question = questions["question"]
n = tf.shape(question)[0]
answers = tf.TensorArray(tf.string, size=n, element_shape=())
point_text = questions["annotations"]["point_text"]
point_seeds = tf.RaggedTensor.from_row_splits(
row_splits=point_text.row_splits,
values=tf.random.split(seeds[0], num=tf.shape(point_text.values)[0])
)
for question_ix in range(n):
anno = questions["annotations"]
answer = questions["answer_with_placeholders"][question_ix]
n_anno = tf.shape(anno["point_text"][question_ix])[0]
for anno_ix in range(n_anno):
points = anno["points"][question_ix, anno_ix]
point_text = points_to_answer(
points[:, 0], points[:, 1], 100, 100,
point_seeds[question_ix, anno_ix],
anno["point_text"][question_ix, anno_ix],
False,
alt_text=anno["alt_text"][question_ix, anno_ix],
)
answer_split = tf.strings.split(answer, sep="<|POINT|>", maxsplit=1)
answer = tf.strings.join([answer_split[0], point_text, answer_split[1]])
# Make sure all placeholders where used
tf.debugging.assert_equal(tf.shape(tf.strings.split(answer, sep="<|POINT|>"))[0], 1)
answers = answers.write(question_ix, answer)
messages = tf.stack([question, answers.stack()], axis=1)
messages = tf.reshape(messages, [-1])
conversation_ids = tf.range(tf.shape(messages)[0] // 2, dtype=tf.int32)
conversation_ids = tf.repeat(conversation_ids, 2)
out = dict(
image=ex["image"],
messages=tf.RaggedTensor.from_value_rowids(messages, conversation_ids)
)
ix = stateless_permutation(tf.shape(messages)[0], seeds[1])
messages = tf.gather(messages, ix)
out.update(_add_metadata(ex))
out["metadata/image_size"] = [img_w, img_h]
return out
def select_point(mask):
bs = tf.shape(mask)[0]
valid = tf.cast(mask, tf.float32)
h, w = tf.shape(mask)[1], tf.shape(mask)[2]
ys = tf.range(h, dtype=tf.int32)
xs = tf.range(w, dtype=tf.int32)
n = tf.reduce_sum(valid, [1, 2])
cy = tf.reduce_sum(tf.cast(ys[None, :, None], tf.float32) * valid, [1, 2]) / n # [bs]
cx = tf.reduce_sum(tf.cast(xs[None, None, :], tf.float32) * valid, [1, 2]) / n # [bs]
dist_y = tf.square(tf.range(h, dtype=tf.float32)[None, :] - cy[:, None]) # [bs, h]
dist_x = tf.square(tf.range(w, dtype=tf.float32)[None, :] - cx[:, None]) # [bs, w]
dist = dist_y[:, :, None] + dist_x[:, None, :] # [batch, h, w]
dist = dist + (1 - valid) * 1e12
min_dist = tf.argmin(tf.reshape(dist, [bs, -1]), axis=-1) # [batch]
w = tf.cast(w, min_dist.dtype)
cy = tf.cast(min_dist // w, tf.float32)
cx = tf.cast(min_dist % w, tf.float32)
return cx, cy
@seqio.map_over_dataset
def refexp_pointing(ex):
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
objects = ex["objects"]
# Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
refexps = objects['refexp']['raw']
bbox = objects["bbox"]
mask = tf.squeeze(objects["mask"], -1)
ix = tf.range(0, tf.shape(refexps)[0], dtype=tf.int32)
ix = tf.random.shuffle(ix)
refexps = tf.gather(refexps, ix)
bbox = tf.gather(bbox, ix)
mask = tf.gather(mask, ix)
cx, cy = select_point(mask)
answers = points_to_text(img_h, img_w, cx, cy)
out = {
"image": ex["image"],
"refexp": refexps.values,
"metadata/image_size": tf.stack([img_w, img_h,]),
"text": tf.repeat(answers, refexps.row_lengths()),
}
if "image_url" in ex:
out["metadata/image_url"] = ex["image_url"]
return out
@seqio.map_over_dataset
def refexp_pointing_inf(ex):
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
objects = ex["objects"]
mask = tf.squeeze(objects["mask"], -1)
cx, cy = select_point(mask)
answers = points_to_text(img_h, img_w, cx, cy)
refexps = objects["refexp"]["raw"]
# We can't use `mask` directly since it is variable size, and thus it
# will break batching. Here we serialize it instead
serialized_masks = tf.map_fn(tf.io.serialize_tensor, mask, fn_output_signature=tf.string)
out = {
"image": ex["image"],
"refexp": refexps,
"metadata/bbox": objects["bbox"],
"metadata/answer": answers,
"metadata/mask": serialized_masks,
"metadata/image_size": tf.stack([img_w, img_h]),
}
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
return out
@seqio.map_over_dataset
def extract_andriod_control_inf(ex, mode):
if mode == "ll":
prompt = tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]])
elif mode == "hl_ll":
prompt = tf.strings.join([
"high_level: ", ex["metadata/hl_instruction"],
" low_level: ", ex["metadata/ll_instruction"]
])
elif mode == "hl":
prompt = tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]])
elif mode == "hl_cot":
prompt = tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]])
else:
raise NotImplementedError()
out = dict(
image=ex["image"],
prompt=prompt,
text=ex["metadata/target_action"]
)
out.update(_add_metadata(ex))
return out
@seqio.map_over_dataset
def extract_android_control(ex):
# Each image has three tasks:
# low level -> action
# high+low level -> action
# high level -> action
# high level -> low level + action (CoT)
out = dict(
image=ex["image"],
prompt=tf.stack([
tf.strings.join(["low_level: ", ex["metadata/ll_instruction"]]),
tf.strings.join([
"high_level: ", ex["metadata/hl_instruction"],
" low_level: ", ex["metadata/ll_instruction"]
]),
tf.strings.join(["high_level: ", ex["metadata/hl_instruction"]]),
tf.strings.join(["high_level_cot: ", ex["metadata/hl_instruction"]]),
]),
text=tf.stack([
ex["metadata/target_action"],
ex["metadata/target_action"],
ex["metadata/target_action"],
tf.strings.join(["Plan: ", ex["metadata/ll_instruction"], " Action: ", ex["metadata/target_action"]]),
])
)
# Only needed if visualizing
# ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
# img_h = tf.shape(ex["image"])[0]
# img_w = tf.shape(ex["image"])[1]
# out["metadata/image_size"] = tf.stack([img_w, img_h,])
out.update(_add_metadata(ex))
return out
@seqio.map_over_dataset(num_seeds=1)
def refexp(ex, seed):
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
objects = ex["objects"]
# Shuffle objects so what object gets truncated if the sequence gets truncated is randomized
refexps = objects['refexp']['raw']
bbox = objects["bbox"]
ix = stateless_permutation(tf.shape(refexps)[0], seed)
refexps = tf.gather(refexps, ix)
bbox = tf.gather(bbox, ix)
x2 = bbox[:, 0] + bbox[:, 2]
y2 = bbox[:, 1] + bbox[:, 3]
with tf.control_dependencies([
tf.debugging.assert_equal(tf.reduce_any(x2 <= tf.cast(img_w, tf.float32)), True),
tf.debugging.assert_equal(tf.reduce_any(y2 <= tf.cast(img_h, tf.float32)), True)
]):
answers = points_to_text(
img_h, img_w,
tf.reshape(tf.stack([bbox[:, 0], x2], 1), [-1]),
tf.reshape(tf.stack([bbox[:, 1], y2], 1), [-1]))
answers = tf.strings.reduce_join(tf.reshape(answers, [-1, 2]), separator=" ", axis=1)
out = {
"image": ex["image"],
"refexp": refexps.values,
"metadata/bbox": bbox,
"metadata/image_size": tf.stack([img_w, img_h,]),
"text": tf.repeat(answers, refexps.row_lengths()),
}
if "image_url" in ex:
out["image_url"] = ex["image_url"]
return out
@seqio.map_over_dataset
def refexp_inf(ex):
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
out = {
"image": ex["image"],
"refexp": ex["objects"]["refexp"]["raw"],
"metadata/bbox": ex["objects"]["bbox"],
"metadata/image_size": tf.stack([img_w, img_h,]),
}
out.update({k: v for k, v in ex.items() if k.startswith("metadata/")})
return out
def point_text_interleaved(*args):
raise NotImplementedError()
@seqio.map_over_dataset
def web_pointing_preprocessor(ex):
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
question = point_text_interleaved(
img_h, img_w, ex["question"], ex["question_points"]["x"], ex["question_points"]["y"])
answer = point_text_interleaved(
img_h, img_w, ex["answer"], ex["answer_points"]["x"], ex["answer_points"]["y"])
answer_points = tf.stack([ex["answer_points"]["x"], ex["answer_points"]["y"]], axis=1)
return {
"question": question,
"answer": answer,
"image": ex["image"],
"metadata/image_size": [img_w, img_h],
"metadata/question_type": ex["question_type"],
"metadata/answer_points": tf.io.serialize_tensor(answer_points),
"metadata/answer": answer,
}
def filter_pointing(ds):
return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] >= 1)
def filter_qa(ds):
return ds.filter(lambda ex: tf.shape(ex["answer_points"]["x"])[0] == 0)
# vaia filtering
def filter_image_only(ds):
return ds.filter(lambda ex: ex["has_image"])
def filter_mc(ds):
return ds.filter(lambda ex: ex["is_mc"])
def remove_is_long(ds):
return ds.filter(lambda ex: not ex["is_long"])
def remove_has_multiple_parts(ds):
return ds.filter(lambda ex: not ex["has_multiple_parts"])
def _split(ds: tf.data.Dataset, keys, n_splits=2):
def _map(ex):
n = tf.shape(ex[keys[0]])[0]
if n < n_splits:
return tf.data.Dataset.from_tensors(ex)
else:
# import pdb; pdb.set_trace()
bs = n // n_splits
remainder = n - bs*n_splits
lens = tf.concat([
tf.ones([remainder], dtype=tf.int32),
tf.zeros([n_splits-remainder], dtype=tf.int32),
], axis=0) + bs
tf.debugging.assert_equal(tf.reduce_sum(lens), n)
ends = tf.cumsum(lens)
parts = []
for split_ix in range(n_splits):
part_ex = dict(ex)
e = ends[split_ix]
s = e - lens[split_ix]
for k in keys:
if isinstance(k, tuple):
assert len(k) == 2
part_ex[k[0]][k[1]] = ex[k[0]][k[1]][s:e]
else:
part_ex[k] = ex[k][s:e]
parts.append(part_ex)
ds = tf.data.Dataset.from_tensors(parts[0])
for sub_ds in parts[1:]:
sub_ds = tf.data.Dataset.from_tensors(sub_ds)
ds = ds.concatenate(sub_ds)
return ds
return ds.flat_map(_map)
def split(ds, n=2):
# return ds
return _split(ds, [k for k in [
"question",
"label",
"text",
"entity",
"messages"
] if k in ds.element_spec], n_splits=n)
def split_points(ds, max_points=50):
label = "question" if "question" in ds.element_spec else "label"
return _split(ds, [
"question", label, "notInImage",
("answer_points", "x"),
("answer_points", "y"),
])
@seqio.map_over_dataset
def fix_count_qa(ex):
ex["label"] = ex["label"][::2]
tf.debugging.assert_equal(tf.shape(ex["answer_points"]["x"])[0], tf.shape(ex["label"])[0])
return ex
def filter_points(ds, max_number=40):
def _add_valid(ex):
valid = (
tf.reduce_all(ex["answer_points"]["x"] >= 0.0, axis=-1) &
tf.reduce_all(ex["answer_points"]["x"] <= 100.0, axis=-1) &
tf.reduce_all(ex["answer_points"]["y"] >= 0.0, axis=-1) &
tf.reduce_all(ex["answer_points"]["y"] <= 100.0, axis=-1) &
(ex["answer_points"]["y"].row_lengths() <= max_number)
)
ex["valid"] = valid
return ex
ds = ds.map(_add_valid)
ds = ds.filter(lambda ex: tf.reduce_any(ex["valid"]))
return ds
# def filter_points(ds, max_number=30):
# n_points = ds["answer_points"]["x"].row_lengths()
# parts = tf.TensorArray(tf.int32, size=tf.shape(n_points[0]), element_shape=tf.TensorShape([None]))
# total = 0
# on_row = 0
# for i in range(n_points):
# n = n_points[i]
# if n > max_number:
# continue
# if n + total > max_number:
#
# return ds
@seqio.map_over_dataset(num_seeds=2)
def pointing_preprocessor(ex, sequence_length, seeds, with_count=False):
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
img_h = tf.shape(image)[0]
img_w = tf.shape(image)[1]
ix = tf.where(ex["valid"])[:, 0]
ix = stateless_shuffle(ix, seeds[0])
if "label" in ex:
question = tf.strings.lower(ex["label"])
else:
question = ex["question"]
question = tf.gather(question, ix) # [n_question]
points_x = tf.gather(ex["answer_points"]["x"], ix) # [n_question, n_points[ragged]]]
points_y = tf.gather(ex["answer_points"]["y"], ix)
not_in_image = tf.gather(ex["notInImage"], ix) # [n_question]
n = tf.shape(points_x)[0]
point_text = tf.TensorArray(dtype=tf.string, size=n, element_shape=()) # [n_question]
point_seeds = tf.random.split(seeds[1], n)
for i in range(n):
answer = points_to_answer(points_x[i], points_y[i], 100, 100, point_seeds[i], question[i], with_count)
point_text = point_text.write(i, answer)
return {
"image": image,
"metadata/image_size": [img_w, img_h],
"entity": question,
"question": question,
"text": point_text.stack(),
}
@seqio.map_over_dataset
def pointing_inf_preprocessor(ex):
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
question = ex["question"]
not_in_image = tf.shape(ex["answer_points"]["x"])[0] == 0
# points are stored in normalized format, de-normalize here
points_x = ex["answer_points"]["x"] * tf.cast(img_w, tf.float32) / 100.0
points_y = ex["answer_points"]["y"] * tf.cast(img_h, tf.float32) / 100.0
out = dict(
image=ex["image"],
question=question,
entity=question,
)
out.update(_add_metadata(ex))
out["metadata/not_in_image"] = not_in_image
# We can't use `mask` directly since it is variable size, and thus it
# will break batching. Here we serialize it instead
serialized_masks = tf.map_fn(tf.io.serialize_tensor, ex["masks"], fn_output_signature=tf.string)
serialized_masks = tf.strings.reduce_join(serialized_masks, separator="|||")
out["metadata/mask"] = serialized_masks
out["metadata/question"] = question
out["metadata/answer_points"] = tf.io.serialize_tensor(tf.stack([points_x, points_y], 1))
out["metadata/image_size"] = [img_w, img_h]
return out
@seqio.map_over_dataset(num_seeds=1)
def count_qa_preprocessor_inf(ex, sequence_length, seed):
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
img_h = tf.shape(image)[0]
img_w = tf.shape(image)[1]
entity = tf.strings.substr(
ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
entity = tf.strings.lower(entity)
tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
return {
"image": image,
"metadata/image_size": [img_w, img_h],
"metadata/count": tf.strings.to_number(ex["answer"]),
"question": ex["question"],
"entity": entity,
}
@seqio.map_over_dataset(num_seeds=1)
def count_qa_preprocessor(ex, sequence_length, seed, with_count=False,
for_inference=False):
point_answer = ex["point_answer"]
numbers_str = tf.strings.regex_replace(point_answer, r'\.$', '')
numbers_str = tf.strings.regex_replace(numbers_str, r'[^\d\.\s]+', '')
numbers_str = tf.strings.strip(numbers_str)
numbers = tf.strings.split(numbers_str)
float_numbers = tf.strings.to_number(numbers, out_type=tf.float32)
coordinates = tf.reshape(float_numbers, (-1, 3))
points_x = coordinates[:, 1]
points_y = coordinates[:, 2]
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
img_h = tf.shape(image)[0]
img_w = tf.shape(image)[1]
entity = tf.strings.substr(
ex["question"], len("How many "), tf.strings.length(ex["question"]) - len("How many "))
entity = tf.strings.split(entity, sep=" are ", maxsplit=1)[0]
entity = tf.strings.lower(entity)
tf.debugging.assert_equal(tf.strings.length(entity) != 0, True)
count = tf.strings.to_number(ex["answer"], out_type=tf.int32)
if for_inference:
return {
"image": image,
"metadata/image_size": [img_w, img_h],
"metadata/count": count,
"question": ex["question"],
"entity": entity,
}
else:
tf.debugging.assert_equal(count, tf.shape(points_x)[0])
# points are already normalized so use w=1, h=1
answer = points_to_answer(points_x, points_y, 1, 1, seed, entity, with_count)
return {
"image": image,
"metadata/image_size": [img_w, img_h],
"metadata/count": count,
"question": ex["question"],
"entity": entity,
"text": answer,
}
@gin.configurable()
@seqio.map_over_dataset
def cleanup_preprocessor(ex, preprocess=False):
if preprocess:
ex["prompt"] = tf.strings.join(
[
"[[User]]: Correct the spelling and punctuation mistakes on the following transcript based on what appears in the image.\n\n{before} ",
ex["prompt"],
"\n[[Assistant]]: {after}"
]
)
return ex
else:
return ex
@gin.configurable()
@seqio.map_over_dataset
def random_text_preprocessor(ex, preprocess=False):
ex["prompt"] = "What does the text say in this image?"
if preprocess:
ex["prompt"] = tf.strings.join(["[[User]]: ", ex["prompt"], "\n[[Assistant]]:"])
return ex
else:
return ex
@seqio.map_over_dataset(num_seeds=25)
def clock_augmentation(ex, seeds):
seeds = list(seeds)
image = ex["image"]
# Apply shear, rotation, and scale through one affine matrix
height = tf.cast(tf.shape(image)[0], tf.float32)
width = tf.cast(tf.shape(image)[1], tf.float32)
_call_id = [0]
def _rng(_minval=0, _maxval=1, shape=(), dtype=tf.float32):
return tf.random.stateless_uniform(shape, seeds.pop(), _minval, _maxval, dtype=dtype)
sel = _rng(0, 1)
if sel < 0.1:
# Straight on
shear_x = 0.
shear_y = 0.
rotation = 0.
elif sel < 0.5:
# Normal looking
shear_x = _rng(-10, 10)
shear_y = _rng(-10, 10)
rotation = _rng(-25, 25)
else:
# Allowed to be very wonky
# if tf.random.stateless_uniform((), seeds.pop(), 0, 1) > 0.8:
# image = image[:, ::-1]
if _rng() > 0.5:
shear_x = _rng( -30, 30)
shear_y = _rng( -30, 30)
else:
shear_x = _rng( -10, 10)
shear_y = _rng( -10, 10)
rng = _rng( 0, 1)
if rng < 0.2:
rotation = _rng( -25, 25)
elif rng < 0.6:
rotation = _rng( -80, 80)
else:
rotation = _rng( -180, 180)
if _rng() > 0.5:
scale = _rng( 0.3, 2)
else:
scale = _rng( 0.3, 1)
# Pad so upscaling/rotation will not move the image out of bounds
pad = tf.cast(tf.maximum(height, width)*0.5, tf.int32)
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
height = tf.cast(tf.shape(image)[0], tf.float32)
width = tf.cast(tf.shape(image)[1], tf.float32)
image = tf.keras.ops.image.affine_transform(
image,
tf.stack(get_affine_matrix(
[height/2, width/2],
rotation,
[0, 0],
1/scale,
[shear_x, shear_y]
) + [0., 0.]),
interpolation='bilinear',
fill_mode='constant',
fill_value=1.,
data_format='channels_last'
)
# Crop, otherwise it would be impossible to put the image at the corner of the image
not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
no_white_ix = tf.where(not_white)
top_left = tf.reduce_min(no_white_ix, axis=0)
bottom_right = tf.reduce_max(no_white_ix, axis=0)
image = tf.image.crop_to_bounding_box(
image,
offset_height=tf.cast(top_left[0], tf.int32),
offset_width=tf.cast(top_left[1], tf.int32),
target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
)
# Translate
height, width = tf.shape(image)[0], tf.shape(image)[1]
translation_seed = _rng(0, 1)
if translation_seed < 0.2:
h_pad = _rng(0, height//2, (2,), dtype=tf.int32)
w_pad = _rng(0, width//2, (2,), dtype=tf.int32)
else:
h_pad = _rng(0, height*2, (2,), dtype=tf.int32)
w_pad = _rng(0, width*2, (2,), dtype=tf.int32)
image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
constant_values=1)
# Random background color
# color_rng = tf.random.stateless_uniform((4,), seeds.pop(), 0, 1)
# random_color = color_rng[:3]
# valid = tf.reduce_all(tf.reduce_sum(tf.abs(random_color[None, None, :] - image), -1) > 0.03)
# if color_rng[0] < 0.2 and valid:
# image = tf.where(tf.reduce_all(image < 0.99, axis=-1, keepdims=True),
# image, image * 0 + random_color[None, None, :])
# Mild color hitter
image = tf.image.stateless_random_hue(image, max_delta=0.05, seed=seeds.pop())
image = tf.image.stateless_random_brightness(image, max_delta=0.15, seed=seeds.pop())
image = tf.image.stateless_random_saturation(image, 0.8, 1.2, seed=seeds.pop())
image = tf.image.stateless_random_contrast(image, 0.8, 1.2, seed=seeds.pop())
# ex["metadata/unaugmented_image"] = ex["image"]
ex["image"] = image
return ex
@seqio.map_over_dataset
def clocks_preprocessor(ex):
time_format = ex["time_format"]
shows_seconds = ex["shows_seconds"]
hour, minute, second = [tf.cast(ex[k], tf.int32) for k in ["hour", "minute", "second"]]
if hour == 0: # Midnight of the previous day
am_pm = "PM"
hour_str = 12
hour = 24
elif hour > 12:
am_pm = "PM"
hour_str = hour - 12
else:
hour_str = hour
am_pm = "AM"
hour_str = tf.strings.as_string(hour_str)
minute_str = tf.strings.as_string(minute)
if tf.strings.length(minute_str) == 1:
minute_str = tf.strings.join(["0", minute_str])
second_str = tf.strings.as_string(second)
if tf.strings.length(second_str) == 1:
second_str = tf.strings.join(["0", second_str])
prefix = "The time shown is "
if time_format == "The time is not shown":
text = "The time is not shown in the image."
hour, minute, second = -1, -1, -1
else:
if not shows_seconds:
second = -1
if time_format == "12 hour clock (without AM/PM)" and shows_seconds:
if hour > 12:
hour = hour - 12
time = tf.strings.join([hour_str, ":", minute_str, ":", second_str])
elif time_format == "12 hour clock (with AM/PM)" and shows_seconds:
time = tf.strings.join([hour_str, ":", minute_str, ":", second_str, " ", am_pm])
elif time_format == "12 hour clock (with AM/PM)" and not shows_seconds:
time = tf.strings.join([hour_str, ":", minute_str, " ", am_pm])
elif time_format == "12 hour clock (without AM/PM)" and not shows_seconds:
if hour > 12:
hour = hour - 12
time = tf.strings.join([hour_str, ":", minute_str])
else:
time = "" # Should never occur, but needed for tf analysis
tf.debugging.assert_equal(tf.strings.length(time) > 0, True)
text = tf.strings.join(["The time shown is ", time])
image = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
image = tf.image.convert_image_dtype(image, tf.float32)[:-120] # remove the black shadow at the bottom
return {
"image": image,
"prompt": "What time is being shown?",
"text": text,
"metadata/time_format": time_format,
"metadata/hour": hour,
"metadata/minute": minute,
"metadata/text": text,
"metadata/second": second,
}
@seqio.map_over_dataset()
def atlas_obscura_preprocessor(ex):
out = dict(
image=ex["image"],
prompt="Where was this picture taken?",
text=tf.strings.join([
ex["place"],
" in ",
ex["city"]
])
)
out["metadata/image_url"] = ex["image_url"]
out["metadata/references"] = out["text"]
return out
@seqio.map_over_dataset()
def famous_birthdays_preprocessor(ex):
out = dict(
image=ex["image"],
image_url=ex["image_url"],
prompt="Who is this?",
text=ex["name"]
)
out["metadata/references"] = out["text"]
return out
@seqio.map_over_dataset()
def mild_color_aug_preprocessor(ex):
if "image_url" in ex: # URL won't show the augmentations
del ex["image_url"]
# ex["metadata/unaugmented_image"] = ex["image"]
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
ex["image"] = mild_color_aug(ex["image"])
return ex
def build_text_with_points(text, points, img_h, img_w):
points = points_to_text(img_h, img_w, points[:, 0], points[:, 1])
parts = tf.strings.split(text, sep="")
with_points = tf.strings.reduce_join(tf.reshape(tf.stack([
parts,
tf.pad(points, [[0, 1]], constant_values=""),
], 1), [-1]), separator="")
return tf.strings.split(with_points, "\n\n")
@seqio.map_over_dataset()
def synth_count_preprocessor(example):
image_shape = tf.shape(example["image"])
h, w = image_shape[0], image_shape[1]
questions = build_text_with_points(example["questions"], example["question_points"], h, w)
answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
keep_q = tf.strings.regex_full_match(questions, "How many.*")
keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
keep = tf.logical_and(keep_q, keep_ans)
questions = tf.boolean_mask(questions, keep)
answers = tf.boolean_mask(answers, keep)
ix = tf.range(0, tf.shape(answers)[0], dtype=tf.int32)
ix = tf.random.shuffle(ix)
return dict(
image=example["image"],
prompt=tf.gather(questions, ix),
text=tf.gather(answers, ix),
)
def synth_count_inf_preprocessor(ds):
@seqio.map_over_dataset(num_seeds=1)
def get_two(example, seed):
image_shape = tf.shape(example["image"])
h, w = image_shape[0], image_shape[1]
questions = build_text_with_points(example["questions"], example["question_points"], h, w)
answers = build_text_with_points(example["answers"], example["answer_points"], h, w)
keep_q = tf.strings.regex_full_match(questions, "How many.*")
keep_ans = tf.strings.regex_full_match(answers, "There are [0-9]+.*")
keep = tf.logical_and(keep_q, keep_ans)
questions = tf.boolean_mask(questions, keep)
answers = tf.boolean_mask(answers, keep)
ix = stateless_permutation(tf.shape(answers)[0], seed)[:2]
return {
"image": example["image"],
"prompt": tf.gather(questions, ix),
"metadata/references": tf.gather(answers, ix),
}
ds = get_two(ds)
return flatten_parts(ds, ["prompt", "metadata/references"])
def mild_color_aug(image):
image = tf.image.random_hue(image, max_delta=0.05)
image = tf.image.random_brightness(image, max_delta=0.15)
image = tf.image.random_saturation(image, 0.7, 1.3)
image = tf.image.random_contrast(image, 0.8, 1.2)
return image
@seqio.map_over_dataset()
def name_entity_augmentation(ex, p_high_color=0.7):
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
image = ex["image"]
image = tf.image.convert_image_dtype(image, tf.float32)
# Horizontal flip
if tf.random.uniform((), 0, 1) > 0.85:
image = image[:, ::-1]
# Random crop
height = tf.cast(tf.shape(image)[0], tf.float32)
width = tf.cast(tf.shape(image)[1], tf.float32)
crop_rng = tf.random.uniform((), 0, 1)
if crop_rng < 0.2:
pass
else:
if crop_rng < 0.4:
h_crop = height * 0.15
w_crop = width * 0.15
else:
h_crop = height * 0.4
w_crop = width * 0.4
crop_h = tf.cast(tf.random.uniform((2,), 0, h_crop/2), tf.int32)
crop_w = tf.cast(tf.random.uniform((2,), 0, w_crop/2), tf.int32)
image = image[crop_h[0]:-crop_h[1]-1, crop_w[0]:-crop_w[1]-1]
height = tf.cast(tf.shape(image)[0], tf.float32)
width = tf.cast(tf.shape(image)[1], tf.float32)
if tf.random.uniform(()) > p_high_color:
image = tf.image.random_hue(image, max_delta=0.05)
image = tf.image.random_brightness(image, max_delta=0.15)
image = tf.image.random_saturation(image, 0.7, 1.3)
image = tf.image.random_contrast(image, 0.8, 1.2)
else:
image = tf.image.random_hue(image, max_delta=0.1)
image = tf.image.random_brightness(image, max_delta=0.3)
image = tf.image.random_saturation(image, 0.0, 2.0)
image = tf.image.random_contrast(image, 0.2, 1.5)
# Apply shear, rotation, and scale through one affine matrix
sel = tf.random.uniform((), 0, 1)
if sel < 0.1:
pass
else:
if sel < 0.15: # Scale only
shear_x = 0
shear_y = 0
rotation = 0
if sel < 0.7: # Mild
shear_x = tf.random.uniform((), -2, 2)
shear_y = tf.random.uniform((), -2, 2)
rotation = tf.random.uniform((), -5, 5)
else: # Severe
shear_x = tf.random.uniform((), -10, 10)
shear_y = tf.random.uniform((), -10, 10)
rotation = tf.random.uniform((), -20, 20)
max_scale = 1.2
scale = tf.random.uniform((), 0.4, max_scale)
# Pad so upscaling/rotation will not move the image out of bounds
pad = tf.cast(tf.maximum(height, width)*0.2, tf.int32)
image = tf.pad(image, [[pad, pad], [pad, pad], [0, 0]], constant_values=1)
image = tf.keras.ops.image.affine_transform(
image,
tf.stack(get_affine_matrix(
[height/2, width/2],
rotation,
[0, 0],
1/scale,
[shear_x, shear_y]
) + [0., 0.]),
interpolation='bilinear',
fill_mode='constant',
fill_value=1.,
data_format='channels_last'
)
# Crop, otherwise it would be impossible to put the image at the corner of the image
not_white = tf.logical_not(tf.reduce_all(image > 0.99, -1))
no_white_ix = tf.where(not_white)
top_left = tf.reduce_min(no_white_ix, axis=0)
bottom_right = tf.reduce_max(no_white_ix, axis=0)
# Very low chance center crop will get nothing but white space, we just skip
if (
(bottom_right[0] - top_left[0]) > 1 and (bottom_right[1] - top_left[1]) > 1
):
image = tf.image.crop_to_bounding_box(
image,
offset_height=tf.cast(top_left[0], tf.int32),
offset_width=tf.cast(top_left[1], tf.int32),
target_height=tf.cast(bottom_right[0] - top_left[0] + 1, tf.int32),
target_width=tf.cast(bottom_right[1] - top_left[1] + 1, tf.int32),
)
# Translate
height, width = tf.shape(image)[0], tf.shape(image)[1]
if tf.random.uniform((), 0, 1) < 0.1:
h_pad = tf.zeros((2,), dtype=tf.int32)
w_pad = tf.zeros((2,), dtype=tf.int32)
elif tf.random.uniform((), 0, 1) < 0.8:
h_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
w_pad = tf.random.uniform((2,), 0, 50, dtype=tf.int32)
else:
pad = tf.cast(tf.maximum(height, width), tf.int32)
h_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
w_pad = tf.random.uniform((2,), 0, pad, dtype=tf.int32)
image = tf.pad(image, [[h_pad[0], w_pad[0]], [h_pad[1], w_pad[1]], [0, 0]],
constant_values=1)
if "image_url" in ex: # URL won't show the augmentations
del ex["image_url"]
# ex["metadata/unaugmented_image"] = ex["image"]
ex["image"] = image
return ex
@seqio.map_over_dataset()
def wiki_art_preprocessor(ex):
out = dict(
image=ex["image"],
prompt="What is this?",
text=ex["question"]
)
out["metadata/title"] = ex["title"]
out["metadata/gt"] = ex["question"]
out["metadata/artist"] = ex["artist"]
out["metadata/painting_url"] = ex["painting_url"]
# if "metadata/unaugmented_image" in ex:
# out["metadata/unaugmented_image"] = ex["metadata/unaugmented_image"]
return out
@seqio.map_over_dataset()
def oscar_preprocessor(ex):
out = dict(
image=ex["image"],
prompt=ex["question"]
)
out.update(_add_metadata(ex))
out["metadata/question"] = ex["question"]
out["metadata/answer"] = ex["answer"]
out["metadata/category"] = ex["category"]
return out
@seqio.map_over_dataset()
def tulu_preprocessor(ex):
return {
"messages": ex["messages"]["content"],
}
# logging.info("Debugging tulue")
# return {"messages": ex["messages"]["content"], "text_weights": 1e-6}
WIKI_DATA_QUESTION = "What is this? Respond with just a proper name."
@seqio.map_over_dataset()
def extract_wiki_data(ex):
return dict(
image=ex["image"],
image_url=ex["image_url"],
prompt=[
WIKI_DATA_QUESTION,
"What is this? Respond with the proper name of the main focus of the image and a few details about it."
],
text=[
tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", "")),
ex["gptResponse"],
]
)
@seqio.map_over_dataset()
def extract_wiki_data_name(ex):
target = tf.strings.strip(tf.strings.regex_replace(ex["question"], r"\(.*\)", ""))
out = dict(
image=ex["image"],
image_url=ex["image_url"],
prompt=WIKI_DATA_QUESTION,
text=target,
)
out["metadata/references"] = target
return out
@seqio.map_over_dataset()
def extract_wiki_data_describe(ex):
out = dict(
image=ex["image"],
image_url=ex["image_url"],
prompt="What is this? Respond with the proper name of the main focus of the image and a few details about it.",
)
out["metadata/references"] = ex["gptResponse"]
return out
@gin.configurable()
def format_multiple_style_qa(ds, types=['multiple_choice', 'short_answer'], styles=['ai2_diagram', 'vqa2'], default_style='vqa2',
strip_instruction=False):
def _extract(ex):
prompt = ex["question"]
out = dict(image=ex["image"])
out.update(_add_metadata(ex))
out["text"] = ex["answer"]
out["metadata/references"] = ex["answer"]
if ex["metadata/question_type"] == 'multiple_choice':
style = styles[0]
else:
style = styles[1]
if strip_instruction:
if ex["metadata/question_type"] == "multiple_choice":
# parts = tf.strings.split(prompt, "\n")
# parts 1 is blank and part -1 is the instruction
# prompt = tf.strings.reduce_join(tf.concat([parts[:1], parts[2:-1]], 0), separator="\n")
prompt = prompt
else:
prompt = tf.strings.split(prompt, "\n")[0]
out["style"] = style
out["prompt"] = prompt
return out
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return ds
@gin.configurable()
def extract_mmmu(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
assert option_format == "abc"
keys_tensor = tf.constant(types, dtype=tf.string)
values_tensor = tf.constant(styles, dtype=tf.string)
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
default_value=tf.constant(default_style, dtype=tf.string),
)
def _extract(ex):
out = dict(image=tf.expand_dims(ex["image_1"], 0))
out.update(_add_metadata(ex))
style = table.lookup(ex["metadata/question_type"])
out["style"] = style
out["text"] = ex["answer"]
out["metadata/references"] = ex["answer"]
if style == styles[0]:
abc = tf.constant(list("abcdefghi".upper()))
options = ex["options"]
num_options = tf.shape(options)[0]
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
short_options = abc[:num_options]
options = tf.stack([short_options, options,], 1)
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
options = tf.strings.reduce_join(options, separator="\n")
out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, ""), tf.int32)) > 1:
# Following LLaVa, don't use any images if there are multiple images paths
# I think the rationale is that this means the image are answer-options
out["image"] = out["image"][:0]
else:
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
out["prompt"] = ex["question"]
out["image"] = out["image"][:0]
return out
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return ds
@gin.configurable()
def extract_mmmu_cot(ds, types=['multiple-choice', 'open'], styles=['ai2_diagram', 'vqa2'], default_style='ai2_diagram', option_format="abc"):
assert option_format == "abc"
keys_tensor = tf.constant(types, dtype=tf.string)
values_tensor = tf.constant(styles, dtype=tf.string)
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
default_value=tf.constant(default_style, dtype=tf.string),
)
def _extract(ex):
# out = dict(image=tf.expand_dims(ex["image_with_question"], 0))
out = dict(image=tf.expand_dims(ex["image_1"], 0))
out.update(_add_metadata(ex))
style = table.lookup(ex["metadata/question_type"])
# out["style"] = style
out["text"] = ex["answer"]
out["metadata/question"] = ex["question"]
out["metadata/references"] = ex["answer"]
if style == styles[0]:
abc = tf.constant(list("abcdefghi".upper()))
options = ex["options"]
num_options = tf.shape(options)[0]
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
short_options = abc[:num_options]
options = tf.stack([short_options, options,], 1)
options = tf.strings.reduce_join(options, axis=-1, separator=": ")
options = tf.strings.reduce_join(options, separator="\n")
out["prompt"] = tf.strings.join([ex["question"], "\n", options, "\n"])
# out["prompt"] = ex["question"]
if tf.reduce_sum(tf.cast(tf.strings.regex_full_match(options, ""), tf.int32)) > 1:
# Following LLaVa, don't use any images if there are multiple images paths
# I think the rationale is that this means the image are answer-options
out["image"] = out["image"][:0]
else:
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
out["prompt"] = ex["question"]
# out["image"] = out["image"][:0]
return out
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return ds
@seqio.map_over_dataset
def reformat_math_vista(ex):
query = ex["query"]
query = tf.strings.split(query, sep="Question:")[-1]
query = tf.strings.strip(tf.strings.split(query, sep="Hint:")[0])
ex["query"] = query
return ex
@seqio.map_over_dataset
def extract_math_vista(ex, styles=['ai2_diagram', 'vqa2']):
out = dict(image=ex["image"])
out.update(_add_metadata(ex))
is_mc = ex["metadata/question_type"] == 'multi_choice'
if is_mc:
style = styles[0]
abc = tf.constant(list("abcdefghi".upper()))
options = ex["choices"]
num_options = tf.shape(options)[0]
dummy_options = tf.tile(tf.constant([""], dtype=tf.string), [9 - num_options])
out["metadata/options"] = tf.concat([options, dummy_options], axis=0)
out["metadata/options"] = tf.ensure_shape(out["metadata/options"], [9])
if ex["metadata/split"] != "test":
short_options = abc[:num_options]
answer_short_option = tf.boolean_mask(short_options, options == ex["answer"])[0]
out["text"] = answer_short_option
else:
out["text"] = ex["answer"]
else:
style = styles[1]
out["metadata/options"] = tf.constant([""] * 9, dtype=tf.string)
out["text"] = ex["answer"]
out["style"] = style
out["prompt"] = ex["query"]
out["metadata/query"] = ex["query"]
out["metadata/references"] = ex["answer"]
return out
NO_POINT_PREFIX = [
"No pointing: ",
"No pointing: ",
"no pointing:\n",
"No pointing:\n",
"Not pointing:\n",
"No Points: ",
"No Points: ",
"NO POINTING\n",
"No pontiing\n",
"No Points:\n ",
"No pointing\n",
"Do not point. ",
"Refrain from pointing. ",
"Avoid generating points . ",
"For this question, do not use points. ",
"Refrain from using points:\n",
"Don't include points in your response. ",
"Don't point. ",
"Don't use points. ",
"Please don't use points.\n\n",
"Please don't use points.\n\n",
"Respond without using points. ",
"Respond without pointing:\n",
"Do not generate ponits: ",
"Do not point. ",
"Do not point\n",
"no pointing\n\n",
"Answer without points: ",
"Answer this question without pointing: ",
"Answer without poiints. ",
"answer without points: ",
"answer with text only, do not points\n"
]
assert all(x[-1].isspace() for x in NO_POINT_PREFIX)
NO_POINT_PREFIX_TF = tf.constant(NO_POINT_PREFIX)
def prefix_how_many(messages, seed):
question = messages[0]
if tf.strings.regex_full_match(tf.strings.lower(question), "how many.*"):
ix = tf.random.stateless_uniform((), seed, 0, len(NO_POINT_PREFIX), tf.int32)
question = tf.strings.join([NO_POINT_PREFIX_TF[ix], question])
return tf.concat([tf.expand_dims(question, 0), messages[1:]], axis=0)
else:
return messages
@seqio.map_over_dataset(num_seeds=1)
def prefix_how_many_messages(ex, seed):
messages = ex["messages"]
n = tf.shape(messages)[0]
seeds = tf.random.split(seed, n)
message_arr = tf.TensorArray(dtype=tf.string, size=n, element_shape=(None,))
for i in range(n):
message_arr = message_arr.write(i, prefix_how_many(messages[i], seeds[i]))
ex["messages"] = tf.RaggedTensor.from_row_splits(
values=message_arr.concat(), row_splits=messages.row_splits)
return ex
def filter_single_turn(ds):
@seqio.map_over_dataset
def _filter(ex):
multi_turn = ex["messages"].row_lengths() > 2
ex["messages"] = tf.ragged.boolean_mask(ex["messages"], multi_turn)
return ex
ds = _filter(ds)
ds = ds.filter(lambda x: tf.shape(x["messages"])[0] > 0)
return ds
@seqio.map_over_dataset(num_seeds=1)
def extract_cockatoo_qa_v2(ex, seed):
messages = tf.RaggedTensor.from_value_rowids(ex["messages"], ex["conversation_ids"])
ix = stateless_permutation(tf.shape(messages)[0], seed)
messages = tf.gather(messages, ix)
out = dict(
image=ex["image"],
messages=messages
)
out.update(_add_metadata(ex))
return out
def format_mmbench(ds):
def _trim(ex):
num_passes = tf.shape(ex["id"])[0]
ex["choices"] = ex["choices"][:num_passes, :num_passes]
ex["answer"] = ex["answer"][:num_passes]
return ex
ds = ds.map(_trim)
ds = flatten_parts(ds, ["id", "query", "choices", "answer"])
def _extract(ex):
out = dict(image=ex["image"])
out.update(_add_metadata(ex))
out["prompt"] = ex["query"]
out["text"] = ex["answer"]
options = ex["choices"]
tf.debugging.assert_equal(tf.reduce_any(tf.strings.regex_full_match(options, ".*\|\|\|.*")), False)
out["metadata/options"] = tf.strings.reduce_join(options, separator="|||")
out["metadata/question"] = ex["question"]
out["metadata/references"] = ex["answer"]
return out
ds = ds.map(_extract, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return ds
@seqio.map_over_dataset
def extract_lvis(ex, class_name_file="gs://oe-training-chrisc/cockatoo/data/lvis_class_names.json"):
with tf.io.gfile.GFile(class_name_file) as f:
class_names = json.load(f)
class_names_arr = [None]*len(class_names)
for k, v in class_names.items():
class_names_arr[int(k)] = v
assert all(x is not None for x in class_names_arr)
class_names_arr = tf.constant(class_names_arr)
return dict(
image=ex["image"],
bbox=ex["objects"]["bbox"],
label=tf.gather(class_names_arr, ex["objects"]["label"]),
)
def extract_open_images_boxes(ds):
# ds = ds.filter(lambda ex: tf.logical_or(
# tf.shape(ex["cap/cap_caption"])[0] > 0,
# tf.shape(ex["detection/bbox"])[0] > 0
# ))
ds = ds.filter(lambda ex: tf.shape(ex["cap/cap_caption"])[0] > 0)
@seqio.map_over_dataset
def _map(ex):
bbox = tf.reshape(ex["detection/bbox"], (-1, 4))
bbox = tf.stack([
bbox[:, 2],
bbox[:, 0],
bbox[:, 3],
bbox[:, 1]
], 1)
return dict(
image=tf.image.decode_jpeg(ex["image"]),
bbox=bbox,
label=ex["detection/label"],
caption=tf.strings.reduce_join(ex["cap/cap_caption"], separator="\n")
)
return _map(ds)
@seqio.map_over_dataset
def region_captions_to_dense(ex):
if "captions" in ex:
captions = ex["captions"]["text"]
boxes = ex["captions"]["bbox"]
else:
captions = ex["label"]
boxes = ex["bbox"]
sh = tf.cast(tf.shape(ex["image"])[:2], tf.float32)
# image_h, image_w = sh[0], sh[1]
w = boxes[:, 2] - boxes[:, 0]
h = boxes[:, 3] - boxes[:, 1]
cx = tf.cast(boxes[:, 0] + w/2, tf.float32)
cy = tf.cast(boxes[:, 1] + h/2, tf.float32)
# w = w / image_w
# h = h / image_h
coor = tf.strings.reduce_join(
float_to_text(tf.stack([cx, cy, w, h], 1)), separator=",", axis=1)
area = w*h
if tf.random.uniform(()) < 0.5:
coor_text = "before"
captions = tf.strings.join([coor, captions], separator=": ")
else:
coor_text = "after"
captions = tf.strings.join([captions, coor], separator=": ")
ix = tf.random.uniform((), 0, 6, tf.int32)
center = boxes
if ix == 0:
order_text = "left"
sort_by = boxes[:, 0]
elif ix == 1:
order_text = "right"
sort_by = -boxes[:, 2]
elif ix == 2:
order_text = "top"
sort_by = boxes[:, 1]
elif ix == 3:
order_text = "bottom"
sort_by = -boxes[:, 3]
elif ix == 4:
order_text = "largest"
sort_by = area
else:
order_text = "smallest"
sort_by = -area
ixs = tf.argsort(sort_by)
captions = tf.gather(captions, ixs)
text = tf.strings.join([
order_text,
coor_text,
tf.strings.reduce_join(captions, separator="\n")
], separator="; ")
if "caption" in ex:
if tf.random.uniform(()) > 0.5:
text = tf.strings.join([text, "\ncaption: ", ex["caption"]])
else:
text = tf.strings.join(["caption: ", ex["caption"], "\n", text])
return dict(
image=ex["image"],
text=text
)
@seqio.map_over_dataset()
def join_captions(ex):
text = tf.random.shuffle(ex['text'])
ex["text"] = tf.strings.reduce_join(text, separator="\n")
return ex
@seqio.map_over_dataset(num_seeds=1)
def extract_figureqa(ex, seed):
questions = ex["questions"]
n = stateless_permutation(tf.shape(questions["question"])[0], seed)
return dict(
image=ex["image"],
questions=tf.gather(questions["question"], n),
question_id=tf.gather(questions["question_id"], n),
answer=tf.gather(tf.strings.as_string(questions["answer"]), n)
)
@seqio.map_over_dataset
def convert_figureqa_answer(ex):
keys_tensor = tf.constant(["0", "1"])
values_tensor = tf.constant(["no", "yes"])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor),
default_value=tf.constant("nan", dtype=tf.string),
)
answer = table.lookup(ex["answer"])
ex["answer"] = answer
return ex
@seqio.map_over_dataset()
def build_question_with_hint(ex):
hint = ex["hint"]
if tf.strings.length(hint) > 0:
ex["question"] = tf.strings.join([hint, ex["question"]], separator="\n")
return ex
@seqio.map_over_dataset()
def build_question_with_context(ex):
context = ex["context"]
if tf.strings.length(context) > 0:
ex["question"] = tf.strings.join([context, ex["question"]], separator="\n")
return ex
def max_words(ds, max_words):
return ds.filter(lambda x: x["n_words"] <= max_words)
@seqio.map_over_dataset
def format_pdfa_eng_wds(example):
return dict(
image=example["image"],
text=tf.strings.reduce_join(example["lines"]["text"], separator="\n"),
)
@gin.configurable()
def accuracy_conditioned_joint(ds, sequence_length, is_eval=False, eval_quality=17,
transcript_quality=None):
# v2: Transcripts no longer get a quality score
is_training = sequence_length.get('is_training', True)
if not is_training:
if is_eval:
prompt = f"quality {eval_quality}:"
else:
prompt = f"quality 17:"
@seqio.map_over_dataset
def _with_prompt(ex):
out = dict(
image=ex["image"],
url=ex["url"],
prompt=prompt,
)
if "text" in ex:
out["text"] = ex["text"]
elif "caption" in ex:
out["text"] = ex["caption"]
return out
return _with_prompt(ds)
elif is_eval:
raise ValueError("is_eval=True and is_training=False")
# each transcript
@seqio.map_over_dataset
def _with_transcript(ex):
if tf.shape(ex["edited_captions"]["caption"])[0] > 0:
edited_caption = ex["edited_captions"]["caption"][0]
n = ex["edited_captions"]["n_edits"][0]
else:
edited_caption = ""
n = 0
text = [
ex["caption"],
ex["transcripts"][tf.random.uniform((), 0, tf.shape(ex["transcripts"])[0], dtype=tf.int32)],
edited_caption
]
edit_quality = 17 - n
prompt = [
"quality 17:",
"" if transcript_quality is None else f"quality: {edit_quality}:",
tf.strings.join(["quality ", tf.strings.as_string(edit_quality), ":"])
]
return dict(
image=ex["image"],
text=tf.stack(text, 0),
url=ex["url"],
prompt=tf.stack(prompt, 0),
style=["long_caption", "transcript", "long_caption"]
)
return _with_transcript(ds)
def select_dense_caption_sample(ds, samples=200):
def compute_hash(string: str) -> str:
return hashlib.sha256(string.encode("utf-8")).hexdigest()
with tf.io.gfile.GFile("gs://oe-training-chrisc/cockatoo/data/dense-caption-eval-v0-final-data.json") as f:
data = json.load(f)
for ex in data:
ex["image_id"] = compute_hash(ex["image"])
data.sort(key=lambda x: x["image_id"])
np.random.RandomState(12312).shuffle(data)
keep = tf.constant([x["image"] for x in data[:samples]])
def _keep(ex):
return tf.reduce_any(ex["url"] == keep)
ds = ds.filter(_keep)
ds = tf.data.experimental.assert_cardinality(samples)(ds)
return ds
@seqio.map_over_dataset()
def charxiv_preprocessor(ex):
question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4", "reasoning_q"]
answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4", "reasoning_a"]
questions = [ex[name] for name in question_names]
answers = [ex[name] for name in answer_names]
return dict(
image=ex["image"],
question=tf.stack(questions, 0),
answer=tf.stack(answers, 0)
)
@seqio.map_over_dataset()
def charxiv_descriptive_preprocessor(ex):
question_names = ["descriptive_q1", "descriptive_q2", "descriptive_q3", "descriptive_q4"]
answer_names = ["descriptive_a1", "descriptive_a2", "descriptive_a3", "descriptive_a4"]
questions = [ex[name] for name in question_names]
answers = [ex[name] for name in answer_names]
return dict(
image=ex["image"],
question=tf.stack(questions, 0),
answer=tf.stack(answers, 0)
)
@seqio.map_over_dataset()
def charxiv_reasoning_preprocessor(ex):
return dict(
image=ex["image"],
question=ex["reasoning_q"],
answer=ex["reasoning_a"]
)
@seqio.map_over_dataset()
def tablevqa_preprocessor(ex):
return dict(
image=ex["image"],
question=ex["question"],
answer=ex["gt"]
)
@seqio.map_over_dataset()
def vtabfact_preprocessor(ex):
return dict(
image=ex["image"],
question=tf.strings.join([ex["question"], "Answer with yes or no."], separator="\n"),
answer=ex["gt"]
)
@seqio.map_over_dataset()
def nutrition_fact_preprocessor(ex):
question_names = ["descriptive_q", "reasoning_q"]
answer_names = ["descriptive_a", "reasoning_a"]
questions = [ex[name] for name in question_names]
answers = [ex[name] for name in answer_names]
return dict(
image=ex["image"],
question=tf.stack(questions, 0),
answer=tf.stack(answers, 0)
)