PhaseNet / app.py
zhuwq0's picture
spaces
300fbbb
raw
history blame
2.95 kB
import gradio as gr
import tensorflow as tf
import numpy as np
import os
import obspy
import pandas as pd
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint("./model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
def normalize(vec):
mu = np.mean(vec, axis=1, keepdims=True)
std = np.std(vec, axis=1, keepdims=True)
std[std == 0] = 1.0
vec = (vec - mu) / std
return vec
def reshape_input(vec):
if len(vec.shape) == 2:
vec = vec[np.newaxis, :, np.newaxis, :]
elif len(vec.shape) == 3:
vec = vec[np.newaxis, :, :, :]
else:
pass
return vec
# inference fn
def predict(inputs):
picks = []
for input in inputs:
file = input.name # if not isinstance(input, str) else input
mseed = obspy.read(file)
begin_time = min([tr.stats.starttime for tr in mseed])
end_time = max([tr.stats.endtime for tr in mseed])
mseed = mseed.trim(begin_time, end_time)
vec = np.asarray([tr.data for tr in mseed]).T
vec = reshape_input(vec) # (nb, nt, nsta, nch)
vec = normalize(vec)
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
preds = sess.run(model.preds, feed_dict=feed)
tmp_picks = extract_picks(
preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")]
) # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
tmp_picks = [
{
"phase_time": x["phase_time"],
"phase_index": x["phase_index"],
"phase_score": x["phase_score"],
"phase_type": x["phase_type"],
}
for x in tmp_picks
]
picks.extend(tmp_picks)
picks = pd.DataFrame(picks)
picks.to_csv("picks.csv", index=False)
return picks, "picks.csv"
# gradio components
inputs = gr.File(file_count="multiple")
# outputs = gr.outputs.JSON()
outputs = [gr.Dataframe(headers=["phase_time", "phase_score", "phase_type"]), gr.File()]
gr.Interface(
predict,
inputs=inputs,
outputs=outputs,
title="PhaseNet",
description="PhaseNet",
examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
allow_flagging="never",
).launch()
# if __name__ == "__main__":
# picks = predict(["tests/test.mseed"])
# print(picks)