boris commited on
Commit
5b16588
·
1 Parent(s): cb127c4

fix: correct clip params

Browse files
tools/inference/log_inference_samples.ipynb CHANGED
@@ -24,25 +24,6 @@
24
  "from dalle_mini.text import TextNormalizer"
25
  ]
26
  },
27
- {
28
- "cell_type": "code",
29
- "execution_count": null,
30
- "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
31
- "metadata": {},
32
- "outputs": [],
33
- "source": [
34
- "run_ids = ['3kaut6e8']\n",
35
- "# Alamy - 3kaut6e8\n",
36
- "# YFCC - to do\n",
37
- "# HF spaces - 4oh3u7ca\n",
38
- "ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
39
- "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
40
- "normalize_text = False\n",
41
- "latest_only = True # log only latest or all versions\n",
42
- "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
43
- "add_clip_32 = False"
44
- ]
45
- },
46
  {
47
  "cell_type": "code",
48
  "execution_count": null,
@@ -50,13 +31,9 @@
50
  "metadata": {},
51
  "outputs": [],
52
  "source": [
53
- "run_ids = ['2u5lk3uw']\n",
54
- "# poorly shuffled 1nj161cl\n",
55
- "# well shuffled he9rrc3q\n",
56
- "# non normalized 1fwxpyfh ! requires changing normalize_text\n",
57
  "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
58
- "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
59
- "normalize_text = True\n",
60
  "latest_only = True # log only latest or all versions\n",
61
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
62
  "add_clip_32 = False"
@@ -85,7 +62,7 @@
85
  "batch_size = 8\n",
86
  "num_images = 128\n",
87
  "top_k = 8\n",
88
- "text_normalizer = TextNormalizer() if normalize_text else None\n",
89
  "padding_item = 'NONE'\n",
90
  "seed = random.randint(0, 2**32-1)\n",
91
  "key = jax.random.PRNGKey(seed)\n",
@@ -230,7 +207,7 @@
230
  "outputs": [],
231
  "source": [
232
  "run_id = run_ids[0]\n",
233
- "# TODO: turn everything into a class or loop over runs"
234
  ]
235
  },
236
  {
@@ -287,7 +264,7 @@
287
  "\n",
288
  " # process one batch of captions\n",
289
  " for batch in tqdm(samples):\n",
290
- " processed_prompts = [text_normalizer(x) for x in batch] if normalize_text else list(batch)\n",
291
  "\n",
292
  " # repeat the prompts to distribute over each device and tokenize\n",
293
  " processed_prompts = processed_prompts * jax.device_count()\n",
@@ -296,7 +273,7 @@
296
  "\n",
297
  " # generate images\n",
298
  " images = []\n",
299
- " pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=None)\n",
300
  " for i in pbar:\n",
301
  " key, subkey = jax.random.split(key)\n",
302
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
@@ -312,7 +289,7 @@
312
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
313
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
314
  " clip_inputs = shard(clip_inputs)\n",
315
- " logits = p_clip(clip_inputs, clip32_params)\n",
316
  " logits = logits.reshape(-1, num_images)\n",
317
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
318
  " logits = jax.device_get(logits)\n",
@@ -348,6 +325,14 @@
348
  " wandb.finish()\n",
349
  " run = None # ensure we don't log on this run"
350
  ]
 
 
 
 
 
 
 
 
351
  }
352
  ],
353
  "metadata": {
 
24
  "from dalle_mini.text import TextNormalizer"
25
  ]
26
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  {
28
  "cell_type": "code",
29
  "execution_count": null,
 
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
+ "run_ids = ['63otg87g']\n",
 
 
 
35
  "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
36
+ "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
 
37
  "latest_only = True # log only latest or all versions\n",
38
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
39
  "add_clip_32 = False"
 
62
  "batch_size = 8\n",
63
  "num_images = 128\n",
64
  "top_k = 8\n",
65
+ "text_normalizer = TextNormalizer()\n",
66
  "padding_item = 'NONE'\n",
67
  "seed = random.randint(0, 2**32-1)\n",
68
  "key = jax.random.PRNGKey(seed)\n",
 
207
  "outputs": [],
208
  "source": [
209
  "run_id = run_ids[0]\n",
210
+ "# TODO: loop over runs"
211
  ]
212
  },
213
  {
 
264
  "\n",
265
  " # process one batch of captions\n",
266
  " for batch in tqdm(samples):\n",
267
+ " processed_prompts = [text_normalizer(x) for x in batch] if model.config.normalize_text else list(batch)\n",
268
  "\n",
269
  " # repeat the prompts to distribute over each device and tokenize\n",
270
  " processed_prompts = processed_prompts * jax.device_count()\n",
 
273
  "\n",
274
  " # generate images\n",
275
  " images = []\n",
276
+ " pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=True)\n",
277
  " for i in pbar:\n",
278
  " key, subkey = jax.random.split(key)\n",
279
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
 
289
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
290
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
291
  " clip_inputs = shard(clip_inputs)\n",
292
+ " logits = p_clip(clip_inputs, clip_params)\n",
293
  " logits = logits.reshape(-1, num_images)\n",
294
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
295
  " logits = jax.device_get(logits)\n",
 
325
  " wandb.finish()\n",
326
  " run = None # ensure we don't log on this run"
327
  ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "id": "415d3f54-7226-43de-9eea-4283a948dc93",
333
+ "metadata": {},
334
+ "outputs": [],
335
+ "source": []
336
  }
337
  ],
338
  "metadata": {