File size: 3,421 Bytes
34dbcc7
 
 
 
300fbbb
 
 
 
34dbcc7
 
300fbbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d1974
34dbcc7
 
 
 
 
 
 
300fbbb
34dbcc7
 
300fbbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34dbcc7
 
 
 
300fbbb
 
34dbcc7
 
 
 
 
300fbbb
34dbcc7
300fbbb
 
 
 
 
 
 
 
34dbcc7
300fbbb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import os
from pathlib import Path

import gradio as gr
import numpy as np
import obspy
import pandas as pd
import tensorflow as tf

from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks

tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint("./model/190703-214543")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)


def normalize(vec):
    mu = np.mean(vec, axis=1, keepdims=True)
    std = np.std(vec, axis=1, keepdims=True)
    std[std == 0] = 1.0
    vec = (vec - mu) / std
    return vec


def reshape_input(vec):
    if len(vec.shape) == 2:
        vec = vec[np.newaxis, :, np.newaxis, :]
    elif len(vec.shape) == 3:
        vec = vec[np.newaxis, :, :, :]
    else:
        pass
    return vec


# inference fn
def predict(mseeds=[], waveforms="", stations=""):
    if len(stations) > 0:
        stations = json.loads(stations)
        print(f"{len(stations)}: {stations = }")
    if len(waveforms) > 0:
        waveforms = json.loads(waveforms)
        waveforms = np.array(waveforms)
        print(f"{waveforms.shape = }")
    picks = []
    for mseed in mseeds:
        file = mseed.name  # if not isinstance(input, str) else input
        mseed = obspy.read(file)
        begin_time = min([tr.stats.starttime for tr in mseed])
        end_time = max([tr.stats.endtime for tr in mseed])
        mseed = mseed.trim(begin_time, end_time)
        vec = np.asarray([tr.data for tr in mseed]).T
        vec = reshape_input(vec)  # (nb, nt, nsta, nch)
        vec = normalize(vec)

        feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
        preds = sess.run(model.preds, feed_dict=feed)
        tmp_picks = extract_picks(
            preds, begin_times=[begin_time.datetime.isoformat(timespec="milliseconds")]
        )  # , station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
        tmp_picks = [
            {
                "phase_time": x["phase_time"],
                "phase_index": x["phase_index"],
                "phase_score": x["phase_score"],
                "phase_type": x["phase_type"],
            }
            for x in tmp_picks
        ]

        picks.extend(tmp_picks)

    picks_df = pd.DataFrame(picks)
    picks_df.to_csv("picks.csv", index=False)

    return picks_df, "picks.csv", json.dumps(picks)


inputs = [
    gr.File(label="mseed file", file_count="multiple"),
    gr.Textbox(label="waveform", visible=False),
    gr.Textbox(label="stations", visible=False),
]

outputs = [gr.Dataframe(headers=["phase_time", "phase_score", "phase_type"]), gr.File(), gr.Textbox(visible=False)]
gr.Interface(
    predict,
    inputs=inputs,
    outputs=outputs,
    title="PhaseNet",
    description="PhaseNet",
    examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
    allow_flagging="never",
).queue().launch()

# if __name__ == "__main__":
#    picks = predict(["tests/test.mseed"])
#    print(picks)