File size: 2,949 Bytes
300fbbb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import tensorflow as tf
import numpy as np
import os
import obspy
import pandas as pd
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint("./model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
def normalize(vec):
mu = np.mean(vec, axis=1, keepdims=True)
std = np.std(vec, axis=1, keepdims=True)
std[std == 0] = 1.0
vec = (vec - mu) / std
return vec
def reshape_input(vec):
if len(vec.shape) == 2:
vec = vec[np.newaxis, :, np.newaxis, :]
elif len(vec.shape) == 3:
vec = vec[np.newaxis, :, :, :]
else:
pass
return vec
# inference fn
def predict(inputs):
picks = []
for input in inputs:
file = input.name # if not isinstance(input, str) else input
mseed = obspy.read(file)
begin_time = min([tr.stats.starttime for tr in mseed])
end_time = max([tr.stats.endtime for tr in mseed])
mseed = mseed.trim(begin_time, end_time)
vec = np.asarray([tr.data for tr in mseed]).T
vec = reshape_input(vec) # (nb, nt, nsta, nch)
vec = normalize(vec)
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
preds = sess.run(model.preds, feed_dict=feed)
tmp_picks = extract_picks(
preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")]
) # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
tmp_picks = [
{
"phase_time": x["phase_time"],
"phase_index": x["phase_index"],
"phase_score": x["phase_score"],
"phase_type": x["phase_type"],
}
for x in tmp_picks
]
picks.extend(tmp_picks)
picks = pd.DataFrame(picks)
picks.to_csv("picks.csv", index=False)
return picks, "picks.csv"
# gradio components
inputs = gr.File(file_count="multiple")
# outputs = gr.outputs.JSON()
outputs = [gr.Dataframe(headers=["phase_time", "phase_score", "phase_type"]), gr.File()]
gr.Interface(
predict,
inputs=inputs,
outputs=outputs,
title="PhaseNet",
description="PhaseNet",
examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
allow_flagging="never",
).launch()
# if __name__ == "__main__":
# picks = predict(["tests/test.mseed"])
# print(picks)
|