Eeman Majumder commited on
Commit
14f7e5e
·
1 Parent(s): 0965163
.streamlit/secrets.toml ADDED
@@ -0,0 +1 @@
 
 
1
+ AUTH_KEY='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'
DALL·E_mini_Inference_pipeline .ipynb ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "118UKH5bWCGa"
7
+ },
8
+ "source": [
9
+ "# DALL·E mini - Inference pipeline\n",
10
+ "\n",
11
+ "*Generate images from a text prompt*\n",
12
+ "\n",
13
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
14
+ "\n",
15
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
16
+ "\n",
17
+ "Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
18
+ "\n",
19
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "dS8LbaonYm3a"
26
+ },
27
+ "source": [
28
+ "## 🛠️ Installation and set-up"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 1,
34
+ "metadata": {
35
+ "colab": {
36
+ "base_uri": "https://localhost:8080/"
37
+ },
38
+ "id": "uzjAM2GBYpZX",
39
+ "outputId": "9042b53c-1260-4ae6-ff54-be878c99d505"
40
+ },
41
+ "outputs": [
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
47
+ "tensorflow-metal 0.5.0 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.\u001b[0m\u001b[31m\n",
48
+ "\u001b[0m"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "# Install required libraries\n",
54
+ "!pip install -q dalle-mini\n",
55
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {
61
+ "id": "ozHzTkyv8cqU"
62
+ },
63
+ "source": [
64
+ "We load required models:\n",
65
+ "* DALL·E mini for text to encoded images\n",
66
+ "* VQGAN for decoding images\n",
67
+ "* CLIP for scoring predictions"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 2,
73
+ "metadata": {
74
+ "id": "K6CxW2o42f-w"
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "# Model references\n",
79
+ "\n",
80
+ "# dalle-mega\n",
81
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
82
+ "DALLE_COMMIT_ID = None\n",
83
+ "\n",
84
+ "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
85
+ "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
86
+ "\n",
87
+ "# VQGAN model\n",
88
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
89
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 3,
95
+ "metadata": {
96
+ "colab": {
97
+ "base_uri": "https://localhost:8080/"
98
+ },
99
+ "id": "Yv-aR3t4Oe5v",
100
+ "outputId": "850b9a43-2506-432f-ae8e-b8b2598e4a98"
101
+ },
102
+ "outputs": [
103
+ {
104
+ "data": {
105
+ "text/plain": [
106
+ "1"
107
+ ]
108
+ },
109
+ "execution_count": 3,
110
+ "metadata": {},
111
+ "output_type": "execute_result"
112
+ }
113
+ ],
114
+ "source": [
115
+ "import jax\n",
116
+ "import jax.numpy as jnp\n",
117
+ "\n",
118
+ "# check how many devices are available\n",
119
+ "jax.local_device_count()"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 1,
125
+ "metadata": {
126
+ "colab": {
127
+ "base_uri": "https://localhost:8080/",
128
+ "height": 240
129
+ },
130
+ "id": "92zYmvsQ38vL",
131
+ "outputId": "556dc277-a885-443b-8848-373696f5acc7"
132
+ },
133
+ "outputs": [
134
+ {
135
+ "ename": "NameError",
136
+ "evalue": "ignored",
137
+ "output_type": "error",
138
+ "traceback": [
139
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
140
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
141
+ "\u001b[0;32m<ipython-input-1-4a35db7a446b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# Load dalle-mini\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m model, params = DalleBart.from_pretrained(\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mDALLE_MODEL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrevision\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mDALLE_COMMIT_ID\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_do_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m )\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
142
+ "\u001b[0;31mNameError\u001b[0m: name 'DALLE_MODEL' is not defined"
143
+ ]
144
+ }
145
+ ],
146
+ "source": [
147
+ "# Load models & tokenizer\n",
148
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
149
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
150
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
151
+ "\n",
152
+ "# Load dalle-mini\n",
153
+ "model, params = DalleBart.from_pretrained(\n",
154
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
155
+ ")\n",
156
+ "\n",
157
+ "# Load VQGAN\n",
158
+ "vqgan, vqgan_params = VQModel.from_pretrained(\n",
159
+ " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
160
+ ")"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "metadata": {
166
+ "id": "o_vH2X1tDtzA"
167
+ },
168
+ "source": [
169
+ "Model parameters are replicated on each device for faster inference."
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "metadata": {
176
+ "id": "wtvLoM48EeVw"
177
+ },
178
+ "outputs": [],
179
+ "source": [
180
+ "from flax.jax_utils import replicate\n",
181
+ "\n",
182
+ "params = replicate(params)\n",
183
+ "vqgan_params = replicate(vqgan_params)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "metadata": {
189
+ "id": "0A9AHQIgZ_qw"
190
+ },
191
+ "source": [
192
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {
199
+ "id": "sOtoOmYsSYPz"
200
+ },
201
+ "outputs": [],
202
+ "source": [
203
+ "from functools import partial\n",
204
+ "\n",
205
+ "# model inference\n",
206
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
207
+ "def p_generate(\n",
208
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
209
+ "):\n",
210
+ " return model.generate(\n",
211
+ " **tokenized_prompt,\n",
212
+ " prng_key=key,\n",
213
+ " params=params,\n",
214
+ " top_k=top_k,\n",
215
+ " top_p=top_p,\n",
216
+ " temperature=temperature,\n",
217
+ " condition_scale=condition_scale,\n",
218
+ " )\n",
219
+ "\n",
220
+ "\n",
221
+ "# decode image\n",
222
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
223
+ "def p_decode(indices, params):\n",
224
+ " return vqgan.decode_code(indices, params=params)"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "markdown",
229
+ "metadata": {
230
+ "id": "HmVN6IBwapBA"
231
+ },
232
+ "source": [
233
+ "Keys are passed to the model on each device to generate unique inference per device."
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {
240
+ "id": "4CTXmlUkThhX"
241
+ },
242
+ "outputs": [],
243
+ "source": [
244
+ "import random\n",
245
+ "\n",
246
+ "# create a random key\n",
247
+ "seed = random.randint(0, 2**32 - 1)\n",
248
+ "key = jax.random.PRNGKey(seed)"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "metadata": {
254
+ "id": "BrnVyCo81pij"
255
+ },
256
+ "source": [
257
+ "## 🖍 Text Prompt"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "markdown",
262
+ "metadata": {
263
+ "id": "rsmj0Aj5OQox"
264
+ },
265
+ "source": [
266
+ "Our model requires processing prompts."
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {
273
+ "id": "YjjhUychOVxm"
274
+ },
275
+ "outputs": [],
276
+ "source": [
277
+ "from dalle_mini import DalleBartProcessor\n",
278
+ "\n",
279
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "markdown",
284
+ "metadata": {
285
+ "id": "BQ7fymSPyvF_"
286
+ },
287
+ "source": [
288
+ "Let's define some text prompts."
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {
295
+ "id": "x_0vI9ge1oKr"
296
+ },
297
+ "outputs": [],
298
+ "source": [
299
+ "prompts = [\n",
300
+ " \"sunset over a lake in the mountains\",\n",
301
+ " \"the Eiffel tower landing on the moon\",\n",
302
+ "]"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "markdown",
307
+ "metadata": {
308
+ "id": "XlZUG3SCLnGE"
309
+ },
310
+ "source": [
311
+ "Note: we could use the same prompt multiple times for faster inference."
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "metadata": {
318
+ "id": "VKjEZGjtO49k"
319
+ },
320
+ "outputs": [],
321
+ "source": [
322
+ "tokenized_prompts = processor(prompts)"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {
328
+ "id": "-CEJBnuJOe5z"
329
+ },
330
+ "source": [
331
+ "Finally we replicate the prompts onto each device."
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {
338
+ "id": "lQePgju5Oe5z"
339
+ },
340
+ "outputs": [],
341
+ "source": [
342
+ "tokenized_prompt = replicate(tokenized_prompts)"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "metadata": {
348
+ "id": "phQ9bhjRkgAZ"
349
+ },
350
+ "source": [
351
+ "## 🎨 Generate images\n",
352
+ "\n",
353
+ "We generate images using dalle-mini model and decode them with the VQGAN."
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "metadata": {
360
+ "id": "d0wVkXpKqnHA"
361
+ },
362
+ "outputs": [],
363
+ "source": [
364
+ "# number of predictions per prompt\n",
365
+ "n_predictions = 8\n",
366
+ "\n",
367
+ "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
368
+ "gen_top_k = None\n",
369
+ "gen_top_p = None\n",
370
+ "temperature = None\n",
371
+ "cond_scale = 10.0"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "metadata": {
378
+ "id": "SDjEx9JxR3v8"
379
+ },
380
+ "outputs": [],
381
+ "source": [
382
+ "from flax.training.common_utils import shard_prng_key\n",
383
+ "import numpy as np\n",
384
+ "from PIL import Image\n",
385
+ "from tqdm.notebook import trange\n",
386
+ "\n",
387
+ "print(f\"Prompts: {prompts}\\n\")\n",
388
+ "# generate images\n",
389
+ "images = []\n",
390
+ "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
391
+ " # get a new key\n",
392
+ " key, subkey = jax.random.split(key)\n",
393
+ " # generate images\n",
394
+ " encoded_images = p_generate(\n",
395
+ " tokenized_prompt,\n",
396
+ " shard_prng_key(subkey),\n",
397
+ " params,\n",
398
+ " gen_top_k,\n",
399
+ " gen_top_p,\n",
400
+ " temperature,\n",
401
+ " cond_scale,\n",
402
+ " )\n",
403
+ " # remove BOS\n",
404
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
405
+ " # decode images\n",
406
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
407
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
408
+ " for decoded_img in decoded_images:\n",
409
+ " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
410
+ " images.append(img)\n",
411
+ " display(img)\n",
412
+ " print()"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "markdown",
417
+ "metadata": {
418
+ "id": "tw02wG9zGmyB"
419
+ },
420
+ "source": [
421
+ "## 🏅 Optional: Rank images by CLIP score\n",
422
+ "\n",
423
+ "We can rank images according to CLIP.\n",
424
+ "\n",
425
+ "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "metadata": {
432
+ "id": "RGjlIW_f6GA0"
433
+ },
434
+ "outputs": [],
435
+ "source": [
436
+ "# CLIP model\n",
437
+ "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
438
+ "CLIP_COMMIT_ID = None\n",
439
+ "\n",
440
+ "# Load CLIP\n",
441
+ "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
442
+ " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
443
+ ")\n",
444
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
445
+ "clip_params = replicate(clip_params)\n",
446
+ "\n",
447
+ "# score images\n",
448
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
449
+ "def p_clip(inputs, params):\n",
450
+ " logits = clip(params=params, **inputs).logits_per_image\n",
451
+ " return logits"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": null,
457
+ "metadata": {
458
+ "id": "FoLXpjCmGpju"
459
+ },
460
+ "outputs": [],
461
+ "source": [
462
+ "from flax.training.common_utils import shard\n",
463
+ "\n",
464
+ "# get clip scores\n",
465
+ "clip_inputs = clip_processor(\n",
466
+ " text=prompts * jax.device_count(),\n",
467
+ " images=images,\n",
468
+ " return_tensors=\"np\",\n",
469
+ " padding=\"max_length\",\n",
470
+ " max_length=77,\n",
471
+ " truncation=True,\n",
472
+ ").data\n",
473
+ "logits = p_clip(shard(clip_inputs), clip_params)\n",
474
+ "\n",
475
+ "# organize scores per prompt\n",
476
+ "p = len(prompts)\n",
477
+ "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "markdown",
482
+ "metadata": {
483
+ "id": "4AAWRm70LgED"
484
+ },
485
+ "source": [
486
+ "Let's now display images ranked by CLIP score."
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "metadata": {
493
+ "id": "zsgxxubLLkIu"
494
+ },
495
+ "outputs": [],
496
+ "source": [
497
+ "for i, prompt in enumerate(prompts):\n",
498
+ " print(f\"Prompt: {prompt}\\n\")\n",
499
+ " for idx in logits[i].argsort()[::-1]:\n",
500
+ " display(images[idx * p + i])\n",
501
+ " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
502
+ " print()"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "markdown",
507
+ "metadata": {
508
+ "id": "oZT9i3jCjir0"
509
+ },
510
+ "source": [
511
+ "## 🪄 Optional: Save your Generated Images as W&B Tables\n",
512
+ "\n",
513
+ "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {
520
+ "id": "-pSiv6Vwjkn0"
521
+ },
522
+ "outputs": [],
523
+ "source": [
524
+ "import wandb\n",
525
+ "\n",
526
+ "# Initialize a W&B run.\n",
527
+ "project = 'dalle-mini-tables-colab'\n",
528
+ "run = wandb.init(project=project)\n",
529
+ "\n",
530
+ "# Initialize an empty W&B Tables.\n",
531
+ "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
532
+ "gen_table = wandb.Table(columns=columns)\n",
533
+ "\n",
534
+ "# Add data to the table.\n",
535
+ "for i, prompt in enumerate(prompts):\n",
536
+ " # If CLIP scores exist, sort the Images\n",
537
+ " if logits is not None:\n",
538
+ " idxs = logits[i].argsort()[::-1]\n",
539
+ " tmp_imgs = images[i::len(prompts)]\n",
540
+ " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
541
+ " else:\n",
542
+ " tmp_imgs = images[i::len(prompts)]\n",
543
+ "\n",
544
+ " # Add the data to the table.\n",
545
+ " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
546
+ "\n",
547
+ "# Log the Table to W&B dashboard.\n",
548
+ "wandb.log({\"Generated Images\": gen_table})\n",
549
+ "\n",
550
+ "# Close the W&B run.\n",
551
+ "run.finish()"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "markdown",
556
+ "metadata": {
557
+ "id": "Ck2ZnHwVjnRd"
558
+ },
559
+ "source": [
560
+ "Click on the link above to check out your generated images."
561
+ ]
562
+ }
563
+ ],
564
+ "metadata": {
565
+ "accelerator": "GPU",
566
+ "colab": {
567
+ "collapsed_sections": [],
568
+ "machine_shape": "hm",
569
+ "name": "DALL·E mini - Inference pipeline.ipynb",
570
+ "provenance": []
571
+ },
572
+ "kernelspec": {
573
+ "display_name": "Python 3.9.13 ('base')",
574
+ "language": "python",
575
+ "name": "python3"
576
+ },
577
+ "language_info": {
578
+ "codemirror_mode": {
579
+ "name": "ipython",
580
+ "version": 3
581
+ },
582
+ "file_extension": ".py",
583
+ "mimetype": "text/x-python",
584
+ "name": "python",
585
+ "nbconvert_exporter": "python",
586
+ "pygments_lexer": "ipython3",
587
+ "version": "3.9.13"
588
+ },
589
+ "vscode": {
590
+ "interpreter": {
591
+ "hash": "3e91440bae70fe36b08f2decfecf198c5281689ed89adf5e1c2c93a1bdd6e28e"
592
+ }
593
+ }
594
+ },
595
+ "nbformat": 4,
596
+ "nbformat_minor": 0
597
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Eeman Majumder
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README copy.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Some_Gen_Stuff
2
+ Trying inference models
Testing_SD.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #___________________________________________________________________________________________________________________________
2
+
3
+ import streamlit as st
4
+ import os
5
+
6
+ #___________________________________________________________________________________________________________________________
7
+
8
+ import torch
9
+ from torch import autocast
10
+ from diffusers import StableDiffusionPipeline
11
+ from datasets import load_dataset
12
+ from PIL import Image
13
+ import re
14
+
15
+ #___________________________________________________________________________________________________________________________
16
+
17
+ st.title('IMGTEXTA')
18
+
19
+ #___________________________________________________________________________________________________________________________
20
+
21
+ model_id = "CompVis/stable-diffusion-v1-4"
22
+ device = "cpu"
23
+
24
+ #___________________________________________________________________________________________________________________________
25
+
26
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=st.secrets["AUTH_KEY"], torch_dtype=torch.float32)
27
+ def dummy(images, **kwargs): return images, False
28
+ pipe.safety_checker = dummy
29
+
30
+ #___________________________________________________________________________________________________________________________
31
+
32
+ def infer(prompt, width, height, steps, scale, seed):
33
+ if seed == -1:
34
+ images_list = pipe(
35
+ [prompt],
36
+ height=height,
37
+ width=width,
38
+ num_inference_steps=steps,
39
+ guidance_scale=scale,
40
+ generator=torch.Generator(device=device).manual_seed(seed))
41
+ else:
42
+ images_list = pipe(
43
+ [prompt],
44
+ height=height,
45
+ width=width,
46
+ num_inference_steps=steps,
47
+ guidance_scale=scale)
48
+
49
+ return images_list["sample"]
50
+
51
+ #___________________________________________________________________________________________________________________________
52
+
53
+ def onclick(prompt):
54
+ st.image(infer(prompt,512,512,30,7.5,-1))
55
+ prompt=st.text_input('Enter Your Prompt')
56
+ if prompt==True:
57
+ onclick(prompt)
58
+
59
+ #___________________________________________________________________________________________________________________________
60
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ nvidia-ml-py3
4
+ ftfy
5
+ datasets
6
+ --extra-index-url https://download.pytorch.org/whl/cu113 torch