|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""AI2D TFDS converter. |
|
|
|
|
|
It's a small dataset, so can be built locally. Copy the data to local disk: |
|
|
|
mkdir -p /tmp/data/ai2d && cd /tmp/data/ai2d |
|
wget https://ai2-public-datasets.s3.amazonaws.com/diagrams/ai2d-all.zip |
|
wget https://s3-us-east-2.amazonaws.com/prior-datasets/ai2d_test_ids.csv |
|
wget https://github.com/googlefonts/dm-fonts/raw/main/Sans/fonts/ttf/DMSans-Regular.ttf |
|
unzip ai2d-all.zip |
|
|
|
Also download a font for rendering, set the location in the flag font_path. |
|
|
|
Then, run conversion locally (make sure to install tensorflow-datasets for the `tfds` util): |
|
|
|
cd third_party/py/big_vision/datasets |
|
env TFDS_DATA_DIR=/tmp/tfds tfds build --datasets=ai2d |
|
|
|
Example to load: |
|
|
|
import tensorflow_datasets as tfds |
|
dataset = tfds.load(ai2d', split='train', data_dir='/tmp/tfds') |
|
""" |
|
|
|
import functools |
|
import glob |
|
import io |
|
import json |
|
import os |
|
from typing import Any, Dict |
|
|
|
from absl import flags |
|
import numpy as np |
|
from PIL import Image |
|
from PIL import ImageDraw |
|
from PIL import ImageFont |
|
import tensorflow_datasets as tfds |
|
|
|
|
|
_DESCRIPTION = """AI2D dataset.""" |
|
|
|
|
|
_CITATION = """ |
|
@inproceedings{kembhavi2016eccv, |
|
author = {Aniruddha Kembhavi, Mike Salvato, Eric Kolve, Minjoon Seo, Hannaneh Hajishirzi, Ali Farhadi}, |
|
title = {A Diagram Is Worth A Dozen Images}, |
|
booktitle = {European Conference on Computer Vision (ECCV)}, |
|
year = {2016} |
|
url={https://api.semanticscholar.org/CorpusID:2682274} |
|
} |
|
""" |
|
|
|
|
|
|
|
_INPUT_PATH = flags.DEFINE_string( |
|
'input_path', '/tmp/data/ai2d/', 'Downloaded AI2D data.' |
|
) |
|
_FONT_PATH = flags.DEFINE_string( |
|
'font_path', '/tmp/data/ai2d/DMSans-Regular.ttf', |
|
'Font for rendering annotations.' |
|
) |
|
|
|
|
|
class Ai2d(tfds.core.GeneratorBasedBuilder): |
|
"""DatasetBuilder for AI2D dataset.""" |
|
|
|
VERSION = tfds.core.Version('1.1.0') |
|
RELEASE_NOTES = {'1.1.0': 'Re-create from scratch + more fields.'} |
|
|
|
def _info(self): |
|
"""Returns the metadata.""" |
|
return tfds.core.DatasetInfo( |
|
builder=self, |
|
description=_DESCRIPTION, |
|
features=tfds.features.FeaturesDict({ |
|
'id': tfds.features.Text(), |
|
'question': tfds.features.Text(), |
|
'label': tfds.features.Scalar(np.int32), |
|
'answer': tfds.features.Text(), |
|
'possible_answers': tfds.features.Sequence(tfds.features.Text()), |
|
'abc_label': tfds.features.Scalar(np.bool_), |
|
'image_name': tfds.features.Text(), |
|
'image': tfds.features.Image(encoding_format='png'), |
|
}), |
|
homepage='https://allenai.org/data/diagrams', |
|
citation=_CITATION, |
|
) |
|
|
|
def _split_generators(self, dl_manager: tfds.download.DownloadManager): |
|
"""Returns SplitGenerators.""" |
|
return {split: self._generate_examples(split) |
|
for split in ('test', 'train')} |
|
|
|
def _generate_examples(self, split: str): |
|
"""Yields (key, example) tuples.""" |
|
with open( |
|
os.path.join(_INPUT_PATH.value, 'ai2d_test_ids.csv'), 'r' |
|
) as f: |
|
all_test_ids = f.readlines() |
|
all_test_ids = [line.strip() for line in all_test_ids] |
|
|
|
all_annotation_paths = glob.glob( |
|
os.path.join(_INPUT_PATH.value, 'ai2d/questions', '*.json')) |
|
for annotation_path in all_annotation_paths: |
|
basename = os.path.basename(annotation_path) |
|
image_id = basename.split('.')[0] |
|
if image_id in all_test_ids and split == 'train': |
|
continue |
|
elif image_id not in all_test_ids and split == 'test': |
|
continue |
|
|
|
text_annotation_path = os.path.join( |
|
_INPUT_PATH.value, 'ai2d/annotations', basename |
|
) |
|
with open(annotation_path, 'r') as f: |
|
with open(text_annotation_path, 'r') as g: |
|
question_json = json.load(f) |
|
text_annotation_json = json.load(g) |
|
for question in question_json['questions']: |
|
label_id = int( |
|
question_json['questions'][question]['correctAnswer'] |
|
) |
|
choices = question_json['questions'][question]['answerTexts'] |
|
abc_label = question_json['questions'][question]['abcLabel'] |
|
annotation = { |
|
'id': question_json['questions'][question]['questionId'], |
|
'question': question, |
|
'label': label_id, |
|
'answer': choices[label_id], |
|
'possible_answers': tuple(choices), |
|
'abc_label': abc_label, |
|
'image_name': question_json['imageName'], |
|
} |
|
annotation['image'] = _create_image( |
|
annotation, text_annotation_json['text'] |
|
) |
|
yield annotation['id'], annotation |
|
|
|
|
|
@functools.cache |
|
def Font( |
|
size: int, |
|
) -> ImageFont.FreeTypeFont: |
|
"""Loads the font from in the specified style. |
|
|
|
Args: |
|
size: The size of the returned font. |
|
|
|
Returns: |
|
The loaded font. |
|
""" |
|
return ImageFont.truetype(_FONT_PATH.value, size=size) |
|
|
|
|
|
def _create_image( |
|
annotation: Dict[str, Any], text_annotation: Dict[str, Any] |
|
) -> bytes: |
|
"""Adds image to one annotation.""" |
|
img_path = os.path.join(_INPUT_PATH.value, 'ai2d/images', |
|
annotation['image_name']) |
|
with open(img_path, 'rb') as f: |
|
if annotation['abc_label']: |
|
raw_image = _draw_text(f, text_annotation) |
|
else: |
|
raw_image = f.read() |
|
return raw_image |
|
|
|
|
|
def _draw_text(image, text_annotations) -> bytes: |
|
"""Replaces text in image by the correct replacement letter from AI2D.""" |
|
image = Image.open(image) |
|
draw = ImageDraw.Draw(image) |
|
for annotation in text_annotations: |
|
current_annotation = text_annotations[annotation] |
|
rectangle = current_annotation['rectangle'] |
|
box = [tuple(rectangle[0]), tuple(rectangle[1]),] |
|
text = current_annotation['replacementText'] |
|
position = box[0] |
|
draw.rectangle(box, fill='white') |
|
font_size = 100 |
|
x_diff = box[1][0] - box[0][0] |
|
y_diff = box[1][1] - box[0][1] |
|
font = Font(font_size) |
|
size = font.getbbox(text) |
|
while (size[2] > x_diff or size[3] > y_diff) and font_size > 0: |
|
font = Font(font_size) |
|
size = font.getbbox(text) |
|
font_size -= 1 |
|
delta = (x_diff - size[2]) // 2 |
|
position = (position[0] + delta, position[1]) |
|
draw.text(position, text, fill='black', font=font) |
|
new_image_bytes = io.BytesIO() |
|
image.save(new_image_bytes, format='PNG') |
|
return new_image_bytes.getvalue() |
|
|