Spaces:
Running
Running
Merge pull request #112 from borisdayma/’cleanup’
Browse files- .github/workflows/style.yml +20 -0
- Makefile +5 -0
- README.md +8 -6
- app/gradio/app_gradio.py +68 -37
- app/gradio/app_gradio_ngrok.py +0 -89
- app/{app.py → streamlit/app.py} +3 -2
- {dalle_mini → app/streamlit}/backend.py +6 -6
- app/{img → streamlit/img}/loading.gif +0 -0
- dalle_mini/data.py +4 -2
- dalle_mini/dataset.py +0 -122
- dalle_mini/helpers.py +0 -14
- dalle_mini/model.py +11 -8
- dalle_mini/text.py +5 -3
- dev/README.md +0 -122
- dev/data/CC12M_downloader.py +0 -91
- dev/data/CC3M_downloader.py +0 -62
- dev/data/README.md +0 -3
- dev/encoding/vqgan-jax-encoding-streaming.ipynb +0 -562
- dev/encoding/vqgan-jax-encoding-with-captions.ipynb +0 -355
- dev/encoding/vqgan-jax-encoding-yfcc100m.ipynb +0 -1129
- dev/encoding/vqgan-jax-encoding.ipynb +0 -0
- dev/environment.yaml +0 -10
- dev/requirements.txt +0 -14
- dev/seq2seq/do_big_run.sh +0 -21
- dev/seq2seq/do_small_run.sh +0 -19
- dev/vqgan/JAX_VQGAN_f16_16384_Reconstruction.ipynb +0 -0
- pyproject.toml +2 -0
- setup.cfg +8 -0
- setup.py +1 -1
- dev/encoding/vqgan-jax-encoding-webdataset.ipynb → tools/dataset/encode_dataset.ipynb +133 -223
- tools/inference/inference_pipeline.ipynb +0 -0
- tools/inference/log_inference_samples.ipynb +121 -46
- {dev/seq2seq → tools/train}/sweep.yaml +0 -0
- dev/seq2seq/run_seq2seq_flax.py → tools/train/train.py +16 -22
.github/workflows/style.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Lint
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
pull_request:
|
7 |
+
branches: [main]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
lint:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v2
|
14 |
+
- uses: psf/black@stable
|
15 |
+
- uses: actions/setup-python@v2
|
16 |
+
with:
|
17 |
+
python-version: 3.9
|
18 |
+
- name: Install requirements
|
19 |
+
run: pip install ".[dev]"
|
20 |
+
- uses: jamescurtin/isort-action@master
|
Makefile
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: style
|
2 |
+
|
3 |
+
style:
|
4 |
+
black .
|
5 |
+
isort .
|
README.md
CHANGED
@@ -4,8 +4,8 @@ emoji: 🥑
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
-
app_file: app/app.py
|
8 |
-
pinned:
|
9 |
---
|
10 |
|
11 |
# DALL·E Mini
|
@@ -28,7 +28,9 @@ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini
|
|
28 |
|
29 |
### Dependencies Installation
|
30 |
|
31 |
-
For
|
|
|
|
|
32 |
|
33 |
### Training of VQGAN
|
34 |
|
@@ -42,15 +44,15 @@ Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
|
|
42 |
|
43 |
### Training of Seq2Seq
|
44 |
|
45 |
-
|
46 |
|
47 |
You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
|
48 |
|
49 |
### Inference Pipeline
|
50 |
|
51 |
-
To generate sample predictions and understand the inference pipeline step by step, refer to [`
|
52 |
|
53 |
-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/
|
54 |
|
55 |
## FAQ
|
56 |
|
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
+
app_file: app/streamlit/app.py
|
8 |
+
pinned: True
|
9 |
---
|
10 |
|
11 |
# DALL·E Mini
|
|
|
28 |
|
29 |
### Dependencies Installation
|
30 |
|
31 |
+
For inference only, use `pip install git+https://github.com/borisdayma/dalle-mini.git`.
|
32 |
+
|
33 |
+
For development, clone the repo and use `pip install -e ".[dev]"`.
|
34 |
|
35 |
### Training of VQGAN
|
36 |
|
|
|
44 |
|
45 |
### Training of Seq2Seq
|
46 |
|
47 |
+
Use [`tools/train/train.py`](tools/train/train.py).
|
48 |
|
49 |
You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
|
50 |
|
51 |
### Inference Pipeline
|
52 |
|
53 |
+
To generate sample predictions and understand the inference pipeline step by step, refer to [`tools/inference/inference_pipeline.ipynb`](tools/inference/inference_pipeline.ipynb).
|
54 |
|
55 |
+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb)
|
56 |
|
57 |
## FAQ
|
58 |
|
app/gradio/app_gradio.py
CHANGED
@@ -2,51 +2,62 @@
|
|
2 |
# coding: utf-8
|
3 |
|
4 |
# Uncomment to run on cpu
|
5 |
-
#import os
|
6 |
-
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
7 |
|
8 |
import random
|
9 |
|
|
|
10 |
import jax
|
11 |
-
import flax.linen as nn
|
12 |
-
from flax.training.common_utils import shard
|
13 |
-
from flax.jax_utils import replicate, unreplicate
|
14 |
-
|
15 |
-
from transformers import BartTokenizer, FlaxBartForConditionalGeneration
|
16 |
-
|
17 |
-
from PIL import Image
|
18 |
import numpy as np
|
19 |
-
|
|
|
|
|
20 |
|
|
|
|
|
21 |
from vqgan_jax.modeling_flax_vqgan import VQModel
|
|
|
22 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
23 |
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
|
39 |
-
model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
|
40 |
-
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
|
41 |
|
42 |
def custom_to_pil(x):
|
43 |
-
x = np.clip(x, 0
|
44 |
-
x = (255*x).astype(np.uint8)
|
45 |
x = Image.fromarray(x)
|
46 |
if not x.mode == "RGB":
|
47 |
x = x.convert("RGB")
|
48 |
return x
|
49 |
|
|
|
50 |
def generate(input, rng, params):
|
51 |
return model.generate(
|
52 |
**input,
|
@@ -59,9 +70,11 @@ def generate(input, rng, params):
|
|
59 |
params=params,
|
60 |
)
|
61 |
|
|
|
62 |
def get_images(indices, params):
|
63 |
return vqgan.decode_code(indices, params=params)
|
64 |
|
|
|
65 |
p_generate = jax.pmap(generate, "batch")
|
66 |
p_get_images = jax.pmap(get_images, "batch")
|
67 |
|
@@ -73,9 +86,16 @@ print("Initialize FlaxCLIPModel")
|
|
73 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
74 |
print("Initialize CLIPProcessor")
|
75 |
|
|
|
76 |
def hallucinate(prompt, num_images=64):
|
77 |
prompt = [prompt] * jax.device_count()
|
78 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
inputs = shard(inputs)
|
80 |
|
81 |
all_images = []
|
@@ -92,6 +112,7 @@ def hallucinate(prompt, num_images=64):
|
|
92 |
all_images.append(custom_to_pil(image))
|
93 |
return all_images
|
94 |
|
|
|
95 |
def clip_top_k(prompt, images, k=8):
|
96 |
inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
|
97 |
outputs = clip(**inputs)
|
@@ -99,24 +120,29 @@ def clip_top_k(prompt, images, k=8):
|
|
99 |
scores = np.array(logits[0]).argsort()[-k:][::-1]
|
100 |
return [images[score] for score in scores]
|
101 |
|
|
|
102 |
def compose_predictions(images, caption=None):
|
103 |
increased_h = 0 if caption is None else 48
|
104 |
w, h = images[0].size[0], images[0].size[1]
|
105 |
-
img = Image.new("RGB", (len(images)*w, h + increased_h))
|
106 |
for i, img_ in enumerate(images):
|
107 |
-
img.paste(img_, (i*w, increased_h))
|
108 |
|
109 |
if caption is not None:
|
110 |
draw = ImageDraw.Draw(img)
|
111 |
-
font = ImageFont.truetype(
|
112 |
-
|
|
|
|
|
113 |
return img
|
114 |
|
|
|
115 |
def top_k_predictions(prompt, num_candidates=32, k=8):
|
116 |
images = hallucinate(prompt, num_images=num_candidates)
|
117 |
images = clip_top_k(prompt, images, k=k)
|
118 |
return images
|
119 |
|
|
|
120 |
def run_inference(prompt, num_images=32, num_preds=8):
|
121 |
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
122 |
predictions = captioned_strip(images)
|
@@ -125,23 +151,28 @@ def run_inference(prompt, num_images=32, num_preds=8):
|
|
125 |
"""
|
126 |
return (output_title, predictions)
|
127 |
|
|
|
128 |
outputs = [
|
129 |
-
gr.outputs.HTML(label=""),
|
130 |
-
gr.outputs.Image(label=
|
131 |
]
|
132 |
|
133 |
description = """
|
134 |
DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
|
135 |
"""
|
136 |
-
gr.Interface(
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
140 |
description=description,
|
141 |
article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
|
142 |
-
layout=
|
143 |
-
theme=
|
144 |
-
examples=[
|
|
|
|
|
|
|
145 |
allow_flagging=False,
|
146 |
live=False,
|
147 |
# server_port=8999
|
|
|
2 |
# coding: utf-8
|
3 |
|
4 |
# Uncomment to run on cpu
|
5 |
+
# import os
|
6 |
+
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
7 |
|
8 |
import random
|
9 |
|
10 |
+
import gradio as gr
|
11 |
import jax
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import numpy as np
|
13 |
+
from flax.jax_utils import replicate
|
14 |
+
from flax.training.common_utils import shard
|
15 |
+
from PIL import Image, ImageDraw, ImageFont
|
16 |
|
17 |
+
# ## CLIP Scoring
|
18 |
+
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
|
19 |
from vqgan_jax.modeling_flax_vqgan import VQModel
|
20 |
+
|
21 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
22 |
|
23 |
+
DALLE_REPO = "flax-community/dalle-mini"
|
24 |
+
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
25 |
|
26 |
+
VQGAN_REPO = "flax-community/vqgan_f16_16384"
|
27 |
+
VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
|
28 |
|
29 |
+
tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
|
30 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(
|
31 |
+
DALLE_REPO, revision=DALLE_COMMIT_ID
|
32 |
+
)
|
33 |
+
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
|
34 |
|
35 |
|
36 |
+
def captioned_strip(images, caption=None, rows=1):
|
37 |
+
increased_h = 0 if caption is None else 48
|
38 |
+
w, h = images[0].size[0], images[0].size[1]
|
39 |
+
img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
|
40 |
+
for i, img_ in enumerate(images):
|
41 |
+
img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
|
42 |
|
43 |
+
if caption is not None:
|
44 |
+
draw = ImageDraw.Draw(img)
|
45 |
+
font = ImageFont.truetype(
|
46 |
+
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
47 |
+
)
|
48 |
+
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
49 |
+
return img
|
50 |
|
|
|
|
|
|
|
51 |
|
52 |
def custom_to_pil(x):
|
53 |
+
x = np.clip(x, 0.0, 1.0)
|
54 |
+
x = (255 * x).astype(np.uint8)
|
55 |
x = Image.fromarray(x)
|
56 |
if not x.mode == "RGB":
|
57 |
x = x.convert("RGB")
|
58 |
return x
|
59 |
|
60 |
+
|
61 |
def generate(input, rng, params):
|
62 |
return model.generate(
|
63 |
**input,
|
|
|
70 |
params=params,
|
71 |
)
|
72 |
|
73 |
+
|
74 |
def get_images(indices, params):
|
75 |
return vqgan.decode_code(indices, params=params)
|
76 |
|
77 |
+
|
78 |
p_generate = jax.pmap(generate, "batch")
|
79 |
p_get_images = jax.pmap(get_images, "batch")
|
80 |
|
|
|
86 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
87 |
print("Initialize CLIPProcessor")
|
88 |
|
89 |
+
|
90 |
def hallucinate(prompt, num_images=64):
|
91 |
prompt = [prompt] * jax.device_count()
|
92 |
+
inputs = tokenizer(
|
93 |
+
prompt,
|
94 |
+
return_tensors="jax",
|
95 |
+
padding="max_length",
|
96 |
+
truncation=True,
|
97 |
+
max_length=128,
|
98 |
+
).data
|
99 |
inputs = shard(inputs)
|
100 |
|
101 |
all_images = []
|
|
|
112 |
all_images.append(custom_to_pil(image))
|
113 |
return all_images
|
114 |
|
115 |
+
|
116 |
def clip_top_k(prompt, images, k=8):
|
117 |
inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
|
118 |
outputs = clip(**inputs)
|
|
|
120 |
scores = np.array(logits[0]).argsort()[-k:][::-1]
|
121 |
return [images[score] for score in scores]
|
122 |
|
123 |
+
|
124 |
def compose_predictions(images, caption=None):
|
125 |
increased_h = 0 if caption is None else 48
|
126 |
w, h = images[0].size[0], images[0].size[1]
|
127 |
+
img = Image.new("RGB", (len(images) * w, h + increased_h))
|
128 |
for i, img_ in enumerate(images):
|
129 |
+
img.paste(img_, (i * w, increased_h))
|
130 |
|
131 |
if caption is not None:
|
132 |
draw = ImageDraw.Draw(img)
|
133 |
+
font = ImageFont.truetype(
|
134 |
+
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
135 |
+
)
|
136 |
+
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
137 |
return img
|
138 |
|
139 |
+
|
140 |
def top_k_predictions(prompt, num_candidates=32, k=8):
|
141 |
images = hallucinate(prompt, num_images=num_candidates)
|
142 |
images = clip_top_k(prompt, images, k=k)
|
143 |
return images
|
144 |
|
145 |
+
|
146 |
def run_inference(prompt, num_images=32, num_preds=8):
|
147 |
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
148 |
predictions = captioned_strip(images)
|
|
|
151 |
"""
|
152 |
return (output_title, predictions)
|
153 |
|
154 |
+
|
155 |
outputs = [
|
156 |
+
gr.outputs.HTML(label=""), # To be used as title
|
157 |
+
gr.outputs.Image(label=""),
|
158 |
]
|
159 |
|
160 |
description = """
|
161 |
DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
|
162 |
"""
|
163 |
+
gr.Interface(
|
164 |
+
run_inference,
|
165 |
+
inputs=[gr.inputs.Textbox(label="What do you want to see?")],
|
166 |
+
outputs=outputs,
|
167 |
+
title="DALL·E mini",
|
168 |
description=description,
|
169 |
article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
|
170 |
+
layout="vertical",
|
171 |
+
theme="huggingface",
|
172 |
+
examples=[
|
173 |
+
["an armchair in the shape of an avocado"],
|
174 |
+
["snowy mountains by the sea"],
|
175 |
+
],
|
176 |
allow_flagging=False,
|
177 |
live=False,
|
178 |
# server_port=8999
|
app/gradio/app_gradio_ngrok.py
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding: utf-8
|
3 |
-
|
4 |
-
import requests
|
5 |
-
from PIL import Image
|
6 |
-
import numpy as np
|
7 |
-
import matplotlib.pyplot as plt
|
8 |
-
from io import BytesIO
|
9 |
-
import base64
|
10 |
-
import os
|
11 |
-
|
12 |
-
import gradio as gr
|
13 |
-
|
14 |
-
from dalle_mini.helpers import captioned_strip
|
15 |
-
|
16 |
-
|
17 |
-
backend_url = os.environ["BACKEND_SERVER"]
|
18 |
-
|
19 |
-
|
20 |
-
class ServiceError(Exception):
|
21 |
-
def __init__(self, status_code):
|
22 |
-
self.status_code = status_code
|
23 |
-
|
24 |
-
def get_images_from_ngrok(prompt):
|
25 |
-
r = requests.post(
|
26 |
-
backend_url,
|
27 |
-
json={"prompt": prompt}
|
28 |
-
)
|
29 |
-
if r.status_code == 200:
|
30 |
-
images = r.json()["images"]
|
31 |
-
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
|
32 |
-
return images
|
33 |
-
else:
|
34 |
-
raise ServiceError(r.status_code)
|
35 |
-
|
36 |
-
def run_inference(prompt):
|
37 |
-
try:
|
38 |
-
images = get_images_from_ngrok(prompt)
|
39 |
-
predictions = captioned_strip(images)
|
40 |
-
output_title = f"""
|
41 |
-
<p style="font-size:22px; font-style:bold">Best predictions</p>
|
42 |
-
<p>We asked our model to generate 128 candidates for your prompt:</p>
|
43 |
-
|
44 |
-
<pre>
|
45 |
-
|
46 |
-
<b>{prompt}</b>
|
47 |
-
</pre>
|
48 |
-
<p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
|
49 |
-
similarity of the text and the image representations.</p>
|
50 |
-
|
51 |
-
<p>This is the result:</p>
|
52 |
-
"""
|
53 |
-
|
54 |
-
output_description = """
|
55 |
-
<p>Read our <a style="color:blue;" href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">full report</a> for more details on how this works.<p>
|
56 |
-
<p style='text-align: center'>Created with <a style="color:blue;" href="https://github.com/borisdayma/dalle-mini">DALL·E mini</a></p>
|
57 |
-
"""
|
58 |
-
|
59 |
-
except ServiceError:
|
60 |
-
output_title = f"""
|
61 |
-
Sorry, there was an error retrieving the images. Please, try again later or <a href="mailto:[email protected]">contact us here</a>.
|
62 |
-
"""
|
63 |
-
predictions = None
|
64 |
-
output_description = ""
|
65 |
-
|
66 |
-
return (output_title, predictions, output_description)
|
67 |
-
|
68 |
-
outputs = [
|
69 |
-
gr.outputs.HTML(label=""), # To be used as title
|
70 |
-
gr.outputs.Image(label=''),
|
71 |
-
gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
|
72 |
-
]
|
73 |
-
|
74 |
-
description = """
|
75 |
-
Welcome to DALL·E-mini, a text-to-image generation model.
|
76 |
-
"""
|
77 |
-
gr.Interface(run_inference,
|
78 |
-
inputs=[gr.inputs.Textbox(label='Prompt')],
|
79 |
-
outputs=outputs,
|
80 |
-
title='DALL·E mini',
|
81 |
-
description=description,
|
82 |
-
article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
|
83 |
-
layout='vertical',
|
84 |
-
theme='huggingface',
|
85 |
-
examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
|
86 |
-
allow_flagging=False,
|
87 |
-
live=False,
|
88 |
-
# server_name="0.0.0.0", # Bind to all interfaces
|
89 |
-
).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/{app.py → streamlit/app.py}
RENAMED
@@ -1,9 +1,10 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
4 |
-
from dalle_mini.backend import ServiceError, get_images_from_backend
|
5 |
import streamlit as st
|
6 |
|
|
|
|
|
7 |
st.sidebar.markdown(
|
8 |
"""
|
9 |
<style>
|
@@ -50,7 +51,7 @@ if prompt != "":
|
|
50 |
<div class="st-b7">
|
51 |
<div class="css-whx05o e13vu3m50">
|
52 |
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
|
53 |
-
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/img/loading.gif" width="30"/>
|
54 |
Generating predictions for: <b>{prompt}</b>
|
55 |
</div>
|
56 |
</div>
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
|
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from .backend import ServiceError, get_images_from_backend
|
7 |
+
|
8 |
st.sidebar.markdown(
|
9 |
"""
|
10 |
<style>
|
|
|
51 |
<div class="st-b7">
|
52 |
<div class="css-whx05o e13vu3m50">
|
53 |
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
|
54 |
+
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
|
55 |
Generating predictions for: <b>{prompt}</b>
|
56 |
</div>
|
57 |
</div>
|
{dalle_mini → app/streamlit}/backend.py
RENAMED
@@ -1,17 +1,17 @@
|
|
1 |
-
import requests
|
2 |
-
from io import BytesIO
|
3 |
import base64
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
|
|
|
6 |
class ServiceError(Exception):
|
7 |
def __init__(self, status_code):
|
8 |
self.status_code = status_code
|
9 |
|
|
|
10 |
def get_images_from_backend(prompt, backend_url):
|
11 |
-
r = requests.post(
|
12 |
-
backend_url,
|
13 |
-
json={"prompt": prompt}
|
14 |
-
)
|
15 |
if r.status_code == 200:
|
16 |
images = r.json()["images"]
|
17 |
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
|
|
|
|
|
|
|
1 |
import base64
|
2 |
+
from io import BytesIO
|
3 |
+
|
4 |
+
import requests
|
5 |
from PIL import Image
|
6 |
|
7 |
+
|
8 |
class ServiceError(Exception):
|
9 |
def __init__(self, status_code):
|
10 |
self.status_code = status_code
|
11 |
|
12 |
+
|
13 |
def get_images_from_backend(prompt, backend_url):
|
14 |
+
r = requests.post(backend_url, json={"prompt": prompt})
|
|
|
|
|
|
|
15 |
if r.status_code == 200:
|
16 |
images = r.json()["images"]
|
17 |
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
|
app/{img → streamlit/img}/loading.gif
RENAMED
File without changes
|
dalle_mini/data.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
from dataclasses import dataclass, field
|
2 |
-
from datasets import load_dataset, Dataset
|
3 |
from functools import partial
|
4 |
-
|
5 |
import jax
|
6 |
import jax.numpy as jnp
|
|
|
|
|
7 |
from flax.training.common_utils import shard
|
|
|
8 |
from .text import TextNormalizer
|
9 |
|
10 |
|
|
|
1 |
from dataclasses import dataclass, field
|
|
|
2 |
from functools import partial
|
3 |
+
|
4 |
import jax
|
5 |
import jax.numpy as jnp
|
6 |
+
import numpy as np
|
7 |
+
from datasets import Dataset, load_dataset
|
8 |
from flax.training.common_utils import shard
|
9 |
+
|
10 |
from .text import TextNormalizer
|
11 |
|
12 |
|
dalle_mini/dataset.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
An image-caption dataset dataloader.
|
3 |
-
Luke Melas-Kyriazi, 2021
|
4 |
-
"""
|
5 |
-
import warnings
|
6 |
-
from typing import Optional, Callable
|
7 |
-
from pathlib import Path
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
import pandas as pd
|
11 |
-
from torch.utils.data import Dataset
|
12 |
-
from torchvision.datasets.folder import default_loader
|
13 |
-
from PIL import ImageFile
|
14 |
-
from PIL.Image import DecompressionBombWarning
|
15 |
-
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
16 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
17 |
-
warnings.filterwarnings("ignore", category=DecompressionBombWarning)
|
18 |
-
|
19 |
-
|
20 |
-
class CaptionDataset(Dataset):
|
21 |
-
"""
|
22 |
-
A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
|
23 |
-
returns the raw text rather than tokens. This is done on purpose, because
|
24 |
-
it's easy to tokenize a batch of text after loading it from this dataset.
|
25 |
-
"""
|
26 |
-
|
27 |
-
def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
|
28 |
-
image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
|
29 |
-
include_captions: bool = True):
|
30 |
-
"""
|
31 |
-
:param images_root: folder where images are stored
|
32 |
-
:param captions_path: path to csv that maps image filenames to captions
|
33 |
-
:param image_transform: image transform pipeline
|
34 |
-
:param text_transform: image transform pipeline
|
35 |
-
:param image_transform_type: image transform type, either `torchvision` or `albumentations`
|
36 |
-
:param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
|
37 |
-
"""
|
38 |
-
|
39 |
-
# Base path for images
|
40 |
-
self.images_root = Path(images_root)
|
41 |
-
|
42 |
-
# Load captions as DataFrame
|
43 |
-
self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
|
44 |
-
self.captions['image_file'] = self.captions['image_file'].astype(str)
|
45 |
-
|
46 |
-
# PyTorch transformation pipeline for the image (normalizing, etc.)
|
47 |
-
self.text_transform = text_transform
|
48 |
-
self.image_transform = image_transform
|
49 |
-
self.image_transform_type = image_transform_type.lower()
|
50 |
-
assert self.image_transform_type in ['torchvision', 'albumentations']
|
51 |
-
|
52 |
-
# Total number of datapoints
|
53 |
-
self.size = len(self.captions)
|
54 |
-
|
55 |
-
# Return image+captions or just images
|
56 |
-
self.include_captions = include_captions
|
57 |
-
|
58 |
-
def verify_that_all_images_exist(self):
|
59 |
-
for image_file in self.captions['image_file']:
|
60 |
-
p = self.images_root / image_file
|
61 |
-
if not p.is_file():
|
62 |
-
print(f'file does not exist: {p}')
|
63 |
-
|
64 |
-
def _get_raw_image(self, i):
|
65 |
-
image_file = self.captions.iloc[i]['image_file']
|
66 |
-
image_path = self.images_root / image_file
|
67 |
-
image = default_loader(image_path)
|
68 |
-
return image
|
69 |
-
|
70 |
-
def _get_raw_text(self, i):
|
71 |
-
return self.captions.iloc[i]['caption']
|
72 |
-
|
73 |
-
def __getitem__(self, i):
|
74 |
-
image = self._get_raw_image(i)
|
75 |
-
caption = self._get_raw_text(i)
|
76 |
-
if self.image_transform is not None:
|
77 |
-
if self.image_transform_type == 'torchvision':
|
78 |
-
image = self.image_transform(image)
|
79 |
-
elif self.image_transform_type == 'albumentations':
|
80 |
-
image = self.image_transform(image=np.array(image))['image']
|
81 |
-
else:
|
82 |
-
raise NotImplementedError(f"{self.image_transform_type=}")
|
83 |
-
return {'image': image, 'text': caption} if self.include_captions else image
|
84 |
-
|
85 |
-
def __len__(self):
|
86 |
-
return self.size
|
87 |
-
|
88 |
-
|
89 |
-
if __name__ == "__main__":
|
90 |
-
import albumentations as A
|
91 |
-
from albumentations.pytorch import ToTensorV2
|
92 |
-
from transformers import AutoTokenizer
|
93 |
-
|
94 |
-
# Paths
|
95 |
-
images_root = './images'
|
96 |
-
captions_path = './images-list-clean.tsv'
|
97 |
-
|
98 |
-
# Create transforms
|
99 |
-
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
|
100 |
-
def tokenize(text):
|
101 |
-
return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
|
102 |
-
image_transform = A.Compose([
|
103 |
-
A.Resize(256, 256), A.CenterCrop(256, 256),
|
104 |
-
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
|
105 |
-
|
106 |
-
# Create dataset
|
107 |
-
dataset = CaptionDataset(
|
108 |
-
images_root=images_root,
|
109 |
-
captions_path=captions_path,
|
110 |
-
image_transform=image_transform,
|
111 |
-
text_transform=tokenize,
|
112 |
-
image_transform_type='albumentations')
|
113 |
-
|
114 |
-
# Create dataloader
|
115 |
-
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
|
116 |
-
batch = next(iter(dataloader))
|
117 |
-
print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
|
118 |
-
|
119 |
-
# # (Optional) Check that all the images exist
|
120 |
-
# dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
|
121 |
-
# dataset.verify_that_all_images_exist()
|
122 |
-
# print('Done')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dalle_mini/helpers.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
from PIL import Image, ImageDraw, ImageFont
|
2 |
-
|
3 |
-
def captioned_strip(images, caption=None, rows=1):
|
4 |
-
increased_h = 0 if caption is None else 48
|
5 |
-
w, h = images[0].size[0], images[0].size[1]
|
6 |
-
img = Image.new("RGB", (len(images)*w//rows, h*rows + increased_h))
|
7 |
-
for i, img_ in enumerate(images):
|
8 |
-
img.paste(img_, (i//rows*w, increased_h + (i % rows) * h))
|
9 |
-
|
10 |
-
if caption is not None:
|
11 |
-
draw = ImageDraw.Draw(img)
|
12 |
-
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
|
13 |
-
draw.text((20, 3), caption, (255,255,255), font=font)
|
14 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dalle_mini/model.py
CHANGED
@@ -1,16 +1,14 @@
|
|
1 |
-
import jax
|
2 |
import flax.linen as nn
|
3 |
-
|
|
|
4 |
from transformers.models.bart.modeling_flax_bart import (
|
5 |
-
FlaxBartModule,
|
6 |
-
FlaxBartForConditionalGenerationModule,
|
7 |
-
FlaxBartForConditionalGeneration,
|
8 |
-
FlaxBartEncoder,
|
9 |
FlaxBartDecoder,
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
|
12 |
-
from transformers import BartConfig
|
13 |
-
|
14 |
|
15 |
class CustomFlaxBartModule(FlaxBartModule):
|
16 |
def setup(self):
|
@@ -46,6 +44,11 @@ class CustomFlaxBartForConditionalGenerationModule(
|
|
46 |
FlaxBartForConditionalGenerationModule
|
47 |
):
|
48 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
|
49 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
50 |
self.lm_head = nn.Dense(
|
51 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
|
|
|
|
1 |
import flax.linen as nn
|
2 |
+
import jax
|
3 |
+
from transformers import BartConfig
|
4 |
from transformers.models.bart.modeling_flax_bart import (
|
|
|
|
|
|
|
|
|
5 |
FlaxBartDecoder,
|
6 |
+
FlaxBartEncoder,
|
7 |
+
FlaxBartForConditionalGeneration,
|
8 |
+
FlaxBartForConditionalGenerationModule,
|
9 |
+
FlaxBartModule,
|
10 |
)
|
11 |
|
|
|
|
|
12 |
|
13 |
class CustomFlaxBartModule(FlaxBartModule):
|
14 |
def setup(self):
|
|
|
44 |
FlaxBartForConditionalGenerationModule
|
45 |
):
|
46 |
def setup(self):
|
47 |
+
# set default config
|
48 |
+
self.config.normalize_text = getattr(self.config, "normalize_text", False)
|
49 |
+
self.config.image_length = getattr(self.config, "image_length", 256)
|
50 |
+
self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
|
51 |
+
|
52 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
53 |
self.lm_head = nn.Dense(
|
54 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
dalle_mini/text.py
CHANGED
@@ -2,13 +2,15 @@
|
|
2 |
Utilities for processing text.
|
3 |
"""
|
4 |
|
|
|
|
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
-
from unidecode import unidecode
|
7 |
|
8 |
-
import re, math, random, html
|
9 |
import ftfy
|
10 |
-
|
11 |
from huggingface_hub import hf_hub_download
|
|
|
12 |
|
13 |
# based on wiki word occurence
|
14 |
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
|
|
2 |
Utilities for processing text.
|
3 |
"""
|
4 |
|
5 |
+
import html
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import re
|
9 |
from pathlib import Path
|
|
|
10 |
|
|
|
11 |
import ftfy
|
|
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
+
from unidecode import unidecode
|
14 |
|
15 |
# based on wiki word occurence
|
16 |
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
dev/README.md
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
# Development Instructions for TPU
|
2 |
-
|
3 |
-
## Setup
|
4 |
-
|
5 |
-
- Apply to the [TRC program](https://sites.research.google/trc/) for free TPU credits if you're elligible.
|
6 |
-
- Follow the [Cloud TPU VM User's Guide](https://cloud.google.com/tpu/docs/users-guide-tpu-vm) to set up gcloud.
|
7 |
-
- Verify `gcloud config list`, in particular account, project & zone.
|
8 |
-
- Create a TPU VM per the guide and connect to it.
|
9 |
-
|
10 |
-
When needing a larger disk:
|
11 |
-
|
12 |
-
- Create a balanced persistent disk (SSD, so pricier than default HDD but much faster): `gcloud compute disks create DISK_NAME --size SIZE_IN_GB --type pd-balanced`
|
13 |
-
- Attach the disk to your instance by adding `--data-disk source=REF` per ["Adding a persistent disk to a TPU VM" guide](https://cloud.google.com/tpu/docs/setup-persistent-disk), eg `gcloud alpha compute tpus tpu-vm create INSTANCE_NAME --accelerator-type=v3-8 --version=v2-alpha --data-disk source=projects/tpu-toys/zones/europe-west4-a/disks/DISK_NAME`
|
14 |
-
- Format the partition as described in the guide.
|
15 |
-
- Make sure to set up automatic remount of disk at restart.
|
16 |
-
|
17 |
-
## Connect VS Code
|
18 |
-
|
19 |
-
- Find external IP in the UI or with `gcloud alpha compute tpus tpu-vm describe INSTANCE_NAME`
|
20 |
-
- Verify you can connect in terminal with `ssh EXTERNAL_IP -i ~/.ssh/google_compute_engine`
|
21 |
-
- Add the same command as ssh host in VS Code.
|
22 |
-
- Check config file
|
23 |
-
|
24 |
-
```
|
25 |
-
Host INSTANCE_NAME
|
26 |
-
HostName EXTERNAL_IP
|
27 |
-
IdentityFile ~/.ssh/google_compute_engine
|
28 |
-
```
|
29 |
-
|
30 |
-
## Environment configuration
|
31 |
-
|
32 |
-
### Use virtual environments (optional)
|
33 |
-
|
34 |
-
We recommend using virtual environments (such as conda, venv or pyenv-virtualenv).
|
35 |
-
|
36 |
-
If you want to use `pyenv` and `pyenv-virtualenv`:
|
37 |
-
|
38 |
-
- Installation
|
39 |
-
|
40 |
-
- [Set up build environment](https://github.com/pyenv/pyenv/wiki#suggested-build-environment)
|
41 |
-
- Use [pyenv-installer](https://github.com/pyenv/pyenv-installer): `curl https://pyenv.run | bash`
|
42 |
-
- bash set-up:
|
43 |
-
|
44 |
-
```bash
|
45 |
-
echo '\n'\
|
46 |
-
'# pyenv setup \n'\
|
47 |
-
'export PYENV_ROOT="$HOME/.pyenv" \n'\
|
48 |
-
'export PATH="$PYENV_ROOT/bin:$PATH" \n'\
|
49 |
-
'eval "$(pyenv init --path)" \n'\
|
50 |
-
'eval "$(pyenv init -)" \n'\
|
51 |
-
'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc
|
52 |
-
```
|
53 |
-
|
54 |
-
- Usage
|
55 |
-
|
56 |
-
- Install a python version: `pyenv install X.X.X`
|
57 |
-
- Create a virtual environment: `pyenv virtualenv 3.9.6 dalle_env`
|
58 |
-
- Activate: `pyenv activate dalle_env`
|
59 |
-
|
60 |
-
Note: you can auto-activate your environment at a location with `echo dalle_env >> .python-version`
|
61 |
-
|
62 |
-
### Tools
|
63 |
-
|
64 |
-
- Git
|
65 |
-
|
66 |
-
- `git config --global user.email "[email protected]"
|
67 |
-
- `git config --global user.name "First Last"
|
68 |
-
|
69 |
-
- Github CLI
|
70 |
-
|
71 |
-
- See [installation instructions](https://github.com/cli/cli/blob/trunk/docs/install_linux.md)
|
72 |
-
- `gh auth login`
|
73 |
-
|
74 |
-
- Direnv
|
75 |
-
|
76 |
-
- Install direnv: `sudo apt-get update && sudo apt-get install direnv`
|
77 |
-
- bash set-up:
|
78 |
-
|
79 |
-
```bash
|
80 |
-
echo -e '\n'\
|
81 |
-
'# direnv setup \n'\
|
82 |
-
'eval "$(direnv hook bash)" \n' >> ~/.bashrc
|
83 |
-
```
|
84 |
-
|
85 |
-
### Set up repo
|
86 |
-
|
87 |
-
- Clone repo: `gh repo clone borisdayma/dalle-mini`
|
88 |
-
- If using `pyenv-virtualenv`, auto-activate env: `echo dalle_env >> .python-version`
|
89 |
-
|
90 |
-
## Environment
|
91 |
-
|
92 |
-
- Install the following (use it later to update our dev requirements.txt)
|
93 |
-
|
94 |
-
```
|
95 |
-
requests
|
96 |
-
pillow
|
97 |
-
jupyterlab
|
98 |
-
ipywidgets
|
99 |
-
|
100 |
-
-e ../datasets[streaming]
|
101 |
-
-e ../transformers
|
102 |
-
-e ../webdataset
|
103 |
-
|
104 |
-
# JAX
|
105 |
-
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
106 |
-
jax[tpu]>=0.2.16
|
107 |
-
flax
|
108 |
-
```
|
109 |
-
|
110 |
-
- `transformers-cli login`
|
111 |
-
|
112 |
-
---
|
113 |
-
|
114 |
-
- set `HF_HOME="/mnt/disks/persist/cache/huggingface"` in `/etc/environment` and ensure you have required permissions, then restart.
|
115 |
-
|
116 |
-
## Working with datasets or models
|
117 |
-
|
118 |
-
- Install [Git LFS](https://github.com/git-lfs/git-lfs/wiki/Installation)
|
119 |
-
- Clone a dataset without large files: `GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/.../...`
|
120 |
-
- Use a local [credential store](https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage) for caching credentials
|
121 |
-
- Track specific extentions: `git lfs track "*.ext"`
|
122 |
-
- See files tracked with LFS with `git lfs ls-files`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/data/CC12M_downloader.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
# Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
|
2 |
-
|
3 |
-
#%%
|
4 |
-
import sys
|
5 |
-
import os
|
6 |
-
from datetime import datetime
|
7 |
-
import pandas as pd
|
8 |
-
import contexttimer
|
9 |
-
from urllib.request import urlopen
|
10 |
-
import requests
|
11 |
-
from PIL import Image
|
12 |
-
import torch
|
13 |
-
from torchvision.transforms import functional as TF
|
14 |
-
from multiprocessing import Pool
|
15 |
-
from tqdm import tqdm
|
16 |
-
import logging
|
17 |
-
|
18 |
-
# Setup
|
19 |
-
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
|
20 |
-
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
|
21 |
-
|
22 |
-
|
23 |
-
# # For downloading SVG images (I can't get this to work)
|
24 |
-
# from io import BytesIO
|
25 |
-
# import cairosvg
|
26 |
-
|
27 |
-
#%%
|
28 |
-
# Load data
|
29 |
-
print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
|
30 |
-
with contexttimer.Timer(prefix="Loading from tsv"):
|
31 |
-
df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)
|
32 |
-
|
33 |
-
url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
|
34 |
-
print(f'Loaded {len(url_to_idx_map)} urls')
|
35 |
-
|
36 |
-
#%%
|
37 |
-
df.head()
|
38 |
-
|
39 |
-
#%%
|
40 |
-
|
41 |
-
# Note: it seems that there are no SVG images
|
42 |
-
df.sample(10000)[1].str.contains('.svg').sum()
|
43 |
-
|
44 |
-
#%%
|
45 |
-
# Resize function
|
46 |
-
def resize(img):
|
47 |
-
max_size_of_short_side = 512
|
48 |
-
if min(img.size) > max_size_of_short_side:
|
49 |
-
img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS)
|
50 |
-
return img
|
51 |
-
|
52 |
-
base_dir = os.path.join(os.getcwd(), 'images')
|
53 |
-
|
54 |
-
def process(item):
|
55 |
-
url, image_id = item
|
56 |
-
try:
|
57 |
-
base_url = os.path.basename(url) # extract base url
|
58 |
-
stem, ext = os.path.splitext(base_url) # split into stem and extension
|
59 |
-
filename = f'{image_id:08d}---{stem}.jpg' # create filename
|
60 |
-
filepath = os.path.join(base_dir, filename) # concat to get filepath
|
61 |
-
if not os.path.isfile(filepath):
|
62 |
-
# if filepath.endswith('.svg'):
|
63 |
-
# raise NotImplementedError()
|
64 |
-
# image_bytes = BytesIO() # create a bytestream
|
65 |
-
# cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image
|
66 |
-
# else:
|
67 |
-
req = requests.get(url, stream=True, timeout=1, verify=False).raw
|
68 |
-
image = Image.open(req).convert('RGB')
|
69 |
-
if min(image.size) > 512:
|
70 |
-
image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
|
71 |
-
# image = resize(image) # resize PIL image
|
72 |
-
image.save(filepath) # save PIL image
|
73 |
-
except Exception as e:
|
74 |
-
logging.info(" ".join(repr(e).splitlines()))
|
75 |
-
logging.error(url)
|
76 |
-
|
77 |
-
#%%
|
78 |
-
#for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
|
79 |
-
# process(item)
|
80 |
-
# if i > 100:
|
81 |
-
# break
|
82 |
-
|
83 |
-
# Use multiprocessing for speed
|
84 |
-
list_of_items = list(url_to_idx_map.items())
|
85 |
-
print(len(list_of_items))
|
86 |
-
list_of_items = list_of_items[10_000_000:]
|
87 |
-
print(len(list_of_items))
|
88 |
-
with Pool(128) as p:
|
89 |
-
r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
|
90 |
-
print('DONE')
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/data/CC3M_downloader.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
This script was adapted from Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
|
3 |
-
Few changes were made for the particular dataset. You're required to have the `.tsv` file downloaded in your directory.
|
4 |
-
Find them here- [https://github.com/google-research-datasets/conceptual-captions]
|
5 |
-
'''
|
6 |
-
|
7 |
-
import sys
|
8 |
-
import os
|
9 |
-
from datetime import datetime
|
10 |
-
import pandas as pd
|
11 |
-
import contexttimer
|
12 |
-
from urllib.request import urlopen
|
13 |
-
import requests
|
14 |
-
from PIL import Image
|
15 |
-
import torch
|
16 |
-
from torchvision.transforms import functional as TF
|
17 |
-
from multiprocessing import Pool
|
18 |
-
from tqdm import tqdm
|
19 |
-
import logging
|
20 |
-
import sys
|
21 |
-
|
22 |
-
# Setup
|
23 |
-
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
|
24 |
-
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
|
25 |
-
|
26 |
-
if len(sys.argv) != 3:
|
27 |
-
print("Provide .tsv file name & output directory. e.g. python downloader.py Train-GCC-training.tsv training")
|
28 |
-
exit(1)
|
29 |
-
|
30 |
-
# Load data
|
31 |
-
print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
|
32 |
-
with contexttimer.Timer(prefix="Loading from tsv"):
|
33 |
-
df = pd.read_csv(sys.argv[1], delimiter='\t', header=None)
|
34 |
-
|
35 |
-
url_to_idx_map = {url: index for index, caption, url in df.itertuples()}
|
36 |
-
print(f'Loaded {len(url_to_idx_map)} urls')
|
37 |
-
|
38 |
-
base_dir = os.path.join(os.getcwd(), sys.argv[2])
|
39 |
-
|
40 |
-
def process(item):
|
41 |
-
url, image_id = item
|
42 |
-
try:
|
43 |
-
base_url = os.path.basename(url) # extract base url
|
44 |
-
stem, ext = os.path.splitext(base_url) # split into stem and extension
|
45 |
-
filename = f'{image_id:08d}---{stem}.jpg' # create filename
|
46 |
-
filepath = os.path.join(base_dir, filename) # concat to get filepath
|
47 |
-
if not os.path.isfile(filepath):
|
48 |
-
req = requests.get(url, stream=True, timeout=1, verify=False).raw
|
49 |
-
image = Image.open(req).convert('RGB')
|
50 |
-
if min(image.size) > 512:
|
51 |
-
image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
|
52 |
-
image.save(filepath) # save PIL image
|
53 |
-
except Exception as e:
|
54 |
-
logging.info(" ".join(repr(e).splitlines()))
|
55 |
-
logging.error(url)
|
56 |
-
|
57 |
-
list_of_items = list(url_to_idx_map.items())
|
58 |
-
print(len(list_of_items))
|
59 |
-
|
60 |
-
with Pool(128) as p:
|
61 |
-
r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
|
62 |
-
print('DONE')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/data/README.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# Data
|
2 |
-
|
3 |
-
Utility scripts for downloading CC3M and CC12M.
|
|
|
|
|
|
|
|
dev/encoding/vqgan-jax-encoding-streaming.ipynb
DELETED
@@ -1,562 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "d0b72877",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# VQGAN JAX Encoding for 🤗 Datasets in streaming mode"
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "markdown",
|
13 |
-
"id": "ba7b31e6",
|
14 |
-
"metadata": {},
|
15 |
-
"source": [
|
16 |
-
"This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and 🤗 Datasets in streaming mode.\n",
|
17 |
-
"\n",
|
18 |
-
"This example uses our YFCC100M dataset, but it should be easy to adapt to any other image/caption dataset in the huggingface hub."
|
19 |
-
]
|
20 |
-
},
|
21 |
-
{
|
22 |
-
"cell_type": "code",
|
23 |
-
"execution_count": null,
|
24 |
-
"id": "3b59489e",
|
25 |
-
"metadata": {},
|
26 |
-
"outputs": [],
|
27 |
-
"source": [
|
28 |
-
"import io\n",
|
29 |
-
"\n",
|
30 |
-
"import requests\n",
|
31 |
-
"from PIL import Image\n",
|
32 |
-
"import numpy as np\n",
|
33 |
-
"from tqdm import tqdm\n",
|
34 |
-
"\n",
|
35 |
-
"import torch\n",
|
36 |
-
"import torchvision.transforms as T\n",
|
37 |
-
"import torchvision.transforms.functional as TF\n",
|
38 |
-
"from torchvision.transforms import InterpolationMode\n",
|
39 |
-
"import os\n",
|
40 |
-
"\n",
|
41 |
-
"import jax\n",
|
42 |
-
"from jax import pmap"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "markdown",
|
47 |
-
"id": "c7c4c1e6",
|
48 |
-
"metadata": {},
|
49 |
-
"source": [
|
50 |
-
"## Dataset and Parameters"
|
51 |
-
]
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"cell_type": "code",
|
55 |
-
"execution_count": null,
|
56 |
-
"id": "d45a289e",
|
57 |
-
"metadata": {},
|
58 |
-
"outputs": [],
|
59 |
-
"source": [
|
60 |
-
"import datasets\n",
|
61 |
-
"from datasets import Dataset, load_dataset"
|
62 |
-
]
|
63 |
-
},
|
64 |
-
{
|
65 |
-
"cell_type": "markdown",
|
66 |
-
"id": "f26e4f18",
|
67 |
-
"metadata": {},
|
68 |
-
"source": [
|
69 |
-
"We'll use the `validation` set for testing. Adjust accordingly."
|
70 |
-
]
|
71 |
-
},
|
72 |
-
{
|
73 |
-
"cell_type": "code",
|
74 |
-
"execution_count": null,
|
75 |
-
"id": "28893c3e",
|
76 |
-
"metadata": {},
|
77 |
-
"outputs": [],
|
78 |
-
"source": [
|
79 |
-
"dataset = load_dataset('dalle-mini/YFCC100M_OpenAI_subset', use_auth_token=True, streaming=True, split='validation')"
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"cell_type": "code",
|
84 |
-
"execution_count": null,
|
85 |
-
"id": "33861477",
|
86 |
-
"metadata": {},
|
87 |
-
"outputs": [],
|
88 |
-
"source": [
|
89 |
-
"from pathlib import Path\n",
|
90 |
-
"\n",
|
91 |
-
"yfcc100m = Path.home()/'data'/'YFCC100M_OpenAI_subset'\n",
|
92 |
-
"yfcc100m_output = yfcc100m/'encoded' # Output directory for encoded files"
|
93 |
-
]
|
94 |
-
},
|
95 |
-
{
|
96 |
-
"cell_type": "code",
|
97 |
-
"execution_count": null,
|
98 |
-
"id": "6e7b71c4",
|
99 |
-
"metadata": {},
|
100 |
-
"outputs": [],
|
101 |
-
"source": [
|
102 |
-
"batch_size = 128 # Per device\n",
|
103 |
-
"num_workers = 16 # Unused in streaming mode"
|
104 |
-
]
|
105 |
-
},
|
106 |
-
{
|
107 |
-
"cell_type": "markdown",
|
108 |
-
"id": "0793c26a",
|
109 |
-
"metadata": {},
|
110 |
-
"source": [
|
111 |
-
"### Data preparation"
|
112 |
-
]
|
113 |
-
},
|
114 |
-
{
|
115 |
-
"cell_type": "markdown",
|
116 |
-
"id": "86415769",
|
117 |
-
"metadata": {},
|
118 |
-
"source": [
|
119 |
-
"* Images: we transform them so they are center-cropped and square, all of the same size so we can build batches for TPU/GPU processing.\n",
|
120 |
-
"* Captions: we extract a single `caption` column from the source data, by concatenating the cleaned title and description.\n",
|
121 |
-
"\n",
|
122 |
-
"These transformations are done using the Datasets `map` function. In the case of streaming datasets, transformations will run as needed instead of pre-processing the dataset at once."
|
123 |
-
]
|
124 |
-
},
|
125 |
-
{
|
126 |
-
"cell_type": "markdown",
|
127 |
-
"id": "0fdf1851",
|
128 |
-
"metadata": {},
|
129 |
-
"source": [
|
130 |
-
"This helper function is used to decode images from the bytes retrieved in `streaming` mode."
|
131 |
-
]
|
132 |
-
},
|
133 |
-
{
|
134 |
-
"cell_type": "code",
|
135 |
-
"execution_count": null,
|
136 |
-
"id": "5bbca804",
|
137 |
-
"metadata": {},
|
138 |
-
"outputs": [],
|
139 |
-
"source": [
|
140 |
-
"from PIL import Image\n",
|
141 |
-
"import io\n",
|
142 |
-
"\n",
|
143 |
-
"def get_image(byte_stream):\n",
|
144 |
-
" image = Image.open(io.BytesIO(byte_stream))\n",
|
145 |
-
" return image.convert('RGB')"
|
146 |
-
]
|
147 |
-
},
|
148 |
-
{
|
149 |
-
"cell_type": "markdown",
|
150 |
-
"id": "b435290b",
|
151 |
-
"metadata": {},
|
152 |
-
"source": [
|
153 |
-
"Image processing"
|
154 |
-
]
|
155 |
-
},
|
156 |
-
{
|
157 |
-
"cell_type": "code",
|
158 |
-
"execution_count": null,
|
159 |
-
"id": "7e73dfa3",
|
160 |
-
"metadata": {},
|
161 |
-
"outputs": [],
|
162 |
-
"source": [
|
163 |
-
"def center_crop(image, max_size=256):\n",
|
164 |
-
" # Note: we allow upscaling too. We should exclude small images. \n",
|
165 |
-
" image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
|
166 |
-
" image = TF.center_crop(image, output_size=2 * [max_size])\n",
|
167 |
-
" return image\n",
|
168 |
-
"\n",
|
169 |
-
"preprocess_image = T.Compose([\n",
|
170 |
-
" get_image,\n",
|
171 |
-
" center_crop,\n",
|
172 |
-
" T.ToTensor(),\n",
|
173 |
-
" lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
|
174 |
-
"])"
|
175 |
-
]
|
176 |
-
},
|
177 |
-
{
|
178 |
-
"cell_type": "markdown",
|
179 |
-
"id": "1e3ac8de",
|
180 |
-
"metadata": {},
|
181 |
-
"source": [
|
182 |
-
"Caption preparation"
|
183 |
-
]
|
184 |
-
},
|
185 |
-
{
|
186 |
-
"cell_type": "code",
|
187 |
-
"execution_count": null,
|
188 |
-
"id": "aadb4d23",
|
189 |
-
"metadata": {},
|
190 |
-
"outputs": [],
|
191 |
-
"source": [
|
192 |
-
"import string\n",
|
193 |
-
"\n",
|
194 |
-
"def create_caption(title, description):\n",
|
195 |
-
" title = title.strip()\n",
|
196 |
-
" description = description.strip()\n",
|
197 |
-
" if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
|
198 |
-
" return f'{title} {description}'"
|
199 |
-
]
|
200 |
-
},
|
201 |
-
{
|
202 |
-
"cell_type": "markdown",
|
203 |
-
"id": "3c4522b9",
|
204 |
-
"metadata": {},
|
205 |
-
"source": [
|
206 |
-
"And this is the basic transformation function to use in `map`. We don't really need the `key`, but we'll keep it for reference. Since we are returning a new dictionary (as opposed to adding entries to the input), this also removes any metadata columns we don't need."
|
207 |
-
]
|
208 |
-
},
|
209 |
-
{
|
210 |
-
"cell_type": "code",
|
211 |
-
"execution_count": null,
|
212 |
-
"id": "2566ff68",
|
213 |
-
"metadata": {},
|
214 |
-
"outputs": [],
|
215 |
-
"source": [
|
216 |
-
"def prepare_item(item):\n",
|
217 |
-
" return {\n",
|
218 |
-
" 'key': item['key'],\n",
|
219 |
-
" 'caption': create_caption(item['title_clean'], item['description_clean']),\n",
|
220 |
-
" 'image': preprocess_image(item['img'])\n",
|
221 |
-
" }"
|
222 |
-
]
|
223 |
-
},
|
224 |
-
{
|
225 |
-
"cell_type": "markdown",
|
226 |
-
"id": "e519e475",
|
227 |
-
"metadata": {},
|
228 |
-
"source": [
|
229 |
-
"Unlike when using non-streaming datasets, the following operation completes immediately in streaming mode. In streaming mode, `num_proc` is not supported."
|
230 |
-
]
|
231 |
-
},
|
232 |
-
{
|
233 |
-
"cell_type": "code",
|
234 |
-
"execution_count": null,
|
235 |
-
"id": "10d7750e",
|
236 |
-
"metadata": {},
|
237 |
-
"outputs": [],
|
238 |
-
"source": [
|
239 |
-
"prepared_dataset = dataset.map(prepare_item, batched=False)"
|
240 |
-
]
|
241 |
-
},
|
242 |
-
{
|
243 |
-
"cell_type": "code",
|
244 |
-
"execution_count": null,
|
245 |
-
"id": "a8595539",
|
246 |
-
"metadata": {},
|
247 |
-
"outputs": [],
|
248 |
-
"source": [
|
249 |
-
"%%time\n",
|
250 |
-
"item = next(iter(prepared_dataset))"
|
251 |
-
]
|
252 |
-
},
|
253 |
-
{
|
254 |
-
"cell_type": "code",
|
255 |
-
"execution_count": null,
|
256 |
-
"id": "04a6eeb4",
|
257 |
-
"metadata": {},
|
258 |
-
"outputs": [],
|
259 |
-
"source": [
|
260 |
-
"assert(list(item.keys()) == ['key', 'caption', 'image'])"
|
261 |
-
]
|
262 |
-
},
|
263 |
-
{
|
264 |
-
"cell_type": "code",
|
265 |
-
"execution_count": null,
|
266 |
-
"id": "40d3115f",
|
267 |
-
"metadata": {},
|
268 |
-
"outputs": [],
|
269 |
-
"source": [
|
270 |
-
"item['image'].shape"
|
271 |
-
]
|
272 |
-
},
|
273 |
-
{
|
274 |
-
"cell_type": "code",
|
275 |
-
"execution_count": null,
|
276 |
-
"id": "dd844e1c",
|
277 |
-
"metadata": {},
|
278 |
-
"outputs": [],
|
279 |
-
"source": [
|
280 |
-
"T.ToPILImage()(item['image'].permute(2, 0, 1))"
|
281 |
-
]
|
282 |
-
},
|
283 |
-
{
|
284 |
-
"cell_type": "markdown",
|
285 |
-
"id": "44d50a51",
|
286 |
-
"metadata": {},
|
287 |
-
"source": [
|
288 |
-
"### Torch DataLoader"
|
289 |
-
]
|
290 |
-
},
|
291 |
-
{
|
292 |
-
"cell_type": "markdown",
|
293 |
-
"id": "17a4bbc6",
|
294 |
-
"metadata": {},
|
295 |
-
"source": [
|
296 |
-
"We'll create a PyTorch DataLoader for convenience. This allows us to easily take batches of our desired size.\n",
|
297 |
-
"\n",
|
298 |
-
"We won't be using parallel processing of the DataLoader for now, as the items will be retrieved on the fly. We could attempt to do it using these recommendations: https://pytorch.org/docs/stable/data.html#multi-process-data-loading. For performance considerations, please refer to this thread: https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13"
|
299 |
-
]
|
300 |
-
},
|
301 |
-
{
|
302 |
-
"cell_type": "code",
|
303 |
-
"execution_count": null,
|
304 |
-
"id": "e1c08b7e",
|
305 |
-
"metadata": {},
|
306 |
-
"outputs": [],
|
307 |
-
"source": [
|
308 |
-
"import torch\n",
|
309 |
-
"from torch.utils.data import DataLoader"
|
310 |
-
]
|
311 |
-
},
|
312 |
-
{
|
313 |
-
"cell_type": "code",
|
314 |
-
"execution_count": null,
|
315 |
-
"id": "6a296677",
|
316 |
-
"metadata": {},
|
317 |
-
"outputs": [],
|
318 |
-
"source": [
|
319 |
-
"torch_dataset = prepared_dataset.with_format(\"torch\")"
|
320 |
-
]
|
321 |
-
},
|
322 |
-
{
|
323 |
-
"cell_type": "markdown",
|
324 |
-
"id": "29ab13bc",
|
325 |
-
"metadata": {},
|
326 |
-
"source": [
|
327 |
-
"**Note**: according to my tests, `num_workers` is not compatible with Datasets in streaming mode. Processes deadlock and there's no progress."
|
328 |
-
]
|
329 |
-
},
|
330 |
-
{
|
331 |
-
"cell_type": "code",
|
332 |
-
"execution_count": null,
|
333 |
-
"id": "e2df5e13",
|
334 |
-
"metadata": {},
|
335 |
-
"outputs": [],
|
336 |
-
"source": [
|
337 |
-
"dataloader = DataLoader(torch_dataset, batch_size=batch_size * jax.device_count())"
|
338 |
-
]
|
339 |
-
},
|
340 |
-
{
|
341 |
-
"cell_type": "code",
|
342 |
-
"execution_count": null,
|
343 |
-
"id": "c15e3783",
|
344 |
-
"metadata": {},
|
345 |
-
"outputs": [],
|
346 |
-
"source": [
|
347 |
-
"batch = next(iter(dataloader))"
|
348 |
-
]
|
349 |
-
},
|
350 |
-
{
|
351 |
-
"cell_type": "code",
|
352 |
-
"execution_count": null,
|
353 |
-
"id": "71d027fe",
|
354 |
-
"metadata": {},
|
355 |
-
"outputs": [],
|
356 |
-
"source": [
|
357 |
-
"batch['image'].shape"
|
358 |
-
]
|
359 |
-
},
|
360 |
-
{
|
361 |
-
"cell_type": "markdown",
|
362 |
-
"id": "a354472b",
|
363 |
-
"metadata": {},
|
364 |
-
"source": [
|
365 |
-
"## VQGAN-JAX model"
|
366 |
-
]
|
367 |
-
},
|
368 |
-
{
|
369 |
-
"cell_type": "code",
|
370 |
-
"execution_count": null,
|
371 |
-
"id": "2fcf01d7",
|
372 |
-
"metadata": {},
|
373 |
-
"outputs": [],
|
374 |
-
"source": [
|
375 |
-
"from vqgan_jax.modeling_flax_vqgan import VQModel"
|
376 |
-
]
|
377 |
-
},
|
378 |
-
{
|
379 |
-
"cell_type": "markdown",
|
380 |
-
"id": "9daa636d",
|
381 |
-
"metadata": {},
|
382 |
-
"source": [
|
383 |
-
"We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
|
384 |
-
]
|
385 |
-
},
|
386 |
-
{
|
387 |
-
"cell_type": "code",
|
388 |
-
"execution_count": null,
|
389 |
-
"id": "47a8b818",
|
390 |
-
"metadata": {
|
391 |
-
"scrolled": true
|
392 |
-
},
|
393 |
-
"outputs": [],
|
394 |
-
"source": [
|
395 |
-
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
396 |
-
]
|
397 |
-
},
|
398 |
-
{
|
399 |
-
"cell_type": "markdown",
|
400 |
-
"id": "62ad01c3",
|
401 |
-
"metadata": {},
|
402 |
-
"source": [
|
403 |
-
"## Encoding"
|
404 |
-
]
|
405 |
-
},
|
406 |
-
{
|
407 |
-
"cell_type": "markdown",
|
408 |
-
"id": "20357f74",
|
409 |
-
"metadata": {},
|
410 |
-
"source": [
|
411 |
-
"Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
|
412 |
-
]
|
413 |
-
},
|
414 |
-
{
|
415 |
-
"cell_type": "code",
|
416 |
-
"execution_count": null,
|
417 |
-
"id": "6686b004",
|
418 |
-
"metadata": {},
|
419 |
-
"outputs": [],
|
420 |
-
"source": [
|
421 |
-
"from flax.training.common_utils import shard\n",
|
422 |
-
"from functools import partial"
|
423 |
-
]
|
424 |
-
},
|
425 |
-
{
|
426 |
-
"cell_type": "code",
|
427 |
-
"execution_count": null,
|
428 |
-
"id": "322a4619",
|
429 |
-
"metadata": {},
|
430 |
-
"outputs": [],
|
431 |
-
"source": [
|
432 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
433 |
-
"def encode(batch):\n",
|
434 |
-
" # Not sure if we should `replicate` params, does not seem to have any effect\n",
|
435 |
-
" _, indices = model.encode(batch)\n",
|
436 |
-
" return indices"
|
437 |
-
]
|
438 |
-
},
|
439 |
-
{
|
440 |
-
"cell_type": "markdown",
|
441 |
-
"id": "14375a41",
|
442 |
-
"metadata": {},
|
443 |
-
"source": [
|
444 |
-
"### Encoding loop"
|
445 |
-
]
|
446 |
-
},
|
447 |
-
{
|
448 |
-
"cell_type": "code",
|
449 |
-
"execution_count": null,
|
450 |
-
"id": "ff6c10d4",
|
451 |
-
"metadata": {},
|
452 |
-
"outputs": [],
|
453 |
-
"source": [
|
454 |
-
"import os\n",
|
455 |
-
"import pandas as pd\n",
|
456 |
-
"\n",
|
457 |
-
"def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
|
458 |
-
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
459 |
-
" \n",
|
460 |
-
" # Saving strategy:\n",
|
461 |
-
" # - Create a new file every so often to prevent excessive file seeking.\n",
|
462 |
-
" # - Save each batch after processing.\n",
|
463 |
-
" # - Keep the file open until we are done with it.\n",
|
464 |
-
" file = None \n",
|
465 |
-
" for n, batch in enumerate(tqdm(iter(dataloader))):\n",
|
466 |
-
" if (n % save_every == 0):\n",
|
467 |
-
" if file is not None:\n",
|
468 |
-
" file.close()\n",
|
469 |
-
" split_num = n // save_every\n",
|
470 |
-
" file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
|
471 |
-
"\n",
|
472 |
-
" images = batch[\"image\"].numpy()\n",
|
473 |
-
" images = shard(images.squeeze())\n",
|
474 |
-
" encoded = encode(images)\n",
|
475 |
-
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
476 |
-
"\n",
|
477 |
-
" keys = batch[\"key\"]\n",
|
478 |
-
" captions = batch[\"caption\"]\n",
|
479 |
-
"\n",
|
480 |
-
" encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
|
481 |
-
" batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
|
482 |
-
" batch_df.to_json(file, orient='records', lines=True)"
|
483 |
-
]
|
484 |
-
},
|
485 |
-
{
|
486 |
-
"cell_type": "markdown",
|
487 |
-
"id": "09ff75a3",
|
488 |
-
"metadata": {},
|
489 |
-
"source": [
|
490 |
-
"Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
|
491 |
-
]
|
492 |
-
},
|
493 |
-
{
|
494 |
-
"cell_type": "code",
|
495 |
-
"execution_count": null,
|
496 |
-
"id": "96222bb4",
|
497 |
-
"metadata": {},
|
498 |
-
"outputs": [],
|
499 |
-
"source": [
|
500 |
-
"save_every = 318"
|
501 |
-
]
|
502 |
-
},
|
503 |
-
{
|
504 |
-
"cell_type": "code",
|
505 |
-
"execution_count": null,
|
506 |
-
"id": "7704863d",
|
507 |
-
"metadata": {},
|
508 |
-
"outputs": [
|
509 |
-
{
|
510 |
-
"name": "stderr",
|
511 |
-
"output_type": "stream",
|
512 |
-
"text": [
|
513 |
-
"28it [01:17, 1.60s/it]"
|
514 |
-
]
|
515 |
-
}
|
516 |
-
],
|
517 |
-
"source": [
|
518 |
-
"encode_captioned_dataset(dataloader, yfcc100m_output, save_every=save_every)"
|
519 |
-
]
|
520 |
-
},
|
521 |
-
{
|
522 |
-
"cell_type": "markdown",
|
523 |
-
"id": "e266a70a",
|
524 |
-
"metadata": {},
|
525 |
-
"source": [
|
526 |
-
"This is ~10-15 slower than local encoding from an SSD. For performance considerations, see the discussion at https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13."
|
527 |
-
]
|
528 |
-
},
|
529 |
-
{
|
530 |
-
"cell_type": "markdown",
|
531 |
-
"id": "8953dd84",
|
532 |
-
"metadata": {},
|
533 |
-
"source": [
|
534 |
-
"----"
|
535 |
-
]
|
536 |
-
}
|
537 |
-
],
|
538 |
-
"metadata": {
|
539 |
-
"interpreter": {
|
540 |
-
"hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
|
541 |
-
},
|
542 |
-
"kernelspec": {
|
543 |
-
"display_name": "Python 3 (ipykernel)",
|
544 |
-
"language": "python",
|
545 |
-
"name": "python3"
|
546 |
-
},
|
547 |
-
"language_info": {
|
548 |
-
"codemirror_mode": {
|
549 |
-
"name": "ipython",
|
550 |
-
"version": 3
|
551 |
-
},
|
552 |
-
"file_extension": ".py",
|
553 |
-
"mimetype": "text/x-python",
|
554 |
-
"name": "python",
|
555 |
-
"nbconvert_exporter": "python",
|
556 |
-
"pygments_lexer": "ipython3",
|
557 |
-
"version": "3.8.10"
|
558 |
-
}
|
559 |
-
},
|
560 |
-
"nbformat": 4,
|
561 |
-
"nbformat_minor": 5
|
562 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/encoding/vqgan-jax-encoding-with-captions.ipynb
DELETED
@@ -1,355 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "d0b72877",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# vqgan-jax-encoding-with-captions"
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "markdown",
|
13 |
-
"id": "875c82b3",
|
14 |
-
"metadata": {},
|
15 |
-
"source": [
|
16 |
-
"Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n",
|
17 |
-
"\n",
|
18 |
-
"We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model."
|
19 |
-
]
|
20 |
-
},
|
21 |
-
{
|
22 |
-
"cell_type": "code",
|
23 |
-
"execution_count": 1,
|
24 |
-
"id": "3b59489e",
|
25 |
-
"metadata": {},
|
26 |
-
"outputs": [],
|
27 |
-
"source": [
|
28 |
-
"import io\n",
|
29 |
-
"\n",
|
30 |
-
"import requests\n",
|
31 |
-
"from PIL import Image\n",
|
32 |
-
"import numpy as np\n",
|
33 |
-
"from tqdm import tqdm\n",
|
34 |
-
"\n",
|
35 |
-
"import torch\n",
|
36 |
-
"import torchvision.transforms as T\n",
|
37 |
-
"import torchvision.transforms.functional as TF\n",
|
38 |
-
"from torchvision.transforms import InterpolationMode\n",
|
39 |
-
"from torch.utils.data import Dataset, DataLoader\n",
|
40 |
-
"\n",
|
41 |
-
"import jax\n",
|
42 |
-
"from jax import pmap"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"cell_type": "markdown",
|
47 |
-
"id": "511c3b9e",
|
48 |
-
"metadata": {},
|
49 |
-
"source": [
|
50 |
-
"## VQGAN-JAX model"
|
51 |
-
]
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"cell_type": "code",
|
55 |
-
"execution_count": 2,
|
56 |
-
"id": "2ca50dc7",
|
57 |
-
"metadata": {},
|
58 |
-
"outputs": [],
|
59 |
-
"source": [
|
60 |
-
"from vqgan_jax.modeling_flax_vqgan import VQModel"
|
61 |
-
]
|
62 |
-
},
|
63 |
-
{
|
64 |
-
"cell_type": "markdown",
|
65 |
-
"id": "7b60da9a",
|
66 |
-
"metadata": {},
|
67 |
-
"source": [
|
68 |
-
"We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
|
69 |
-
]
|
70 |
-
},
|
71 |
-
{
|
72 |
-
"cell_type": "code",
|
73 |
-
"execution_count": 3,
|
74 |
-
"id": "29ce8b15",
|
75 |
-
"metadata": {},
|
76 |
-
"outputs": [
|
77 |
-
{
|
78 |
-
"data": {
|
79 |
-
"application/vnd.jupyter.widget-view+json": {
|
80 |
-
"model_id": "db406bdfc5d5428eaeae1631a04989dd",
|
81 |
-
"version_major": 2,
|
82 |
-
"version_minor": 0
|
83 |
-
},
|
84 |
-
"text/plain": [
|
85 |
-
"Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]"
|
86 |
-
]
|
87 |
-
},
|
88 |
-
"metadata": {},
|
89 |
-
"output_type": "display_data"
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"data": {
|
93 |
-
"application/vnd.jupyter.widget-view+json": {
|
94 |
-
"model_id": "3e37f07fba6d48fca70313ae1fa8cc32",
|
95 |
-
"version_major": 2,
|
96 |
-
"version_minor": 0
|
97 |
-
},
|
98 |
-
"text/plain": [
|
99 |
-
"Downloading: 0%| | 0.00/304M [00:00<?, ?B/s]"
|
100 |
-
]
|
101 |
-
},
|
102 |
-
"metadata": {},
|
103 |
-
"output_type": "display_data"
|
104 |
-
},
|
105 |
-
{
|
106 |
-
"name": "stderr",
|
107 |
-
"output_type": "stream",
|
108 |
-
"text": [
|
109 |
-
"INFO:absl:Starting the local TPU driver.\n",
|
110 |
-
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
111 |
-
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n"
|
112 |
-
]
|
113 |
-
},
|
114 |
-
{
|
115 |
-
"name": "stdout",
|
116 |
-
"output_type": "stream",
|
117 |
-
"text": [
|
118 |
-
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
|
119 |
-
]
|
120 |
-
}
|
121 |
-
],
|
122 |
-
"source": [
|
123 |
-
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
124 |
-
]
|
125 |
-
},
|
126 |
-
{
|
127 |
-
"cell_type": "markdown",
|
128 |
-
"id": "c7c4c1e6",
|
129 |
-
"metadata": {},
|
130 |
-
"source": [
|
131 |
-
"## Dataset"
|
132 |
-
]
|
133 |
-
},
|
134 |
-
{
|
135 |
-
"cell_type": "markdown",
|
136 |
-
"id": "7014a7ce",
|
137 |
-
"metadata": {},
|
138 |
-
"source": [
|
139 |
-
"We use Luke Melas-Kyriazi's `dataset.py` which reads image paths and captions from a tsv file that contains both. We only need the images for encoding."
|
140 |
-
]
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"cell_type": "code",
|
144 |
-
"execution_count": 4,
|
145 |
-
"id": "85832702",
|
146 |
-
"metadata": {},
|
147 |
-
"outputs": [],
|
148 |
-
"source": [
|
149 |
-
"from dalle_mini.dataset import *"
|
150 |
-
]
|
151 |
-
},
|
152 |
-
{
|
153 |
-
"cell_type": "code",
|
154 |
-
"execution_count": 5,
|
155 |
-
"id": "81b19eca",
|
156 |
-
"metadata": {},
|
157 |
-
"outputs": [],
|
158 |
-
"source": [
|
159 |
-
"cc12m_images = '/data/CC12M/images'\n",
|
160 |
-
"cc12m_list = '/data/CC12M/images-list-clean.tsv'\n",
|
161 |
-
"# cc12m_list = '/data/CC12M/images-10000.tsv'\n",
|
162 |
-
"cc12m_output = '/data/CC12M/images-encoded.tsv'"
|
163 |
-
]
|
164 |
-
},
|
165 |
-
{
|
166 |
-
"cell_type": "code",
|
167 |
-
"execution_count": 6,
|
168 |
-
"id": "fecc9a00",
|
169 |
-
"metadata": {},
|
170 |
-
"outputs": [],
|
171 |
-
"source": [
|
172 |
-
"image_size = 256\n",
|
173 |
-
"def image_transform(image):\n",
|
174 |
-
" s = min(image.size)\n",
|
175 |
-
" r = image_size / s\n",
|
176 |
-
" s = (round(r * image.size[1]), round(r * image.size[0]))\n",
|
177 |
-
" image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
|
178 |
-
" image = TF.center_crop(image, output_size = 2 * [image_size])\n",
|
179 |
-
" image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
|
180 |
-
" image = image.permute(0, 2, 3, 1).numpy()\n",
|
181 |
-
" return image"
|
182 |
-
]
|
183 |
-
},
|
184 |
-
{
|
185 |
-
"cell_type": "code",
|
186 |
-
"execution_count": 7,
|
187 |
-
"id": "4ce2211f",
|
188 |
-
"metadata": {},
|
189 |
-
"outputs": [],
|
190 |
-
"source": [
|
191 |
-
"dataset = CaptionDataset(\n",
|
192 |
-
" images_root=cc12m_images,\n",
|
193 |
-
" captions_path=cc12m_list,\n",
|
194 |
-
" image_transform=image_transform,\n",
|
195 |
-
" image_transform_type='torchvision',\n",
|
196 |
-
" include_captions=False\n",
|
197 |
-
")"
|
198 |
-
]
|
199 |
-
},
|
200 |
-
{
|
201 |
-
"cell_type": "code",
|
202 |
-
"execution_count": 8,
|
203 |
-
"id": "cc922704",
|
204 |
-
"metadata": {},
|
205 |
-
"outputs": [
|
206 |
-
{
|
207 |
-
"data": {
|
208 |
-
"text/plain": [
|
209 |
-
"8592141"
|
210 |
-
]
|
211 |
-
},
|
212 |
-
"execution_count": 8,
|
213 |
-
"metadata": {},
|
214 |
-
"output_type": "execute_result"
|
215 |
-
}
|
216 |
-
],
|
217 |
-
"source": [
|
218 |
-
"len(dataset)"
|
219 |
-
]
|
220 |
-
},
|
221 |
-
{
|
222 |
-
"cell_type": "markdown",
|
223 |
-
"id": "62ad01c3",
|
224 |
-
"metadata": {},
|
225 |
-
"source": [
|
226 |
-
"## Encoding"
|
227 |
-
]
|
228 |
-
},
|
229 |
-
{
|
230 |
-
"cell_type": "code",
|
231 |
-
"execution_count": 9,
|
232 |
-
"id": "88f36d0b",
|
233 |
-
"metadata": {},
|
234 |
-
"outputs": [],
|
235 |
-
"source": [
|
236 |
-
"def encode(model, batch):\n",
|
237 |
-
"# print(\"jitting encode function\")\n",
|
238 |
-
" _, indices = model.encode(batch)\n",
|
239 |
-
" return indices"
|
240 |
-
]
|
241 |
-
},
|
242 |
-
{
|
243 |
-
"cell_type": "code",
|
244 |
-
"execution_count": 10,
|
245 |
-
"id": "1f35f0cb",
|
246 |
-
"metadata": {},
|
247 |
-
"outputs": [],
|
248 |
-
"source": [
|
249 |
-
"def superbatch_generator(dataloader, num_tpus):\n",
|
250 |
-
" iter_loader = iter(dataloader)\n",
|
251 |
-
" for batch in iter_loader:\n",
|
252 |
-
" superbatch = [batch.squeeze(1)]\n",
|
253 |
-
" try:\n",
|
254 |
-
" for b in range(num_tpus-1):\n",
|
255 |
-
" batch = next(iter_loader)\n",
|
256 |
-
" if batch is None:\n",
|
257 |
-
" break\n",
|
258 |
-
" # Skip incomplete last batch\n",
|
259 |
-
" if batch.shape[0] == dataloader.batch_size:\n",
|
260 |
-
" superbatch.append(batch.squeeze(1))\n",
|
261 |
-
" except StopIteration:\n",
|
262 |
-
" pass\n",
|
263 |
-
" superbatch = torch.stack(superbatch, axis=0)\n",
|
264 |
-
" yield superbatch"
|
265 |
-
]
|
266 |
-
},
|
267 |
-
{
|
268 |
-
"cell_type": "code",
|
269 |
-
"execution_count": 11,
|
270 |
-
"id": "2210705b",
|
271 |
-
"metadata": {},
|
272 |
-
"outputs": [],
|
273 |
-
"source": [
|
274 |
-
"import os\n",
|
275 |
-
"\n",
|
276 |
-
"def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
|
277 |
-
" if os.path.isfile(output_tsv):\n",
|
278 |
-
" print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
|
279 |
-
" return\n",
|
280 |
-
" \n",
|
281 |
-
" num_tpus = 8 \n",
|
282 |
-
" dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
|
283 |
-
" superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
|
284 |
-
" \n",
|
285 |
-
" p_encoder = pmap(lambda batch: encode(model, batch))\n",
|
286 |
-
"\n",
|
287 |
-
" # We save each superbatch to avoid reallocation of buffers as we process them.\n",
|
288 |
-
" # We keep the file open to prevent excessive file seeks.\n",
|
289 |
-
" with open(output_tsv, \"w\") as file:\n",
|
290 |
-
" iterations = len(dataset) // (batch_size * num_tpus)\n",
|
291 |
-
" for n in tqdm(range(iterations)):\n",
|
292 |
-
" superbatch = next(superbatches)\n",
|
293 |
-
" encoded = p_encoder(superbatch.numpy())\n",
|
294 |
-
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
295 |
-
"\n",
|
296 |
-
" # Extract fields from the dataset internal `captions` property, and save to disk\n",
|
297 |
-
" start_index = n * batch_size * num_tpus\n",
|
298 |
-
" end_index = (n+1) * batch_size * num_tpus\n",
|
299 |
-
" paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
|
300 |
-
" captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
|
301 |
-
" encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
|
302 |
-
" batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
|
303 |
-
" batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)\n",
|
304 |
-
" "
|
305 |
-
]
|
306 |
-
},
|
307 |
-
{
|
308 |
-
"cell_type": "code",
|
309 |
-
"execution_count": null,
|
310 |
-
"id": "7704863d",
|
311 |
-
"metadata": {},
|
312 |
-
"outputs": [
|
313 |
-
{
|
314 |
-
"name": "stderr",
|
315 |
-
"output_type": "stream",
|
316 |
-
"text": [
|
317 |
-
" 4%|██▋ | 621/16781 [07:09<3:02:46, 1.47it/s]"
|
318 |
-
]
|
319 |
-
}
|
320 |
-
],
|
321 |
-
"source": [
|
322 |
-
"encode_captioned_dataset(dataset, cc12m_output, batch_size=64, num_workers=16)"
|
323 |
-
]
|
324 |
-
},
|
325 |
-
{
|
326 |
-
"cell_type": "markdown",
|
327 |
-
"id": "8953dd84",
|
328 |
-
"metadata": {},
|
329 |
-
"source": [
|
330 |
-
"----"
|
331 |
-
]
|
332 |
-
}
|
333 |
-
],
|
334 |
-
"metadata": {
|
335 |
-
"kernelspec": {
|
336 |
-
"display_name": "Python 3 (ipykernel)",
|
337 |
-
"language": "python",
|
338 |
-
"name": "python3"
|
339 |
-
},
|
340 |
-
"language_info": {
|
341 |
-
"codemirror_mode": {
|
342 |
-
"name": "ipython",
|
343 |
-
"version": 3
|
344 |
-
},
|
345 |
-
"file_extension": ".py",
|
346 |
-
"mimetype": "text/x-python",
|
347 |
-
"name": "python",
|
348 |
-
"nbconvert_exporter": "python",
|
349 |
-
"pygments_lexer": "ipython3",
|
350 |
-
"version": "3.8.10"
|
351 |
-
}
|
352 |
-
},
|
353 |
-
"nbformat": 4,
|
354 |
-
"nbformat_minor": 5
|
355 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/encoding/vqgan-jax-encoding-yfcc100m.ipynb
DELETED
@@ -1,1129 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "markdown",
|
5 |
-
"id": "d0b72877",
|
6 |
-
"metadata": {},
|
7 |
-
"source": [
|
8 |
-
"# vqgan-jax-encoding-yfcc100m"
|
9 |
-
]
|
10 |
-
},
|
11 |
-
{
|
12 |
-
"cell_type": "markdown",
|
13 |
-
"id": "ba7b31e6",
|
14 |
-
"metadata": {},
|
15 |
-
"source": [
|
16 |
-
"Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n",
|
17 |
-
"\n",
|
18 |
-
"This dataset was prepared by @borisdayma in Json lines format."
|
19 |
-
]
|
20 |
-
},
|
21 |
-
{
|
22 |
-
"cell_type": "code",
|
23 |
-
"execution_count": 92,
|
24 |
-
"id": "3b59489e",
|
25 |
-
"metadata": {},
|
26 |
-
"outputs": [],
|
27 |
-
"source": [
|
28 |
-
"import io\n",
|
29 |
-
"\n",
|
30 |
-
"import requests\n",
|
31 |
-
"from PIL import Image\n",
|
32 |
-
"import numpy as np\n",
|
33 |
-
"from tqdm import tqdm\n",
|
34 |
-
"\n",
|
35 |
-
"import torch\n",
|
36 |
-
"import torchvision.transforms as T\n",
|
37 |
-
"import torchvision.transforms.functional as TF\n",
|
38 |
-
"from torchvision.transforms import InterpolationMode\n",
|
39 |
-
"from torch.utils.data import Dataset, DataLoader\n",
|
40 |
-
"from torchvision.datasets.folder import default_loader\n",
|
41 |
-
"import os\n",
|
42 |
-
"\n",
|
43 |
-
"import jax\n",
|
44 |
-
"from jax import pmap"
|
45 |
-
]
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"cell_type": "markdown",
|
49 |
-
"id": "511c3b9e",
|
50 |
-
"metadata": {},
|
51 |
-
"source": [
|
52 |
-
"## VQGAN-JAX model"
|
53 |
-
]
|
54 |
-
},
|
55 |
-
{
|
56 |
-
"cell_type": "code",
|
57 |
-
"execution_count": 93,
|
58 |
-
"id": "2ca50dc7",
|
59 |
-
"metadata": {},
|
60 |
-
"outputs": [],
|
61 |
-
"source": [
|
62 |
-
"from vqgan_jax.modeling_flax_vqgan import VQModel"
|
63 |
-
]
|
64 |
-
},
|
65 |
-
{
|
66 |
-
"cell_type": "markdown",
|
67 |
-
"id": "7b60da9a",
|
68 |
-
"metadata": {},
|
69 |
-
"source": [
|
70 |
-
"We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
|
71 |
-
]
|
72 |
-
},
|
73 |
-
{
|
74 |
-
"cell_type": "code",
|
75 |
-
"execution_count": 167,
|
76 |
-
"id": "29ce8b15",
|
77 |
-
"metadata": {},
|
78 |
-
"outputs": [
|
79 |
-
{
|
80 |
-
"name": "stdout",
|
81 |
-
"output_type": "stream",
|
82 |
-
"text": [
|
83 |
-
"Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
|
84 |
-
]
|
85 |
-
}
|
86 |
-
],
|
87 |
-
"source": [
|
88 |
-
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
89 |
-
]
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"cell_type": "markdown",
|
93 |
-
"id": "c7c4c1e6",
|
94 |
-
"metadata": {},
|
95 |
-
"source": [
|
96 |
-
"## Dataset"
|
97 |
-
]
|
98 |
-
},
|
99 |
-
{
|
100 |
-
"cell_type": "code",
|
101 |
-
"execution_count": 94,
|
102 |
-
"id": "33861477",
|
103 |
-
"metadata": {},
|
104 |
-
"outputs": [],
|
105 |
-
"source": [
|
106 |
-
"import pandas as pd\n",
|
107 |
-
"from pathlib import Path"
|
108 |
-
]
|
109 |
-
},
|
110 |
-
{
|
111 |
-
"cell_type": "code",
|
112 |
-
"execution_count": 134,
|
113 |
-
"id": "81b19eca",
|
114 |
-
"metadata": {},
|
115 |
-
"outputs": [],
|
116 |
-
"source": [
|
117 |
-
"yfcc100m = Path('/home/khali/TPU-Test/YFCC100M_OpenAI_subset')\n",
|
118 |
-
"# Images are 'sharded' from the following directory\n",
|
119 |
-
"yfcc100m_images = yfcc100m/'data'/'data'/'images'\n",
|
120 |
-
"yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'\n",
|
121 |
-
"yfcc100m_output = yfcc100m/'metadata_encoded.tsv'"
|
122 |
-
]
|
123 |
-
},
|
124 |
-
{
|
125 |
-
"cell_type": "markdown",
|
126 |
-
"id": "1c58bb4a",
|
127 |
-
"metadata": {},
|
128 |
-
"source": [
|
129 |
-
"### Cleanup"
|
130 |
-
]
|
131 |
-
},
|
132 |
-
{
|
133 |
-
"cell_type": "markdown",
|
134 |
-
"id": "1a14ae3d",
|
135 |
-
"metadata": {},
|
136 |
-
"source": [
|
137 |
-
"We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas."
|
138 |
-
]
|
139 |
-
},
|
140 |
-
{
|
141 |
-
"cell_type": "code",
|
142 |
-
"execution_count": 96,
|
143 |
-
"id": "7811648c",
|
144 |
-
"metadata": {},
|
145 |
-
"outputs": [],
|
146 |
-
"source": [
|
147 |
-
"import datasets\n",
|
148 |
-
"from datasets import Dataset, load_dataset"
|
149 |
-
]
|
150 |
-
},
|
151 |
-
{
|
152 |
-
"cell_type": "code",
|
153 |
-
"execution_count": 10,
|
154 |
-
"id": "4811a230",
|
155 |
-
"metadata": {},
|
156 |
-
"outputs": [
|
157 |
-
{
|
158 |
-
"name": "stderr",
|
159 |
-
"output_type": "stream",
|
160 |
-
"text": [
|
161 |
-
"tcmalloc: large alloc 1254047744 bytes == 0xb2b08000 @ 0x7f9e78632680 0x7f9e78653824 0x585b92 0x504d56 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
|
162 |
-
"tcmalloc: large alloc 1254047744 bytes == 0xfd74e000 @ 0x7f9e78632680 0x7f9e78653824 0x590214 0x586f90 0x56e1f3 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8 0x5f2702 0x56c332\n",
|
163 |
-
"tcmalloc: large alloc 5016190976 bytes == 0x148b42000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
164 |
-
"tcmalloc: large alloc 5019099136 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
165 |
-
"tcmalloc: large alloc 5019811840 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
166 |
-
"tcmalloc: large alloc 5024571392 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
167 |
-
"tcmalloc: large alloc 5021097984 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
168 |
-
"tcmalloc: large alloc 5022818304 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
169 |
-
"tcmalloc: large alloc 5020794880 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
170 |
-
"tcmalloc: large alloc 5019451392 bytes == 0x39f9a8000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
171 |
-
"tcmalloc: large alloc 5020565504 bytes == 0x4cb4ec000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
172 |
-
"tcmalloc: large alloc 5012561920 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
173 |
-
"tcmalloc: large alloc 5021835264 bytes == 0x5f6cba000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n",
|
174 |
-
"tcmalloc: large alloc 5017436160 bytes == 0x273f12000 @ 0x7f9e78632680 0x7f9e78653824 0x5b9144 0x7f9b2929127e 0x7f9b29291a19 0x7f9b29291886 0x7f9b29291cef 0x7f9b2928f204 0x5f2cc9 0x5f30ff 0x5705f6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x56acb6 0x5f5956 0x5a8cb3 0x56ae94 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56\n"
|
175 |
-
]
|
176 |
-
}
|
177 |
-
],
|
178 |
-
"source": [
|
179 |
-
"# The metadata is too bog to load into memory at once, so chopping it into chunks\n",
|
180 |
-
"chunk_size=1000000\n",
|
181 |
-
"batch_no=1\n",
|
182 |
-
"for chunk in pd.read_json(yfcc100m_metadata, orient=\"records\", lines=True,chunksize=chunk_size):\n",
|
183 |
-
" chunk.to_csv('./chunks/chunk'+str(batch_no)+'.tsv', sep=\"\\t\", index=False)\n",
|
184 |
-
" batch_no+=1"
|
185 |
-
]
|
186 |
-
},
|
187 |
-
{
|
188 |
-
"cell_type": "code",
|
189 |
-
"execution_count": 25,
|
190 |
-
"id": "46b2f083",
|
191 |
-
"metadata": {},
|
192 |
-
"outputs": [
|
193 |
-
{
|
194 |
-
"data": {
|
195 |
-
"text/html": [
|
196 |
-
"<div>\n",
|
197 |
-
"<style scoped>\n",
|
198 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
199 |
-
" vertical-align: middle;\n",
|
200 |
-
" }\n",
|
201 |
-
"\n",
|
202 |
-
" .dataframe tbody tr th {\n",
|
203 |
-
" vertical-align: top;\n",
|
204 |
-
" }\n",
|
205 |
-
"\n",
|
206 |
-
" .dataframe thead th {\n",
|
207 |
-
" text-align: right;\n",
|
208 |
-
" }\n",
|
209 |
-
"</style>\n",
|
210 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
211 |
-
" <thead>\n",
|
212 |
-
" <tr style=\"text-align: right;\">\n",
|
213 |
-
" <th></th>\n",
|
214 |
-
" <th>photoid</th>\n",
|
215 |
-
" <th>uid</th>\n",
|
216 |
-
" <th>unickname</th>\n",
|
217 |
-
" <th>datetaken</th>\n",
|
218 |
-
" <th>dateuploaded</th>\n",
|
219 |
-
" <th>capturedevice</th>\n",
|
220 |
-
" <th>title</th>\n",
|
221 |
-
" <th>description</th>\n",
|
222 |
-
" <th>usertags</th>\n",
|
223 |
-
" <th>machinetags</th>\n",
|
224 |
-
" <th>...</th>\n",
|
225 |
-
" <th>licenseurl</th>\n",
|
226 |
-
" <th>serverid</th>\n",
|
227 |
-
" <th>farmid</th>\n",
|
228 |
-
" <th>secret</th>\n",
|
229 |
-
" <th>secretoriginal</th>\n",
|
230 |
-
" <th>ext</th>\n",
|
231 |
-
" <th>marker</th>\n",
|
232 |
-
" <th>key</th>\n",
|
233 |
-
" <th>title_clean</th>\n",
|
234 |
-
" <th>description_clean</th>\n",
|
235 |
-
" </tr>\n",
|
236 |
-
" </thead>\n",
|
237 |
-
" <tbody>\n",
|
238 |
-
" <tr>\n",
|
239 |
-
" <th>0</th>\n",
|
240 |
-
" <td>137943</td>\n",
|
241 |
-
" <td>48600072071@N01</td>\n",
|
242 |
-
" <td>doctor+paradox</td>\n",
|
243 |
-
" <td>2004-08-01 18:13:06.0</td>\n",
|
244 |
-
" <td>1091409186</td>\n",
|
245 |
-
" <td>NaN</td>\n",
|
246 |
-
" <td>A+Picture+Share%21</td>\n",
|
247 |
-
" <td>Antenna</td>\n",
|
248 |
-
" <td>cameraphone,cayugaheights,green,hydrant,ithaca...</td>\n",
|
249 |
-
" <td>NaN</td>\n",
|
250 |
-
" <td>...</td>\n",
|
251 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
252 |
-
" <td>1</td>\n",
|
253 |
-
" <td>1</td>\n",
|
254 |
-
" <td>1650c7cdc6</td>\n",
|
255 |
-
" <td>1650c7cdc6</td>\n",
|
256 |
-
" <td>jpg</td>\n",
|
257 |
-
" <td>0</td>\n",
|
258 |
-
" <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
|
259 |
-
" <td>A Picture Share!</td>\n",
|
260 |
-
" <td>Antenna</td>\n",
|
261 |
-
" </tr>\n",
|
262 |
-
" <tr>\n",
|
263 |
-
" <th>1</th>\n",
|
264 |
-
" <td>1246361</td>\n",
|
265 |
-
" <td>44124324682@N01</td>\n",
|
266 |
-
" <td>mharrsch</td>\n",
|
267 |
-
" <td>2004-11-03 23:04:02.0</td>\n",
|
268 |
-
" <td>1099523042</td>\n",
|
269 |
-
" <td>NaN</td>\n",
|
270 |
-
" <td>An+ornate+Roman+urn</td>\n",
|
271 |
-
" <td>Photographed+at+the+%3Ca+href%3D%22http%3A%2F%...</td>\n",
|
272 |
-
" <td>ancient,baltimore,burial,death,empire,funeral,...</td>\n",
|
273 |
-
" <td>NaN</td>\n",
|
274 |
-
" <td>...</td>\n",
|
275 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
276 |
-
" <td>1</td>\n",
|
277 |
-
" <td>1</td>\n",
|
278 |
-
" <td>cf37054610</td>\n",
|
279 |
-
" <td>cf37054610</td>\n",
|
280 |
-
" <td>jpg</td>\n",
|
281 |
-
" <td>0</td>\n",
|
282 |
-
" <td>d29f01b149167d683f9ddde464bb3db</td>\n",
|
283 |
-
" <td>An ornate Roman urn</td>\n",
|
284 |
-
" <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
|
285 |
-
" </tr>\n",
|
286 |
-
" <tr>\n",
|
287 |
-
" <th>2</th>\n",
|
288 |
-
" <td>1251599</td>\n",
|
289 |
-
" <td>51035803024@N01</td>\n",
|
290 |
-
" <td>bmitd67</td>\n",
|
291 |
-
" <td>2004-10-30 17:09:32.0</td>\n",
|
292 |
-
" <td>1099538888</td>\n",
|
293 |
-
" <td>Canon+PowerShot+S30</td>\n",
|
294 |
-
" <td>Jai+%26+Tara+on+the+Cumberland</td>\n",
|
295 |
-
" <td>Another+trip+for+the+happy+couple.</td>\n",
|
296 |
-
" <td>blue+heron,cumberland+river,jai,tara,tennessee</td>\n",
|
297 |
-
" <td>NaN</td>\n",
|
298 |
-
" <td>...</td>\n",
|
299 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
300 |
-
" <td>1</td>\n",
|
301 |
-
" <td>1</td>\n",
|
302 |
-
" <td>4a4234e32c</td>\n",
|
303 |
-
" <td>4a4234e32c</td>\n",
|
304 |
-
" <td>jpg</td>\n",
|
305 |
-
" <td>0</td>\n",
|
306 |
-
" <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
|
307 |
-
" <td>Jai & Tara on the Cumberland</td>\n",
|
308 |
-
" <td>Another trip for the happy couple.</td>\n",
|
309 |
-
" </tr>\n",
|
310 |
-
" <tr>\n",
|
311 |
-
" <th>3</th>\n",
|
312 |
-
" <td>2348587</td>\n",
|
313 |
-
" <td>73621375@N00</td>\n",
|
314 |
-
" <td>Thom+Watson</td>\n",
|
315 |
-
" <td>2004-12-18 21:08:09.0</td>\n",
|
316 |
-
" <td>1103497228</td>\n",
|
317 |
-
" <td>SONY+DSC-W1</td>\n",
|
318 |
-
" <td>Castle+gate+-+%22lite-brited%22</td>\n",
|
319 |
-
" <td>Taken+at+the+Miracle+of+Lights+display+in+Cent...</td>\n",
|
320 |
-
" <td>bullrunpark,castle,centreville,christmas,decor...</td>\n",
|
321 |
-
" <td>NaN</td>\n",
|
322 |
-
" <td>...</td>\n",
|
323 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
324 |
-
" <td>2</td>\n",
|
325 |
-
" <td>1</td>\n",
|
326 |
-
" <td>7162c974c3</td>\n",
|
327 |
-
" <td>7162c974c3</td>\n",
|
328 |
-
" <td>jpg</td>\n",
|
329 |
-
" <td>0</td>\n",
|
330 |
-
" <td>d29ce96395848478b1e8396e44899</td>\n",
|
331 |
-
" <td>Castle gate - \"lite-brited\"</td>\n",
|
332 |
-
" <td>Taken at the Miracle of Lights display in Cent...</td>\n",
|
333 |
-
" </tr>\n",
|
334 |
-
" <tr>\n",
|
335 |
-
" <th>4</th>\n",
|
336 |
-
" <td>3516047</td>\n",
|
337 |
-
" <td>48600072071@N01</td>\n",
|
338 |
-
" <td>doctor+paradox</td>\n",
|
339 |
-
" <td>2005-01-18 16:44:18.0</td>\n",
|
340 |
-
" <td>1106084658</td>\n",
|
341 |
-
" <td>NaN</td>\n",
|
342 |
-
" <td>A+Picture+Share%21</td>\n",
|
343 |
-
" <td>Tabular</td>\n",
|
344 |
-
" <td>cameraphone,moblog,unfound</td>\n",
|
345 |
-
" <td>NaN</td>\n",
|
346 |
-
" <td>...</td>\n",
|
347 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
348 |
-
" <td>3</td>\n",
|
349 |
-
" <td>1</td>\n",
|
350 |
-
" <td>663e0d8b3d</td>\n",
|
351 |
-
" <td>663e0d8b3d</td>\n",
|
352 |
-
" <td>jpg</td>\n",
|
353 |
-
" <td>0</td>\n",
|
354 |
-
" <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
|
355 |
-
" <td>A Picture Share!</td>\n",
|
356 |
-
" <td>Tabular</td>\n",
|
357 |
-
" </tr>\n",
|
358 |
-
" <tr>\n",
|
359 |
-
" <th>...</th>\n",
|
360 |
-
" <td>...</td>\n",
|
361 |
-
" <td>...</td>\n",
|
362 |
-
" <td>...</td>\n",
|
363 |
-
" <td>...</td>\n",
|
364 |
-
" <td>...</td>\n",
|
365 |
-
" <td>...</td>\n",
|
366 |
-
" <td>...</td>\n",
|
367 |
-
" <td>...</td>\n",
|
368 |
-
" <td>...</td>\n",
|
369 |
-
" <td>...</td>\n",
|
370 |
-
" <td>...</td>\n",
|
371 |
-
" <td>...</td>\n",
|
372 |
-
" <td>...</td>\n",
|
373 |
-
" <td>...</td>\n",
|
374 |
-
" <td>...</td>\n",
|
375 |
-
" <td>...</td>\n",
|
376 |
-
" <td>...</td>\n",
|
377 |
-
" <td>...</td>\n",
|
378 |
-
" <td>...</td>\n",
|
379 |
-
" <td>...</td>\n",
|
380 |
-
" <td>...</td>\n",
|
381 |
-
" </tr>\n",
|
382 |
-
" <tr>\n",
|
383 |
-
" <th>999995</th>\n",
|
384 |
-
" <td>4648651054</td>\n",
|
385 |
-
" <td>24511045@N04</td>\n",
|
386 |
-
" <td>mtfrazier</td>\n",
|
387 |
-
" <td>2010-05-02 15:47:45.0</td>\n",
|
388 |
-
" <td>1275083371</td>\n",
|
389 |
-
" <td>Canon+EOS+50D</td>\n",
|
390 |
-
" <td>U.S.+Navy+Blue+Angels%3A+2010</td>\n",
|
391 |
-
" <td>2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri</td>\n",
|
392 |
-
" <td>NaN</td>\n",
|
393 |
-
" <td>NaN</td>\n",
|
394 |
-
" <td>...</td>\n",
|
395 |
-
" <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
|
396 |
-
" <td>4072</td>\n",
|
397 |
-
" <td>5</td>\n",
|
398 |
-
" <td>2d12d73fb0</td>\n",
|
399 |
-
" <td>dd5856ea42</td>\n",
|
400 |
-
" <td>jpg</td>\n",
|
401 |
-
" <td>0</td>\n",
|
402 |
-
" <td>60fa2911cb81eb25b356e9fee978aef</td>\n",
|
403 |
-
" <td>U.S. Navy Blue Angels: 2010</td>\n",
|
404 |
-
" <td>2 May 2010 Sunday St. Joseph, Missouri</td>\n",
|
405 |
-
" </tr>\n",
|
406 |
-
" <tr>\n",
|
407 |
-
" <th>999996</th>\n",
|
408 |
-
" <td>4652130996</td>\n",
|
409 |
-
" <td>21963865@N04</td>\n",
|
410 |
-
" <td>GRAB1.0</td>\n",
|
411 |
-
" <td>2010-05-29 19:23:10.0</td>\n",
|
412 |
-
" <td>1275200833</td>\n",
|
413 |
-
" <td>SONY+DSLR-A230</td>\n",
|
414 |
-
" <td>Attempts+on+Her+Life</td>\n",
|
415 |
-
" <td>BAPA+1+production+of+Martin+Crimp%27s+Attempts...</td>\n",
|
416 |
-
" <td>NaN</td>\n",
|
417 |
-
" <td>NaN</td>\n",
|
418 |
-
" <td>...</td>\n",
|
419 |
-
" <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
|
420 |
-
" <td>4003</td>\n",
|
421 |
-
" <td>5</td>\n",
|
422 |
-
" <td>8889121579</td>\n",
|
423 |
-
" <td>2f46599456</td>\n",
|
424 |
-
" <td>jpg</td>\n",
|
425 |
-
" <td>0</td>\n",
|
426 |
-
" <td>60f5ef5ce4c2d24566226abebd67d4</td>\n",
|
427 |
-
" <td>Attempts on Her Life</td>\n",
|
428 |
-
" <td>BAPA 1 production of Martin Crimp's Attempts o...</td>\n",
|
429 |
-
" </tr>\n",
|
430 |
-
" <tr>\n",
|
431 |
-
" <th>999997</th>\n",
|
432 |
-
" <td>4652568339</td>\n",
|
433 |
-
" <td>64025277@N00</td>\n",
|
434 |
-
" <td>1Sock</td>\n",
|
435 |
-
" <td>2010-05-13 15:38:37.0</td>\n",
|
436 |
-
" <td>1275234267</td>\n",
|
437 |
-
" <td>Canon+EOS+DIGITAL+REBEL+XT</td>\n",
|
438 |
-
" <td>Carlsbad+Caverns+3</td>\n",
|
439 |
-
" <td>%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%...</td>\n",
|
440 |
-
" <td>carlsbad,carlsbad+caverns,cave,faa,new+mexico,...</td>\n",
|
441 |
-
" <td>NaN</td>\n",
|
442 |
-
" <td>...</td>\n",
|
443 |
-
" <td>http://creativecommons.org/licenses/by-nc-nd/2.0/</td>\n",
|
444 |
-
" <td>4010</td>\n",
|
445 |
-
" <td>5</td>\n",
|
446 |
-
" <td>0a1808a69e</td>\n",
|
447 |
-
" <td>cf6d348e3d</td>\n",
|
448 |
-
" <td>jpg</td>\n",
|
449 |
-
" <td>0</td>\n",
|
450 |
-
" <td>60f029482d1d1028fda5281daf498f</td>\n",
|
451 |
-
" <td>Carlsbad Caverns 3</td>\n",
|
452 |
-
" <td>♥♥♥♥♥♥♥ Interested in purchasing this photogra...</td>\n",
|
453 |
-
" </tr>\n",
|
454 |
-
" <tr>\n",
|
455 |
-
" <th>999998</th>\n",
|
456 |
-
" <td>4653110895</td>\n",
|
457 |
-
" <td>20483509@N00</td>\n",
|
458 |
-
" <td>subberculture</td>\n",
|
459 |
-
" <td>2010-05-30 15:37:05.0</td>\n",
|
460 |
-
" <td>1275245596</td>\n",
|
461 |
-
" <td>Canon+DIGITAL+IXUS+40</td>\n",
|
462 |
-
" <td>Want</td>\n",
|
463 |
-
" <td>Isn%27t+that+gorgeous%3F</td>\n",
|
464 |
-
" <td>2010,edinburgh+museum,may,phonebox,wood</td>\n",
|
465 |
-
" <td>NaN</td>\n",
|
466 |
-
" <td>...</td>\n",
|
467 |
-
" <td>http://creativecommons.org/licenses/by-sa/2.0/</td>\n",
|
468 |
-
" <td>4066</td>\n",
|
469 |
-
" <td>5</td>\n",
|
470 |
-
" <td>77c3b3a254</td>\n",
|
471 |
-
" <td>c4697e1511</td>\n",
|
472 |
-
" <td>jpg</td>\n",
|
473 |
-
" <td>0</td>\n",
|
474 |
-
" <td>60f72775f433cf8de3efaeb431866153</td>\n",
|
475 |
-
" <td>Want</td>\n",
|
476 |
-
" <td>Isn't that gorgeous?</td>\n",
|
477 |
-
" </tr>\n",
|
478 |
-
" <tr>\n",
|
479 |
-
" <th>999999</th>\n",
|
480 |
-
" <td>4655503987</td>\n",
|
481 |
-
" <td>8457193@N07</td>\n",
|
482 |
-
" <td>zackojones</td>\n",
|
483 |
-
" <td>2010-05-30 15:34:58.0</td>\n",
|
484 |
-
" <td>1275310230</td>\n",
|
485 |
-
" <td>Canon+EOS+7D</td>\n",
|
486 |
-
" <td>Summertime</td>\n",
|
487 |
-
" <td>You+gotta+love+it%21</td>\n",
|
488 |
-
" <td>georgia,savannah,united+states,us</td>\n",
|
489 |
-
" <td>NaN</td>\n",
|
490 |
-
" <td>...</td>\n",
|
491 |
-
" <td>http://creativecommons.org/licenses/by-nc-sa/2.0/</td>\n",
|
492 |
-
" <td>4043</td>\n",
|
493 |
-
" <td>5</td>\n",
|
494 |
-
" <td>caff543bfe</td>\n",
|
495 |
-
" <td>f60952ac4d</td>\n",
|
496 |
-
" <td>jpg</td>\n",
|
497 |
-
" <td>0</td>\n",
|
498 |
-
" <td>60f687e11b913bce461e9525d8047e0</td>\n",
|
499 |
-
" <td>Summertime</td>\n",
|
500 |
-
" <td>You gotta love it!</td>\n",
|
501 |
-
" </tr>\n",
|
502 |
-
" </tbody>\n",
|
503 |
-
"</table>\n",
|
504 |
-
"<p>1000000 rows × 26 columns</p>\n",
|
505 |
-
"</div>"
|
506 |
-
],
|
507 |
-
"text/plain": [
|
508 |
-
" photoid uid unickname datetaken \\\n",
|
509 |
-
"0 137943 48600072071@N01 doctor+paradox 2004-08-01 18:13:06.0 \n",
|
510 |
-
"1 1246361 44124324682@N01 mharrsch 2004-11-03 23:04:02.0 \n",
|
511 |
-
"2 1251599 51035803024@N01 bmitd67 2004-10-30 17:09:32.0 \n",
|
512 |
-
"3 2348587 73621375@N00 Thom+Watson 2004-12-18 21:08:09.0 \n",
|
513 |
-
"4 3516047 48600072071@N01 doctor+paradox 2005-01-18 16:44:18.0 \n",
|
514 |
-
"... ... ... ... ... \n",
|
515 |
-
"999995 4648651054 24511045@N04 mtfrazier 2010-05-02 15:47:45.0 \n",
|
516 |
-
"999996 4652130996 21963865@N04 GRAB1.0 2010-05-29 19:23:10.0 \n",
|
517 |
-
"999997 4652568339 64025277@N00 1Sock 2010-05-13 15:38:37.0 \n",
|
518 |
-
"999998 4653110895 20483509@N00 subberculture 2010-05-30 15:37:05.0 \n",
|
519 |
-
"999999 4655503987 8457193@N07 zackojones 2010-05-30 15:34:58.0 \n",
|
520 |
-
"\n",
|
521 |
-
" dateuploaded capturedevice \\\n",
|
522 |
-
"0 1091409186 NaN \n",
|
523 |
-
"1 1099523042 NaN \n",
|
524 |
-
"2 1099538888 Canon+PowerShot+S30 \n",
|
525 |
-
"3 1103497228 SONY+DSC-W1 \n",
|
526 |
-
"4 1106084658 NaN \n",
|
527 |
-
"... ... ... \n",
|
528 |
-
"999995 1275083371 Canon+EOS+50D \n",
|
529 |
-
"999996 1275200833 SONY+DSLR-A230 \n",
|
530 |
-
"999997 1275234267 Canon+EOS+DIGITAL+REBEL+XT \n",
|
531 |
-
"999998 1275245596 Canon+DIGITAL+IXUS+40 \n",
|
532 |
-
"999999 1275310230 Canon+EOS+7D \n",
|
533 |
-
"\n",
|
534 |
-
" title \\\n",
|
535 |
-
"0 A+Picture+Share%21 \n",
|
536 |
-
"1 An+ornate+Roman+urn \n",
|
537 |
-
"2 Jai+%26+Tara+on+the+Cumberland \n",
|
538 |
-
"3 Castle+gate+-+%22lite-brited%22 \n",
|
539 |
-
"4 A+Picture+Share%21 \n",
|
540 |
-
"... ... \n",
|
541 |
-
"999995 U.S.+Navy+Blue+Angels%3A+2010 \n",
|
542 |
-
"999996 Attempts+on+Her+Life \n",
|
543 |
-
"999997 Carlsbad+Caverns+3 \n",
|
544 |
-
"999998 Want \n",
|
545 |
-
"999999 Summertime \n",
|
546 |
-
"\n",
|
547 |
-
" description \\\n",
|
548 |
-
"0 Antenna \n",
|
549 |
-
"1 Photographed+at+the+%3Ca+href%3D%22http%3A%2F%... \n",
|
550 |
-
"2 Another+trip+for+the+happy+couple. \n",
|
551 |
-
"3 Taken+at+the+Miracle+of+Lights+display+in+Cent... \n",
|
552 |
-
"4 Tabular \n",
|
553 |
-
"... ... \n",
|
554 |
-
"999995 2+May+2010%0ASunday%0ASt.+Joseph%2C+Missouri \n",
|
555 |
-
"999996 BAPA+1+production+of+Martin+Crimp%27s+Attempts... \n",
|
556 |
-
"999997 %E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%E2%99%A5%... \n",
|
557 |
-
"999998 Isn%27t+that+gorgeous%3F \n",
|
558 |
-
"999999 You+gotta+love+it%21 \n",
|
559 |
-
"\n",
|
560 |
-
" usertags machinetags ... \\\n",
|
561 |
-
"0 cameraphone,cayugaheights,green,hydrant,ithaca... NaN ... \n",
|
562 |
-
"1 ancient,baltimore,burial,death,empire,funeral,... NaN ... \n",
|
563 |
-
"2 blue+heron,cumberland+river,jai,tara,tennessee NaN ... \n",
|
564 |
-
"3 bullrunpark,castle,centreville,christmas,decor... NaN ... \n",
|
565 |
-
"4 cameraphone,moblog,unfound NaN ... \n",
|
566 |
-
"... ... ... ... \n",
|
567 |
-
"999995 NaN NaN ... \n",
|
568 |
-
"999996 NaN NaN ... \n",
|
569 |
-
"999997 carlsbad,carlsbad+caverns,cave,faa,new+mexico,... NaN ... \n",
|
570 |
-
"999998 2010,edinburgh+museum,may,phonebox,wood NaN ... \n",
|
571 |
-
"999999 georgia,savannah,united+states,us NaN ... \n",
|
572 |
-
"\n",
|
573 |
-
" licenseurl serverid farmid \\\n",
|
574 |
-
"0 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
|
575 |
-
"1 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
|
576 |
-
"2 http://creativecommons.org/licenses/by-nc-sa/2.0/ 1 1 \n",
|
577 |
-
"3 http://creativecommons.org/licenses/by-nc-sa/2.0/ 2 1 \n",
|
578 |
-
"4 http://creativecommons.org/licenses/by-nc-sa/2.0/ 3 1 \n",
|
579 |
-
"... ... ... ... \n",
|
580 |
-
"999995 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4072 5 \n",
|
581 |
-
"999996 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4003 5 \n",
|
582 |
-
"999997 http://creativecommons.org/licenses/by-nc-nd/2.0/ 4010 5 \n",
|
583 |
-
"999998 http://creativecommons.org/licenses/by-sa/2.0/ 4066 5 \n",
|
584 |
-
"999999 http://creativecommons.org/licenses/by-nc-sa/2.0/ 4043 5 \n",
|
585 |
-
"\n",
|
586 |
-
" secret secretoriginal ext marker \\\n",
|
587 |
-
"0 1650c7cdc6 1650c7cdc6 jpg 0 \n",
|
588 |
-
"1 cf37054610 cf37054610 jpg 0 \n",
|
589 |
-
"2 4a4234e32c 4a4234e32c jpg 0 \n",
|
590 |
-
"3 7162c974c3 7162c974c3 jpg 0 \n",
|
591 |
-
"4 663e0d8b3d 663e0d8b3d jpg 0 \n",
|
592 |
-
"... ... ... ... ... \n",
|
593 |
-
"999995 2d12d73fb0 dd5856ea42 jpg 0 \n",
|
594 |
-
"999996 8889121579 2f46599456 jpg 0 \n",
|
595 |
-
"999997 0a1808a69e cf6d348e3d jpg 0 \n",
|
596 |
-
"999998 77c3b3a254 c4697e1511 jpg 0 \n",
|
597 |
-
"999999 caff543bfe f60952ac4d jpg 0 \n",
|
598 |
-
"\n",
|
599 |
-
" key title_clean \\\n",
|
600 |
-
"0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
|
601 |
-
"1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
|
602 |
-
"2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
|
603 |
-
"3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
|
604 |
-
"4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
|
605 |
-
"... ... ... \n",
|
606 |
-
"999995 60fa2911cb81eb25b356e9fee978aef U.S. Navy Blue Angels: 2010 \n",
|
607 |
-
"999996 60f5ef5ce4c2d24566226abebd67d4 Attempts on Her Life \n",
|
608 |
-
"999997 60f029482d1d1028fda5281daf498f Carlsbad Caverns 3 \n",
|
609 |
-
"999998 60f72775f433cf8de3efaeb431866153 Want \n",
|
610 |
-
"999999 60f687e11b913bce461e9525d8047e0 Summertime \n",
|
611 |
-
"\n",
|
612 |
-
" description_clean \n",
|
613 |
-
"0 Antenna \n",
|
614 |
-
"1 Photographed at the Walters Art Museum, Baltim... \n",
|
615 |
-
"2 Another trip for the happy couple. \n",
|
616 |
-
"3 Taken at the Miracle of Lights display in Cent... \n",
|
617 |
-
"4 Tabular \n",
|
618 |
-
"... ... \n",
|
619 |
-
"999995 2 May 2010 Sunday St. Joseph, Missouri \n",
|
620 |
-
"999996 BAPA 1 production of Martin Crimp's Attempts o... \n",
|
621 |
-
"999997 ♥♥♥♥♥♥♥ Interested in purchasing this photogra... \n",
|
622 |
-
"999998 Isn't that gorgeous? \n",
|
623 |
-
"999999 You gotta love it! \n",
|
624 |
-
"\n",
|
625 |
-
"[1000000 rows x 26 columns]"
|
626 |
-
]
|
627 |
-
},
|
628 |
-
"execution_count": 25,
|
629 |
-
"metadata": {},
|
630 |
-
"output_type": "execute_result"
|
631 |
-
}
|
632 |
-
],
|
633 |
-
"source": [
|
634 |
-
"# looking up at a chunk\n",
|
635 |
-
"pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")"
|
636 |
-
]
|
637 |
-
},
|
638 |
-
{
|
639 |
-
"cell_type": "code",
|
640 |
-
"execution_count": 98,
|
641 |
-
"id": "c51c5597",
|
642 |
-
"metadata": {},
|
643 |
-
"outputs": [
|
644 |
-
{
|
645 |
-
"data": {
|
646 |
-
"text/html": [
|
647 |
-
"<div>\n",
|
648 |
-
"<style scoped>\n",
|
649 |
-
" .dataframe tbody tr th:only-of-type {\n",
|
650 |
-
" vertical-align: middle;\n",
|
651 |
-
" }\n",
|
652 |
-
"\n",
|
653 |
-
" .dataframe tbody tr th {\n",
|
654 |
-
" vertical-align: top;\n",
|
655 |
-
" }\n",
|
656 |
-
"\n",
|
657 |
-
" .dataframe thead th {\n",
|
658 |
-
" text-align: right;\n",
|
659 |
-
" }\n",
|
660 |
-
"</style>\n",
|
661 |
-
"<table border=\"1\" class=\"dataframe\">\n",
|
662 |
-
" <thead>\n",
|
663 |
-
" <tr style=\"text-align: right;\">\n",
|
664 |
-
" <th></th>\n",
|
665 |
-
" <th>key</th>\n",
|
666 |
-
" <th>title_clean</th>\n",
|
667 |
-
" <th>description_clean</th>\n",
|
668 |
-
" <th>ext</th>\n",
|
669 |
-
" </tr>\n",
|
670 |
-
" </thead>\n",
|
671 |
-
" <tbody>\n",
|
672 |
-
" <tr>\n",
|
673 |
-
" <th>0</th>\n",
|
674 |
-
" <td>d29e7c6a3028418c64eb15e3cf577c2</td>\n",
|
675 |
-
" <td>A Picture Share!</td>\n",
|
676 |
-
" <td>Antenna</td>\n",
|
677 |
-
" <td>jpg</td>\n",
|
678 |
-
" </tr>\n",
|
679 |
-
" <tr>\n",
|
680 |
-
" <th>1</th>\n",
|
681 |
-
" <td>d29f01b149167d683f9ddde464bb3db</td>\n",
|
682 |
-
" <td>An ornate Roman urn</td>\n",
|
683 |
-
" <td>Photographed at the Walters Art Museum, Baltim...</td>\n",
|
684 |
-
" <td>jpg</td>\n",
|
685 |
-
" </tr>\n",
|
686 |
-
" <tr>\n",
|
687 |
-
" <th>2</th>\n",
|
688 |
-
" <td>d296e9e34bdae41edb6c679ff824ab2a</td>\n",
|
689 |
-
" <td>Jai & Tara on the Cumberland</td>\n",
|
690 |
-
" <td>Another trip for the happy couple.</td>\n",
|
691 |
-
" <td>jpg</td>\n",
|
692 |
-
" </tr>\n",
|
693 |
-
" <tr>\n",
|
694 |
-
" <th>3</th>\n",
|
695 |
-
" <td>d29ce96395848478b1e8396e44899</td>\n",
|
696 |
-
" <td>Castle gate - \"lite-brited\"</td>\n",
|
697 |
-
" <td>Taken at the Miracle of Lights display in Cent...</td>\n",
|
698 |
-
" <td>jpg</td>\n",
|
699 |
-
" </tr>\n",
|
700 |
-
" <tr>\n",
|
701 |
-
" <th>4</th>\n",
|
702 |
-
" <td>d29abf32c4e12ff881f975b70e0cec0</td>\n",
|
703 |
-
" <td>A Picture Share!</td>\n",
|
704 |
-
" <td>Tabular</td>\n",
|
705 |
-
" <td>jpg</td>\n",
|
706 |
-
" </tr>\n",
|
707 |
-
" </tbody>\n",
|
708 |
-
"</table>\n",
|
709 |
-
"</div>"
|
710 |
-
],
|
711 |
-
"text/plain": [
|
712 |
-
" key title_clean \\\n",
|
713 |
-
"0 d29e7c6a3028418c64eb15e3cf577c2 A Picture Share! \n",
|
714 |
-
"1 d29f01b149167d683f9ddde464bb3db An ornate Roman urn \n",
|
715 |
-
"2 d296e9e34bdae41edb6c679ff824ab2a Jai & Tara on the Cumberland \n",
|
716 |
-
"3 d29ce96395848478b1e8396e44899 Castle gate - \"lite-brited\" \n",
|
717 |
-
"4 d29abf32c4e12ff881f975b70e0cec0 A Picture Share! \n",
|
718 |
-
"\n",
|
719 |
-
" description_clean ext \n",
|
720 |
-
"0 Antenna jpg \n",
|
721 |
-
"1 Photographed at the Walters Art Museum, Baltim... jpg \n",
|
722 |
-
"2 Another trip for the happy couple. jpg \n",
|
723 |
-
"3 Taken at the Miracle of Lights display in Cent... jpg \n",
|
724 |
-
"4 Tabular jpg "
|
725 |
-
]
|
726 |
-
},
|
727 |
-
"execution_count": 98,
|
728 |
-
"metadata": {},
|
729 |
-
"output_type": "execute_result"
|
730 |
-
}
|
731 |
-
],
|
732 |
-
"source": [
|
733 |
-
"# Looking at a chunk with only the relevant columns that we need\n",
|
734 |
-
"df = pd.read_csv(\"./chunks/chunk1.tsv\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
|
735 |
-
"df.head()"
|
736 |
-
]
|
737 |
-
},
|
738 |
-
{
|
739 |
-
"cell_type": "markdown",
|
740 |
-
"id": "cc1668f8",
|
741 |
-
"metadata": {},
|
742 |
-
"source": [
|
743 |
-
"### Grabbing each chunks from the folder, cleaning it up, only taking the entries which image exist and appending it to the global df"
|
744 |
-
]
|
745 |
-
},
|
746 |
-
{
|
747 |
-
"cell_type": "code",
|
748 |
-
"execution_count": null,
|
749 |
-
"id": "abbcccf3",
|
750 |
-
"metadata": {},
|
751 |
-
"outputs": [],
|
752 |
-
"source": [
|
753 |
-
"# the function that helps us to decide whether an image with certain id exists in storage, we only take the ones that we have the images for\n",
|
754 |
-
"def image_exists(item):\n",
|
755 |
-
" name, _, _, ext, _ = item\n",
|
756 |
-
" root=str(yfcc100m_images)\n",
|
757 |
-
" image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".\"+ext)\n",
|
758 |
-
" if image_path.exists():\n",
|
759 |
-
" return True\n",
|
760 |
-
" else:\n",
|
761 |
-
" return None"
|
762 |
-
]
|
763 |
-
},
|
764 |
-
{
|
765 |
-
"cell_type": "code",
|
766 |
-
"execution_count": 86,
|
767 |
-
"id": "44fa86ab",
|
768 |
-
"metadata": {},
|
769 |
-
"outputs": [],
|
770 |
-
"source": [
|
771 |
-
"# This cell does it all, grabs each chunk, cleans it up based on image existing condition, etc.\n",
|
772 |
-
"global_df = pd.DataFrame()\n",
|
773 |
-
"chunks_dir = \"./chunks\"\n",
|
774 |
-
"for filename in os.listdir(chunks_dir):\n",
|
775 |
-
" df = pd.read_csv(f\"./chunks/{str(filename)}\", sep=\"\\t\")[[\"key\", \"title_clean\", \"description_clean\", \"ext\"]]\n",
|
776 |
-
" df['caption'] = df[\"title_clean\"]+\". \"+df['description_clean']\n",
|
777 |
-
" df['is_exist'] = df.apply(image_exists, axis=1)\n",
|
778 |
-
" df = df.dropna()[[\"key\", \"caption\"]]\n",
|
779 |
-
" df.columns = ['image_file', 'caption']\n",
|
780 |
-
" global_df = global_df.append(df, ignore_index=True)"
|
781 |
-
]
|
782 |
-
},
|
783 |
-
{
|
784 |
-
"cell_type": "code",
|
785 |
-
"execution_count": 89,
|
786 |
-
"id": "45024fdc",
|
787 |
-
"metadata": {},
|
788 |
-
"outputs": [],
|
789 |
-
"source": [
|
790 |
-
"# saving the tsv to disk\n",
|
791 |
-
"global_df.to_csv('./chunks/YFCC_subset_clean.tsv', sep=\"\\t\", index=False)"
|
792 |
-
]
|
793 |
-
},
|
794 |
-
{
|
795 |
-
"cell_type": "code",
|
796 |
-
"execution_count": 101,
|
797 |
-
"id": "dca4eb73",
|
798 |
-
"metadata": {},
|
799 |
-
"outputs": [],
|
800 |
-
"source": [
|
801 |
-
"# loading the tsv from disk (for explicitness, also my electricity was gone, glad it happened after I saved to the disk :( )\n",
|
802 |
-
"\n",
|
803 |
-
"dataset = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")"
|
804 |
-
]
|
805 |
-
},
|
806 |
-
{
|
807 |
-
"cell_type": "code",
|
808 |
-
"execution_count": 153,
|
809 |
-
"id": "a511264a",
|
810 |
-
"metadata": {},
|
811 |
-
"outputs": [],
|
812 |
-
"source": [
|
813 |
-
"\"\"\"\n",
|
814 |
-
"Luke Melas-Kyriazi's dataset.py's modified version for YFCC\n",
|
815 |
-
"\"\"\"\n",
|
816 |
-
"import warnings\n",
|
817 |
-
"from typing import Optional, Callable\n",
|
818 |
-
"from pathlib import Path\n",
|
819 |
-
"import numpy as np\n",
|
820 |
-
"import torch\n",
|
821 |
-
"import pandas as pd\n",
|
822 |
-
"from torch.utils.data import Dataset\n",
|
823 |
-
"from torchvision.datasets.folder import default_loader\n",
|
824 |
-
"from PIL import ImageFile\n",
|
825 |
-
"from PIL.Image import DecompressionBombWarning\n",
|
826 |
-
"ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
|
827 |
-
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
828 |
-
"warnings.filterwarnings(\"ignore\", category=DecompressionBombWarning)\n",
|
829 |
-
"\n",
|
830 |
-
"\n",
|
831 |
-
"class CaptionDataset(Dataset):\n",
|
832 |
-
" \"\"\"\n",
|
833 |
-
" A PyTorch Dataset class for (image, texts) tasks. Note that this dataset \n",
|
834 |
-
" returns the raw text rather than tokens. This is done on purpose, because\n",
|
835 |
-
" it's easy to tokenize a batch of text after loading it from this dataset.\n",
|
836 |
-
" \"\"\"\n",
|
837 |
-
"\n",
|
838 |
-
" def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, \n",
|
839 |
-
" image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',\n",
|
840 |
-
" include_captions: bool = True):\n",
|
841 |
-
" \"\"\"\n",
|
842 |
-
" :param images_root: folder where images are stored\n",
|
843 |
-
" :param captions_path: path to csv that maps image filenames to captions\n",
|
844 |
-
" :param image_transform: image transform pipeline\n",
|
845 |
-
" :param text_transform: image transform pipeline\n",
|
846 |
-
" :param image_transform_type: image transform type, either `torchvision` or `albumentations`\n",
|
847 |
-
" :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.\n",
|
848 |
-
" \"\"\"\n",
|
849 |
-
"\n",
|
850 |
-
" # Base path for images\n",
|
851 |
-
" self.images_root = Path(images_root)\n",
|
852 |
-
"\n",
|
853 |
-
" # Load captions as DataFrame\n",
|
854 |
-
" self.captions = pd.read_csv(f\"./chunks/YFCC_subset_clean.tsv\", sep=\"\\t\")\n",
|
855 |
-
" self.captions['image_file'] = self.captions['image_file'].astype(str)\n",
|
856 |
-
"\n",
|
857 |
-
" # PyTorch transformation pipeline for the image (normalizing, etc.)\n",
|
858 |
-
" self.text_transform = text_transform\n",
|
859 |
-
" self.image_transform = image_transform\n",
|
860 |
-
" self.image_transform_type = image_transform_type.lower()\n",
|
861 |
-
" assert self.image_transform_type in ['torchvision', 'albumentations']\n",
|
862 |
-
"\n",
|
863 |
-
" # Total number of datapoints\n",
|
864 |
-
" self.size = len(self.captions)\n",
|
865 |
-
"\n",
|
866 |
-
" # Return image+captions or just images\n",
|
867 |
-
" self.include_captions = include_captions\n",
|
868 |
-
" \n",
|
869 |
-
" def image_exists(item):\n",
|
870 |
-
" name, caption = item\n",
|
871 |
-
" root=str(self.images_root)\n",
|
872 |
-
" image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
|
873 |
-
"\n",
|
874 |
-
" return image_path.exists()\n",
|
875 |
-
"\n",
|
876 |
-
" def verify_that_all_images_exist(self):\n",
|
877 |
-
" for image_file in self.captions['image_file']:\n",
|
878 |
-
" if not image_exists:\n",
|
879 |
-
" print(f'file does not exist: {p}')\n",
|
880 |
-
"\n",
|
881 |
-
" def _get_raw_image(self, i):\n",
|
882 |
-
" name = self.captions.iloc[i]['image_file']\n",
|
883 |
-
" image_path = (Path(self.images_root)/name[0:3]/name[3:6]/name).with_suffix(\".jpg\")\n",
|
884 |
-
" image = default_loader(image_path)\n",
|
885 |
-
" return image\n",
|
886 |
-
"\n",
|
887 |
-
" def _get_raw_text(self, i):\n",
|
888 |
-
" return self.captions.iloc[i]['caption']\n",
|
889 |
-
"\n",
|
890 |
-
" def __getitem__(self, i):\n",
|
891 |
-
" image = self._get_raw_image(i)\n",
|
892 |
-
" caption = self._get_raw_text(i)\n",
|
893 |
-
" if self.image_transform is not None:\n",
|
894 |
-
" if self.image_transform_type == 'torchvision':\n",
|
895 |
-
" image = self.image_transform(image)\n",
|
896 |
-
" elif self.image_transform_type == 'albumentations':\n",
|
897 |
-
" image = self.image_transform(image=np.array(image))['image']\n",
|
898 |
-
" else:\n",
|
899 |
-
" raise NotImplementedError(f\"{self.image_transform_type=}\")\n",
|
900 |
-
" return {'image': image, 'text': caption} if self.include_captions else image\n",
|
901 |
-
"\n",
|
902 |
-
" def __len__(self):\n",
|
903 |
-
" return self.size\n",
|
904 |
-
"\n",
|
905 |
-
"\n",
|
906 |
-
"if __name__ == \"__main__\":\n",
|
907 |
-
" import albumentations as A\n",
|
908 |
-
" from albumentations.pytorch import ToTensorV2\n",
|
909 |
-
" from transformers import AutoTokenizer\n",
|
910 |
-
" \n",
|
911 |
-
"\n",
|
912 |
-
" images_root = \"/home/khali/TPU-Test/YFCC100M_OpenAI_subset/data/data/images\"\n",
|
913 |
-
" captions_path = './YFCC_subset_clean.tsv'\n",
|
914 |
-
" image_size = 256\n",
|
915 |
-
" \n",
|
916 |
-
" # Create transforms\n",
|
917 |
-
" def image_transform(image):\n",
|
918 |
-
" s = min(image.size)\n",
|
919 |
-
" r = image_size / s\n",
|
920 |
-
" s = (round(r * image.size[1]), round(r * image.size[0]))\n",
|
921 |
-
" image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
|
922 |
-
" image = TF.center_crop(image, output_size = 2 * [image_size])\n",
|
923 |
-
" image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
|
924 |
-
" image = image.permute(0, 2, 3, 1).numpy()\n",
|
925 |
-
" return image\n",
|
926 |
-
" \n",
|
927 |
-
" # Create dataset\n",
|
928 |
-
" dataset = CaptionDataset(\n",
|
929 |
-
" images_root=images_root,\n",
|
930 |
-
" captions_path=captions_path,\n",
|
931 |
-
" image_transform=image_transform,\n",
|
932 |
-
" image_transform_type='torchvision',\n",
|
933 |
-
" include_captions=False\n",
|
934 |
-
" )"
|
935 |
-
]
|
936 |
-
},
|
937 |
-
{
|
938 |
-
"cell_type": "code",
|
939 |
-
"execution_count": 155,
|
940 |
-
"id": "cc922704",
|
941 |
-
"metadata": {},
|
942 |
-
"outputs": [
|
943 |
-
{
|
944 |
-
"data": {
|
945 |
-
"text/plain": [
|
946 |
-
"2483316"
|
947 |
-
]
|
948 |
-
},
|
949 |
-
"execution_count": 155,
|
950 |
-
"metadata": {},
|
951 |
-
"output_type": "execute_result"
|
952 |
-
}
|
953 |
-
],
|
954 |
-
"source": [
|
955 |
-
"len(dataset)"
|
956 |
-
]
|
957 |
-
},
|
958 |
-
{
|
959 |
-
"cell_type": "code",
|
960 |
-
"execution_count": 156,
|
961 |
-
"id": "6e47ba46",
|
962 |
-
"metadata": {},
|
963 |
-
"outputs": [],
|
964 |
-
"source": [
|
965 |
-
"dataloader = DataLoader(dataset, batch_size=32, num_workers=4)"
|
966 |
-
]
|
967 |
-
},
|
968 |
-
{
|
969 |
-
"cell_type": "code",
|
970 |
-
"execution_count": 1,
|
971 |
-
"id": "c8a130eb",
|
972 |
-
"metadata": {},
|
973 |
-
"outputs": [],
|
974 |
-
"source": [
|
975 |
-
"# looking at a batch\n",
|
976 |
-
"next(iter(dataloader))"
|
977 |
-
]
|
978 |
-
},
|
979 |
-
{
|
980 |
-
"cell_type": "code",
|
981 |
-
"execution_count": null,
|
982 |
-
"id": "c192fd44",
|
983 |
-
"metadata": {},
|
984 |
-
"outputs": [],
|
985 |
-
"source": [
|
986 |
-
"# import matplotlib.pyplot as plt\n",
|
987 |
-
"# for tensor_image, _ in dataloader:\n",
|
988 |
-
"# print(tensor_image)\n",
|
989 |
-
"# plt.imshow(tensor_image.permute(1, 2, 0))\n",
|
990 |
-
"# break"
|
991 |
-
]
|
992 |
-
},
|
993 |
-
{
|
994 |
-
"cell_type": "markdown",
|
995 |
-
"id": "62ad01c3",
|
996 |
-
"metadata": {},
|
997 |
-
"source": [
|
998 |
-
"## Encoding"
|
999 |
-
]
|
1000 |
-
},
|
1001 |
-
{
|
1002 |
-
"cell_type": "code",
|
1003 |
-
"execution_count": 158,
|
1004 |
-
"id": "88f36d0b",
|
1005 |
-
"metadata": {},
|
1006 |
-
"outputs": [],
|
1007 |
-
"source": [
|
1008 |
-
"def encode(model, batch):\n",
|
1009 |
-
"# print(\"jitting encode function\")\n",
|
1010 |
-
" _, indices = model.encode(batch)\n",
|
1011 |
-
" return indices"
|
1012 |
-
]
|
1013 |
-
},
|
1014 |
-
{
|
1015 |
-
"cell_type": "code",
|
1016 |
-
"execution_count": 160,
|
1017 |
-
"id": "1f35f0cb",
|
1018 |
-
"metadata": {},
|
1019 |
-
"outputs": [],
|
1020 |
-
"source": [
|
1021 |
-
"def superbatch_generator(dataloader, num_tpus):\n",
|
1022 |
-
" iter_loader = iter(dataloader)\n",
|
1023 |
-
" for batch in iter_loader:\n",
|
1024 |
-
" superbatch = [batch.squeeze(1)]\n",
|
1025 |
-
" try:\n",
|
1026 |
-
" for b in range(num_tpus-1):\n",
|
1027 |
-
" batch = next(iter_loader)\n",
|
1028 |
-
" if batch is None:\n",
|
1029 |
-
" break\n",
|
1030 |
-
" # Skip incomplete last batch\n",
|
1031 |
-
" if batch.shape[0] == dataloader.batch_size:\n",
|
1032 |
-
" superbatch.append(batch.squeeze(1))\n",
|
1033 |
-
" except StopIteration:\n",
|
1034 |
-
" pass\n",
|
1035 |
-
" superbatch = torch.stack(superbatch, axis=0)\n",
|
1036 |
-
" yield superbatch"
|
1037 |
-
]
|
1038 |
-
},
|
1039 |
-
{
|
1040 |
-
"cell_type": "code",
|
1041 |
-
"execution_count": 170,
|
1042 |
-
"id": "2210705b",
|
1043 |
-
"metadata": {},
|
1044 |
-
"outputs": [],
|
1045 |
-
"source": [
|
1046 |
-
"import os\n",
|
1047 |
-
"\n",
|
1048 |
-
"def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
|
1049 |
-
" if os.path.isfile(output_tsv):\n",
|
1050 |
-
" print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
|
1051 |
-
" return\n",
|
1052 |
-
" \n",
|
1053 |
-
" num_tpus = 8 \n",
|
1054 |
-
" dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
|
1055 |
-
" superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
|
1056 |
-
" \n",
|
1057 |
-
" p_encoder = pmap(lambda batch: encode(model, batch))\n",
|
1058 |
-
"\n",
|
1059 |
-
" # We save each superbatch to avoid reallocation of buffers as we process them.\n",
|
1060 |
-
" # We keep the file open to prevent excessive file seeks.\n",
|
1061 |
-
" with open(output_tsv, \"w\") as file:\n",
|
1062 |
-
" iterations = len(dataset) // (batch_size * num_tpus)\n",
|
1063 |
-
" for n in tqdm(range(iterations)):\n",
|
1064 |
-
" superbatch = next(superbatches)\n",
|
1065 |
-
" encoded = p_encoder(superbatch.numpy())\n",
|
1066 |
-
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
1067 |
-
"\n",
|
1068 |
-
" # Extract fields from the dataset internal `captions` property, and save to disk\n",
|
1069 |
-
" start_index = n * batch_size * num_tpus\n",
|
1070 |
-
" end_index = (n+1) * batch_size * num_tpus\n",
|
1071 |
-
" paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
|
1072 |
-
" captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
|
1073 |
-
" encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
|
1074 |
-
" batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
|
1075 |
-
" batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)"
|
1076 |
-
]
|
1077 |
-
},
|
1078 |
-
{
|
1079 |
-
"cell_type": "code",
|
1080 |
-
"execution_count": 171,
|
1081 |
-
"id": "7704863d",
|
1082 |
-
"metadata": {},
|
1083 |
-
"outputs": [
|
1084 |
-
{
|
1085 |
-
"name": "stderr",
|
1086 |
-
"output_type": "stream",
|
1087 |
-
"text": [
|
1088 |
-
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4850/4850 [2:27:51<00:00, 1.83s/it]\n"
|
1089 |
-
]
|
1090 |
-
}
|
1091 |
-
],
|
1092 |
-
"source": [
|
1093 |
-
"encode_captioned_dataset(dataset, yfcc100m_output, batch_size=64, num_workers=16)"
|
1094 |
-
]
|
1095 |
-
},
|
1096 |
-
{
|
1097 |
-
"cell_type": "markdown",
|
1098 |
-
"id": "8953dd84",
|
1099 |
-
"metadata": {},
|
1100 |
-
"source": [
|
1101 |
-
"----"
|
1102 |
-
]
|
1103 |
-
}
|
1104 |
-
],
|
1105 |
-
"metadata": {
|
1106 |
-
"interpreter": {
|
1107 |
-
"hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
|
1108 |
-
},
|
1109 |
-
"kernelspec": {
|
1110 |
-
"display_name": "Python 3 (ipykernel)",
|
1111 |
-
"language": "python",
|
1112 |
-
"name": "python3"
|
1113 |
-
},
|
1114 |
-
"language_info": {
|
1115 |
-
"codemirror_mode": {
|
1116 |
-
"name": "ipython",
|
1117 |
-
"version": 3
|
1118 |
-
},
|
1119 |
-
"file_extension": ".py",
|
1120 |
-
"mimetype": "text/x-python",
|
1121 |
-
"name": "python",
|
1122 |
-
"nbconvert_exporter": "python",
|
1123 |
-
"pygments_lexer": "ipython3",
|
1124 |
-
"version": "3.8.10"
|
1125 |
-
}
|
1126 |
-
},
|
1127 |
-
"nbformat": 4,
|
1128 |
-
"nbformat_minor": 5
|
1129 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/encoding/vqgan-jax-encoding.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
dev/environment.yaml
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
name: dalle
|
2 |
-
channels:
|
3 |
-
- defaults
|
4 |
-
dependencies:
|
5 |
-
- python=3.9.5
|
6 |
-
- pip=21.1.3
|
7 |
-
- ipython=7.22.0
|
8 |
-
- cudatoolkit
|
9 |
-
- pip:
|
10 |
-
- -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/requirements.txt
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
requests
|
2 |
-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
3 |
-
jax[tpu]>=0.2.16
|
4 |
-
transformers
|
5 |
-
datasets
|
6 |
-
flax
|
7 |
-
jupyter
|
8 |
-
wandb
|
9 |
-
nltk
|
10 |
-
optax
|
11 |
-
git+https://github.com/patil-suraj/vqgan-jax.git@610d842dd33c739325a944102ed33acc07692dd5
|
12 |
-
|
13 |
-
# Inference
|
14 |
-
ftfy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/seq2seq/do_big_run.sh
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
python run_seq2seq_flax.py \
|
2 |
-
--dataset_repo_or_path dalle-mini/encoded \
|
3 |
-
--train_file **/train/*/*.jsonl \
|
4 |
-
--validation_file **/valid/*/*.jsonl \
|
5 |
-
--len_train 129847128 \
|
6 |
-
--len_eval 157312 \
|
7 |
-
--eval_steps 1000 \
|
8 |
-
--streaming \
|
9 |
-
--normalize_text \
|
10 |
-
--output_dir output \
|
11 |
-
--per_device_train_batch_size 56 \
|
12 |
-
--per_device_eval_batch_size 56 \
|
13 |
-
--preprocessing_num_workers 80 \
|
14 |
-
--warmup_steps 5000 \
|
15 |
-
--gradient_accumulation_steps 8 \
|
16 |
-
--do_train \
|
17 |
-
--do_eval \
|
18 |
-
--adafactor \
|
19 |
-
--num_train_epochs 6 \
|
20 |
-
--log_model \
|
21 |
-
--learning_rate 0.005
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/seq2seq/do_small_run.sh
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
python run_seq2seq_flax.py \
|
2 |
-
--dataset_repo_or_path dalle-mini/encoded \
|
3 |
-
--train_file **/train/CC3M/*.jsonl \
|
4 |
-
--validation_file **/valid/*/*.jsonl \
|
5 |
-
--len_train 129847128 \
|
6 |
-
--len_eval 157312 \
|
7 |
-
--streaming \
|
8 |
-
--output_dir output \
|
9 |
-
--per_device_train_batch_size 16 \
|
10 |
-
--per_device_eval_batch_size 16 \
|
11 |
-
--preprocessing_num_workers 80 \
|
12 |
-
--warmup_steps 125 \
|
13 |
-
--gradient_accumulation_steps 8 \
|
14 |
-
--do_train \
|
15 |
-
--do_eval \
|
16 |
-
--adafactor \
|
17 |
-
--num_train_epochs 1 \
|
18 |
-
--max_train_samples 10000 \
|
19 |
-
--learning_rate 0.005
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev/vqgan/JAX_VQGAN_f16_16384_Reconstruction.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[tool.isort]
|
2 |
+
profile = "black"
|
setup.cfg
CHANGED
@@ -16,3 +16,11 @@ install_requires =
|
|
16 |
ftfy
|
17 |
jax
|
18 |
flax
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
ftfy
|
17 |
jax
|
18 |
flax
|
19 |
+
|
20 |
+
[options.extras_require]
|
21 |
+
dev =
|
22 |
+
tqdm
|
23 |
+
wandb
|
24 |
+
optax
|
25 |
+
black[jupyter]
|
26 |
+
isort
|
setup.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
from setuptools import setup
|
2 |
|
3 |
if __name__ == "__main__":
|
4 |
-
setup()
|
|
|
1 |
from setuptools import setup
|
2 |
|
3 |
if __name__ == "__main__":
|
4 |
+
setup()
|
dev/encoding/vqgan-jax-encoding-webdataset.ipynb → tools/dataset/encode_dataset.ipynb
RENAMED
@@ -5,7 +5,7 @@
|
|
5 |
"id": "d0b72877",
|
6 |
"metadata": {},
|
7 |
"source": [
|
8 |
-
"#
|
9 |
]
|
10 |
},
|
11 |
{
|
@@ -15,7 +15,11 @@
|
|
15 |
"source": [
|
16 |
"This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
|
17 |
"\n",
|
18 |
-
"
|
|
|
|
|
|
|
|
|
19 |
]
|
20 |
},
|
21 |
{
|
@@ -25,19 +29,15 @@
|
|
25 |
"metadata": {},
|
26 |
"outputs": [],
|
27 |
"source": [
|
28 |
-
"
|
29 |
-
"from tqdm import tqdm\n",
|
30 |
"\n",
|
31 |
-
"import torch\n",
|
32 |
"import torchvision.transforms as T\n",
|
33 |
-
"import torchvision.transforms.functional as TF\n",
|
34 |
-
"from torchvision.transforms import InterpolationMode\n",
|
35 |
-
"import math\n",
|
36 |
"\n",
|
37 |
"import webdataset as wds\n",
|
38 |
"\n",
|
39 |
"import jax\n",
|
40 |
-
"
|
|
|
41 |
]
|
42 |
},
|
43 |
{
|
@@ -45,184 +45,110 @@
|
|
45 |
"id": "c7c4c1e6",
|
46 |
"metadata": {},
|
47 |
"source": [
|
48 |
-
"##
|
49 |
-
]
|
50 |
-
},
|
51 |
-
{
|
52 |
-
"cell_type": "markdown",
|
53 |
-
"id": "9822850f",
|
54 |
-
"metadata": {},
|
55 |
-
"source": [
|
56 |
-
"The following is the list of shards we'll process. We hardcode the length of data so that we can see nice progress bars using `tqdm`."
|
57 |
]
|
58 |
},
|
59 |
{
|
60 |
"cell_type": "code",
|
61 |
-
"execution_count":
|
62 |
"id": "1265dbfe",
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
66 |
-
"shards =
|
67 |
-
"
|
68 |
-
]
|
69 |
-
},
|
70 |
-
{
|
71 |
-
"cell_type": "markdown",
|
72 |
-
"id": "7e38fa14",
|
73 |
-
"metadata": {},
|
74 |
-
"source": [
|
75 |
-
"If we are extra cautious or our server is unreliable, we can enable retries by providing a custom `curl` retrieval command:"
|
76 |
-
]
|
77 |
-
},
|
78 |
-
{
|
79 |
-
"cell_type": "code",
|
80 |
-
"execution_count": null,
|
81 |
-
"id": "4c8c5960",
|
82 |
-
"metadata": {},
|
83 |
-
"outputs": [],
|
84 |
-
"source": [
|
85 |
-
"# Enable curl retries to try to work around temporary network / server errors.\n",
|
86 |
-
"# This shouldn't be necessary when using reliable servers.\n",
|
87 |
-
"# shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'"
|
88 |
-
]
|
89 |
-
},
|
90 |
-
{
|
91 |
-
"cell_type": "code",
|
92 |
-
"execution_count": null,
|
93 |
-
"id": "13c6631b",
|
94 |
-
"metadata": {},
|
95 |
-
"outputs": [],
|
96 |
-
"source": [
|
97 |
-
"from pathlib import Path\n",
|
98 |
"\n",
|
99 |
-
"
|
100 |
-
"
|
|
|
|
|
101 |
"\n",
|
102 |
-
"
|
103 |
-
"
|
|
|
|
|
|
|
104 |
]
|
105 |
},
|
106 |
{
|
107 |
"cell_type": "code",
|
108 |
-
"execution_count":
|
109 |
-
"id": "
|
110 |
-
"metadata": {},
|
111 |
-
"outputs": [
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
]
|
116 |
},
|
117 |
{
|
118 |
"cell_type": "markdown",
|
119 |
-
"id": "
|
120 |
"metadata": {},
|
121 |
"source": [
|
122 |
-
"
|
123 |
-
]
|
124 |
-
},
|
125 |
-
{
|
126 |
-
"cell_type": "code",
|
127 |
-
"execution_count": null,
|
128 |
-
"id": "669b35df",
|
129 |
-
"metadata": {},
|
130 |
-
"outputs": [],
|
131 |
-
"source": [
|
132 |
-
"def center_crop(image, max_size=256):\n",
|
133 |
-
" # Note: we allow upscaling too. We should exclude small images. \n",
|
134 |
-
" image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
|
135 |
-
" image = TF.center_crop(image, output_size=2 * [max_size])\n",
|
136 |
-
" return image\n",
|
137 |
-
"\n",
|
138 |
-
"preprocess_image = T.Compose([\n",
|
139 |
-
" center_crop,\n",
|
140 |
-
" T.ToTensor(),\n",
|
141 |
-
" lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
|
142 |
-
"])"
|
143 |
]
|
144 |
},
|
145 |
{
|
146 |
"cell_type": "markdown",
|
147 |
-
"id": "
|
148 |
"metadata": {},
|
149 |
"source": [
|
150 |
-
"
|
151 |
-
"\n",
|
152 |
-
"Note that we receive the contents of the `json` structure, which will be replaced by the string we return.\n",
|
153 |
-
"If we want to keep other fields inside `json`, we can add `caption` as a new field."
|
154 |
]
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
"execution_count": null,
|
159 |
-
"id": "
|
160 |
"metadata": {},
|
161 |
"outputs": [],
|
162 |
"source": [
|
163 |
-
"
|
164 |
-
"
|
165 |
-
"
|
166 |
-
"
|
167 |
-
"
|
|
|
168 |
]
|
169 |
},
|
170 |
{
|
171 |
"cell_type": "markdown",
|
172 |
-
"id": "
|
173 |
"metadata": {},
|
174 |
"source": [
|
175 |
-
"
|
176 |
-
"
|
177 |
-
"
|
178 |
-
"
|
179 |
-
"We can also create our custom exception handler as demonstrated here:"
|
180 |
]
|
181 |
},
|
182 |
{
|
183 |
-
"cell_type": "
|
184 |
-
"
|
185 |
-
"id": "369d9719",
|
186 |
-
"metadata": {},
|
187 |
-
"outputs": [],
|
188 |
-
"source": [
|
189 |
-
"# UNUSED - Log exceptions to a file\n",
|
190 |
-
"def ignore_and_log(exn):\n",
|
191 |
-
" with open('errors.txt', 'a') as f:\n",
|
192 |
-
" f.write(f'{repr(exn)}\\n')\n",
|
193 |
-
" return True"
|
194 |
-
]
|
195 |
-
},
|
196 |
-
{
|
197 |
-
"cell_type": "code",
|
198 |
-
"execution_count": null,
|
199 |
-
"id": "27de1414",
|
200 |
-
"metadata": {},
|
201 |
-
"outputs": [],
|
202 |
-
"source": [
|
203 |
-
"# Or simply use `wds.ignore_and_continue`\n",
|
204 |
-
"exception_handler = wds.warn_and_continue"
|
205 |
-
]
|
206 |
-
},
|
207 |
-
{
|
208 |
-
"cell_type": "code",
|
209 |
-
"execution_count": null,
|
210 |
-
"id": "5149b6d5",
|
211 |
"metadata": {},
|
212 |
-
"outputs": [],
|
213 |
"source": [
|
214 |
-
"
|
215 |
-
" length=batches, # Hint so `len` is implemented\n",
|
216 |
-
" shardshuffle=False, # Keep same order for encoded files for easier bookkeeping. Set to `True` for training.\n",
|
217 |
-
" handler=exception_handler, # Ignore read errors instead of failing.\n",
|
218 |
-
")\n",
|
219 |
-
"\n",
|
220 |
-
"dataset = (dataset \n",
|
221 |
-
" .decode('pil') # decode image with PIL\n",
|
222 |
-
"# .map_dict(jpg=preprocess_image, json=create_caption, handler=exception_handler) # Process fields with functions defined above\n",
|
223 |
-
" .map_dict(jpg=preprocess_image, json=create_caption) # Process fields with functions defined above\n",
|
224 |
-
" .to_tuple('__key__', 'jpg', 'json') # filter to keep only key (for reference), image, caption.\n",
|
225 |
-
" .batched(bs)) # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it"
|
226 |
]
|
227 |
},
|
228 |
{
|
@@ -235,7 +161,7 @@
|
|
235 |
"outputs": [],
|
236 |
"source": [
|
237 |
"%%time\n",
|
238 |
-
"
|
239 |
]
|
240 |
},
|
241 |
{
|
@@ -251,54 +177,50 @@
|
|
251 |
{
|
252 |
"cell_type": "code",
|
253 |
"execution_count": null,
|
254 |
-
"id": "
|
255 |
"metadata": {},
|
256 |
"outputs": [],
|
257 |
"source": [
|
258 |
-
"
|
259 |
-
]
|
260 |
-
},
|
261 |
-
{
|
262 |
-
"cell_type": "markdown",
|
263 |
-
"id": "44d50a51",
|
264 |
-
"metadata": {},
|
265 |
-
"source": [
|
266 |
-
"### Torch DataLoader"
|
267 |
]
|
268 |
},
|
269 |
{
|
270 |
"cell_type": "code",
|
271 |
"execution_count": null,
|
272 |
-
"id": "
|
273 |
"metadata": {},
|
274 |
"outputs": [],
|
275 |
"source": [
|
276 |
-
"
|
277 |
]
|
278 |
},
|
279 |
{
|
280 |
"cell_type": "markdown",
|
281 |
-
"id": "
|
282 |
"metadata": {},
|
283 |
"source": [
|
284 |
-
"
|
285 |
]
|
286 |
},
|
287 |
{
|
288 |
"cell_type": "code",
|
289 |
"execution_count": null,
|
290 |
-
"id": "
|
291 |
"metadata": {},
|
292 |
"outputs": [],
|
293 |
"source": [
|
294 |
-
"
|
|
|
|
|
295 |
]
|
296 |
},
|
297 |
{
|
298 |
"cell_type": "markdown",
|
299 |
-
"id": "
|
300 |
"metadata": {},
|
301 |
"source": [
|
|
|
|
|
302 |
"We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
|
303 |
]
|
304 |
},
|
@@ -311,7 +233,11 @@
|
|
311 |
},
|
312 |
"outputs": [],
|
313 |
"source": [
|
314 |
-
"
|
|
|
|
|
|
|
|
|
315 |
]
|
316 |
},
|
317 |
{
|
@@ -327,18 +253,7 @@
|
|
327 |
"id": "20357f74",
|
328 |
"metadata": {},
|
329 |
"source": [
|
330 |
-
"Encoding is really simple using `shard` to automatically distribute
|
331 |
-
]
|
332 |
-
},
|
333 |
-
{
|
334 |
-
"cell_type": "code",
|
335 |
-
"execution_count": null,
|
336 |
-
"id": "6686b004",
|
337 |
-
"metadata": {},
|
338 |
-
"outputs": [],
|
339 |
-
"source": [
|
340 |
-
"from flax.training.common_utils import shard\n",
|
341 |
-
"from functools import partial"
|
342 |
]
|
343 |
},
|
344 |
{
|
@@ -348,21 +263,17 @@
|
|
348 |
"metadata": {},
|
349 |
"outputs": [],
|
350 |
"source": [
|
|
|
|
|
|
|
|
|
351 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
352 |
-
"def
|
353 |
" # Not sure if we should `replicate` params, does not seem to have any effect\n",
|
354 |
-
" _, indices =
|
355 |
" return indices"
|
356 |
]
|
357 |
},
|
358 |
-
{
|
359 |
-
"cell_type": "markdown",
|
360 |
-
"id": "14375a41",
|
361 |
-
"metadata": {},
|
362 |
-
"source": [
|
363 |
-
"### Encoding loop"
|
364 |
-
]
|
365 |
-
},
|
366 |
{
|
367 |
"cell_type": "code",
|
368 |
"execution_count": null,
|
@@ -370,49 +281,48 @@
|
|
370 |
"metadata": {},
|
371 |
"outputs": [],
|
372 |
"source": [
|
373 |
-
"import os\n",
|
374 |
"import pandas as pd\n",
|
375 |
"\n",
|
376 |
-
"def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
|
377 |
-
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
378 |
"\n",
|
379 |
-
"
|
380 |
-
"
|
381 |
-
"
|
382 |
-
"
|
383 |
-
"
|
384 |
-
" for
|
385 |
-
"
|
386 |
-
"
|
387 |
-
"
|
388 |
-
"
|
389 |
-
"
|
390 |
-
"\n",
|
391 |
-
"
|
392 |
-
"
|
|
|
|
|
|
|
|
|
393 |
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
|
|
|
|
394 |
"\n",
|
395 |
-
"
|
396 |
-
"
|
397 |
-
"
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
"
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
"metadata": {},
|
413 |
-
"outputs": [],
|
414 |
-
"source": [
|
415 |
-
"save_every = 318"
|
416 |
]
|
417 |
},
|
418 |
{
|
@@ -422,7 +332,7 @@
|
|
422 |
"metadata": {},
|
423 |
"outputs": [],
|
424 |
"source": [
|
425 |
-
"
|
426 |
]
|
427 |
},
|
428 |
{
|
@@ -453,7 +363,7 @@
|
|
453 |
"name": "python",
|
454 |
"nbconvert_exporter": "python",
|
455 |
"pygments_lexer": "ipython3",
|
456 |
-
"version": "3.
|
457 |
}
|
458 |
},
|
459 |
"nbformat": 4,
|
|
|
5 |
"id": "d0b72877",
|
6 |
"metadata": {},
|
7 |
"source": [
|
8 |
+
"# Pre-encoding a dataset for DALLE·mini"
|
9 |
]
|
10 |
},
|
11 |
{
|
|
|
15 |
"source": [
|
16 |
"This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
|
17 |
"\n",
|
18 |
+
"Adapt it to your own dataset and image encoder.\n",
|
19 |
+
"\n",
|
20 |
+
"At the end you should have a dataset of pairs:\n",
|
21 |
+
"* a caption defined as a string\n",
|
22 |
+
"* an encoded image defined as a list of int."
|
23 |
]
|
24 |
},
|
25 |
{
|
|
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
32 |
+
"from tqdm.notebook import tqdm\n",
|
|
|
33 |
"\n",
|
|
|
34 |
"import torchvision.transforms as T\n",
|
|
|
|
|
|
|
35 |
"\n",
|
36 |
"import webdataset as wds\n",
|
37 |
"\n",
|
38 |
"import jax\n",
|
39 |
+
"import braceexpand\n",
|
40 |
+
"from pathlib import Path"
|
41 |
]
|
42 |
},
|
43 |
{
|
|
|
45 |
"id": "c7c4c1e6",
|
46 |
"metadata": {},
|
47 |
"source": [
|
48 |
+
"## Configuration Parameters"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
]
|
50 |
},
|
51 |
{
|
52 |
"cell_type": "code",
|
53 |
+
"execution_count": 3,
|
54 |
"id": "1265dbfe",
|
55 |
"metadata": {},
|
56 |
"outputs": [],
|
57 |
"source": [
|
58 |
+
"shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
|
59 |
+
"encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
"\n",
|
61 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
62 |
+
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
63 |
+
" \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
|
64 |
+
")\n",
|
65 |
"\n",
|
66 |
+
"# good defaults for a TPU v3-8\n",
|
67 |
+
"batch_size = 128 # Per device\n",
|
68 |
+
"num_workers = 8 # For parallel processing\n",
|
69 |
+
"total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
|
70 |
+
"save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
|
71 |
]
|
72 |
},
|
73 |
{
|
74 |
"cell_type": "code",
|
75 |
+
"execution_count": 5,
|
76 |
+
"id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [
|
79 |
+
{
|
80 |
+
"data": {
|
81 |
+
"text/plain": [
|
82 |
+
"['XXX/shard-0000.tar',\n",
|
83 |
+
" 'XXX/shard-0001.tar',\n",
|
84 |
+
" 'XXX/shard-0002.tar',\n",
|
85 |
+
" 'XXX/shard-0003.tar',\n",
|
86 |
+
" 'XXX/shard-0004.tar',\n",
|
87 |
+
" 'XXX/shard-0005.tar',\n",
|
88 |
+
" 'XXX/shard-0006.tar',\n",
|
89 |
+
" 'XXX/shard-0007.tar',\n",
|
90 |
+
" 'XXX/shard-0008.tar']"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
"execution_count": 5,
|
94 |
+
"metadata": {},
|
95 |
+
"output_type": "execute_result"
|
96 |
+
}
|
97 |
+
],
|
98 |
+
"source": [
|
99 |
+
"shards = list(\n",
|
100 |
+
" braceexpand.braceexpand(shards)\n",
|
101 |
+
") # better display for tqdm with known length"
|
102 |
]
|
103 |
},
|
104 |
{
|
105 |
"cell_type": "markdown",
|
106 |
+
"id": "75dba8e2",
|
107 |
"metadata": {},
|
108 |
"source": [
|
109 |
+
"## Load data"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
]
|
111 |
},
|
112 |
{
|
113 |
"cell_type": "markdown",
|
114 |
+
"id": "a1e8fb95",
|
115 |
"metadata": {},
|
116 |
"source": [
|
117 |
+
"We load data using `webdataset`."
|
|
|
|
|
|
|
118 |
]
|
119 |
},
|
120 |
{
|
121 |
"cell_type": "code",
|
122 |
"execution_count": null,
|
123 |
+
"id": "9ef5de9e",
|
124 |
"metadata": {},
|
125 |
"outputs": [],
|
126 |
"source": [
|
127 |
+
"ds = (\n",
|
128 |
+
" wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
|
129 |
+
" .decode(\"rgb\", handler=wds.warn_and_continue)\n",
|
130 |
+
" .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
|
131 |
+
" .batched(total_bs) # load in batch per worker (faster)\n",
|
132 |
+
")"
|
133 |
]
|
134 |
},
|
135 |
{
|
136 |
"cell_type": "markdown",
|
137 |
+
"id": "90981824",
|
138 |
"metadata": {},
|
139 |
"source": [
|
140 |
+
"Note:\n",
|
141 |
+
"* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
|
142 |
+
"* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
|
143 |
+
"* you can also filter out some items using `select`."
|
|
|
144 |
]
|
145 |
},
|
146 |
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"id": "129c377d",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
"metadata": {},
|
|
|
150 |
"source": [
|
151 |
+
"We can now inspect our data."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
]
|
153 |
},
|
154 |
{
|
|
|
161 |
"outputs": [],
|
162 |
"source": [
|
163 |
"%%time\n",
|
164 |
+
"images, captions = next(iter(ds))"
|
165 |
]
|
166 |
},
|
167 |
{
|
|
|
177 |
{
|
178 |
"cell_type": "code",
|
179 |
"execution_count": null,
|
180 |
+
"id": "5acfc4d8",
|
181 |
"metadata": {},
|
182 |
"outputs": [],
|
183 |
"source": [
|
184 |
+
"captions[:10]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
]
|
186 |
},
|
187 |
{
|
188 |
"cell_type": "code",
|
189 |
"execution_count": null,
|
190 |
+
"id": "c24693c0",
|
191 |
"metadata": {},
|
192 |
"outputs": [],
|
193 |
"source": [
|
194 |
+
"T.ToPILImage()(images[0].permute(2, 0, 1))"
|
195 |
]
|
196 |
},
|
197 |
{
|
198 |
"cell_type": "markdown",
|
199 |
+
"id": "3059ffb1",
|
200 |
"metadata": {},
|
201 |
"source": [
|
202 |
+
"Finally we create our dataloader."
|
203 |
]
|
204 |
},
|
205 |
{
|
206 |
"cell_type": "code",
|
207 |
"execution_count": null,
|
208 |
+
"id": "c227c551",
|
209 |
"metadata": {},
|
210 |
"outputs": [],
|
211 |
"source": [
|
212 |
+
"dl = (\n",
|
213 |
+
" wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
|
214 |
+
") # avoid partial batch at the end of each worker"
|
215 |
]
|
216 |
},
|
217 |
{
|
218 |
"cell_type": "markdown",
|
219 |
+
"id": "a354472b",
|
220 |
"metadata": {},
|
221 |
"source": [
|
222 |
+
"## Image encoder\n",
|
223 |
+
"\n",
|
224 |
"We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
|
225 |
]
|
226 |
},
|
|
|
233 |
},
|
234 |
"outputs": [],
|
235 |
"source": [
|
236 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
237 |
+
"from flax.jax_utils import replicate\n",
|
238 |
+
"\n",
|
239 |
+
"vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
|
240 |
+
"vqgan_params = replicate(vqgan.params)"
|
241 |
]
|
242 |
},
|
243 |
{
|
|
|
253 |
"id": "20357f74",
|
254 |
"metadata": {},
|
255 |
"source": [
|
256 |
+
"Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
]
|
258 |
},
|
259 |
{
|
|
|
263 |
"metadata": {},
|
264 |
"outputs": [],
|
265 |
"source": [
|
266 |
+
"from flax.training.common_utils import shard\n",
|
267 |
+
"from functools import partial\n",
|
268 |
+
"\n",
|
269 |
+
"\n",
|
270 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
271 |
+
"def p_encode(batch, params):\n",
|
272 |
" # Not sure if we should `replicate` params, does not seem to have any effect\n",
|
273 |
+
" _, indices = vqgan.encode(batch, params=params)\n",
|
274 |
" return indices"
|
275 |
]
|
276 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
{
|
278 |
"cell_type": "code",
|
279 |
"execution_count": null,
|
|
|
281 |
"metadata": {},
|
282 |
"outputs": [],
|
283 |
"source": [
|
|
|
284 |
"import pandas as pd\n",
|
285 |
"\n",
|
|
|
|
|
286 |
"\n",
|
287 |
+
"def encode_dataset(dataloader, output_dir, save_frequency):\n",
|
288 |
+
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
289 |
+
" all_captions = []\n",
|
290 |
+
" all_encoding = []\n",
|
291 |
+
" n_file = 1\n",
|
292 |
+
" for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
|
293 |
+
" images = images.numpy()\n",
|
294 |
+
" n = len(images) // 8 * 8\n",
|
295 |
+
" if n != len(images):\n",
|
296 |
+
" # get the max number of images we can (multiple of 8)\n",
|
297 |
+
" print(f\"Different sizes {n} vs {len(images)}\")\n",
|
298 |
+
" images = images[:n]\n",
|
299 |
+
" captions = captions[:n]\n",
|
300 |
+
" if not len(captions):\n",
|
301 |
+
" print(f\"No images/captions in batch...\")\n",
|
302 |
+
" continue\n",
|
303 |
+
" images = shard(images)\n",
|
304 |
+
" encoded = p_encode(images, vqgan_params)\n",
|
305 |
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
306 |
+
" all_captions.extend(captions)\n",
|
307 |
+
" all_encoding.extend(encoded.tolist())\n",
|
308 |
"\n",
|
309 |
+
" # save files\n",
|
310 |
+
" if (idx + 1) % save_frequency == 0:\n",
|
311 |
+
" print(f\"Saving file {n_file}\")\n",
|
312 |
+
" batch_df = pd.DataFrame.from_dict(\n",
|
313 |
+
" {\"caption\": all_captions, \"encoding\": all_encoding}\n",
|
314 |
+
" )\n",
|
315 |
+
" batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
|
316 |
+
" all_captions = []\n",
|
317 |
+
" all_encoding = []\n",
|
318 |
+
" n_file += 1\n",
|
319 |
+
"\n",
|
320 |
+
" if len(all_captions):\n",
|
321 |
+
" print(f\"Saving final file {n_file}\")\n",
|
322 |
+
" batch_df = pd.DataFrame.from_dict(\n",
|
323 |
+
" {\"caption\": all_captions, \"encoding\": all_encoding}\n",
|
324 |
+
" )\n",
|
325 |
+
" batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
|
|
|
|
|
|
|
|
|
326 |
]
|
327 |
},
|
328 |
{
|
|
|
332 |
"metadata": {},
|
333 |
"outputs": [],
|
334 |
"source": [
|
335 |
+
"encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
|
336 |
]
|
337 |
},
|
338 |
{
|
|
|
363 |
"name": "python",
|
364 |
"nbconvert_exporter": "python",
|
365 |
"pygments_lexer": "ipython3",
|
366 |
+
"version": "3.9.7"
|
367 |
}
|
368 |
},
|
369 |
"nbformat": 4,
|
tools/inference/inference_pipeline.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
tools/inference/log_inference_samples.ipynb
CHANGED
@@ -31,11 +31,14 @@
|
|
31 |
"metadata": {},
|
32 |
"outputs": [],
|
33 |
"source": [
|
34 |
-
"run_ids = [
|
35 |
-
"ENTITY, PROJECT =
|
36 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID =
|
37 |
-
"
|
38 |
-
"
|
|
|
|
|
|
|
39 |
"add_clip_32 = False"
|
40 |
]
|
41 |
},
|
@@ -63,8 +66,8 @@
|
|
63 |
"num_images = 128\n",
|
64 |
"top_k = 8\n",
|
65 |
"text_normalizer = TextNormalizer()\n",
|
66 |
-
"padding_item =
|
67 |
-
"seed = random.randint(0, 2**32-1)\n",
|
68 |
"key = jax.random.PRNGKey(seed)\n",
|
69 |
"api = wandb.Api()"
|
70 |
]
|
@@ -100,12 +103,15 @@
|
|
100 |
"def p_decode(indices, params):\n",
|
101 |
" return vqgan.decode_code(indices, params=params)\n",
|
102 |
"\n",
|
|
|
103 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
104 |
"def p_clip16(inputs, params):\n",
|
105 |
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
106 |
" return logits\n",
|
107 |
"\n",
|
|
|
108 |
"if add_clip_32:\n",
|
|
|
109 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
110 |
" def p_clip32(inputs, params):\n",
|
111 |
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
@@ -119,13 +125,13 @@
|
|
119 |
"metadata": {},
|
120 |
"outputs": [],
|
121 |
"source": [
|
122 |
-
"with open(
|
123 |
" samples = [l.strip() for l in f.readlines()]\n",
|
124 |
" # make list multiple of batch_size by adding elements\n",
|
125 |
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
126 |
" samples.extend(samples_to_add)\n",
|
127 |
" # reshape\n",
|
128 |
-
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
129 |
]
|
130 |
},
|
131 |
{
|
@@ -138,9 +144,17 @@
|
|
138 |
"def get_artifact_versions(run_id, latest_only=False):\n",
|
139 |
" try:\n",
|
140 |
" if latest_only:\n",
|
141 |
-
" return [
|
|
|
|
|
|
|
|
|
142 |
" else:\n",
|
143 |
-
" return api.artifact_versions(
|
|
|
|
|
|
|
|
|
144 |
" except:\n",
|
145 |
" return []"
|
146 |
]
|
@@ -153,7 +167,7 @@
|
|
153 |
"outputs": [],
|
154 |
"source": [
|
155 |
"def get_training_config(run_id):\n",
|
156 |
-
" training_run = api.run(f
|
157 |
" config = training_run.config\n",
|
158 |
" return config"
|
159 |
]
|
@@ -168,8 +182,8 @@
|
|
168 |
"# retrieve inference run details\n",
|
169 |
"def get_last_inference_version(run_id):\n",
|
170 |
" try:\n",
|
171 |
-
" inference_run = api.run(f
|
172 |
-
" return inference_run.summary.get(
|
173 |
" except:\n",
|
174 |
" return None"
|
175 |
]
|
@@ -183,7 +197,6 @@
|
|
183 |
"source": [
|
184 |
"# compile functions - needed only once per run\n",
|
185 |
"def pmap_model_function(model):\n",
|
186 |
-
" \n",
|
187 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
188 |
" def _generate(tokenized_prompt, key, params):\n",
|
189 |
" return model.generate(\n",
|
@@ -195,7 +208,7 @@
|
|
195 |
" top_k=gen_top_k,\n",
|
196 |
" top_p=gen_top_p\n",
|
197 |
" )\n",
|
198 |
-
"
|
199 |
" return _generate"
|
200 |
]
|
201 |
},
|
@@ -222,13 +235,21 @@
|
|
222 |
"training_config = get_training_config(run_id)\n",
|
223 |
"run = None\n",
|
224 |
"p_generate = None\n",
|
225 |
-
"model_files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
"for artifact in artifact_versions:\n",
|
227 |
-
" print(f
|
228 |
" version = int(artifact.version[1:])\n",
|
229 |
" results16, results32 = [], []\n",
|
230 |
-
" columns = [
|
231 |
-
"
|
232 |
" if latest_only:\n",
|
233 |
" assert last_inference_version is None or version > last_inference_version\n",
|
234 |
" else:\n",
|
@@ -236,14 +257,23 @@
|
|
236 |
" # we should start from v0\n",
|
237 |
" assert version == 0\n",
|
238 |
" elif version <= last_inference_version:\n",
|
239 |
-
" print(
|
|
|
|
|
240 |
" else:\n",
|
241 |
" # check we are logging the correct version\n",
|
242 |
" assert version == last_inference_version + 1\n",
|
243 |
"\n",
|
244 |
" # start/resume corresponding run\n",
|
245 |
" if run is None:\n",
|
246 |
-
" run = wandb.init(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
"\n",
|
248 |
" # work in temporary directory\n",
|
249 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
@@ -264,64 +294,109 @@
|
|
264 |
"\n",
|
265 |
" # process one batch of captions\n",
|
266 |
" for batch in tqdm(samples):\n",
|
267 |
-
" processed_prompts =
|
|
|
|
|
|
|
|
|
268 |
"\n",
|
269 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
270 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
271 |
-
" tokenized_prompt = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
" tokenized_prompt = shard(tokenized_prompt)\n",
|
273 |
"\n",
|
274 |
" # generate images\n",
|
275 |
" images = []\n",
|
276 |
-
" pbar = tqdm(
|
|
|
|
|
|
|
|
|
277 |
" for i in pbar:\n",
|
278 |
" key, subkey = jax.random.split(key)\n",
|
279 |
-
" encoded_images = p_generate(
|
|
|
|
|
280 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
281 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
282 |
-
" decoded_images = decoded_images.clip(0
|
|
|
|
|
283 |
" for img in decoded_images:\n",
|
284 |
-
" images.append(
|
|
|
|
|
285 |
"\n",
|
286 |
-
" def add_clip_results(results, processor, p_clip, clip_params)
|
287 |
-
" clip_inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
289 |
-
" images_per_prompt_indices = np.asarray(
|
290 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
" clip_inputs = shard(clip_inputs)\n",
|
292 |
" logits = p_clip(clip_inputs, clip_params)\n",
|
293 |
" logits = logits.reshape(-1, num_images)\n",
|
294 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
295 |
" logits = jax.device_get(logits)\n",
|
296 |
" # add to results table\n",
|
297 |
-
" for i, (idx, scores, sample) in enumerate(
|
298 |
-
"
|
|
|
|
|
|
|
299 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
300 |
-
" top_images = [
|
|
|
|
|
|
|
301 |
" results.append([sample] + top_images)\n",
|
302 |
-
"
|
303 |
" # get clip scores\n",
|
304 |
-
" pbar.set_description(
|
305 |
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
306 |
-
"
|
307 |
" # get clip 32 scores\n",
|
308 |
" if add_clip_32:\n",
|
309 |
-
" pbar.set_description(
|
310 |
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
311 |
"\n",
|
312 |
" pbar.close()\n",
|
313 |
"\n",
|
314 |
-
" \n",
|
315 |
-
"\n",
|
316 |
" # log results\n",
|
317 |
" table = wandb.Table(columns=columns, data=results16)\n",
|
318 |
-
" run.log({
|
319 |
" wandb.finish()\n",
|
320 |
-
"
|
321 |
-
" if add_clip_32
|
322 |
-
" run = wandb.init(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
" table = wandb.Table(columns=columns, data=results32)\n",
|
324 |
-
" run.log({
|
325 |
" wandb.finish()\n",
|
326 |
" run = None # ensure we don't log on this run"
|
327 |
]
|
|
|
31 |
"metadata": {},
|
32 |
"outputs": [],
|
33 |
"source": [
|
34 |
+
"run_ids = [\"63otg87g\"]\n",
|
35 |
+
"ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
|
36 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
37 |
+
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
38 |
+
" \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
|
39 |
+
")\n",
|
40 |
+
"latest_only = True # log only latest or all versions\n",
|
41 |
+
"suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
|
42 |
"add_clip_32 = False"
|
43 |
]
|
44 |
},
|
|
|
66 |
"num_images = 128\n",
|
67 |
"top_k = 8\n",
|
68 |
"text_normalizer = TextNormalizer()\n",
|
69 |
+
"padding_item = \"NONE\"\n",
|
70 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
71 |
"key = jax.random.PRNGKey(seed)\n",
|
72 |
"api = wandb.Api()"
|
73 |
]
|
|
|
103 |
"def p_decode(indices, params):\n",
|
104 |
" return vqgan.decode_code(indices, params=params)\n",
|
105 |
"\n",
|
106 |
+
"\n",
|
107 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
108 |
"def p_clip16(inputs, params):\n",
|
109 |
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
110 |
" return logits\n",
|
111 |
"\n",
|
112 |
+
"\n",
|
113 |
"if add_clip_32:\n",
|
114 |
+
"\n",
|
115 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
116 |
" def p_clip32(inputs, params):\n",
|
117 |
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
|
|
125 |
"metadata": {},
|
126 |
"outputs": [],
|
127 |
"source": [
|
128 |
+
"with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
|
129 |
" samples = [l.strip() for l in f.readlines()]\n",
|
130 |
" # make list multiple of batch_size by adding elements\n",
|
131 |
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
132 |
" samples.extend(samples_to_add)\n",
|
133 |
" # reshape\n",
|
134 |
+
" samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
|
135 |
]
|
136 |
},
|
137 |
{
|
|
|
144 |
"def get_artifact_versions(run_id, latest_only=False):\n",
|
145 |
" try:\n",
|
146 |
" if latest_only:\n",
|
147 |
+
" return [\n",
|
148 |
+
" api.artifact(\n",
|
149 |
+
" type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
|
150 |
+
" )\n",
|
151 |
+
" ]\n",
|
152 |
" else:\n",
|
153 |
+
" return api.artifact_versions(\n",
|
154 |
+
" type_name=\"bart_model\",\n",
|
155 |
+
" name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
|
156 |
+
" per_page=10000,\n",
|
157 |
+
" )\n",
|
158 |
" except:\n",
|
159 |
" return []"
|
160 |
]
|
|
|
167 |
"outputs": [],
|
168 |
"source": [
|
169 |
"def get_training_config(run_id):\n",
|
170 |
+
" training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
|
171 |
" config = training_run.config\n",
|
172 |
" return config"
|
173 |
]
|
|
|
182 |
"# retrieve inference run details\n",
|
183 |
"def get_last_inference_version(run_id):\n",
|
184 |
" try:\n",
|
185 |
+
" inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
|
186 |
+
" return inference_run.summary.get(\"version\", None)\n",
|
187 |
" except:\n",
|
188 |
" return None"
|
189 |
]
|
|
|
197 |
"source": [
|
198 |
"# compile functions - needed only once per run\n",
|
199 |
"def pmap_model_function(model):\n",
|
|
|
200 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
201 |
" def _generate(tokenized_prompt, key, params):\n",
|
202 |
" return model.generate(\n",
|
|
|
208 |
" top_k=gen_top_k,\n",
|
209 |
" top_p=gen_top_p\n",
|
210 |
" )\n",
|
211 |
+
"\n",
|
212 |
" return _generate"
|
213 |
]
|
214 |
},
|
|
|
235 |
"training_config = get_training_config(run_id)\n",
|
236 |
"run = None\n",
|
237 |
"p_generate = None\n",
|
238 |
+
"model_files = [\n",
|
239 |
+
" \"config.json\",\n",
|
240 |
+
" \"flax_model.msgpack\",\n",
|
241 |
+
" \"merges.txt\",\n",
|
242 |
+
" \"special_tokens_map.json\",\n",
|
243 |
+
" \"tokenizer.json\",\n",
|
244 |
+
" \"tokenizer_config.json\",\n",
|
245 |
+
" \"vocab.json\",\n",
|
246 |
+
"]\n",
|
247 |
"for artifact in artifact_versions:\n",
|
248 |
+
" print(f\"Processing artifact: {artifact.name}\")\n",
|
249 |
" version = int(artifact.version[1:])\n",
|
250 |
" results16, results32 = [], []\n",
|
251 |
+
" columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
|
252 |
+
"\n",
|
253 |
" if latest_only:\n",
|
254 |
" assert last_inference_version is None or version > last_inference_version\n",
|
255 |
" else:\n",
|
|
|
257 |
" # we should start from v0\n",
|
258 |
" assert version == 0\n",
|
259 |
" elif version <= last_inference_version:\n",
|
260 |
+
" print(\n",
|
261 |
+
" f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
|
262 |
+
" )\n",
|
263 |
" else:\n",
|
264 |
" # check we are logging the correct version\n",
|
265 |
" assert version == last_inference_version + 1\n",
|
266 |
"\n",
|
267 |
" # start/resume corresponding run\n",
|
268 |
" if run is None:\n",
|
269 |
+
" run = wandb.init(\n",
|
270 |
+
" job_type=\"inference\",\n",
|
271 |
+
" entity=\"dalle-mini\",\n",
|
272 |
+
" project=\"dalle-mini\",\n",
|
273 |
+
" config=training_config,\n",
|
274 |
+
" id=f\"{run_id}-clip16{suffix}\",\n",
|
275 |
+
" resume=\"allow\",\n",
|
276 |
+
" )\n",
|
277 |
"\n",
|
278 |
" # work in temporary directory\n",
|
279 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
|
|
294 |
"\n",
|
295 |
" # process one batch of captions\n",
|
296 |
" for batch in tqdm(samples):\n",
|
297 |
+
" processed_prompts = (\n",
|
298 |
+
" [text_normalizer(x) for x in batch]\n",
|
299 |
+
" if model.config.normalize_text\n",
|
300 |
+
" else list(batch)\n",
|
301 |
+
" )\n",
|
302 |
"\n",
|
303 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
304 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
305 |
+
" tokenized_prompt = tokenizer(\n",
|
306 |
+
" processed_prompts,\n",
|
307 |
+
" return_tensors=\"jax\",\n",
|
308 |
+
" padding=\"max_length\",\n",
|
309 |
+
" truncation=True,\n",
|
310 |
+
" max_length=128,\n",
|
311 |
+
" ).data\n",
|
312 |
" tokenized_prompt = shard(tokenized_prompt)\n",
|
313 |
"\n",
|
314 |
" # generate images\n",
|
315 |
" images = []\n",
|
316 |
+
" pbar = tqdm(\n",
|
317 |
+
" range(num_images // jax.device_count()),\n",
|
318 |
+
" desc=\"Generating Images\",\n",
|
319 |
+
" leave=True,\n",
|
320 |
+
" )\n",
|
321 |
" for i in pbar:\n",
|
322 |
" key, subkey = jax.random.split(key)\n",
|
323 |
+
" encoded_images = p_generate(\n",
|
324 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params\n",
|
325 |
+
" )\n",
|
326 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
327 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
328 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
|
329 |
+
" (-1, 256, 256, 3)\n",
|
330 |
+
" )\n",
|
331 |
" for img in decoded_images:\n",
|
332 |
+
" images.append(\n",
|
333 |
+
" Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
|
334 |
+
" )\n",
|
335 |
"\n",
|
336 |
+
" def add_clip_results(results, processor, p_clip, clip_params):\n",
|
337 |
+
" clip_inputs = processor(\n",
|
338 |
+
" text=batch,\n",
|
339 |
+
" images=images,\n",
|
340 |
+
" return_tensors=\"np\",\n",
|
341 |
+
" padding=\"max_length\",\n",
|
342 |
+
" max_length=77,\n",
|
343 |
+
" truncation=True,\n",
|
344 |
+
" ).data\n",
|
345 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
346 |
+
" images_per_prompt_indices = np.asarray(\n",
|
347 |
+
" range(0, len(images), batch_size)\n",
|
348 |
+
" )\n",
|
349 |
+
" clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
|
350 |
+
" list(\n",
|
351 |
+
" clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
|
352 |
+
" for i in range(batch_size)\n",
|
353 |
+
" )\n",
|
354 |
+
" )\n",
|
355 |
" clip_inputs = shard(clip_inputs)\n",
|
356 |
" logits = p_clip(clip_inputs, clip_params)\n",
|
357 |
" logits = logits.reshape(-1, num_images)\n",
|
358 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
359 |
" logits = jax.device_get(logits)\n",
|
360 |
" # add to results table\n",
|
361 |
+
" for i, (idx, scores, sample) in enumerate(\n",
|
362 |
+
" zip(top_scores, logits, batch)\n",
|
363 |
+
" ):\n",
|
364 |
+
" if sample == padding_item:\n",
|
365 |
+
" continue\n",
|
366 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
367 |
+
" top_images = [\n",
|
368 |
+
" wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
|
369 |
+
" for x in idx\n",
|
370 |
+
" ]\n",
|
371 |
" results.append([sample] + top_images)\n",
|
372 |
+
"\n",
|
373 |
" # get clip scores\n",
|
374 |
+
" pbar.set_description(\"Calculating CLIP 16 scores\")\n",
|
375 |
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
376 |
+
"\n",
|
377 |
" # get clip 32 scores\n",
|
378 |
" if add_clip_32:\n",
|
379 |
+
" pbar.set_description(\"Calculating CLIP 32 scores\")\n",
|
380 |
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
381 |
"\n",
|
382 |
" pbar.close()\n",
|
383 |
"\n",
|
|
|
|
|
384 |
" # log results\n",
|
385 |
" table = wandb.Table(columns=columns, data=results16)\n",
|
386 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
387 |
" wandb.finish()\n",
|
388 |
+
"\n",
|
389 |
+
" if add_clip_32:\n",
|
390 |
+
" run = wandb.init(\n",
|
391 |
+
" job_type=\"inference\",\n",
|
392 |
+
" entity=\"dalle-mini\",\n",
|
393 |
+
" project=\"dalle-mini\",\n",
|
394 |
+
" config=training_config,\n",
|
395 |
+
" id=f\"{run_id}-clip32{suffix}\",\n",
|
396 |
+
" resume=\"allow\",\n",
|
397 |
+
" )\n",
|
398 |
" table = wandb.Table(columns=columns, data=results32)\n",
|
399 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
400 |
" wandb.finish()\n",
|
401 |
" run = None # ensure we don't log on this run"
|
402 |
]
|
{dev/seq2seq → tools/train}/sweep.yaml
RENAMED
File without changes
|
dev/seq2seq/run_seq2seq_flax.py → tools/train/train.py
RENAMED
@@ -18,37 +18,31 @@ Fine-tuning the library models for seq2seq, text to image.
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
21 |
-
import
|
22 |
import logging
|
|
|
23 |
import sys
|
24 |
import time
|
25 |
-
from dataclasses import dataclass, field
|
26 |
from pathlib import Path
|
27 |
from typing import Callable, Optional
|
28 |
-
import json
|
29 |
|
30 |
import datasets
|
31 |
-
from datasets import Dataset
|
32 |
-
from tqdm import tqdm
|
33 |
-
from dataclasses import asdict
|
34 |
-
|
35 |
import jax
|
36 |
import jax.numpy as jnp
|
37 |
import optax
|
38 |
import transformers
|
|
|
|
|
39 |
from flax import jax_utils, traverse_util
|
40 |
-
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.jax_utils import unreplicate
|
|
|
42 |
from flax.training import train_state
|
43 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
44 |
-
from
|
45 |
-
|
46 |
-
HfArgumentParser,
|
47 |
-
)
|
48 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
49 |
|
50 |
-
import wandb
|
51 |
-
|
52 |
from dalle_mini.data import Dataset
|
53 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
54 |
|
@@ -797,7 +791,7 @@ def main():
|
|
797 |
|
798 |
# init variables
|
799 |
last_time = time.perf_counter()
|
800 |
-
|
801 |
|
802 |
for epoch in epochs:
|
803 |
state.replace(epoch=jax_utils.replicate(epoch))
|
@@ -821,20 +815,20 @@ def main():
|
|
821 |
last_time = new_time
|
822 |
|
823 |
# train step
|
824 |
-
state,
|
825 |
state, batch, jax_utils.replicate(delta_time)
|
826 |
)
|
827 |
step = unreplicate(state.step)
|
828 |
|
829 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
830 |
# log metrics
|
831 |
-
|
832 |
# log state parameters
|
833 |
state_dict = {
|
834 |
k.split("_")[-1]: unreplicate(getattr(state, k))
|
835 |
for k in ["epoch", "train_time", "train_samples"]
|
836 |
}
|
837 |
-
wandb_log(state_dict, step=step, prefix="train")
|
838 |
|
839 |
eval_metrics = None
|
840 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
@@ -844,12 +838,12 @@ def main():
|
|
844 |
run_save_model(state, eval_metrics)
|
845 |
|
846 |
# log final train metrics
|
847 |
-
if
|
848 |
-
|
849 |
-
wandb_log(
|
850 |
|
851 |
epochs.write(
|
852 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {
|
853 |
)
|
854 |
|
855 |
# Final evaluation
|
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
21 |
+
import json
|
22 |
import logging
|
23 |
+
import os
|
24 |
import sys
|
25 |
import time
|
26 |
+
from dataclasses import asdict, dataclass, field
|
27 |
from pathlib import Path
|
28 |
from typing import Callable, Optional
|
|
|
29 |
|
30 |
import datasets
|
|
|
|
|
|
|
|
|
31 |
import jax
|
32 |
import jax.numpy as jnp
|
33 |
import optax
|
34 |
import transformers
|
35 |
+
import wandb
|
36 |
+
from datasets import Dataset
|
37 |
from flax import jax_utils, traverse_util
|
|
|
38 |
from flax.jax_utils import unreplicate
|
39 |
+
from flax.serialization import from_bytes, to_bytes
|
40 |
from flax.training import train_state
|
41 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
42 |
+
from tqdm import tqdm
|
43 |
+
from transformers import AutoTokenizer, HfArgumentParser
|
|
|
|
|
44 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
45 |
|
|
|
|
|
46 |
from dalle_mini.data import Dataset
|
47 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
48 |
|
|
|
791 |
|
792 |
# init variables
|
793 |
last_time = time.perf_counter()
|
794 |
+
train_metrics = None
|
795 |
|
796 |
for epoch in epochs:
|
797 |
state.replace(epoch=jax_utils.replicate(epoch))
|
|
|
815 |
last_time = new_time
|
816 |
|
817 |
# train step
|
818 |
+
state, train_metrics = p_train_step(
|
819 |
state, batch, jax_utils.replicate(delta_time)
|
820 |
)
|
821 |
step = unreplicate(state.step)
|
822 |
|
823 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
824 |
# log metrics
|
825 |
+
metrics = unreplicate(train_metrics)
|
826 |
# log state parameters
|
827 |
state_dict = {
|
828 |
k.split("_")[-1]: unreplicate(getattr(state, k))
|
829 |
for k in ["epoch", "train_time", "train_samples"]
|
830 |
}
|
831 |
+
wandb_log({**metrics, **state_dict}, step=step, prefix="train")
|
832 |
|
833 |
eval_metrics = None
|
834 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
|
|
838 |
run_save_model(state, eval_metrics)
|
839 |
|
840 |
# log final train metrics
|
841 |
+
if train_metrics is not None:
|
842 |
+
train_metrics = unreplicate(train_metrics)
|
843 |
+
wandb_log(train_metrics, step=step, prefix="train")
|
844 |
|
845 |
epochs.write(
|
846 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
847 |
)
|
848 |
|
849 |
# Final evaluation
|