zhuwq0 commited on
Commit
34dbcc7
·
1 Parent(s): 39cf57a

update to support gradio_client

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,9 +1,13 @@
 
 
 
 
1
  import gradio as gr
2
- import tensorflow as tf
3
  import numpy as np
4
- import os
5
  import obspy
6
  import pandas as pd
 
 
7
  from phasenet.model import ModelConfig, UNet
8
  from phasenet.postprocess import extract_picks
9
 
@@ -43,11 +47,17 @@ def reshape_input(vec):
43
 
44
 
45
  # inference fn
46
- def predict(inputs):
47
-
 
 
 
 
 
 
48
  picks = []
49
- for input in inputs:
50
- file = input.name # if not isinstance(input, str) else input
51
  mseed = obspy.read(file)
52
  begin_time = min([tr.stats.starttime for tr in mseed])
53
  end_time = max([tr.stats.endtime for tr in mseed])
@@ -73,16 +83,19 @@ def predict(inputs):
73
 
74
  picks.extend(tmp_picks)
75
 
76
- picks = pd.DataFrame(picks)
77
- picks.to_csv("picks.csv", index=False)
 
 
78
 
79
- return picks, "picks.csv"
80
 
 
 
 
 
 
81
 
82
- # gradio components
83
- inputs = gr.File(file_count="multiple")
84
- # outputs = gr.outputs.JSON()
85
- outputs = [gr.Dataframe(headers=["phase_time", "phase_score", "phase_type"]), gr.File()]
86
  gr.Interface(
87
  predict,
88
  inputs=inputs,
@@ -91,7 +104,7 @@ gr.Interface(
91
  description="PhaseNet",
92
  examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
93
  allow_flagging="never",
94
- ).launch()
95
 
96
  # if __name__ == "__main__":
97
  # picks = predict(["tests/test.mseed"])
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+
5
  import gradio as gr
 
6
  import numpy as np
 
7
  import obspy
8
  import pandas as pd
9
+ import tensorflow as tf
10
+
11
  from phasenet.model import ModelConfig, UNet
12
  from phasenet.postprocess import extract_picks
13
 
 
47
 
48
 
49
  # inference fn
50
+ def predict(mseeds, waveforms, stations):
51
+ if len(stations) > 0:
52
+ stations = json.loads(stations)
53
+ print(f"{len(stations)}: {stations = }")
54
+ if len(waveforms) > 0:
55
+ waveforms = json.loads(waveforms)
56
+ waveforms = np.array(waveforms)
57
+ print(f"{waveforms.shape = }")
58
  picks = []
59
+ for mseed in mseeds:
60
+ file = mseed.name # if not isinstance(input, str) else input
61
  mseed = obspy.read(file)
62
  begin_time = min([tr.stats.starttime for tr in mseed])
63
  end_time = max([tr.stats.endtime for tr in mseed])
 
83
 
84
  picks.extend(tmp_picks)
85
 
86
+ picks_df = pd.DataFrame(picks)
87
+ picks_df.to_csv("picks.csv", index=False)
88
+
89
+ return picks_df, "picks.csv", json.dumps(picks)
90
 
 
91
 
92
+ inputs = [
93
+ gr.File(label="mseed file", file_count="multiple"),
94
+ gr.Textbox(label="waveform", visible=False),
95
+ gr.Textbox(label="stations", visible=False),
96
+ ]
97
 
98
+ outputs = [gr.Dataframe(headers=["phase_time", "phase_score", "phase_type"]), gr.File(), gr.Textbox(visible=False)]
 
 
 
99
  gr.Interface(
100
  predict,
101
  inputs=inputs,
 
104
  description="PhaseNet",
105
  examples=[[[os.path.join(os.path.dirname(__file__), "tests/test.mseed")]]],
106
  allow_flagging="never",
107
+ ).queue().launch()
108
 
109
  # if __name__ == "__main__":
110
  # picks = predict(["tests/test.mseed"])