update
Browse files- pipeline.py +28 -8
- tests/test_api.ipynb +71 -4
pipeline.py
CHANGED
@@ -15,9 +15,6 @@ class PreTrainedPipeline():
|
|
15 |
# Preload all the elements you are going to need at inference.
|
16 |
# For instance your model, processors, tokenizer that might be needed.
|
17 |
# This function is only called once, so do all the heavy processing I/O here"""
|
18 |
-
# raise NotImplementedError(
|
19 |
-
# "Please implement PreTrainedPipeline __init__ function"
|
20 |
-
# )
|
21 |
|
22 |
## load model
|
23 |
tf.compat.v1.reset_default_graph()
|
@@ -32,7 +29,6 @@ class PreTrainedPipeline():
|
|
32 |
print(f"restoring model {latest_check_point}")
|
33 |
saver.restore(sess, latest_check_point)
|
34 |
|
35 |
-
##
|
36 |
self.sess = sess
|
37 |
self.model = model
|
38 |
|
@@ -51,21 +47,45 @@ class PreTrainedPipeline():
|
|
51 |
# "Please implement PreTrainedPipeline __call__ function"
|
52 |
# )
|
53 |
|
54 |
-
vec = np.asarray(json.loads(inputs))
|
|
|
|
|
55 |
|
56 |
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
|
57 |
preds = self.sess.run(self.model.preds, feed_dict=feed)
|
58 |
|
59 |
picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
|
60 |
-
|
61 |
# picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
|
62 |
|
63 |
# return picks
|
64 |
-
return [[picks, {"label": "debug", "score": 0.1}]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
68 |
pipeline = PreTrainedPipeline()
|
69 |
-
inputs =
|
70 |
inputs = json.dumps(inputs)
|
71 |
picks = pipeline(inputs)
|
|
|
|
15 |
# Preload all the elements you are going to need at inference.
|
16 |
# For instance your model, processors, tokenizer that might be needed.
|
17 |
# This function is only called once, so do all the heavy processing I/O here"""
|
|
|
|
|
|
|
18 |
|
19 |
## load model
|
20 |
tf.compat.v1.reset_default_graph()
|
|
|
29 |
print(f"restoring model {latest_check_point}")
|
30 |
saver.restore(sess, latest_check_point)
|
31 |
|
|
|
32 |
self.sess = sess
|
33 |
self.model = model
|
34 |
|
|
|
47 |
# "Please implement PreTrainedPipeline __call__ function"
|
48 |
# )
|
49 |
|
50 |
+
vec = np.asarray(json.loads(inputs))
|
51 |
+
vec = self.reshape_input(vec) # (nb, nt, nsta, nch)
|
52 |
+
vec = self.normalize(vec)
|
53 |
|
54 |
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
|
55 |
preds = self.sess.run(self.model.preds, feed_dict=feed)
|
56 |
|
57 |
picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
|
58 |
+
picks = [{'phase_index': x['phase_index'], 'phase_score': x['phase_score'], 'phase_type': x['phase_type']} for x in picks]
|
59 |
# picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
|
60 |
|
61 |
# return picks
|
62 |
+
# return [[picks, {"label": "debug", "score": 0.1}]]
|
63 |
+
return [picks]
|
64 |
+
|
65 |
+
def normalize(self, vec):
|
66 |
+
mu = np.mean(vec, axis=1, keepdims=True)
|
67 |
+
std = np.std(vec, axis=1, keepdims=True)
|
68 |
+
std[std == 0] = 1.0
|
69 |
+
vec = (vec - mu) / std
|
70 |
+
return vec
|
71 |
+
|
72 |
+
def reshape_input(self, vec):
|
73 |
+
if len(vec.shape) == 2:
|
74 |
+
vec = vec[np.newaxis, :, np.newaxis, :]
|
75 |
+
elif len(vec.shape) == 3:
|
76 |
+
vec = vec[np.newaxis, :, :, :]
|
77 |
+
else:
|
78 |
+
pass
|
79 |
+
return vec
|
80 |
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
+
import obspy
|
84 |
+
waveform = obspy.read()
|
85 |
+
array = np.array([x.data for x in waveform]).T
|
86 |
+
|
87 |
pipeline = PreTrainedPipeline()
|
88 |
+
inputs = array.tolist()
|
89 |
inputs = json.dumps(inputs)
|
90 |
picks = pipeline(inputs)
|
91 |
+
print(picks)
|
tests/test_api.ipynb
CHANGED
@@ -2,14 +2,60 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
"name": "stdout",
|
10 |
"output_type": "stream",
|
11 |
"text": [
|
12 |
-
"[[{'label': 'debug', 'score': 0.1}]]\n"
|
13 |
]
|
14 |
}
|
15 |
],
|
@@ -27,8 +73,8 @@
|
|
27 |
" return response.json()\n",
|
28 |
" # return json.loads(response.content.decode(\"utf-8\"))\n",
|
29 |
"\n",
|
30 |
-
"array = np.random.rand(10, 3).tolist()\n",
|
31 |
-
"inputs = json.dumps(array)\n",
|
32 |
"data = {\n",
|
33 |
"\t# \"inputs\": \"I like you. I love you\",\n",
|
34 |
" \"inputs\": inputs,\n",
|
@@ -38,6 +84,27 @@
|
|
38 |
"output = query(data)\n",
|
39 |
"print(output)"
|
40 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
}
|
42 |
],
|
43 |
"metadata": {
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 20,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import obspy\n",
|
10 |
+
"import numpy as np"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 22,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"name": "stdout",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
"(3000, 3)\n"
|
23 |
+
]
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"source": [
|
27 |
+
"waveform = obspy.read()\n",
|
28 |
+
"array = np.array([x.data for x in waveform]).T\n",
|
29 |
+
"print(array.shape)"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 25,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [
|
37 |
+
{
|
38 |
+
"name": "stdout",
|
39 |
+
"output_type": "stream",
|
40 |
+
"text": [
|
41 |
+
"(3000, 3)\n"
|
42 |
+
]
|
43 |
+
}
|
44 |
+
],
|
45 |
+
"source": [
|
46 |
+
"print(array.shape)"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 26,
|
52 |
"metadata": {},
|
53 |
"outputs": [
|
54 |
{
|
55 |
"name": "stdout",
|
56 |
"output_type": "stream",
|
57 |
"text": [
|
58 |
+
"[[[{'file_name': '0000', 'station_id': '0000', 'begin_time': '1970-01-01T00:00:00.000+00:00', 'phase_index': 573, 'phase_time': '1970-01-01T00:00:05.730+00:00', 'phase_score': 0.999, 'phase_type': 'S', 'dt': 0.01}], {'label': 'debug', 'score': 0.1}]]\n"
|
59 |
]
|
60 |
}
|
61 |
],
|
|
|
73 |
" return response.json()\n",
|
74 |
" # return json.loads(response.content.decode(\"utf-8\"))\n",
|
75 |
"\n",
|
76 |
+
"# array = np.random.rand(10, 3).tolist()\n",
|
77 |
+
"inputs = json.dumps(array.tolist())\n",
|
78 |
"data = {\n",
|
79 |
"\t# \"inputs\": \"I like you. I love you\",\n",
|
80 |
" \"inputs\": inputs,\n",
|
|
|
84 |
"output = query(data)\n",
|
85 |
"print(output)"
|
86 |
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": []
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": null,
|
98 |
+
"metadata": {},
|
99 |
+
"outputs": [],
|
100 |
+
"source": []
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": null,
|
105 |
+
"metadata": {},
|
106 |
+
"outputs": [],
|
107 |
+
"source": []
|
108 |
}
|
109 |
],
|
110 |
"metadata": {
|