import functools from absl import app from absl import flags from absl import logging from clu import metric_writers from clu import platform import flax.linen as nn import jax from ml_collections import config_flags import tensorflow as tf import sys, os from pathlib import Path # append current path to sys.path sys.path.append(str(Path(__file__).parent.parent.parent / "scenic")) import logging import flax from flax import jax_utils from flax.training import checkpoints from scenic.projects.vid2seq import models, trainer from scenic.train_lib_deprecated import train_utils from scenic import app import ml_collections import numpy as np import jax.numpy as jnp from clu import metric_writers from scenic.projects.vid2seq.datasets.dense_video_captioning_tfrecord_dataset import get_datasets from scenic.projects.vid2seq import dvc_eval MAX_CAPTION_STR_LEN = 200 MAX_KEY_STR_LEN = 400 class ScenicModel: def __init__(self, flags): self.FLAGS = flags jax.config.config_with_absl() run = (functools.partial(self._run_main, main=self._init_model)) run(self._init_model) def _run_main(self, argv, *, main): """Runs the `main` method after some initial setup.""" del argv # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') # Enable wrapping of all module calls in a named_call for easier profiling: nn.enable_named_call() logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX devices: %r', jax.devices()) # Add a note so that we can tell which task is which JAX host. # (task 0 is not guaranteed to be the host 0) platform.work_unit().set_task_status( f'host_id: {jax.process_index()}, host_count: {jax.process_count()}') if jax.process_index() == 0: platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, self.FLAGS.workdir, 'Workdir') self.FLAGS.config.dataset_configs.base_dir = self.FLAGS.data_dir self.FLAGS.config.init_from.checkpoint_path = self.FLAGS.ckpt_dir rng = jax.random.PRNGKey(self.FLAGS.config.rng_seed) logging.info('RNG: %s', rng) writer = metric_writers.create_default_writer( self.FLAGS.workdir, just_logging=jax.process_index() > 0, asynchronous=True) return main(rng=rng, config=self.FLAGS.config, workdir=self.FLAGS.workdir, writer=writer) def _init_model(self, rng: jnp.ndarray, config: ml_collections.ConfigDict, workdir: str, writer: metric_writers.MetricWriter): data_rng, rng = jax.random.split(rng) dataset_dict = get_datasets(config, data_rng=data_rng) datasets_metadata = { name: ds.meta_data for name, ds in dataset_dict.items() } all_datasets = [] all_datasets_num_train_examples = [] for name, metadata in datasets_metadata.items(): all_datasets.append(name) all_datasets_num_train_examples.append( metadata.get('num_train_examples', 0)) dataset = dataset_dict[all_datasets[0]] model_cls = models.DenseVideoCaptioningModel model = model_cls(config, dataset.meta_data) train_state, start_step = trainer.init_state(model, dataset, config, workdir, rng) self.train_state = jax_utils.replicate(train_state) logging.info('Number of processes is %s', jax.process_count()) del rng import functools self.infer_step_pmapped = jax.pmap( functools.partial( trainer.infer_step, model=model, config=config, debug=config.debug_eval), axis_name='batch', ) self.tokenizer = trainer.get_tokenizer(config) # dsname = 'validation' # self.iterator = dataset.valid_iter[dsname] self.config = config self.data_rng = data_rng def __call__(self, data_dir=None): # self.FLAGS.config.dataset_configs.base_dir = data_dir dataset_dict = get_datasets(self.config, data_rng=self.data_rng) self.iterator = dataset_dict["youcook"].valid_iter['validation'] batch = next(self.iterator) train_state = train_utils.sync_model_state_across_replicas(self.train_state) eval_packs = {} keys = [] eval_pack = { 'gts': dvc_eval.convert_strings_to_uint8_arrays( batch['caption_strings'], MAX_CAPTION_STR_LEN), 'key': dvc_eval.convert_strings_to_uint8_arrays( batch['videoid'], MAX_KEY_STR_LEN), 'batch_mask': batch['batch_mask'], 'duration': batch['duration'], 'gts_start': batch['timestamp_start'], 'gts_end': batch['timestamp_end'], 'split': batch['split'] if 'split' in batch else np.ones_like(batch['timestamp_start']), } to_del = ['caption_strings', 'key', 'videoid', 'timestamp_start', 'timestamp_end', 'split'] # 'duration', for x in to_del: if x in batch: del batch[x] # import pdb # pdb.set_trace() _, preds = self.infer_step_pmapped(train_state, batch) #model, config) # import pdb # pdb.set_trace() eval_pack['pred'] = preds eval_pack = jax.tree_map( lambda x: x.reshape((np.prod(x.shape[:2]),) + x.shape[2:]), eval_pack) vocabulary_size = self.config.dataset_configs.vocabulary_size # pred_text = trainer.decode_tokens(preds, tokenizer, vocabulary_size) # print(preds, pred_text) format_outputs = [] for i, valid in enumerate(eval_pack['batch_mask']): print("===============video[", str(0), "]====================") if valid: key = dvc_eval.convert_uint8_array_to_string(eval_pack['key'][i]) if key in eval_packs: # redundant video continue keys.append(key) pred, pred_timestamps = [], [] # get indexes in the predicted seq that delimit the pred segments indexes = [ j for j in range(len(eval_pack['pred'][i]) - 1) if eval_pack['pred'][i][j] >= vocabulary_size and eval_pack['pred'][i][j + 1] >= vocabulary_size ] # pylint: disable=g-complex-comprehension last_processed = -2 order = self.config.dataset_configs.order # iterate over predicted segments and decode them for j in range(len(indexes)): if indexes[j] == last_processed + 1: # 3 timestamps != 2 events continue # get predicted tokens and transform to string if order == 'ld': start_idx = indexes[j] + 2 end_idx = indexes[j + 1] if j < len(indexes) - 1 else len( eval_pack['pred'][i]) else: start_idx = indexes[j - 1] + 2 if j > 0 else 0 end_idx = indexes[j] pred_seq = [int(eval_pack['pred'][i][k]) for k in range(start_idx, end_idx)] pred_text = trainer.decode_tokens(pred_seq, self.tokenizer, vocabulary_size) # get start and end num_bins = 100 # from config max_offset = num_bins - 1 pred_time = [ (int(eval_pack['pred'][i][indexes[j]]) - vocabulary_size) * eval_pack['duration'][i] / max_offset, (int(eval_pack['pred'][i][indexes[j] + 1]) - vocabulary_size) * eval_pack['duration'][i] / max_offset ] # if pred_time[1] <= pred_time[0]: # remove end < start # continue last_processed = indexes[j] pred.append(pred_text) pred_timestamps.append(pred_time) # round to 2 decimal places format_output = "[{x}s, {y}s] ".format(x=np.around(pred_time[0][0]/1000000, decimals=2), y=np.around(pred_time[1][0]/1000000, decimals=2)) format_output += pred_text format_outputs.append(format_output) print(format_outputs) print("===============================================") return format_outputs class ScenicCall: def __init__(self, main, flags): self.main = main self.FLAGS = flags def __call__(self): return self.run() def run(self): # Provide access to --jax_backend_target and --jax_xla_backend flags. jax.config.config_with_absl() run = (functools.partial(self._run_main, main=self.main)) return run(self.main) def _run_main(self, argv, *, main): """Runs the `main` method after some initial setup.""" del argv # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') # Enable wrapping of all module calls in a named_call for easier profiling: nn.enable_named_call() logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX devices: %r', jax.devices()) # Add a note so that we can tell which task is which JAX host. # (task 0 is not guaranteed to be the host 0) platform.work_unit().set_task_status( f'host_id: {jax.process_index()}, host_count: {jax.process_count()}') if jax.process_index() == 0: platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, self.FLAGS.workdir, 'Workdir') self.FLAGS.config.dataset_configs.base_dir = self.FLAGS.data_dir rng = jax.random.PRNGKey(self.FLAGS.config.rng_seed) logging.info('RNG: %s', rng) writer = metric_writers.create_default_writer( self.FLAGS.workdir, just_logging=jax.process_index() > 0, asynchronous=True) return main(rng=rng, config=self.FLAGS.config, workdir=self.FLAGS.workdir, writer=writer)