cakiki commited on
Commit
99cea46
·
1 Parent(s): 125e851

Upload 2 files

Browse files
Files changed (2) hide show
  1. dataset.py +110 -0
  2. vae_embeddings.ipynb +276 -0
dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # TODO: Address all TODOs and remove all explanatory comments
15
+ """TODO: Add a description here."""
16
+
17
+
18
+ import zipfile
19
+ import os
20
+ import datasets
21
+ from PIL import Image
22
+ from io import BytesIO
23
+
24
+ # TODO: Name of the dataset usually match the script name with CamelCase instead of snake_case
25
+ class sdbias(datasets.GeneratorBasedBuilder):
26
+ """TODO: Short description of my dataset."""
27
+
28
+ VERSION = datasets.Version("1.1.0")
29
+
30
+ # This is an example of a dataset with multiple configurations.
31
+ # If you don't want/need to define several sub-sets in your dataset,
32
+ # just remove the BUILDER_CONFIG_CLASS and the BUILDER_CONFIGS attributes.
33
+
34
+ # If you need to make complex sub-parts in the datasets with configurable options
35
+ # You can create your own builder configuration class to store attribute, inheriting from datasets.BuilderConfig
36
+ # BUILDER_CONFIG_CLASS = MyBuilderConfig
37
+
38
+ # You will be able to load one or the other configurations in the following list with
39
+ # data = datasets.load_dataset('my_dataset', 'first_domain')
40
+ # data = datasets.load_dataset('my_dataset', 'second_domain')
41
+ BUILDER_CONFIGS = [
42
+ datasets.BuilderConfig(name="first_domain", version=VERSION, description="This part of my dataset covers a first domain"),
43
+ ]
44
+
45
+ DEFAULT_CONFIG_NAME = "first_domain" # It's not mandatory to have a default configuration. Just use one if it make sense.
46
+
47
+ def _info(self):
48
+ if self.config.name == "first_domain": # This is the name of the configuration selected in BUILDER_CONFIGS above
49
+ features = datasets.Features(
50
+ {
51
+ "adjective": datasets.Value("string"),
52
+ "profession": datasets.Value("string"),
53
+ "seed": datasets.Value("int32"),
54
+ "image": datasets.Image()
55
+ # These are the features of your dataset like images, labels ...
56
+ }
57
+ )
58
+ return datasets.DatasetInfo(
59
+ # This is the description that will appear on the datasets page.
60
+ description="bla",
61
+ # This defines the different columns of the dataset and their types
62
+ features=features, # Here we define them above because they are different between the two configurations
63
+ # If there's a common (input, target) tuple from the features, uncomment supervised_keys line below and
64
+ # specify them. They'll be used if as_supervised=True in builder.as_dataset.
65
+ # supervised_keys=("sentence", "label"),
66
+ # Homepage of the dataset for documentation
67
+ homepage="bla",
68
+ # License for the dataset if available
69
+ license="bla",
70
+ # Citation for the dataset
71
+ citation="bli",
72
+ )
73
+
74
+ def _split_generators(self, dl_manager):
75
+ # TODO: This method is tasked with downloading/extracting the data and defining the splits depending on the configuration
76
+ # If several configurations are possible (listed in BUILDER_CONFIGS), the configuration selected by the user is in self.config.name
77
+
78
+ # dl_manager is a datasets.download.DownloadManager that can be used to download and extract URLS
79
+ # It can accept any type or nested list/dict and will give back the same structure with the url replaced with path to local files.
80
+ # By default the archives will be extracted and a path to a cached folder where they are extracted is returned instead of the archive
81
+ data_dir = "/mnt/1da05489-3812-4f15-a6e5-c8d3c57df39e/StableDiffusionBiasExplorer/zipped_images"
82
+ return [
83
+ datasets.SplitGenerator(
84
+ name=datasets.Split.TRAIN,
85
+ # These kwargs will be passed to _generate_examples
86
+ gen_kwargs={
87
+ "filepath":data_dir,
88
+ "split": "train",
89
+ },
90
+ ),
91
+ ]
92
+
93
+ # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
94
+ def _generate_examples(self, filepath, split):
95
+ zip_files = os.listdir(filepath)
96
+ key = 0
97
+ for zip_file in zip_files:
98
+ with zipfile.ZipFile(filepath + "/" + zip_file, "r") as zf:
99
+ for f in zf.filelist:
100
+ if ".jpg" in f.filename:
101
+ jpg_content = BytesIO(zf.read(f))
102
+ with Image.open(jpg_content) as image:
103
+ yield key, {
104
+ "adjective": zip_file.split("_", 1)[0],
105
+ "profession": zip_file.split("_", 1)[-1].replace(".zip",""),
106
+ "seed": int(f.filename.split("Seed_")[-1].split("/")[0]),
107
+ "image": image,
108
+ }
109
+ key+=1
110
+
vae_embeddings.ipynb ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "873b1354-b85f-4c5b-9163-95190f07b39a",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import zipfile\n",
12
+ "from PIL import Image\n",
13
+ "from io import BytesIO\n",
14
+ "import numpy as np\n",
15
+ "from datasets import load_dataset\n",
16
+ "import torch\n",
17
+ "from diffusers import AutoencoderKL, UNet2DModel, UNet2DConditionModel\n",
18
+ "import pickle"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "35949720-3e01-43b0-8487-a1b2131d5a9e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "def preprocess_image(image):\n",
29
+ " w, h = image.size\n",
30
+ " w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32\n",
31
+ " image = image.resize((w, h), resample=Image.Resampling.LANCZOS)\n",
32
+ " image = np.array(image).astype(np.float32) / 255.0\n",
33
+ " image = image[None].transpose(0, 3, 1, 2)\n",
34
+ " return 2.0 * image - 1.0\n",
35
+ "\n",
36
+ "def vae_embedding(preprocessed, num_samples=5, device=\"cuda\"):\n",
37
+ " with torch.no_grad():\n",
38
+ " processed_image = preprocessed.to(device=device)\n",
39
+ " latent_dist = vae.encode(processed_image).latent_dist\n",
40
+ " t = [0.18215*latent_dist.sample().to(\"cpu\").squeeze() for i in range(num_samples)] # sample num_samples latent vecs\n",
41
+ " t = torch.stack(t) # stack them\n",
42
+ " return torch.mean(t, axis=0).numpy() #average them. output shape: (4,64,64)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 3,
48
+ "id": "6ebd9d84-98f7-4883-ac4b-0ec875b86911",
49
+ "metadata": {
50
+ "tags": []
51
+ },
52
+ "outputs": [
53
+ {
54
+ "name": "stderr",
55
+ "output_type": "stream",
56
+ "text": [
57
+ "Using custom data configuration SDbiaseval--dataset-cc8e38e46c1acd54\n",
58
+ "Found cached dataset parquet (/mnt/1da05489-3812-4f15-a6e5-c8d3c57df39e/cache/huggingface/SDbiaseval___parquet/SDbiaseval--dataset-cc8e38e46c1acd54/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
59
+ ]
60
+ },
61
+ {
62
+ "data": {
63
+ "application/vnd.jupyter.widget-view+json": {
64
+ "model_id": "f184861d2e2749c9b7c1c1ea3910be27",
65
+ "version_major": 2,
66
+ "version_minor": 0
67
+ },
68
+ "text/plain": [
69
+ " 0%| | 0/1 [00:00<?, ?it/s]"
70
+ ]
71
+ },
72
+ "metadata": {},
73
+ "output_type": "display_data"
74
+ },
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "CPU times: user 196 ms, sys: 23.3 ms, total: 219 ms\n",
80
+ "Wall time: 2.51 s\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "%%time\n",
86
+ "# dset = load_dataset(\"./dataset.py\", ignore_verifications=True) This uses the loading script and loads data from the zipped folders\n",
87
+ "dset = load_dataset(\"SDbiaseval/dataset\")\n",
88
+ "ds = dset[\"train\"]"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 4,
94
+ "id": "fd832e2b-6ced-43ca-a4ca-fd54f523d22e",
95
+ "metadata": {
96
+ "tags": []
97
+ },
98
+ "outputs": [],
99
+ "source": [
100
+ "vae = AutoencoderKL.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"vae\");\n",
101
+ "vae.eval()\n",
102
+ "vae.to(\"cuda\");"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 5,
108
+ "id": "b2af2692-a372-4b96-8250-8c83c122457d",
109
+ "metadata": {},
110
+ "outputs": [
111
+ {
112
+ "name": "stdout",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "19554 batches of 16. Last batch of size 15.\n"
116
+ ]
117
+ }
118
+ ],
119
+ "source": [
120
+ "ix = np.arange(len(ds))\n",
121
+ "np.random.shuffle(ix)\n",
122
+ "batch_size = 16\n",
123
+ "batche_indices = np.array_split(ix, np.ceil(len(ix)/batch_size))\n",
124
+ "print(f\"{len(batche_indices)} batches of {batch_size}. Last batch of size {len(batche_indices[-1])}.\")"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": 15,
130
+ "id": "8a54fdf1-f0e5-487e-b53d-afc8dbcc989c",
131
+ "metadata": {},
132
+ "outputs": [
133
+ {
134
+ "name": "stdout",
135
+ "output_type": "stream",
136
+ "text": [
137
+ "CPU times: user 9h 52min 30s, sys: 2min 25s, total: 9h 54min 55s\n",
138
+ "Wall time: 7h 54min 48s\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "%%time\n",
144
+ "embs = []\n",
145
+ "for i in batche_indices:\n",
146
+ " imx = ds.select(i)[\"image\"]\n",
147
+ " preprocessed = np.concatenate([preprocess_image(im) for im in imx])\n",
148
+ " emb = vae_embedding(torch.from_numpy(preprocessed), num_samples=10)\n",
149
+ " embs.append(emb)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 16,
155
+ "id": "06d9346c-912f-4e24-a0ff-d5386c1780a1",
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "with open('embs.pkl', 'wb') as f:\n",
160
+ " pickle.dump(embs, f)"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "id": "3d0cbe87-dfb2-4c59-adf5-b4d015e2d441",
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "embeddings = np.concatenate(embs)"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 4,
176
+ "id": "a6e826a9-93e0-4298-813d-9c42d139ff96",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "with open(\"embs.pkl\", \"rb\") as f:\n",
181
+ " embeddings = pickle.load(f)"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 5,
187
+ "id": "0783bb60-5439-4a62-a4ac-15198688b331",
188
+ "metadata": {},
189
+ "outputs": [
190
+ {
191
+ "name": "stdout",
192
+ "output_type": "stream",
193
+ "text": [
194
+ "CPU times: user 3.82 s, sys: 4.34 s, total: 8.16 s\n",
195
+ "Wall time: 8.2 s\n"
196
+ ]
197
+ }
198
+ ],
199
+ "source": [
200
+ "%%time\n",
201
+ "embeddings = np.concatenate(embeddings)"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": 6,
207
+ "id": "50369f37-a4f1-4a7c-89dd-b4ef9a8ebf8b",
208
+ "metadata": {},
209
+ "outputs": [
210
+ {
211
+ "data": {
212
+ "text/plain": [
213
+ "(312860, 4, 64, 64)"
214
+ ]
215
+ },
216
+ "execution_count": 6,
217
+ "metadata": {},
218
+ "output_type": "execute_result"
219
+ }
220
+ ],
221
+ "source": [
222
+ "embeddings.shape"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": 7,
228
+ "id": "93f1ea7b-cbcd-49c3-a7c7-4ea26012f9b3",
229
+ "metadata": {},
230
+ "outputs": [
231
+ {
232
+ "name": "stdout",
233
+ "output_type": "stream",
234
+ "text": [
235
+ "CPU times: user 0 ns, sys: 10.3 s, total: 10.3 s\n",
236
+ "Wall time: 10.3 s\n"
237
+ ]
238
+ }
239
+ ],
240
+ "source": [
241
+ "%%time\n",
242
+ "with open('vae_embeddings.npy', 'wb') as f:\n",
243
+ " np.save(f, embeddings)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "2b316682-f5cc-44d7-a8ed-f1da9b6c3089",
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": []
253
+ }
254
+ ],
255
+ "metadata": {
256
+ "kernelspec": {
257
+ "display_name": "Python 3",
258
+ "language": "python",
259
+ "name": "python3"
260
+ },
261
+ "language_info": {
262
+ "codemirror_mode": {
263
+ "name": "ipython",
264
+ "version": 3
265
+ },
266
+ "file_extension": ".py",
267
+ "mimetype": "text/x-python",
268
+ "name": "python",
269
+ "nbconvert_exporter": "python",
270
+ "pygments_lexer": "ipython3",
271
+ "version": "3.9.5"
272
+ }
273
+ },
274
+ "nbformat": 4,
275
+ "nbformat_minor": 5
276
+ }