File size: 2,949 Bytes
300fbbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import tensorflow as tf
import numpy as np
import os
import obspy
import pandas as pd
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks

tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint("./model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)


def normalize(vec):
    mu = np.mean(vec, axis=1, keepdims=True)
    std = np.std(vec, axis=1, keepdims=True)
    std[std == 0] = 1.0
    vec = (vec - mu) / std
    return vec


def reshape_input(vec):
    if len(vec.shape) == 2:
        vec = vec[np.newaxis, :, np.newaxis, :]
    elif len(vec.shape) == 3:
        vec = vec[np.newaxis, :, :, :]
    else:
        pass
    return vec


# inference fn
def predict(inputs):

    picks = []
    for input in inputs:
        file = input.name  # if not isinstance(input, str) else input
        mseed = obspy.read(file)
        begin_time = min([tr.stats.starttime for tr in mseed])
        end_time = max([tr.stats.endtime for tr in mseed])
        mseed = mseed.trim(begin_time, end_time)
        vec = np.asarray([tr.data for tr in mseed]).T
        vec = reshape_input(vec)  # (nb, nt, nsta, nch)
        vec = normalize(vec)

        feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
        preds = sess.run(model.preds, feed_dict=feed)
        tmp_picks = extract_picks(
            preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")]
        )  # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
        tmp_picks = [
            {
                "phase_time": x["phase_time"],
                "phase_index": x["phase_index"],
                "phase_score": x["phase_score"],
                "phase_type": x["phase_type"],
            }
            for x in tmp_picks
        ]

        picks.extend(tmp_picks)

    picks = pd.DataFrame(picks)
    picks.to_csv("picks.csv", index=False)

    return picks, "picks.csv"


# gradio components
inputs = gr.File(file_count="multiple")
# outputs = gr.outputs.JSON()
outputs = [gr.Dataframe(headers=["phase_time", "phase_score", "phase_type"]), gr.File()]
gr.Interface(
    predict,
    inputs=inputs,
    outputs=outputs,
    title="PhaseNet",
    description="PhaseNet",
    examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
    allow_flagging="never",
).launch()

# if __name__ == "__main__":
#    picks = predict(["tests/test.mseed"])
#    print(picks)