import json import os from pathlib import Path import gradio as gr import numpy as np import obspy import pandas as pd import tensorflow as tf from phasenet.model import ModelConfig, UNet from phasenet.postprocess import extract_picks tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) ## load model tf.compat.v1.reset_default_graph() model = UNet(mode="pred") sess_config = tf.compat.v1.ConfigProto() sess_config.gpu_options.allow_growth = True sess = tf.compat.v1.Session(config=sess_config) saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) init = tf.compat.v1.global_variables_initializer() sess.run(init) latest_check_point = tf.train.latest_checkpoint("./model/190703-214543") print(f"restoring model {latest_check_point}") saver.restore(sess, latest_check_point) def normalize(vec): mu = np.mean(vec, axis=1, keepdims=True) std = np.std(vec, axis=1, keepdims=True) std[std == 0] = 1.0 vec = (vec - mu) / std return vec def reshape_input(vec): if len(vec.shape) == 2: vec = vec[np.newaxis, :, np.newaxis, :] elif len(vec.shape) == 3: vec = vec[np.newaxis, :, :, :] else: pass return vec # inference fn def predict(mseeds=[], waveforms="", stations=""): if len(stations) > 0: stations = json.loads(stations) print(f"{len(stations)}: {stations = }") if len(waveforms) > 0: waveforms = json.loads(waveforms) waveforms = np.array(waveforms) print(f"{waveforms.shape = }") picks = [] if mseeds is None: mseeds = [] for mseed in mseeds: file = mseed.name # if not isinstance(input, str) else input mseed = obspy.read(file) begin_time = min([tr.stats.starttime for tr in mseed]) end_time = max([tr.stats.endtime for tr in mseed]) mseed = mseed.trim(begin_time, end_time) vec = np.asarray([tr.data for tr in mseed]).T vec = reshape_input(vec) # (nb, nt, nsta, nch) vec = normalize(vec) feed = {model.X: vec, model.drop_rate: 0, model.is_training: False} preds = sess.run(model.preds, feed_dict=feed) tmp_picks = extract_picks( preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")] ) # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw) tmp_picks = [ { "phase_time": x["phase_time"], "phase_index": x["phase_index"], "phase_score": x["phase_score"], "phase_type": x["phase_type"], } for x in tmp_picks ] picks.extend(tmp_picks) if len(picks) > 0: picks_df = pd.DataFrame(picks) picks_df.to_csv("picks.csv", index=False) else: picks_df = None os.system("touch picks.csv") return picks_df, "picks.csv", json.dumps(picks) inputs = [ gr.File(label="mseeds", file_count="multiple"), gr.Textbox(label="waveform", visible=False), gr.Textbox(label="stations", visible=False), ] outputs = [ gr.Dataframe(label="picks", headers=["phase_time", "phase_score", "phase_type"]), gr.File(label="csv"), gr.Textbox(label="json", visible=False), ] gr.Interface( predict, inputs=inputs, outputs=outputs, title="PhaseNet", description="PhaseNet", examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]], allow_flagging="never", ).queue().launch() # if __name__ == "__main__": # picks = predict(["tests/test.mseed"]) # print(picks)