Spaces:
Running
Running
fix: style
Browse files- tools/inference/inference_pipeline.ipynb +46 -19
- tools/train/train.py +5 -5
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -70,15 +70,15 @@
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
-
"DALLE_MODEL =
|
74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
77 |
-
"VQGAN_REPO =
|
78 |
-
"VQGAN_COMMIT_ID =
|
79 |
"\n",
|
80 |
"# CLIP model\n",
|
81 |
-
"CLIP_REPO =
|
82 |
"CLIP_COMMIT_ID = None"
|
83 |
]
|
84 |
},
|
@@ -121,18 +121,28 @@
|
|
121 |
"import wandb\n",
|
122 |
"\n",
|
123 |
"# Load dalle-mini\n",
|
124 |
-
"if
|
125 |
" # wandb artifact\n",
|
126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
127 |
" # we only download required files (no need for opt_state which is large)\n",
|
128 |
-
" model_files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
" for f in model_files:\n",
|
130 |
-
" artifact.get_path(f).download(
|
131 |
-
" model = DalleBart.from_pretrained(
|
132 |
-
" tokenizer = AutoTokenizer.from_pretrained(
|
133 |
"else:\n",
|
134 |
" # local folder or 🤗 Hub\n",
|
135 |
-
" model = DalleBart.from_pretrained(
|
|
|
|
|
136 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
137 |
"\n",
|
138 |
"# Load VQGAN\n",
|
@@ -191,7 +201,7 @@
|
|
191 |
"from functools import partial\n",
|
192 |
"\n",
|
193 |
"# model inference\n",
|
194 |
-
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3,4))\n",
|
195 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
196 |
" return model.generate(\n",
|
197 |
" **tokenized_prompt,\n",
|
@@ -203,11 +213,13 @@
|
|
203 |
" top_p=top_p\n",
|
204 |
" )\n",
|
205 |
"\n",
|
|
|
206 |
"# decode images\n",
|
207 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
208 |
"def p_decode(indices, params):\n",
|
209 |
" return vqgan.decode_code(indices, params=params)\n",
|
210 |
"\n",
|
|
|
211 |
"# score images\n",
|
212 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
213 |
"def p_clip(inputs, params):\n",
|
@@ -235,7 +247,7 @@
|
|
235 |
"import random\n",
|
236 |
"\n",
|
237 |
"# create a random key\n",
|
238 |
-
"seed = random.randint(0, 2**32-1)\n",
|
239 |
"key = jax.random.PRNGKey(seed)"
|
240 |
]
|
241 |
},
|
@@ -287,7 +299,7 @@
|
|
287 |
},
|
288 |
"outputs": [],
|
289 |
"source": [
|
290 |
-
"prompt =
|
291 |
]
|
292 |
},
|
293 |
{
|
@@ -323,7 +335,13 @@
|
|
323 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
324 |
"\n",
|
325 |
"# tokenize\n",
|
326 |
-
"tokenized_prompt = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
"tokenized_prompt"
|
328 |
]
|
329 |
},
|
@@ -408,12 +426,14 @@
|
|
408 |
" # get a new key\n",
|
409 |
" key, subkey = jax.random.split(key)\n",
|
410 |
" # generate images\n",
|
411 |
-
" encoded_images = p_generate(
|
|
|
|
|
412 |
" # remove BOS\n",
|
413 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
414 |
" # decode images\n",
|
415 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
416 |
-
" decoded_images = decoded_images.clip(0
|
417 |
" for img in decoded_images:\n",
|
418 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
419 |
]
|
@@ -436,7 +456,14 @@
|
|
436 |
"outputs": [],
|
437 |
"source": [
|
438 |
"# get clip scores\n",
|
439 |
-
"clip_inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
441 |
"logits = logits.squeeze().flatten()"
|
442 |
]
|
@@ -458,10 +485,10 @@
|
|
458 |
},
|
459 |
"outputs": [],
|
460 |
"source": [
|
461 |
-
"print(f
|
462 |
"for idx in logits.argsort()[::-1]:\n",
|
463 |
" display(images[idx])\n",
|
464 |
-
" print(f
|
465 |
]
|
466 |
}
|
467 |
],
|
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-3bqwu04f:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
|
74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
77 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
78 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
79 |
"\n",
|
80 |
"# CLIP model\n",
|
81 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
|
82 |
"CLIP_COMMIT_ID = None"
|
83 |
]
|
84 |
},
|
|
|
121 |
"import wandb\n",
|
122 |
"\n",
|
123 |
"# Load dalle-mini\n",
|
124 |
+
"if \":\" in DALLE_MODEL:\n",
|
125 |
" # wandb artifact\n",
|
126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
127 |
" # we only download required files (no need for opt_state which is large)\n",
|
128 |
+
" model_files = [\n",
|
129 |
+
" \"config.json\",\n",
|
130 |
+
" \"flax_model.msgpack\",\n",
|
131 |
+
" \"merges.txt\",\n",
|
132 |
+
" \"special_tokens_map.json\",\n",
|
133 |
+
" \"tokenizer.json\",\n",
|
134 |
+
" \"tokenizer_config.json\",\n",
|
135 |
+
" \"vocab.json\",\n",
|
136 |
+
" ]\n",
|
137 |
" for f in model_files:\n",
|
138 |
+
" artifact.get_path(f).download(\"model\")\n",
|
139 |
+
" model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
|
140 |
+
" tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
|
141 |
"else:\n",
|
142 |
" # local folder or 🤗 Hub\n",
|
143 |
+
" model = DalleBart.from_pretrained(\n",
|
144 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
145 |
+
" )\n",
|
146 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
147 |
"\n",
|
148 |
"# Load VQGAN\n",
|
|
|
201 |
"from functools import partial\n",
|
202 |
"\n",
|
203 |
"# model inference\n",
|
204 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
|
205 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
206 |
" return model.generate(\n",
|
207 |
" **tokenized_prompt,\n",
|
|
|
213 |
" top_p=top_p\n",
|
214 |
" )\n",
|
215 |
"\n",
|
216 |
+
"\n",
|
217 |
"# decode images\n",
|
218 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
219 |
"def p_decode(indices, params):\n",
|
220 |
" return vqgan.decode_code(indices, params=params)\n",
|
221 |
"\n",
|
222 |
+
"\n",
|
223 |
"# score images\n",
|
224 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
225 |
"def p_clip(inputs, params):\n",
|
|
|
247 |
"import random\n",
|
248 |
"\n",
|
249 |
"# create a random key\n",
|
250 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
251 |
"key = jax.random.PRNGKey(seed)"
|
252 |
]
|
253 |
},
|
|
|
299 |
},
|
300 |
"outputs": [],
|
301 |
"source": [
|
302 |
+
"prompt = \"a red T-shirt\""
|
303 |
]
|
304 |
},
|
305 |
{
|
|
|
335 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
336 |
"\n",
|
337 |
"# tokenize\n",
|
338 |
+
"tokenized_prompt = tokenizer(\n",
|
339 |
+
" repeated_prompts,\n",
|
340 |
+
" return_tensors=\"jax\",\n",
|
341 |
+
" padding=\"max_length\",\n",
|
342 |
+
" truncation=True,\n",
|
343 |
+
" max_length=128,\n",
|
344 |
+
").data\n",
|
345 |
"tokenized_prompt"
|
346 |
]
|
347 |
},
|
|
|
426 |
" # get a new key\n",
|
427 |
" key, subkey = jax.random.split(key)\n",
|
428 |
" # generate images\n",
|
429 |
+
" encoded_images = p_generate(\n",
|
430 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
|
431 |
+
" )\n",
|
432 |
" # remove BOS\n",
|
433 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
434 |
" # decode images\n",
|
435 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
436 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
437 |
" for img in decoded_images:\n",
|
438 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
439 |
]
|
|
|
456 |
"outputs": [],
|
457 |
"source": [
|
458 |
"# get clip scores\n",
|
459 |
+
"clip_inputs = processor(\n",
|
460 |
+
" text=[prompt] * jax.device_count(),\n",
|
461 |
+
" images=images,\n",
|
462 |
+
" return_tensors=\"np\",\n",
|
463 |
+
" padding=\"max_length\",\n",
|
464 |
+
" max_length=77,\n",
|
465 |
+
" truncation=True,\n",
|
466 |
+
").data\n",
|
467 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
468 |
"logits = logits.squeeze().flatten()"
|
469 |
]
|
|
|
485 |
},
|
486 |
"outputs": [],
|
487 |
"source": [
|
488 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
489 |
"for idx in logits.argsort()[::-1]:\n",
|
490 |
" display(images[idx])\n",
|
491 |
+
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
492 |
]
|
493 |
}
|
494 |
],
|
tools/train/train.py
CHANGED
@@ -219,9 +219,7 @@ class TrainingArguments:
|
|
219 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
220 |
},
|
221 |
)
|
222 |
-
weight_decay: float = field(
|
223 |
-
default=None, metadata={"help": "Weight decay."}
|
224 |
-
)
|
225 |
beta1: float = field(
|
226 |
default=0.9,
|
227 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
@@ -237,13 +235,15 @@ class TrainingArguments:
|
|
237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
238 |
)
|
239 |
block_size: int = field(
|
240 |
-
default=1024,
|
|
|
241 |
)
|
242 |
preconditioning_compute_steps: int = field(
|
243 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
244 |
)
|
245 |
skip_preconditioning_dim_size_gt: int = field(
|
246 |
-
default=4096,
|
|
|
247 |
)
|
248 |
optim_quantized: bool = field(
|
249 |
default=False,
|
|
|
219 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
220 |
},
|
221 |
)
|
222 |
+
weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
|
|
|
|
|
223 |
beta1: float = field(
|
224 |
default=0.9,
|
225 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
|
|
235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
236 |
)
|
237 |
block_size: int = field(
|
238 |
+
default=1024,
|
239 |
+
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
240 |
)
|
241 |
preconditioning_compute_steps: int = field(
|
242 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
243 |
)
|
244 |
skip_preconditioning_dim_size_gt: int = field(
|
245 |
+
default=4096,
|
246 |
+
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
|
247 |
)
|
248 |
optim_quantized: bool = field(
|
249 |
default=False,
|