zhuwq0 commited on
Commit
1b6e4e8
·
1 Parent(s): 7bb513d
Files changed (2) hide show
  1. pipeline.py +28 -8
  2. tests/test_api.ipynb +71 -4
pipeline.py CHANGED
@@ -15,9 +15,6 @@ class PreTrainedPipeline():
15
  # Preload all the elements you are going to need at inference.
16
  # For instance your model, processors, tokenizer that might be needed.
17
  # This function is only called once, so do all the heavy processing I/O here"""
18
- # raise NotImplementedError(
19
- # "Please implement PreTrainedPipeline __init__ function"
20
- # )
21
 
22
  ## load model
23
  tf.compat.v1.reset_default_graph()
@@ -32,7 +29,6 @@ class PreTrainedPipeline():
32
  print(f"restoring model {latest_check_point}")
33
  saver.restore(sess, latest_check_point)
34
 
35
- ##
36
  self.sess = sess
37
  self.model = model
38
 
@@ -51,21 +47,45 @@ class PreTrainedPipeline():
51
  # "Please implement PreTrainedPipeline __call__ function"
52
  # )
53
 
54
- vec = np.asarray(json.loads(inputs))[np.newaxis, :, np.newaxis, :]
 
 
55
 
56
  feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
57
  preds = self.sess.run(self.model.preds, feed_dict=feed)
58
 
59
  picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
60
-
61
  # picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
62
 
63
  # return picks
64
- return [[picks, {"label": "debug", "score": 0.1}]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  if __name__ == "__main__":
 
 
 
 
68
  pipeline = PreTrainedPipeline()
69
- inputs = np.random.rand(1000, 3).tolist()
70
  inputs = json.dumps(inputs)
71
  picks = pipeline(inputs)
 
 
15
  # Preload all the elements you are going to need at inference.
16
  # For instance your model, processors, tokenizer that might be needed.
17
  # This function is only called once, so do all the heavy processing I/O here"""
 
 
 
18
 
19
  ## load model
20
  tf.compat.v1.reset_default_graph()
 
29
  print(f"restoring model {latest_check_point}")
30
  saver.restore(sess, latest_check_point)
31
 
 
32
  self.sess = sess
33
  self.model = model
34
 
 
47
  # "Please implement PreTrainedPipeline __call__ function"
48
  # )
49
 
50
+ vec = np.asarray(json.loads(inputs))
51
+ vec = self.reshape_input(vec) # (nb, nt, nsta, nch)
52
+ vec = self.normalize(vec)
53
 
54
  feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
55
  preds = self.sess.run(self.model.preds, feed_dict=feed)
56
 
57
  picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
58
+ picks = [{'phase_index': x['phase_index'], 'phase_score': x['phase_score'], 'phase_type': x['phase_type']} for x in picks]
59
  # picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
60
 
61
  # return picks
62
+ # return [[picks, {"label": "debug", "score": 0.1}]]
63
+ return [picks]
64
+
65
+ def normalize(self, vec):
66
+ mu = np.mean(vec, axis=1, keepdims=True)
67
+ std = np.std(vec, axis=1, keepdims=True)
68
+ std[std == 0] = 1.0
69
+ vec = (vec - mu) / std
70
+ return vec
71
+
72
+ def reshape_input(self, vec):
73
+ if len(vec.shape) == 2:
74
+ vec = vec[np.newaxis, :, np.newaxis, :]
75
+ elif len(vec.shape) == 3:
76
+ vec = vec[np.newaxis, :, :, :]
77
+ else:
78
+ pass
79
+ return vec
80
 
81
 
82
  if __name__ == "__main__":
83
+ import obspy
84
+ waveform = obspy.read()
85
+ array = np.array([x.data for x in waveform]).T
86
+
87
  pipeline = PreTrainedPipeline()
88
+ inputs = array.tolist()
89
  inputs = json.dumps(inputs)
90
  picks = pipeline(inputs)
91
+ print(picks)
tests/test_api.ipynb CHANGED
@@ -2,14 +2,60 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 30,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "metadata": {},
7
  "outputs": [
8
  {
9
  "name": "stdout",
10
  "output_type": "stream",
11
  "text": [
12
- "[[{'label': 'debug', 'score': 0.1}]]\n"
13
  ]
14
  }
15
  ],
@@ -27,8 +73,8 @@
27
  " return response.json()\n",
28
  " # return json.loads(response.content.decode(\"utf-8\"))\n",
29
  "\n",
30
- "array = np.random.rand(10, 3).tolist()\n",
31
- "inputs = json.dumps(array)\n",
32
  "data = {\n",
33
  "\t# \"inputs\": \"I like you. I love you\",\n",
34
  " \"inputs\": inputs,\n",
@@ -38,6 +84,27 @@
38
  "output = query(data)\n",
39
  "print(output)"
40
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
42
  ],
43
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 20,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import obspy\n",
10
+ "import numpy as np"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 22,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "name": "stdout",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "(3000, 3)\n"
23
+ ]
24
+ }
25
+ ],
26
+ "source": [
27
+ "waveform = obspy.read()\n",
28
+ "array = np.array([x.data for x in waveform]).T\n",
29
+ "print(array.shape)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 25,
35
+ "metadata": {},
36
+ "outputs": [
37
+ {
38
+ "name": "stdout",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "(3000, 3)\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "print(array.shape)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 26,
52
  "metadata": {},
53
  "outputs": [
54
  {
55
  "name": "stdout",
56
  "output_type": "stream",
57
  "text": [
58
+ "[[[{'file_name': '0000', 'station_id': '0000', 'begin_time': '1970-01-01T00:00:00.000+00:00', 'phase_index': 573, 'phase_time': '1970-01-01T00:00:05.730+00:00', 'phase_score': 0.999, 'phase_type': 'S', 'dt': 0.01}], {'label': 'debug', 'score': 0.1}]]\n"
59
  ]
60
  }
61
  ],
 
73
  " return response.json()\n",
74
  " # return json.loads(response.content.decode(\"utf-8\"))\n",
75
  "\n",
76
+ "# array = np.random.rand(10, 3).tolist()\n",
77
+ "inputs = json.dumps(array.tolist())\n",
78
  "data = {\n",
79
  "\t# \"inputs\": \"I like you. I love you\",\n",
80
  " \"inputs\": inputs,\n",
 
84
  "output = query(data)\n",
85
  "print(output)"
86
  ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": []
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": []
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": []
108
  }
109
  ],
110
  "metadata": {