boris commited on
Commit
7851774
·
2 Parent(s): bc78bfd 62e13ba

Merge pull request #42 from borisdayma/chore-clean

Browse files

chore: cleanup repo
Former-commit-id: 9977d1dc821ac8be7eef928e1aa6e2aaacd2c5f7

Files changed (3) hide show
  1. README.md +7 -3
  2. dev/seq2seq/run_seq2seq_flax.py +6 -52
  3. img/logo.png +0 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Dalle Mini
3
- emoji: 🎨
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
@@ -12,13 +12,17 @@ pinned: false
12
 
13
  _Generate images from a text prompt_
14
 
15
- TODO: add some cool example
 
 
 
 
16
 
17
  ## Create my own images with the demo → Coming soon
18
 
19
  ## How does it work?
20
 
21
- Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA?accessToken=2ua7j8ebc810fuxyv49wbipmq3fb2e78yq3rvs5dy4wew07wwm2csdo8zcuyr14e).
22
 
23
  ## Development
24
 
 
1
  ---
2
  title: Dalle Mini
3
+ emoji: 🥑
4
  colorFrom: red
5
  colorTo: blue
6
  sdk: gradio
 
12
 
13
  _Generate images from a text prompt_
14
 
15
+ <img src="img/logo.png" width="200">
16
+
17
+ Our logo was generated with DALL-E mini by typing "logo of an armchair in the shape of an avocado".
18
+
19
+ You can also create your own pictures with the demo (TODO: add link).
20
 
21
  ## Create my own images with the demo → Coming soon
22
 
23
  ## How does it work?
24
 
25
+ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
26
 
27
  ## Development
28
 
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -83,6 +83,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
83
 
84
 
85
  # Model hyperparameters, for convenience
 
86
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
87
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
88
  BOS_TOKEN_ID = 16384
@@ -217,7 +218,7 @@ class DataTrainingArguments:
217
  default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
218
  )
219
  predict_with_generate: bool = field(
220
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
221
  )
222
  num_beams: Optional[int] = field(
223
  default=None,
@@ -376,9 +377,6 @@ def main():
376
  else:
377
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
378
 
379
- logger.warning(f"WARNING: eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
380
- training_args.eval_steps = 400
381
-
382
  if (
383
  os.path.exists(training_args.output_dir)
384
  and os.listdir(training_args.output_dir)
@@ -425,11 +423,10 @@ def main():
425
  # (the dataset will be downloaded automatically from the datasets Hub).
426
  #
427
  data_files = {}
428
- logger.warning(f"WARNING: Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
429
  if data_args.train_file is not None:
430
- data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv", "/data/YFCC/metadata_encoded.tsv"]
431
  if data_args.validation_file is not None:
432
- data_files["validation"] = ["/data/CC3M/validation-encoded.tsv"]
433
  if data_args.test_file is not None:
434
  data_files["test"] = data_args.test_file
435
  dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
@@ -608,35 +605,6 @@ def main():
608
  desc="Running tokenizer on prediction dataset",
609
  )
610
 
611
- # Metric
612
- #metric = load_metric("rouge")
613
-
614
- def postprocess_text(preds, labels):
615
- preds = [pred.strip() for pred in preds]
616
- labels = [label.strip() for label in labels]
617
-
618
- # rougeLSum expects newline after each sentence
619
- preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
620
- labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
621
-
622
- return preds, labels
623
-
624
- def compute_metrics(preds, labels):
625
- decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
626
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
627
-
628
- # Some simple post-processing
629
- decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
630
-
631
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
632
- # Extract a few results from ROUGE
633
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
634
-
635
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
636
- result["gen_len"] = np.mean(prediction_lens)
637
- result = {k: round(v, 4) for k, v in result.items()}
638
- return result
639
-
640
  # Initialize our training
641
  rng = jax.random.PRNGKey(training_args.seed)
642
  rng, dropout_rng = jax.random.split(rng)
@@ -822,15 +790,8 @@ def main():
822
  # log metrics
823
  wandb_log(eval_metrics, step=global_step, prefix='eval')
824
 
825
- # compute ROUGE metrics
826
- rouge_desc = ""
827
- # if data_args.predict_with_generate:
828
- # rouge_metrics = compute_metrics(eval_preds, eval_labels)
829
- # eval_metrics.update(rouge_metrics)
830
- # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
831
-
832
  # Print metrics and update progress bar
833
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
834
  epochs.write(desc)
835
  epochs.desc = desc
836
 
@@ -955,15 +916,8 @@ def main():
955
  pred_metrics = get_metrics(pred_metrics)
956
  pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
957
 
958
- # compute ROUGE metrics
959
- rouge_desc = ""
960
- if data_args.predict_with_generate:
961
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
962
- pred_metrics.update(rouge_metrics)
963
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
964
-
965
  # Print metrics
966
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
967
  logger.info(desc)
968
 
969
 
 
83
 
84
 
85
  # Model hyperparameters, for convenience
86
+ # TODO: the model has now it's own definition file and should be imported
87
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
88
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
89
  BOS_TOKEN_ID = 16384
 
218
  default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
219
  )
220
  predict_with_generate: bool = field(
221
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
222
  )
223
  num_beams: Optional[int] = field(
224
  default=None,
 
377
  else:
378
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
379
 
 
 
 
380
  if (
381
  os.path.exists(training_args.output_dir)
382
  and os.listdir(training_args.output_dir)
 
423
  # (the dataset will be downloaded automatically from the datasets Hub).
424
  #
425
  data_files = {}
 
426
  if data_args.train_file is not None:
427
+ data_files["train"] = data_args.train_file
428
  if data_args.validation_file is not None:
429
+ data_files["validation"] = data_args.validation_file
430
  if data_args.test_file is not None:
431
  data_files["test"] = data_args.test_file
432
  dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
 
605
  desc="Running tokenizer on prediction dataset",
606
  )
607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  # Initialize our training
609
  rng = jax.random.PRNGKey(training_args.seed)
610
  rng, dropout_rng = jax.random.split(rng)
 
790
  # log metrics
791
  wandb_log(eval_metrics, step=global_step, prefix='eval')
792
 
 
 
 
 
 
 
 
793
  # Print metrics and update progress bar
794
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
795
  epochs.write(desc)
796
  epochs.desc = desc
797
 
 
916
  pred_metrics = get_metrics(pred_metrics)
917
  pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
918
 
 
 
 
 
 
 
 
919
  # Print metrics
920
+ desc = f"Predict Loss: {pred_metrics['loss']})"
921
  logger.info(desc)
922
 
923
 
img/logo.png ADDED