boris commited on
Commit
bf4da91
·
unverified ·
2 Parent(s): 3073ff4 5b79afd

Merge pull request #16 from borisdayma/feat-log_model

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. seq2seq/run_seq2seq_flax.py +34 -11
requirements.txt CHANGED
@@ -9,4 +9,4 @@ flax
9
  jupyter
10
  # for logging
11
  tensorboard
12
- tetnsorflow
 
9
  jupyter
10
  # for logging
11
  tensorboard
12
+ tensorflow
seq2seq/run_seq2seq_flax.py CHANGED
@@ -199,7 +199,7 @@ class DataTrainingArguments:
199
  },
200
  )
201
  preprocessing_num_workers: Optional[int] = field(
202
- default=None,
203
  metadata={"help": "The number of processes to use for the preprocessing."},
204
  )
205
  source_prefix: Optional[str] = field(
@@ -225,6 +225,9 @@ class DataTrainingArguments:
225
  "value if set."
226
  },
227
  )
 
 
 
228
 
229
  def __post_init__(self):
230
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -812,6 +815,36 @@ def main():
812
  cur_step = epoch * (len(train_dataset) // train_batch_size)
813
  write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815
  # ======================== Prediction loop ==============================
816
  if training_args.do_predict:
817
  logger.info("*** Predict ***")
@@ -851,16 +884,6 @@ def main():
851
  desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
852
  logger.info(desc)
853
 
854
- # save checkpoint after each epoch and push checkpoint to the hub
855
- if jax.process_index() == 0:
856
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
857
- model.save_pretrained(
858
- training_args.output_dir,
859
- params=params,
860
- push_to_hub=training_args.push_to_hub,
861
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
862
- )
863
-
864
 
865
  if __name__ == "__main__":
866
  main()
 
199
  },
200
  )
201
  preprocessing_num_workers: Optional[int] = field(
202
+ default=80, # ensure we have the same datasets cached data and avoid using too much space
203
  metadata={"help": "The number of processes to use for the preprocessing."},
204
  )
205
  source_prefix: Optional[str] = field(
 
225
  "value if set."
226
  },
227
  )
228
+ log_model: bool = field(
229
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
230
+ )
231
 
232
  def __post_init__(self):
233
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
 
815
  cur_step = epoch * (len(train_dataset) // train_batch_size)
816
  write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
817
 
818
+ # save checkpoint after each epoch and push checkpoint to the hub
819
+ if jax.process_index() == 0:
820
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
821
+
822
+ # save model locally
823
+ model.save_pretrained(
824
+ training_args.output_dir,
825
+ params=params,
826
+ )
827
+
828
+ # save to W&B
829
+ if data_args.log_model:
830
+ metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
831
+ artifact = wandb.Artifact(
832
+ name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
833
+ )
834
+ artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
835
+ artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
836
+ wandb.run.log_artifact(artifact)
837
+
838
+ # save to the hub
839
+ if training_args.push_to_hub:
840
+ model.save_pretrained(
841
+ training_args.output_dir,
842
+ params=params,
843
+ push_to_hub=training_args.push_to_hub,
844
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
845
+ temp_dir=True # avoid issues with being in a repository
846
+ )
847
+
848
  # ======================== Prediction loop ==============================
849
  if training_args.do_predict:
850
  logger.info("*** Predict ***")
 
884
  desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
885
  logger.info(desc)
886
 
 
 
 
 
 
 
 
 
 
 
887
 
888
  if __name__ == "__main__":
889
  main()