kaushalya commited on
Commit
af144e3
·
1 Parent(s): acfbf02

Update notebooks

Browse files
run_medclip.sh CHANGED
@@ -1,4 +1,4 @@
1
- python src/hybrid_clip/run_hybrid_clip.py \
2
  --output_dir ./snapshots/vision_augmented_biobert \
3
  --text_model_name_or_path="allenai/scibert_scivocab_uncased" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
 
1
+ python src/medclip/run_medclip.py \
2
  --output_dir ./snapshots/vision_augmented_biobert \
3
  --text_model_name_or_path="allenai/scibert_scivocab_uncased" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
src/medclip/run_medclip.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training a CLIP like dual encoder models using text and vision encoders in the library.
18
+
19
+ The script can be used to train CLIP like models for languages other than english by using
20
+ a text encoder pre-trained in the desired language. Currently this script support the following vision
21
+ and text models:
22
+ Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
23
+ Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import os
29
+ import sys
30
+ import time
31
+ from dataclasses import dataclass, field
32
+ from pathlib import Path
33
+ from typing import Callable, Optional
34
+
35
+ import torch
36
+ from torchvision.datasets import VisionDataset
37
+ from torchvision.io import ImageReadMode, read_image
38
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
39
+ from torchvision.transforms.functional import InterpolationMode
40
+ from torchvision.transforms.transforms import GaussianBlur, RandomAutocontrast, RandomHorizontalFlip
41
+ from tqdm import tqdm
42
+
43
+ import jax
44
+ import jax.numpy as jnp
45
+ import optax
46
+ import transformers
47
+ from flax import jax_utils
48
+ from flax.jax_utils import unreplicate
49
+ from flax.training import train_state
50
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
51
+ from modeling_hybrid_clip import FlaxHybridCLIP
52
+ from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+ # Cache the result
58
+ has_tensorboard = is_tensorboard_available()
59
+ if has_tensorboard:
60
+ try:
61
+ from flax.metrics.tensorboard import SummaryWriter
62
+ except ImportError as ie:
63
+ has_tensorboard = False
64
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
65
+
66
+ else:
67
+ print(
68
+ "Unable to display metrics through TensorBoard because the package is not installed: "
69
+ "Please run pip install tensorboard to enable."
70
+ )
71
+
72
+
73
+ @dataclass
74
+ class ModelArguments:
75
+ """
76
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
77
+ """
78
+
79
+ text_model_name_or_path: str = field(
80
+ metadata={
81
+ "help": "The text model checkpoint for weights initialization."
82
+ "Don't set if you want to train a model from scratch."
83
+ },
84
+ )
85
+ vision_model_name_or_path: str = field(
86
+ metadata={
87
+ "help": "The vision model checkpoint for weights initialization."
88
+ "Don't set if you want to train a model from scratch."
89
+ },
90
+ )
91
+ from_pt: bool = field(
92
+ default=True,
93
+ metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
94
+ )
95
+ config_name: Optional[str] = field(
96
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
97
+ )
98
+ tokenizer_name: Optional[str] = field(
99
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
100
+ )
101
+ cache_dir: Optional[str] = field(
102
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
103
+ )
104
+ use_fast_tokenizer: bool = field(
105
+ default=True,
106
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
107
+ )
108
+ dtype: Optional[str] = field(
109
+ default="float32",
110
+ metadata={
111
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
112
+ },
113
+ )
114
+
115
+
116
+ @dataclass
117
+ class DataTrainingArguments:
118
+ """
119
+ Arguments pertaining to what data we are going to input our model for training and eval.
120
+ """
121
+
122
+ data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
123
+ train_file: Optional[str] = field(
124
+ default=None, metadata={"help": "The input training data file (a jsonlines file)."}
125
+ )
126
+ validation_file: Optional[str] = field(
127
+ default=None,
128
+ metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
129
+ )
130
+ max_seq_length: Optional[int] = field(
131
+ default=72,
132
+ metadata={
133
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
134
+ "than this will be truncated, sequences shorter will be padded."
135
+ },
136
+ )
137
+ max_train_samples: Optional[int] = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
141
+ "value if set."
142
+ },
143
+ )
144
+ max_eval_samples: Optional[int] = field(
145
+ default=None,
146
+ metadata={
147
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
148
+ "value if set."
149
+ },
150
+ )
151
+ overwrite_cache: bool = field(
152
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
153
+ )
154
+ overwrite_cache: bool = field(
155
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
156
+ )
157
+ preprocessing_num_workers: Optional[int] = field(
158
+ default=None,
159
+ metadata={"help": "The number of processes to use for the preprocessing."},
160
+ )
161
+
162
+ def __post_init__(self):
163
+ if self.train_file is None and self.validation_file is None:
164
+ raise ValueError("Need either a dataset name or a training/validation file.")
165
+ else:
166
+ if self.train_file is not None:
167
+ extension = self.train_file.split(".")[-1]
168
+ assert extension == "json", "`train_file` should be a json file."
169
+ if self.validation_file is not None:
170
+ extension = self.validation_file.split(".")[-1]
171
+ assert extension == "json", "`validation_file` should be a json file."
172
+
173
+
174
+ # We use torchvision for faster image pre-processing.
175
+ # We need to ensure faster processing speed as it can become a bottleneck on TPU
176
+ class Transform(torch.nn.Module):
177
+ def __init__(self, image_size):
178
+ super().__init__()
179
+ self.transforms = torch.nn.Sequential(
180
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
181
+ CenterCrop(image_size),
182
+ GaussianBlur(3, sigma=(0.05, 0.2)),
183
+ RandomAutocontrast(p=0.5),
184
+ RandomHorizontalFlip(p=0.5),
185
+ ConvertImageDtype(torch.float),
186
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
187
+ )
188
+
189
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
190
+ with torch.no_grad():
191
+ x = self.transforms(x)
192
+ return x
193
+
194
+
195
+ class ImageTextDataset(VisionDataset):
196
+ """
197
+ Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
198
+
199
+ Args:
200
+ root: (string): The root path where the dataset is stored
201
+ file_path: (string): Path to the file containing the image_paths and associated captions.
202
+ The expected format is jsonlines where each line is a json object containing to keys.
203
+ `image_path`: The path to the image.
204
+ `captions`: An `array` of captions.
205
+ transform (callable, optional): A function/transform that takes in an PIL image
206
+ and returns a transformed version. E.g, ``transforms.ToTensor``
207
+ target_transform (callable, optional): A function/transform that takes in the
208
+ target and transforms it.
209
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
210
+ and returns a transformed version.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ root: str,
216
+ file_path: str,
217
+ captions_per_image=2,
218
+ transform: Optional[Callable] = None,
219
+ target_transform: Optional[Callable] = None,
220
+ transforms: Optional[Callable] = None,
221
+ ):
222
+ super().__init__(root, transforms, transform, target_transform)
223
+
224
+ with open(file_path, "r") as f:
225
+ examples = [json.loads(line) for line in f.readlines()]
226
+
227
+ self.captions = []
228
+ self.image_paths = []
229
+
230
+ for example in examples:
231
+ # self.captions.extend(example["captions"])
232
+ # self.image_paths.append(example["image_path"])
233
+ self.captions.extend(example["captions"][:captions_per_image])
234
+ self.image_paths.extend([example["image_path"]] * captions_per_image)
235
+
236
+ def _load_image(self, idx: int):
237
+ path = self.image_paths[idx]
238
+ return read_image(path, mode=ImageReadMode.RGB)
239
+
240
+ def _load_target(self, idx):
241
+ return self.captions[idx]
242
+
243
+ def __getitem__(self, index: int):
244
+ image = self._load_image(index)
245
+ target = self._load_target(index)
246
+
247
+ if self.transforms is not None:
248
+ image, target = self.transforms(image, target)
249
+
250
+ return image, target
251
+
252
+ def __len__(self) -> int:
253
+ return len(self.captions)
254
+
255
+
256
+ class TrainState(train_state.TrainState):
257
+ dropout_rng: jnp.ndarray
258
+
259
+ def replicate(self):
260
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
261
+
262
+
263
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
264
+ summary_writer.scalar("train_time", train_time, step)
265
+
266
+ train_metrics = get_metrics(train_metrics)
267
+ for key, vals in train_metrics.items():
268
+ tag = f"train_{key}"
269
+ for i, val in enumerate(vals):
270
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
271
+
272
+ for metric_name, value in eval_metrics.items():
273
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
274
+
275
+
276
+ def create_learning_rate_fn(
277
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
278
+ ) -> Callable[[int], jnp.array]:
279
+ """Returns a linear warmup, linear_decay learning rate function."""
280
+ steps_per_epoch = train_ds_size // train_batch_size
281
+ num_train_steps = steps_per_epoch * num_train_epochs
282
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
283
+ decay_fn = optax.linear_schedule(
284
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
285
+ )
286
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
287
+ return schedule_fn
288
+
289
+
290
+ def main():
291
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
292
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
293
+ # If we pass only one argument to the script and it's the path to a json file,
294
+ # let's parse it to get our arguments.
295
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
296
+ else:
297
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
298
+
299
+ if (
300
+ os.path.exists(training_args.output_dir)
301
+ and os.listdir(training_args.output_dir)
302
+ and training_args.do_train
303
+ and not training_args.overwrite_output_dir
304
+ ):
305
+ raise ValueError(
306
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
307
+ "Use --overwrite_output_dir to overcome."
308
+ )
309
+
310
+ # Make one log on every process with the configuration for debugging.
311
+ logging.basicConfig(
312
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
313
+ datefmt="%m/%d/%Y %H:%M:%S",
314
+ level=logging.INFO,
315
+ )
316
+ # Setup logging, we only want one process per machine to log things on the screen.
317
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
318
+ if jax.process_index() == 0:
319
+ transformers.utils.logging.set_verbosity_info()
320
+ else:
321
+ transformers.utils.logging.set_verbosity_error()
322
+
323
+ # Set the verbosity to info of the Transformers logger (on main process only):
324
+ logger.info(f"Training/evaluation parameters {training_args}")
325
+
326
+ if model_args.tokenizer_name:
327
+ tokenizer = AutoTokenizer.from_pretrained(
328
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
329
+ )
330
+ elif model_args.text_model_name_or_path:
331
+ tokenizer = AutoTokenizer.from_pretrained(
332
+ model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
333
+ )
334
+ else:
335
+ raise ValueError(
336
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
337
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
338
+ )
339
+
340
+ model = FlaxHybridCLIP.from_text_vision_pretrained(
341
+ model_args.text_model_name_or_path,
342
+ model_args.vision_model_name_or_path,
343
+ seed=training_args.seed,
344
+ dtype=getattr(jnp, model_args.dtype),
345
+ text_from_pt=model_args.from_pt,
346
+ vision_from_pt=model_args.from_pt,
347
+ )
348
+ config = model.config
349
+ # set seed for torch dataloaders
350
+ set_seed(training_args.seed)
351
+
352
+ # Initialize torchvision transforms and jit them for faster processing
353
+ preprocess = Transform(config.vision_config.image_size)
354
+ preprocess = torch.jit.script(preprocess)
355
+
356
+ # Initialize the image-text dataset
357
+ train_dataset = ImageTextDataset(
358
+ data_args.data_dir,
359
+ data_args.train_file,
360
+ captions_per_image=1,
361
+ transform=preprocess,
362
+ )
363
+
364
+ eval_dataset = ImageTextDataset(
365
+ data_args.data_dir,
366
+ data_args.validation_file,
367
+ captions_per_image=1,
368
+ transform=preprocess,
369
+ )
370
+
371
+ # Store some constant
372
+ num_epochs = int(training_args.num_train_epochs)
373
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
374
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
375
+ steps_per_epoch = len(train_dataset) // train_batch_size
376
+ total_train_steps = steps_per_epoch * num_epochs
377
+
378
+ # Use collate function to tokenizer the text and convert the processed images to numpy
379
+ def collate_fn(examples):
380
+ pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
381
+ # pixel_values = torch.stack([example[0] for example in examples]).numpy()
382
+ captions = [example[1] for example in examples]
383
+ inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np",
384
+ truncation=True)
385
+
386
+ batch = {
387
+ "pixel_values": pixel_values,
388
+ "input_ids": inputs["input_ids"],
389
+ "attention_mask": inputs["attention_mask"],
390
+ }
391
+
392
+ return batch
393
+
394
+ # Create data loaders
395
+ train_loader = torch.utils.data.DataLoader(
396
+ train_dataset,
397
+ batch_size=train_batch_size,
398
+ shuffle=True,
399
+ num_workers=data_args.preprocessing_num_workers,
400
+ persistent_workers=True,
401
+ drop_last=True,
402
+ collate_fn=collate_fn,
403
+ )
404
+
405
+ eval_loader = torch.utils.data.DataLoader(
406
+ eval_dataset,
407
+ batch_size=eval_batch_size,
408
+ shuffle=False,
409
+ num_workers=data_args.preprocessing_num_workers,
410
+ persistent_workers=True,
411
+ drop_last=True,
412
+ collate_fn=collate_fn,
413
+ )
414
+
415
+ # Enable tensorboard only on the master node
416
+ if has_tensorboard and jax.process_index() == 0:
417
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
418
+
419
+ # Initialize our training
420
+ rng = jax.random.PRNGKey(training_args.seed)
421
+ rng, dropout_rng = jax.random.split(rng)
422
+
423
+ # Create learning rate schedule
424
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
425
+ len(train_dataset),
426
+ train_batch_size,
427
+ training_args.num_train_epochs,
428
+ training_args.warmup_steps,
429
+ training_args.learning_rate,
430
+ )
431
+
432
+ # create adam optimizer
433
+ adamw = optax.adamw(
434
+ learning_rate=linear_decay_lr_schedule_fn,
435
+ b1=training_args.adam_beta1,
436
+ b2=training_args.adam_beta2,
437
+ eps=training_args.adam_epsilon,
438
+ weight_decay=training_args.weight_decay,
439
+ )
440
+
441
+ # Setup train state
442
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
443
+
444
+ def cross_entropy(logits, axis):
445
+ logprobs = jax.nn.log_softmax(logits, axis=axis)
446
+ nll = jnp.diag(logprobs)
447
+ ce = -jnp.mean(nll)
448
+ return ce
449
+
450
+ def clip_loss(similarity):
451
+ loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
452
+ return loss
453
+
454
+ # Define gradient update step fn
455
+ def train_step(state, batch):
456
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
457
+
458
+ def compute_loss(params):
459
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
460
+ loss = clip_loss(logits)
461
+ return loss
462
+
463
+ grad_fn = jax.value_and_grad(compute_loss)
464
+ loss, grad = grad_fn(state.params)
465
+ grad = jax.lax.pmean(grad, "batch")
466
+
467
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
468
+
469
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
470
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
471
+
472
+ return new_state, metrics
473
+
474
+ # Define eval fn
475
+ def eval_step(params, batch):
476
+ logits = model(**batch, params=params, train=False)[0]
477
+ loss = clip_loss(logits)
478
+
479
+ # summarize metrics
480
+ metrics = {"loss": loss}
481
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
482
+ return metrics
483
+
484
+ # Create parallel version of the train and eval step
485
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
486
+ p_eval_step = jax.pmap(eval_step, "batch")
487
+
488
+ # Replicate the train state on each device
489
+ state = state.replicate()
490
+
491
+ logger.info("***** Running training *****")
492
+ logger.info(f" Num examples = {len(train_dataset)}")
493
+ logger.info(f" Num Epochs = {num_epochs}")
494
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
495
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
496
+ logger.info(f" Total optimization steps = {total_train_steps}")
497
+
498
+ train_time = 0
499
+ # Create sampling rng
500
+ rng, input_rng = jax.random.split(rng)
501
+
502
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
503
+ for epoch in epochs:
504
+ # ======================== Training ================================
505
+ train_start = time.time()
506
+
507
+ # Create sampling rng
508
+ rng, input_rng = jax.random.split(rng)
509
+ train_metrics = []
510
+
511
+ steps_per_epoch = len(train_dataset) // train_batch_size
512
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
513
+ # train
514
+ for batch in train_loader:
515
+ batch = shard(batch)
516
+ state, train_metric = p_train_step(state, batch)
517
+ train_metrics.append(train_metric)
518
+
519
+ train_step_progress_bar.update(1)
520
+
521
+ train_time += time.time() - train_start
522
+
523
+ train_metric = unreplicate(train_metric)
524
+
525
+ train_step_progress_bar.close()
526
+ epochs.write(
527
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
528
+ )
529
+
530
+ # ======================== Evaluating ==============================
531
+ eval_metrics = []
532
+ eval_steps = len(eval_dataset) // eval_batch_size
533
+ eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
534
+ for batch in eval_loader:
535
+ # Model forward
536
+ batch = shard(batch)
537
+ metrics = p_eval_step(state.params, batch)
538
+ eval_metrics.append(metrics)
539
+
540
+ eval_step_progress_bar.update(1)
541
+
542
+ # normalize eval metrics
543
+ eval_metrics = get_metrics(eval_metrics)
544
+
545
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
546
+
547
+ # Print metrics and update progress bar
548
+ eval_step_progress_bar.close()
549
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
550
+ epochs.write(desc)
551
+ epochs.desc = desc
552
+
553
+ # Save metrics
554
+ if has_tensorboard and jax.process_index() == 0:
555
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
556
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
557
+
558
+ # save checkpoint after each epoch and push checkpoint to the hub
559
+ if jax.process_index() == 0:
560
+ params = jax.device_get(unreplicate(state.params))
561
+ model.save_pretrained(
562
+ training_args.output_dir,
563
+ params=params,
564
+ push_to_hub=training_args.push_to_hub,
565
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
566
+ )
567
+
568
+
569
+ if __name__ == "__main__":
570
+ main()
src/medclip/test_clip.ipynb ADDED
The diff for this file is too large to render. See raw diff