import argparse import logging import multiprocessing import os import time from functools import partial import numpy as np import tensorflow as tf from tqdm import tqdm from data_reader import DataReader_pred, normalize_batch from model import UNet from util import * tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) def read_args(): """Returns args""" parser = argparse.ArgumentParser() parser.add_argument("--format", default="numpy", type=str, help="Input data format: numpy or mseed") parser.add_argument("--batch_size", default=20, type=int, help="Batch size") parser.add_argument("--output_dir", default="output", help="Output directory (default: output)") parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)") parser.add_argument("--sampling_rate", default=100, type=int, help="sampling rate of pred data") parser.add_argument("--data_dir", default="./Dataset/pred/", help="Input file directory") parser.add_argument("--data_list", default="./Dataset/pred.csv", help="Input csv file") parser.add_argument("--plot_figure", action="store_true", help="If plot figure") parser.add_argument("--save_signal", action="store_true", help="If save denoised signal") parser.add_argument("--save_noise", action="store_true", help="If save denoised noise") args = parser.parse_args() return args def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None): current_time = time.strftime("%y%m%d-%H%M%S") if log_dir is None: log_dir = os.path.join(args.log_dir, "pred", current_time) logging.info("Pred log: %s" % log_dir) # logging.info("Dataset size: {}".format(data_reader.num_data)) if not os.path.exists(log_dir): os.makedirs(log_dir) if args.plot_figure: figure_dir = os.path.join(log_dir, 'figures') os.makedirs(figure_dir, exist_ok=True) if args.save_signal or args.save_noise: result_dir = os.path.join(log_dir, 'results') os.makedirs(result_dir, exist_ok=True) with tf.compat.v1.name_scope('Input_Batch'): data_batch = data_reader.dataset(args.batch_size) # model = UNet(input_batch=data_batch, mode='pred') model = UNet(mode='pred') sess_config = tf.compat.v1.ConfigProto() sess_config.gpu_options.allow_growth = True # sess_config.log_device_placement = False with tf.compat.v1.Session(config=sess_config) as sess: saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) init = tf.compat.v1.global_variables_initializer() sess.run(init) latest_check_point = tf.train.latest_checkpoint(args.model_dir) logging.info(f"restoring models: {latest_check_point}") saver.restore(sess, latest_check_point) if args.plot_figure: num_pool = multiprocessing.cpu_count() else: num_pool = 2 multiprocessing.set_start_method('spawn') pool = multiprocessing.Pool(num_pool) for _ in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"): X_batch, fname_batch, t0_batch = sess.run(data_batch) nbt, nch, nst, nf, nt, nimg = X_batch.shape X_batch_ = np.reshape(X_batch, [nbt * nch * nst, nf, nt, nimg]) X_batch_ = normalize_batch(X_batch_) preds_batch = sess.run( model.preds, feed_dict={model.X: X_batch_, model.drop_rate: 0, model.is_training: False}, ) preds_batch = np.reshape(preds_batch, [nbt, nch, nst, nf, nt, preds_batch.shape[-1]]) # preds_batch, X_batch, ratio_batch, fname_batch = sess.run( # [model.preds, data_batch[0], data_batch[1], data_batch[2]], # feed_dict={model.drop_rate: 0, model.is_training: False}, # ) if args.save_signal or args.save_noise: save_results( preds_batch, X_batch, fname=[x.decode() for x in fname_batch], t0=[x.decode() for x in t0_batch], save_signal=args.save_signal, save_noise=args.save_noise, result_dir=result_dir, ) if args.plot_figure: pool.starmap( partial( plot_figures, figure_dir=figure_dir, ), zip(preds_batch, X_batch, [x.decode() for x in fname_batch]), ) pool.close() return 0 def main(args): logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) with tf.compat.v1.name_scope('create_inputs'): data_reader = DataReader_pred( format=args.format, signal_dir=args.data_dir, signal_list=args.data_list, sampling_rate=args.sampling_rate ) logging.info("Dataset Size: {}".format(data_reader.n_signal)) pred_fn(args, data_reader, log_dir=args.output_dir) return 0 if __name__ == '__main__': args = read_args() main(args)