File size: 3,617 Bytes
34dbcc7
 
 
 
300fbbb
 
 
 
34dbcc7
 
300fbbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d1974
34dbcc7
 
 
 
 
 
 
300fbbb
4da07fe
 
34dbcc7
 
300fbbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4da07fe
c8f6682
 
 
 
 
34dbcc7
 
300fbbb
 
34dbcc7
cc4f687
34dbcc7
 
 
300fbbb
cc4f687
 
 
 
 
300fbbb
 
 
 
 
 
 
 
34dbcc7
300fbbb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os
from pathlib import Path

import gradio as gr
import numpy as np
import obspy
import pandas as pd
import tensorflow as tf

from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks

tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint("./model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)


def normalize(vec):
    mu = np.mean(vec, axis=1, keepdims=True)
    std = np.std(vec, axis=1, keepdims=True)
    std[std == 0] = 1.0
    vec = (vec - mu) / std
    return vec


def reshape_input(vec):
    if len(vec.shape) == 2:
        vec = vec[np.newaxis, :, np.newaxis, :]
    elif len(vec.shape) == 3:
        vec = vec[np.newaxis, :, :, :]
    else:
        pass
    return vec


# inference fn
def predict(mseeds=[], waveforms="", stations=""):
    if len(stations) > 0:
        stations = json.loads(stations)
        print(f"{len(stations)}: {stations = }")
    if len(waveforms) > 0:
        waveforms = json.loads(waveforms)
        waveforms = np.array(waveforms)
        print(f"{waveforms.shape = }")
    picks = []
    if mseeds is None:
        mseeds = []
    for mseed in mseeds:
        file = mseed.name  # if not isinstance(input, str) else input
        mseed = obspy.read(file)
        begin_time = min([tr.stats.starttime for tr in mseed])
        end_time = max([tr.stats.endtime for tr in mseed])
        mseed = mseed.trim(begin_time, end_time)
        vec = np.asarray([tr.data for tr in mseed]).T
        vec = reshape_input(vec)  # (nb, nt, nsta, nch)
        vec = normalize(vec)

        feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
        preds = sess.run(model.preds, feed_dict=feed)
        tmp_picks = extract_picks(
            preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")]
        )  # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
        tmp_picks = [
            {
                "phase_time": x["phase_time"],
                "phase_index": x["phase_index"],
                "phase_score": x["phase_score"],
                "phase_type": x["phase_type"],
            }
            for x in tmp_picks
        ]

        picks.extend(tmp_picks)

    if len(picks) > 0:
        picks_df = pd.DataFrame(picks)
        picks_df.to_csv("picks.csv", index=False)
    else:
        picks_df = None
        os.system("touch picks.csv")

    return picks_df, "picks.csv", json.dumps(picks)


inputs = [
    gr.File(label="mseeds", file_count="multiple"),
    gr.Textbox(label="waveform", visible=False),
    gr.Textbox(label="stations", visible=False),
]

outputs = [
    gr.Dataframe(label="picks", headers=["phase_time", "phase_score", "phase_type"]),
    gr.File(label="csv"),
    gr.Textbox(label="json", visible=False),
]
gr.Interface(
    predict,
    inputs=inputs,
    outputs=outputs,
    title="PhaseNet",
    description="PhaseNet",
    examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
    allow_flagging="never",
).queue().launch()

# if __name__ == "__main__":
#    picks = predict(["tests/test.mseed"])
#    print(picks)