import gradio as gr import tensorflow as tf import numpy as np import os import obspy import pandas as pd from phasenet.model import ModelConfig, UNet from phasenet.postprocess import extract_picks tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) ## load model tf.compat.v1.reset_default_graph() model = UNet(mode="pred") sess_config = tf.compat.v1.ConfigProto() sess_config.gpu_options.allow_growth = True sess = tf.compat.v1.Session(config=sess_config) saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) init = tf.compat.v1.global_variables_initializer() sess.run(init) latest_check_point = tf.train.latest_checkpoint("./model/190703-214543") print(f"restoring model {latest_check_point}") saver.restore(sess, latest_check_point) def normalize(vec): mu = np.mean(vec, axis=1, keepdims=True) std = np.std(vec, axis=1, keepdims=True) std[std == 0] = 1.0 vec = (vec - mu) / std return vec def reshape_input(vec): if len(vec.shape) == 2: vec = vec[np.newaxis, :, np.newaxis, :] elif len(vec.shape) == 3: vec = vec[np.newaxis, :, :, :] else: pass return vec # inference fn def predict(inputs): picks = [] for input in inputs: file = input.name # if not isinstance(input, str) else input mseed = obspy.read(file) begin_time = min([tr.stats.starttime for tr in mseed]) end_time = max([tr.stats.endtime for tr in mseed]) mseed = mseed.trim(begin_time, end_time) vec = np.asarray([tr.data for tr in mseed]).T vec = reshape_input(vec) # (nb, nt, nsta, nch) vec = normalize(vec) feed = {model.X: vec, model.drop_rate: 0, model.is_training: False} preds = sess.run(model.preds, feed_dict=feed) tmp_picks = extract_picks( preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")] ) # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw) tmp_picks = [ { "phase_time": x["phase_time"], "phase_index": x["phase_index"], "phase_score": x["phase_score"], "phase_type": x["phase_type"], } for x in tmp_picks ] picks.extend(tmp_picks) picks = pd.DataFrame(picks) picks.to_csv("picks.csv", index=False) return picks, "picks.csv" # gradio components inputs = gr.File(file_count="multiple") # outputs = gr.outputs.JSON() outputs = [gr.Dataframe(headers=["phase_time", "phase_score", "phase_type"]), gr.File()] gr.Interface( predict, inputs=inputs, outputs=outputs, title="PhaseNet", description="PhaseNet", examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]], allow_flagging="never", ).launch() # if __name__ == "__main__": # picks = predict(["tests/test.mseed"]) # print(picks)