File size: 3,617 Bytes
34dbcc7 300fbbb 34dbcc7 300fbbb 82d1974 34dbcc7 300fbbb 4da07fe 34dbcc7 300fbbb 4da07fe c8f6682 34dbcc7 300fbbb 34dbcc7 cc4f687 34dbcc7 300fbbb cc4f687 300fbbb 34dbcc7 300fbbb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import json
import os
from pathlib import Path
import gradio as gr
import numpy as np
import obspy
import pandas as pd
import tensorflow as tf
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint("./model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
def normalize(vec):
mu = np.mean(vec, axis=1, keepdims=True)
std = np.std(vec, axis=1, keepdims=True)
std[std == 0] = 1.0
vec = (vec - mu) / std
return vec
def reshape_input(vec):
if len(vec.shape) == 2:
vec = vec[np.newaxis, :, np.newaxis, :]
elif len(vec.shape) == 3:
vec = vec[np.newaxis, :, :, :]
else:
pass
return vec
# inference fn
def predict(mseeds=[], waveforms="", stations=""):
if len(stations) > 0:
stations = json.loads(stations)
print(f"{len(stations)}: {stations = }")
if len(waveforms) > 0:
waveforms = json.loads(waveforms)
waveforms = np.array(waveforms)
print(f"{waveforms.shape = }")
picks = []
if mseeds is None:
mseeds = []
for mseed in mseeds:
file = mseed.name # if not isinstance(input, str) else input
mseed = obspy.read(file)
begin_time = min([tr.stats.starttime for tr in mseed])
end_time = max([tr.stats.endtime for tr in mseed])
mseed = mseed.trim(begin_time, end_time)
vec = np.asarray([tr.data for tr in mseed]).T
vec = reshape_input(vec) # (nb, nt, nsta, nch)
vec = normalize(vec)
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
preds = sess.run(model.preds, feed_dict=feed)
tmp_picks = extract_picks(
preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")]
) # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
tmp_picks = [
{
"phase_time": x["phase_time"],
"phase_index": x["phase_index"],
"phase_score": x["phase_score"],
"phase_type": x["phase_type"],
}
for x in tmp_picks
]
picks.extend(tmp_picks)
if len(picks) > 0:
picks_df = pd.DataFrame(picks)
picks_df.to_csv("picks.csv", index=False)
else:
picks_df = None
os.system("touch picks.csv")
return picks_df, "picks.csv", json.dumps(picks)
inputs = [
gr.File(label="mseeds", file_count="multiple"),
gr.Textbox(label="waveform", visible=False),
gr.Textbox(label="stations", visible=False),
]
outputs = [
gr.Dataframe(label="picks", headers=["phase_time", "phase_score", "phase_type"]),
gr.File(label="csv"),
gr.Textbox(label="json", visible=False),
]
gr.Interface(
predict,
inputs=inputs,
outputs=outputs,
title="PhaseNet",
description="PhaseNet",
examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
allow_flagging="never",
).queue().launch()
# if __name__ == "__main__":
# picks = predict(["tests/test.mseed"])
# print(picks)
|