Spaces:
Running
Running
#import warnings | |
#warnings.filterwarnings('ignore', category=FutureWarning) | |
import numpy as np | |
import tensorflow as tf | |
tf.compat.v1.disable_eager_execution() | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
import argparse | |
import os | |
import time | |
import logging | |
from model import UNet | |
from data_reader import * | |
from util import * | |
from tqdm import tqdm | |
import multiprocessing | |
from functools import partial | |
def read_args(): | |
"""Returns args""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", | |
default="train", | |
help="train/valid/test/debug (default: train)") | |
parser.add_argument("--epochs", | |
default=10, | |
type=int, | |
help="Number of epochs (default: 10)") | |
parser.add_argument("--batch_size", | |
default=20, | |
type=int, | |
help="Batch size (default: 20)") | |
parser.add_argument("--learning_rate", | |
default=0.001, | |
type=float, | |
help="learning rate (default: 0.001)") | |
parser.add_argument("--decay_step", | |
default=-1, | |
type=int, | |
help="decay step (default: -1)") | |
parser.add_argument("--decay_rate", | |
default=0.9, | |
type=float, | |
help="decay rate (default: 0.9)") | |
parser.add_argument("--momentum", | |
default=0.9, | |
type=float, | |
help="momentum (default: 0.9)") | |
parser.add_argument("--filters_root", | |
default=8, | |
type=int, | |
help="filters root (default: 8)") | |
parser.add_argument("--depth", | |
default=6, | |
type=int, | |
help="depth (default: 6)") | |
parser.add_argument("--kernel_size", | |
nargs="+", | |
type=int, | |
default=[3, 3], | |
help="kernel size (default: [3, 3]") | |
parser.add_argument("--pool_size", | |
nargs="+", | |
type=int, | |
default=[2, 2], | |
help="pool size (default: [2, 2]") | |
parser.add_argument("--drop_rate", | |
default=0, | |
type=float, | |
help="drop out rate (default: 0)") | |
parser.add_argument("--dilation_rate", | |
nargs="+", | |
type=int, | |
default=[1, 1], | |
help="dilation_rate (default: [1, 1]") | |
parser.add_argument("--loss_type", | |
default="cross_entropy", | |
help="loss type: cross_entropy, IOU, mean_squared (default: cross_entropy)") | |
parser.add_argument("--weight_decay", | |
default=0, | |
type=float, | |
help="weight decay (default: 0)") | |
parser.add_argument("--optimizer", | |
default="adam", | |
help="optimizer: adam, momentum (default: adam)") | |
parser.add_argument("--summary", | |
default=True, | |
type=bool, | |
help="summary (default: True)") | |
parser.add_argument("--class_weights", | |
nargs="+", | |
default=[1, 1], | |
type=float, | |
help="class weights (default: [1, 1]") | |
parser.add_argument("--log_dir", | |
default="log", | |
help="Tensorboard log directory (default: log)") | |
parser.add_argument("--model_dir", | |
default=None, | |
help="Checkpoint directory") | |
parser.add_argument("--num_plots", | |
default=10, | |
type=int, | |
help="plotting trainning result (default: 10)") | |
parser.add_argument("--input_length", | |
default=None, | |
type=int, | |
help="input length") | |
parser.add_argument("--sampling_rate", | |
default=100, | |
type=int, | |
help="sampling rate of pred data in Hz (default: 100)") | |
parser.add_argument("--train_signal_dir", | |
default="./Dataset/train/", | |
help="Input file directory (default: ./Dataset/train/)") | |
parser.add_argument("--train_signal_list", | |
default="./Dataset/train.csv", | |
help="Input csv file (default: ./Dataset/train.csv)") | |
parser.add_argument("--train_noise_dir", | |
default="./Dataset/train/", | |
help="Input file directory (default: ./Dataset/train/)") | |
parser.add_argument("--train_noise_list", | |
default="./Dataset/train.csv", | |
help="Input csv file (default: ./Dataset/train.csv)") | |
parser.add_argument("--valid_signal_dir", | |
default="./Dataset/", | |
help="Input file directory (default: ./Dataset/)") | |
parser.add_argument("--valid_signal_list", | |
default=None, | |
help="Input csv file") | |
parser.add_argument("--valid_noise_dir", | |
default="./Dataset/", | |
help="Input file directory (default: ./Dataset/)") | |
parser.add_argument("--valid_noise_list", | |
default=None, | |
help="Input csv file") | |
parser.add_argument("--data_dir", | |
default="./Dataset/pred/", | |
help="Input file directory (default: ./Dataset/pred/)") | |
parser.add_argument("--data_list", | |
default="./Dataset/pred.csv", | |
help="Input csv file (default: ./Dataset/pred.csv)") | |
parser.add_argument("--output_dir", | |
default=None, | |
help="Output directory") | |
parser.add_argument("--fpred", | |
default="preds.npz", | |
help="ouput file name of test data") | |
parser.add_argument("--plot_figure", | |
action="store_true", | |
help="If plot figure for test") | |
parser.add_argument("--save_result", | |
action="store_true", | |
help="If save result for test") | |
args = parser.parse_args() | |
return args | |
def set_config(args, data_reader): | |
config = Config() | |
config.X_shape = data_reader.X_shape | |
config.n_channel = config.X_shape[-1] | |
config.Y_shape = data_reader.Y_shape | |
config.n_class = config.Y_shape[-1] | |
config.depths = args.depth | |
config.filters_root = args.filters_root | |
config.kernel_size = args.kernel_size | |
config.pool_size = args.pool_size | |
config.dilation_rate = args.dilation_rate | |
config.batch_size = args.batch_size | |
config.class_weights = args.class_weights | |
config.loss_type = args.loss_type | |
config.weight_decay = args.weight_decay | |
config.optimizer = args.optimizer | |
config.learning_rate = args.learning_rate | |
if (args.decay_step == -1) and (args.mode == 'train'): | |
config.decay_step = data_reader.n_signal // args.batch_size | |
else: | |
config.decay_step = args.decay_step | |
config.decay_rate = args.decay_rate | |
config.momentum = args.momentum | |
config.summary = args.summary | |
config.drop_rate = args.drop_rate | |
config.class_weights = args.class_weights | |
return config | |
def train_fn(args, data_reader, data_reader_valid=None): | |
current_time = time.strftime("%y%m%d-%H%M%S") | |
log_dir = os.path.join(args.log_dir, current_time) | |
logging.info("Training log: {}".format(log_dir)) | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
figure_dir = os.path.join(log_dir, 'figures') | |
if not os.path.exists(figure_dir): | |
os.makedirs(figure_dir) | |
config = set_config(args, data_reader) | |
with open(os.path.join(log_dir, 'config.log'), 'w') as fp: | |
fp.write('\n'.join("%s: %s" % item for item in vars(config).items())) | |
with tf.compat.v1.name_scope('Input_Batch'): | |
batch = data_reader.dequeue(args.batch_size) | |
if data_reader_valid is not None: | |
batch_valid = data_reader_valid.dequeue(args.batch_size) | |
model = UNet(config) | |
sess_config = tf.compat.v1.ConfigProto() | |
sess_config.gpu_options.allow_growth = True | |
sess_config.log_device_placement = False | |
with tf.compat.v1.Session(config=sess_config) as sess: | |
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph) | |
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5) | |
init = tf.compat.v1.global_variables_initializer() | |
sess.run(init) | |
if args.model_dir is not None: | |
logging.info("restoring models...") | |
latest_check_point = tf.train.latest_checkpoint(args.model_dir) | |
saver.restore(sess, latest_check_point) | |
model.reset_learning_rate(sess, learning_rate=0.01, global_step=0) | |
threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count()) | |
if data_reader_valid is not None: | |
threads_valid = data_reader_valid.start_threads(sess, n_threads=multiprocessing.cpu_count()) | |
flog = open(os.path.join(log_dir, 'loss.log'), 'w') | |
total_step = 0 | |
mean_loss = 0 | |
pool = multiprocessing.Pool(2) | |
for epoch in range(args.epochs): | |
progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc="{}: ".format(log_dir.split("/")[-1])) | |
for step in progressbar: | |
X_batch, Y_batch = sess.run(batch) | |
loss_batch = model.train_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate) | |
if epoch < 1: | |
mean_loss = loss_batch | |
else: | |
total_step += 1 | |
mean_loss += (loss_batch-mean_loss)/total_step | |
progressbar.set_description("{}: epoch={}, loss={:.6f}, mean loss={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, mean_loss)) | |
flog.write("Epoch: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss)) | |
saver.save(sess, os.path.join(log_dir, "model_{}.ckpt".format(epoch))) | |
## valid | |
if data_reader_valid is not None: | |
mean_loss_valid = 0 | |
total_step_valid = 0 | |
progressbar = tqdm(range(0, data_reader_valid.n_signal, args.batch_size), desc="Valid: ") | |
for step in progressbar: | |
X_batch, Y_batch = sess.run(batch_valid) | |
loss_batch, preds_batch = model.valid_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate) | |
total_step_valid += 1 | |
mean_loss_valid += (loss_batch-mean_loss_valid)/total_step_valid | |
progressbar.set_description("Valid: loss={:.6f}, mean loss={:.6f}".format(loss_batch, mean_loss_valid)) | |
flog.write("Valid: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss_valid)) | |
# plot_result(epoch, args.num_plots, figure_dir, preds_batch, X_batch, Y_batch) | |
pool.map(partial(plot_result_thread, | |
epoch = epoch, | |
preds = preds_batch, | |
X = X_batch, | |
Y = Y_batch, | |
figure_dir = figure_dir), | |
range(args.num_plots)) | |
flog.close() | |
pool.close() | |
data_reader.coord.request_stop() | |
if data_reader_valid is not None: | |
data_reader_valid.coord.request_stop() | |
try: | |
data_reader.coord.join(threads, stop_grace_period_secs=10, ignore_live_threads=True) | |
if data_reader_valid is not None: | |
data_reader_valid.coord.join(threads_valid, stop_grace_period_secs=10, ignore_live_threads=True) | |
except: | |
pass | |
sess.run(data_reader.queue.close(cancel_pending_enqueues=True)) | |
if data_reader_valid is not None: | |
sess.run(data_reader_valid.queue.close(cancel_pending_enqueues=True)) | |
return 0 | |
def test_fn(args, data_reader, figure_dir=None, result_dir=None): | |
current_time = time.strftime("%y%m%d-%H%M%S") | |
log_dir = os.path.join(args.log_dir, args.mode, current_time) | |
logging.info("{} log: {}".format(args.mode, log_dir)) | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
if (args.plot_figure == True) and (figure_dir is None): | |
figure_dir = os.path.join(log_dir, 'figures') | |
if not os.path.exists(figure_dir): | |
os.makedirs(figure_dir) | |
if (args.save_result == True) and (result_dir is None): | |
result_dir = os.path.join(log_dir, 'results') | |
if not os.path.exists(result_dir): | |
os.makedirs(result_dir) | |
config = set_config(args, data_reader) | |
with open(os.path.join(log_dir, 'config.log'), 'w') as fp: | |
fp.write('\n'.join("%s: %s" % item for item in vars(config).items())) | |
with tf.compat.v1.name_scope('Input_Batch'): | |
batch = data_reader.dequeue(args.batch_size) | |
model = UNet(config, input_batch=batch, mode='test') | |
sess_config = tf.compat.v1.ConfigProto() | |
sess_config.gpu_options.allow_growth = True | |
sess_config.log_device_placement = False | |
with tf.compat.v1.Session(config=sess_config) as sess: | |
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph) | |
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5) | |
init = tf.compat.v1.global_variables_initializer() | |
sess.run(init) | |
logging.info("restoring models...") | |
latest_check_point = tf.train.latest_checkpoint(args.model_dir) | |
saver.restore(sess, latest_check_point) | |
threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count()) | |
flog = open(os.path.join(log_dir, 'loss.log'), 'w') | |
total_step = 0 | |
mean_loss = 0 | |
progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc=args.mode) | |
if args.plot_figure: | |
num_pool = multiprocessing.cpu_count()*2 | |
elif args.save_result: | |
num_pool = multiprocessing.cpu_count() | |
else: | |
num_pool = 2 | |
pool = multiprocessing.Pool(num_pool) | |
for step in progressbar: | |
if step + args.batch_size >= data_reader.n_signal: | |
for t in threads: | |
t.join() | |
sess.run(data_reader.queue.close()) | |
loss_batch, preds_batch, X_batch, Y_batch, ratio_batch, \ | |
signal_batch, noise_batch, fname_batch = model.test_on_batch(sess, summary_writer) | |
total_step += 1 | |
mean_loss += (loss_batch-mean_loss)/total_step | |
progressbar.set_description("{}: loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, mean_loss)) | |
flog.write("step: {}, loss: {}\n".format(step, loss_batch)) | |
flog.flush() | |
pool.map(partial(postprocessing_test, | |
preds=preds_batch, | |
X=X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis], | |
fname=fname_batch, | |
figure_dir=figure_dir, | |
result_dir=result_dir, | |
signal_FT=signal_batch, | |
noise_FT=noise_batch), | |
range(len(X_batch))) | |
flog.close() | |
pool.close() | |
return 0 | |
def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None): | |
current_time = time.strftime("%y%m%d-%H%M%S") | |
if log_dir is None: | |
log_dir = os.path.join(args.log_dir, "pred", current_time) | |
logging.info("Pred log: %s" % log_dir) | |
# logging.info("Dataset size: {}".format(data_reader.num_data)) | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
if args.plot_figure: | |
figure_dir = os.path.join(log_dir, 'figures') | |
os.makedirs(figure_dir, exist_ok=True) | |
if args.save_result: | |
result_dir = os.path.join(log_dir, 'results') | |
os.makedirs(result_dir, exist_ok=True) | |
config = set_config(args, data_reader) | |
with open(os.path.join(log_dir, 'config.log'), 'w') as fp: | |
fp.write('\n'.join("%s: %s" % item for item in vars(config).items())) | |
with tf.compat.v1.name_scope('Input_Batch'): | |
data_batch = data_reader.dataset(args.batch_size) | |
# model = UNet(config, input_batch=batch, mode='pred') | |
model = UNet(config, mode='pred') | |
sess_config = tf.compat.v1.ConfigProto() | |
sess_config.gpu_options.allow_growth = True | |
#sess_config.log_device_placement = False | |
with tf.compat.v1.Session(config=sess_config) as sess: | |
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) | |
init = tf.compat.v1.global_variables_initializer() | |
sess.run(init) | |
logging.info("restoring models...") | |
latest_check_point = tf.train.latest_checkpoint(args.model_dir) | |
saver.restore(sess, latest_check_point) | |
# threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count()) | |
if args.plot_figure: | |
num_pool = multiprocessing.cpu_count() | |
elif args.save_result: | |
num_pool = multiprocessing.cpu_count() | |
else: | |
num_pool = 2 | |
multiprocessing.set_start_method('spawn') | |
pool = multiprocessing.Pool(num_pool) | |
for step in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"): | |
#if step + args.batch_size >= data_reader.n_signal: | |
# for t in threads: | |
# t.join() | |
# sess.run(data_reader.queue.close()) | |
# X_batch = [] | |
# ratio_batch = [] | |
# fname_batch = [] | |
# for i in range(step, min(step+args.batch_size, data_reader.n_signal)): | |
# X, ratio, fname = data_reader[i] | |
# if np.std(X) == 0: | |
# continue | |
# X_batch.append(X) | |
# ratio_batch.append(ratio) | |
# fname_batch.append(fname) | |
# X_batch = np.stack(X_batch, axis=0) | |
# ratio_batch = np.array(ratio_batch) | |
X_batch, ratio_batch, fname_batch = sess.run(data_batch) | |
preds_batch = sess.run(model.preds, feed_dict={model.X: X_batch, | |
model.drop_rate: 0, | |
model.is_training: False}) | |
#preds_batch, X_batch, ratio_batch, fname_batch = sess.run([model.preds, | |
# batch[0], | |
# batch[1], | |
# batch[2]], | |
# feed_dict={model.drop_rate: 0, | |
# model.is_training: False}) | |
pool.map(partial(postprocessing_pred, | |
preds = preds_batch, | |
X = X_batch*ratio_batch[:,np.newaxis,:,np.newaxis], | |
fname = [x.decode() for x in fname_batch], | |
figure_dir = figure_dir, | |
result_dir = result_dir), | |
range(len(X_batch))) | |
# for i in range(len(X_batch)): | |
# postprocessing_thread(i, | |
# preds = preds_batch, | |
# X = X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis], | |
# fname = fname_batch, | |
# figure_dir = figure_dir, | |
# result_dir = result_dir) | |
pool.close() | |
return 0 | |
def main(args): | |
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) | |
coord = tf.train.Coordinator() | |
if args.mode == "train": | |
with tf.compat.v1.name_scope('create_inputs'): | |
data_reader = DataReader( | |
signal_dir=args.train_signal_dir, | |
signal_list=args.train_signal_list, | |
noise_dir=args.train_noise_dir, | |
noise_list=args.train_noise_list, | |
queue_size=args.batch_size*2, | |
coord=coord) | |
if (args.valid_signal_list is not None) and (args.valid_noise_list is not None): | |
data_reader_valid = DataReader( | |
signal_dir=args.valid_signal_dir, | |
signal_list=args.valid_signal_list, | |
noise_dir=args.valid_noise_dir, | |
noise_list=args.valid_noise_list, | |
queue_size=args.batch_size*2, | |
coord=coord) | |
logging.info("Dataset size: training %d, validation %d" % (data_reader.n_signal, data_reader_valid.n_signal)) | |
else: | |
data_reader_valid = None | |
logging.info("Dataset size: training %d, validation 0" % (data_reader.n_signal)) | |
train_fn(args, data_reader, data_reader_valid) | |
elif args.mode == "valid" or args.mode == "test": | |
with tf.compat.v1.name_scope('create_inputs'): | |
data_reader = DataReader_test( | |
signal_dir=args.valid_signal_dir, | |
signal_list=args.valid_signal_list, | |
noise_dir=args.valid_noise_dir, | |
noise_list=args.valid_noise_list, | |
queue_size=args.batch_size*2, | |
coord=coord) | |
logging.info("Dataset Size: {}".format(data_reader.n_signal)) | |
test_fn(args, data_reader) | |
elif args.mode == "pred": | |
with tf.compat.v1.name_scope('create_inputs'): | |
data_reader = DataReader_pred( | |
signal_dir=args.data_dir, | |
signal_list=args.data_list, | |
sampling_rate=args.sampling_rate) | |
logging.info("Dataset Size: {}".format(data_reader.n_signal)) | |
pred_fn(args, data_reader, log_dir=args.output_dir) | |
else: | |
print("mode should be: train, valid, test, debug or pred") | |
coord.request_stop() | |
coord.join() | |
return 0 | |
if __name__ == '__main__': | |
args = read_args() | |
main(args) | |