PhaseNet / app.py
zhuwq0's picture
add default values
cc4f687
import json
import os
from pathlib import Path
import gradio as gr
import numpy as np
import obspy
import pandas as pd
import tensorflow as tf
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint("./model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
def normalize(vec):
mu = np.mean(vec, axis=1, keepdims=True)
std = np.std(vec, axis=1, keepdims=True)
std[std == 0] = 1.0
vec = (vec - mu) / std
return vec
def reshape_input(vec):
if len(vec.shape) == 2:
vec = vec[np.newaxis, :, np.newaxis, :]
elif len(vec.shape) == 3:
vec = vec[np.newaxis, :, :, :]
else:
pass
return vec
# inference fn
def predict(mseeds=[], waveforms="", stations=""):
if len(stations) > 0:
stations = json.loads(stations)
print(f"{len(stations)}: {stations = }")
if len(waveforms) > 0:
waveforms = json.loads(waveforms)
waveforms = np.array(waveforms)
print(f"{waveforms.shape = }")
picks = []
if mseeds is None:
mseeds = []
for mseed in mseeds:
file = mseed.name # if not isinstance(input, str) else input
mseed = obspy.read(file)
begin_time = min([tr.stats.starttime for tr in mseed])
end_time = max([tr.stats.endtime for tr in mseed])
mseed = mseed.trim(begin_time, end_time)
vec = np.asarray([tr.data for tr in mseed]).T
vec = reshape_input(vec) # (nb, nt, nsta, nch)
vec = normalize(vec)
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
preds = sess.run(model.preds, feed_dict=feed)
tmp_picks = extract_picks(
preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")]
) # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
tmp_picks = [
{
"phase_time": x["phase_time"],
"phase_index": x["phase_index"],
"phase_score": x["phase_score"],
"phase_type": x["phase_type"],
}
for x in tmp_picks
]
picks.extend(tmp_picks)
if len(picks) > 0:
picks_df = pd.DataFrame(picks)
picks_df.to_csv("picks.csv", index=False)
else:
picks_df = None
os.system("touch picks.csv")
return picks_df, "picks.csv", json.dumps(picks)
inputs = [
gr.File(label="mseeds", file_count="multiple"),
gr.Textbox(label="waveform", visible=False),
gr.Textbox(label="stations", visible=False),
]
outputs = [
gr.Dataframe(label="picks", headers=["phase_time", "phase_score", "phase_type"]),
gr.File(label="csv"),
gr.Textbox(label="json", visible=False),
]
gr.Interface(
predict,
inputs=inputs,
outputs=outputs,
title="PhaseNet",
description="PhaseNet",
examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
allow_flagging="never",
).queue().launch()
# if __name__ == "__main__":
# picks = predict(["tests/test.mseed"])
# print(picks)