Spaces:
Runtime error
Runtime error
Eeman Majumder
commited on
Commit
·
14f7e5e
1
Parent(s):
0965163
anal
Browse files- .streamlit/secrets.toml +1 -0
- DALL·E_mini_Inference_pipeline .ipynb +597 -0
- LICENSE +21 -0
- README copy.md +2 -0
- Testing_SD.ipynb +0 -0
- app.py +60 -0
- requirements.txt +6 -0
.streamlit/secrets.toml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
AUTH_KEY='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'
|
DALL·E_mini_Inference_pipeline .ipynb
ADDED
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "118UKH5bWCGa"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"# DALL·E mini - Inference pipeline\n",
|
10 |
+
"\n",
|
11 |
+
"*Generate images from a text prompt*\n",
|
12 |
+
"\n",
|
13 |
+
"<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
|
14 |
+
"\n",
|
15 |
+
"This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
|
16 |
+
"\n",
|
17 |
+
"Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
|
18 |
+
"\n",
|
19 |
+
"For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"metadata": {
|
25 |
+
"id": "dS8LbaonYm3a"
|
26 |
+
},
|
27 |
+
"source": [
|
28 |
+
"## 🛠️ Installation and set-up"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 1,
|
34 |
+
"metadata": {
|
35 |
+
"colab": {
|
36 |
+
"base_uri": "https://localhost:8080/"
|
37 |
+
},
|
38 |
+
"id": "uzjAM2GBYpZX",
|
39 |
+
"outputId": "9042b53c-1260-4ae6-ff54-be878c99d505"
|
40 |
+
},
|
41 |
+
"outputs": [
|
42 |
+
{
|
43 |
+
"name": "stdout",
|
44 |
+
"output_type": "stream",
|
45 |
+
"text": [
|
46 |
+
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
47 |
+
"tensorflow-metal 0.5.0 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.\u001b[0m\u001b[31m\n",
|
48 |
+
"\u001b[0m"
|
49 |
+
]
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"source": [
|
53 |
+
"# Install required libraries\n",
|
54 |
+
"!pip install -q dalle-mini\n",
|
55 |
+
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"metadata": {
|
61 |
+
"id": "ozHzTkyv8cqU"
|
62 |
+
},
|
63 |
+
"source": [
|
64 |
+
"We load required models:\n",
|
65 |
+
"* DALL·E mini for text to encoded images\n",
|
66 |
+
"* VQGAN for decoding images\n",
|
67 |
+
"* CLIP for scoring predictions"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "code",
|
72 |
+
"execution_count": 2,
|
73 |
+
"metadata": {
|
74 |
+
"id": "K6CxW2o42f-w"
|
75 |
+
},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"# Model references\n",
|
79 |
+
"\n",
|
80 |
+
"# dalle-mega\n",
|
81 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
|
82 |
+
"DALLE_COMMIT_ID = None\n",
|
83 |
+
"\n",
|
84 |
+
"# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
|
85 |
+
"# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
|
86 |
+
"\n",
|
87 |
+
"# VQGAN model\n",
|
88 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
89 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 3,
|
95 |
+
"metadata": {
|
96 |
+
"colab": {
|
97 |
+
"base_uri": "https://localhost:8080/"
|
98 |
+
},
|
99 |
+
"id": "Yv-aR3t4Oe5v",
|
100 |
+
"outputId": "850b9a43-2506-432f-ae8e-b8b2598e4a98"
|
101 |
+
},
|
102 |
+
"outputs": [
|
103 |
+
{
|
104 |
+
"data": {
|
105 |
+
"text/plain": [
|
106 |
+
"1"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
"execution_count": 3,
|
110 |
+
"metadata": {},
|
111 |
+
"output_type": "execute_result"
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"source": [
|
115 |
+
"import jax\n",
|
116 |
+
"import jax.numpy as jnp\n",
|
117 |
+
"\n",
|
118 |
+
"# check how many devices are available\n",
|
119 |
+
"jax.local_device_count()"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": 1,
|
125 |
+
"metadata": {
|
126 |
+
"colab": {
|
127 |
+
"base_uri": "https://localhost:8080/",
|
128 |
+
"height": 240
|
129 |
+
},
|
130 |
+
"id": "92zYmvsQ38vL",
|
131 |
+
"outputId": "556dc277-a885-443b-8848-373696f5acc7"
|
132 |
+
},
|
133 |
+
"outputs": [
|
134 |
+
{
|
135 |
+
"ename": "NameError",
|
136 |
+
"evalue": "ignored",
|
137 |
+
"output_type": "error",
|
138 |
+
"traceback": [
|
139 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
140 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
141 |
+
"\u001b[0;32m<ipython-input-1-4a35db7a446b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# Load dalle-mini\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m model, params = DalleBart.from_pretrained(\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mDALLE_MODEL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrevision\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mDALLE_COMMIT_ID\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_do_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m )\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
142 |
+
"\u001b[0;31mNameError\u001b[0m: name 'DALLE_MODEL' is not defined"
|
143 |
+
]
|
144 |
+
}
|
145 |
+
],
|
146 |
+
"source": [
|
147 |
+
"# Load models & tokenizer\n",
|
148 |
+
"from dalle_mini import DalleBart, DalleBartProcessor\n",
|
149 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
150 |
+
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
151 |
+
"\n",
|
152 |
+
"# Load dalle-mini\n",
|
153 |
+
"model, params = DalleBart.from_pretrained(\n",
|
154 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
155 |
+
")\n",
|
156 |
+
"\n",
|
157 |
+
"# Load VQGAN\n",
|
158 |
+
"vqgan, vqgan_params = VQModel.from_pretrained(\n",
|
159 |
+
" VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
|
160 |
+
")"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"metadata": {
|
166 |
+
"id": "o_vH2X1tDtzA"
|
167 |
+
},
|
168 |
+
"source": [
|
169 |
+
"Model parameters are replicated on each device for faster inference."
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": null,
|
175 |
+
"metadata": {
|
176 |
+
"id": "wtvLoM48EeVw"
|
177 |
+
},
|
178 |
+
"outputs": [],
|
179 |
+
"source": [
|
180 |
+
"from flax.jax_utils import replicate\n",
|
181 |
+
"\n",
|
182 |
+
"params = replicate(params)\n",
|
183 |
+
"vqgan_params = replicate(vqgan_params)"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "markdown",
|
188 |
+
"metadata": {
|
189 |
+
"id": "0A9AHQIgZ_qw"
|
190 |
+
},
|
191 |
+
"source": [
|
192 |
+
"Model functions are compiled and parallelized to take advantage of multiple devices."
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": null,
|
198 |
+
"metadata": {
|
199 |
+
"id": "sOtoOmYsSYPz"
|
200 |
+
},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"from functools import partial\n",
|
204 |
+
"\n",
|
205 |
+
"# model inference\n",
|
206 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
|
207 |
+
"def p_generate(\n",
|
208 |
+
" tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
|
209 |
+
"):\n",
|
210 |
+
" return model.generate(\n",
|
211 |
+
" **tokenized_prompt,\n",
|
212 |
+
" prng_key=key,\n",
|
213 |
+
" params=params,\n",
|
214 |
+
" top_k=top_k,\n",
|
215 |
+
" top_p=top_p,\n",
|
216 |
+
" temperature=temperature,\n",
|
217 |
+
" condition_scale=condition_scale,\n",
|
218 |
+
" )\n",
|
219 |
+
"\n",
|
220 |
+
"\n",
|
221 |
+
"# decode image\n",
|
222 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
223 |
+
"def p_decode(indices, params):\n",
|
224 |
+
" return vqgan.decode_code(indices, params=params)"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "markdown",
|
229 |
+
"metadata": {
|
230 |
+
"id": "HmVN6IBwapBA"
|
231 |
+
},
|
232 |
+
"source": [
|
233 |
+
"Keys are passed to the model on each device to generate unique inference per device."
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": null,
|
239 |
+
"metadata": {
|
240 |
+
"id": "4CTXmlUkThhX"
|
241 |
+
},
|
242 |
+
"outputs": [],
|
243 |
+
"source": [
|
244 |
+
"import random\n",
|
245 |
+
"\n",
|
246 |
+
"# create a random key\n",
|
247 |
+
"seed = random.randint(0, 2**32 - 1)\n",
|
248 |
+
"key = jax.random.PRNGKey(seed)"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "markdown",
|
253 |
+
"metadata": {
|
254 |
+
"id": "BrnVyCo81pij"
|
255 |
+
},
|
256 |
+
"source": [
|
257 |
+
"## 🖍 Text Prompt"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "markdown",
|
262 |
+
"metadata": {
|
263 |
+
"id": "rsmj0Aj5OQox"
|
264 |
+
},
|
265 |
+
"source": [
|
266 |
+
"Our model requires processing prompts."
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": null,
|
272 |
+
"metadata": {
|
273 |
+
"id": "YjjhUychOVxm"
|
274 |
+
},
|
275 |
+
"outputs": [],
|
276 |
+
"source": [
|
277 |
+
"from dalle_mini import DalleBartProcessor\n",
|
278 |
+
"\n",
|
279 |
+
"processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"cell_type": "markdown",
|
284 |
+
"metadata": {
|
285 |
+
"id": "BQ7fymSPyvF_"
|
286 |
+
},
|
287 |
+
"source": [
|
288 |
+
"Let's define some text prompts."
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": null,
|
294 |
+
"metadata": {
|
295 |
+
"id": "x_0vI9ge1oKr"
|
296 |
+
},
|
297 |
+
"outputs": [],
|
298 |
+
"source": [
|
299 |
+
"prompts = [\n",
|
300 |
+
" \"sunset over a lake in the mountains\",\n",
|
301 |
+
" \"the Eiffel tower landing on the moon\",\n",
|
302 |
+
"]"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "markdown",
|
307 |
+
"metadata": {
|
308 |
+
"id": "XlZUG3SCLnGE"
|
309 |
+
},
|
310 |
+
"source": [
|
311 |
+
"Note: we could use the same prompt multiple times for faster inference."
|
312 |
+
]
|
313 |
+
},
|
314 |
+
{
|
315 |
+
"cell_type": "code",
|
316 |
+
"execution_count": null,
|
317 |
+
"metadata": {
|
318 |
+
"id": "VKjEZGjtO49k"
|
319 |
+
},
|
320 |
+
"outputs": [],
|
321 |
+
"source": [
|
322 |
+
"tokenized_prompts = processor(prompts)"
|
323 |
+
]
|
324 |
+
},
|
325 |
+
{
|
326 |
+
"cell_type": "markdown",
|
327 |
+
"metadata": {
|
328 |
+
"id": "-CEJBnuJOe5z"
|
329 |
+
},
|
330 |
+
"source": [
|
331 |
+
"Finally we replicate the prompts onto each device."
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "code",
|
336 |
+
"execution_count": null,
|
337 |
+
"metadata": {
|
338 |
+
"id": "lQePgju5Oe5z"
|
339 |
+
},
|
340 |
+
"outputs": [],
|
341 |
+
"source": [
|
342 |
+
"tokenized_prompt = replicate(tokenized_prompts)"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "markdown",
|
347 |
+
"metadata": {
|
348 |
+
"id": "phQ9bhjRkgAZ"
|
349 |
+
},
|
350 |
+
"source": [
|
351 |
+
"## 🎨 Generate images\n",
|
352 |
+
"\n",
|
353 |
+
"We generate images using dalle-mini model and decode them with the VQGAN."
|
354 |
+
]
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"cell_type": "code",
|
358 |
+
"execution_count": null,
|
359 |
+
"metadata": {
|
360 |
+
"id": "d0wVkXpKqnHA"
|
361 |
+
},
|
362 |
+
"outputs": [],
|
363 |
+
"source": [
|
364 |
+
"# number of predictions per prompt\n",
|
365 |
+
"n_predictions = 8\n",
|
366 |
+
"\n",
|
367 |
+
"# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
|
368 |
+
"gen_top_k = None\n",
|
369 |
+
"gen_top_p = None\n",
|
370 |
+
"temperature = None\n",
|
371 |
+
"cond_scale = 10.0"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "code",
|
376 |
+
"execution_count": null,
|
377 |
+
"metadata": {
|
378 |
+
"id": "SDjEx9JxR3v8"
|
379 |
+
},
|
380 |
+
"outputs": [],
|
381 |
+
"source": [
|
382 |
+
"from flax.training.common_utils import shard_prng_key\n",
|
383 |
+
"import numpy as np\n",
|
384 |
+
"from PIL import Image\n",
|
385 |
+
"from tqdm.notebook import trange\n",
|
386 |
+
"\n",
|
387 |
+
"print(f\"Prompts: {prompts}\\n\")\n",
|
388 |
+
"# generate images\n",
|
389 |
+
"images = []\n",
|
390 |
+
"for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
|
391 |
+
" # get a new key\n",
|
392 |
+
" key, subkey = jax.random.split(key)\n",
|
393 |
+
" # generate images\n",
|
394 |
+
" encoded_images = p_generate(\n",
|
395 |
+
" tokenized_prompt,\n",
|
396 |
+
" shard_prng_key(subkey),\n",
|
397 |
+
" params,\n",
|
398 |
+
" gen_top_k,\n",
|
399 |
+
" gen_top_p,\n",
|
400 |
+
" temperature,\n",
|
401 |
+
" cond_scale,\n",
|
402 |
+
" )\n",
|
403 |
+
" # remove BOS\n",
|
404 |
+
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
405 |
+
" # decode images\n",
|
406 |
+
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
407 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
408 |
+
" for decoded_img in decoded_images:\n",
|
409 |
+
" img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
|
410 |
+
" images.append(img)\n",
|
411 |
+
" display(img)\n",
|
412 |
+
" print()"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "markdown",
|
417 |
+
"metadata": {
|
418 |
+
"id": "tw02wG9zGmyB"
|
419 |
+
},
|
420 |
+
"source": [
|
421 |
+
"## 🏅 Optional: Rank images by CLIP score\n",
|
422 |
+
"\n",
|
423 |
+
"We can rank images according to CLIP.\n",
|
424 |
+
"\n",
|
425 |
+
"**Note: your session may crash if you don't have a subscription to Colab Pro.**"
|
426 |
+
]
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"cell_type": "code",
|
430 |
+
"execution_count": null,
|
431 |
+
"metadata": {
|
432 |
+
"id": "RGjlIW_f6GA0"
|
433 |
+
},
|
434 |
+
"outputs": [],
|
435 |
+
"source": [
|
436 |
+
"# CLIP model\n",
|
437 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
|
438 |
+
"CLIP_COMMIT_ID = None\n",
|
439 |
+
"\n",
|
440 |
+
"# Load CLIP\n",
|
441 |
+
"clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
|
442 |
+
" CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
443 |
+
")\n",
|
444 |
+
"clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
|
445 |
+
"clip_params = replicate(clip_params)\n",
|
446 |
+
"\n",
|
447 |
+
"# score images\n",
|
448 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
449 |
+
"def p_clip(inputs, params):\n",
|
450 |
+
" logits = clip(params=params, **inputs).logits_per_image\n",
|
451 |
+
" return logits"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "code",
|
456 |
+
"execution_count": null,
|
457 |
+
"metadata": {
|
458 |
+
"id": "FoLXpjCmGpju"
|
459 |
+
},
|
460 |
+
"outputs": [],
|
461 |
+
"source": [
|
462 |
+
"from flax.training.common_utils import shard\n",
|
463 |
+
"\n",
|
464 |
+
"# get clip scores\n",
|
465 |
+
"clip_inputs = clip_processor(\n",
|
466 |
+
" text=prompts * jax.device_count(),\n",
|
467 |
+
" images=images,\n",
|
468 |
+
" return_tensors=\"np\",\n",
|
469 |
+
" padding=\"max_length\",\n",
|
470 |
+
" max_length=77,\n",
|
471 |
+
" truncation=True,\n",
|
472 |
+
").data\n",
|
473 |
+
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
474 |
+
"\n",
|
475 |
+
"# organize scores per prompt\n",
|
476 |
+
"p = len(prompts)\n",
|
477 |
+
"logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
|
478 |
+
]
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"cell_type": "markdown",
|
482 |
+
"metadata": {
|
483 |
+
"id": "4AAWRm70LgED"
|
484 |
+
},
|
485 |
+
"source": [
|
486 |
+
"Let's now display images ranked by CLIP score."
|
487 |
+
]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"cell_type": "code",
|
491 |
+
"execution_count": null,
|
492 |
+
"metadata": {
|
493 |
+
"id": "zsgxxubLLkIu"
|
494 |
+
},
|
495 |
+
"outputs": [],
|
496 |
+
"source": [
|
497 |
+
"for i, prompt in enumerate(prompts):\n",
|
498 |
+
" print(f\"Prompt: {prompt}\\n\")\n",
|
499 |
+
" for idx in logits[i].argsort()[::-1]:\n",
|
500 |
+
" display(images[idx * p + i])\n",
|
501 |
+
" print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
|
502 |
+
" print()"
|
503 |
+
]
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"cell_type": "markdown",
|
507 |
+
"metadata": {
|
508 |
+
"id": "oZT9i3jCjir0"
|
509 |
+
},
|
510 |
+
"source": [
|
511 |
+
"## 🪄 Optional: Save your Generated Images as W&B Tables\n",
|
512 |
+
"\n",
|
513 |
+
"W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
|
514 |
+
]
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"cell_type": "code",
|
518 |
+
"execution_count": null,
|
519 |
+
"metadata": {
|
520 |
+
"id": "-pSiv6Vwjkn0"
|
521 |
+
},
|
522 |
+
"outputs": [],
|
523 |
+
"source": [
|
524 |
+
"import wandb\n",
|
525 |
+
"\n",
|
526 |
+
"# Initialize a W&B run.\n",
|
527 |
+
"project = 'dalle-mini-tables-colab'\n",
|
528 |
+
"run = wandb.init(project=project)\n",
|
529 |
+
"\n",
|
530 |
+
"# Initialize an empty W&B Tables.\n",
|
531 |
+
"columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
|
532 |
+
"gen_table = wandb.Table(columns=columns)\n",
|
533 |
+
"\n",
|
534 |
+
"# Add data to the table.\n",
|
535 |
+
"for i, prompt in enumerate(prompts):\n",
|
536 |
+
" # If CLIP scores exist, sort the Images\n",
|
537 |
+
" if logits is not None:\n",
|
538 |
+
" idxs = logits[i].argsort()[::-1]\n",
|
539 |
+
" tmp_imgs = images[i::len(prompts)]\n",
|
540 |
+
" tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
|
541 |
+
" else:\n",
|
542 |
+
" tmp_imgs = images[i::len(prompts)]\n",
|
543 |
+
"\n",
|
544 |
+
" # Add the data to the table.\n",
|
545 |
+
" gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
|
546 |
+
"\n",
|
547 |
+
"# Log the Table to W&B dashboard.\n",
|
548 |
+
"wandb.log({\"Generated Images\": gen_table})\n",
|
549 |
+
"\n",
|
550 |
+
"# Close the W&B run.\n",
|
551 |
+
"run.finish()"
|
552 |
+
]
|
553 |
+
},
|
554 |
+
{
|
555 |
+
"cell_type": "markdown",
|
556 |
+
"metadata": {
|
557 |
+
"id": "Ck2ZnHwVjnRd"
|
558 |
+
},
|
559 |
+
"source": [
|
560 |
+
"Click on the link above to check out your generated images."
|
561 |
+
]
|
562 |
+
}
|
563 |
+
],
|
564 |
+
"metadata": {
|
565 |
+
"accelerator": "GPU",
|
566 |
+
"colab": {
|
567 |
+
"collapsed_sections": [],
|
568 |
+
"machine_shape": "hm",
|
569 |
+
"name": "DALL·E mini - Inference pipeline.ipynb",
|
570 |
+
"provenance": []
|
571 |
+
},
|
572 |
+
"kernelspec": {
|
573 |
+
"display_name": "Python 3.9.13 ('base')",
|
574 |
+
"language": "python",
|
575 |
+
"name": "python3"
|
576 |
+
},
|
577 |
+
"language_info": {
|
578 |
+
"codemirror_mode": {
|
579 |
+
"name": "ipython",
|
580 |
+
"version": 3
|
581 |
+
},
|
582 |
+
"file_extension": ".py",
|
583 |
+
"mimetype": "text/x-python",
|
584 |
+
"name": "python",
|
585 |
+
"nbconvert_exporter": "python",
|
586 |
+
"pygments_lexer": "ipython3",
|
587 |
+
"version": "3.9.13"
|
588 |
+
},
|
589 |
+
"vscode": {
|
590 |
+
"interpreter": {
|
591 |
+
"hash": "3e91440bae70fe36b08f2decfecf198c5281689ed89adf5e1c2c93a1bdd6e28e"
|
592 |
+
}
|
593 |
+
}
|
594 |
+
},
|
595 |
+
"nbformat": 4,
|
596 |
+
"nbformat_minor": 0
|
597 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Eeman Majumder
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README copy.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Some_Gen_Stuff
|
2 |
+
Trying inference models
|
Testing_SD.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#___________________________________________________________________________________________________________________________
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
import os
|
5 |
+
|
6 |
+
#___________________________________________________________________________________________________________________________
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import autocast
|
10 |
+
from diffusers import StableDiffusionPipeline
|
11 |
+
from datasets import load_dataset
|
12 |
+
from PIL import Image
|
13 |
+
import re
|
14 |
+
|
15 |
+
#___________________________________________________________________________________________________________________________
|
16 |
+
|
17 |
+
st.title('IMGTEXTA')
|
18 |
+
|
19 |
+
#___________________________________________________________________________________________________________________________
|
20 |
+
|
21 |
+
model_id = "CompVis/stable-diffusion-v1-4"
|
22 |
+
device = "cpu"
|
23 |
+
|
24 |
+
#___________________________________________________________________________________________________________________________
|
25 |
+
|
26 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=st.secrets["AUTH_KEY"], torch_dtype=torch.float32)
|
27 |
+
def dummy(images, **kwargs): return images, False
|
28 |
+
pipe.safety_checker = dummy
|
29 |
+
|
30 |
+
#___________________________________________________________________________________________________________________________
|
31 |
+
|
32 |
+
def infer(prompt, width, height, steps, scale, seed):
|
33 |
+
if seed == -1:
|
34 |
+
images_list = pipe(
|
35 |
+
[prompt],
|
36 |
+
height=height,
|
37 |
+
width=width,
|
38 |
+
num_inference_steps=steps,
|
39 |
+
guidance_scale=scale,
|
40 |
+
generator=torch.Generator(device=device).manual_seed(seed))
|
41 |
+
else:
|
42 |
+
images_list = pipe(
|
43 |
+
[prompt],
|
44 |
+
height=height,
|
45 |
+
width=width,
|
46 |
+
num_inference_steps=steps,
|
47 |
+
guidance_scale=scale)
|
48 |
+
|
49 |
+
return images_list["sample"]
|
50 |
+
|
51 |
+
#___________________________________________________________________________________________________________________________
|
52 |
+
|
53 |
+
def onclick(prompt):
|
54 |
+
st.image(infer(prompt,512,512,30,7.5,-1))
|
55 |
+
prompt=st.text_input('Enter Your Prompt')
|
56 |
+
if prompt==True:
|
57 |
+
onclick(prompt)
|
58 |
+
|
59 |
+
#___________________________________________________________________________________________________________________________
|
60 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers
|
2 |
+
transformers
|
3 |
+
nvidia-ml-py3
|
4 |
+
ftfy
|
5 |
+
datasets
|
6 |
+
--extra-index-url https://download.pytorch.org/whl/cu113 torch
|