zhuwq0 commited on
Commit
81c99dc
·
0 Parent(s):
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tensorflow/tensorflow
2
+
3
+ # Create the environment:
4
+ # COPY env.yml /app
5
+ # RUN conda env create --name cs329s --file=env.yml
6
+ # Make RUN commands use the new environment:
7
+ # SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"]
8
+
9
+ RUN pip install tqdm obspy pandas
10
+ RUN pip install uvicorn fastapi kafka-python
11
+
12
+ WORKDIR /opt
13
+
14
+ # Copy files
15
+ COPY deepdenoiser /opt/deepdenoiser
16
+ # COPY model /opt/model
17
+ RUN wget https://github.com/AI4EPS/models/releases/download/DeepDenoiser/model.tar && tar -xvf model.tar && rm model.tar
18
+
19
+ # Expose API port
20
+ EXPOSE 8000
21
+
22
+ ENV PYTHONUNBUFFERED=1
23
+
24
+ # Start API server
25
+ #ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "cs329s", "uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
26
+ ENTRYPOINT ["uvicorn", "--app-dir", "deepdenoiser", "app:app", "--reload", "--port", "7860", "--host", "0.0.0.0"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Weiqiang Zhu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
deepdenoiser/__init__.py ADDED
File without changes
deepdenoiser/app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict, namedtuple
3
+ from datetime import datetime, timedelta
4
+ from json import dumps
5
+ from typing import Any, AnyStr, Dict, List, NamedTuple, Union
6
+
7
+ import numpy as np
8
+ import requests
9
+ import tensorflow as tf
10
+ from fastapi import FastAPI
11
+ from kafka import KafkaProducer
12
+ from pydantic import BaseModel
13
+ import scipy
14
+ from scipy.interpolate import interp1d
15
+
16
+ from model import UNet
17
+
18
+ tf.compat.v1.disable_eager_execution()
19
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
20
+ PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
21
+ JSONObject = Dict[AnyStr, Any]
22
+ JSONArray = List[Any]
23
+ JSONStructure = Union[JSONArray, JSONObject]
24
+
25
+ app = FastAPI()
26
+ X_SHAPE = [3000, 1, 3]
27
+ SAMPLING_RATE = 100
28
+
29
+ # load model
30
+ model = UNet(mode="pred")
31
+ sess_config = tf.compat.v1.ConfigProto()
32
+ sess_config.gpu_options.allow_growth = True
33
+
34
+ sess = tf.compat.v1.Session(config=sess_config)
35
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
36
+ init = tf.compat.v1.global_variables_initializer()
37
+ sess.run(init)
38
+ latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802")
39
+ print(f"restoring model {latest_check_point}")
40
+ saver.restore(sess, latest_check_point)
41
+
42
+ # Kafak producer
43
+ use_kafka = False
44
+ # BROKER_URL = 'localhost:9092'
45
+ # BROKER_URL = 'my-kafka-headless:9092'
46
+
47
+ try:
48
+ print("Connecting to k8s kafka")
49
+ BROKER_URL = "quakeflow-kafka-headless:9092"
50
+ producer = KafkaProducer(
51
+ bootstrap_servers=[BROKER_URL],
52
+ key_serializer=lambda x: dumps(x).encode("utf-8"),
53
+ value_serializer=lambda x: dumps(x).encode("utf-8"),
54
+ )
55
+ use_kafka = True
56
+ print("k8s kafka connection success!")
57
+ except BaseException:
58
+ print("k8s Kafka connection error")
59
+ try:
60
+ print("Connecting to local kafka")
61
+ producer = KafkaProducer(
62
+ bootstrap_servers=["localhost:9092"],
63
+ key_serializer=lambda x: dumps(x).encode("utf-8"),
64
+ value_serializer=lambda x: dumps(x).encode("utf-8"),
65
+ )
66
+ use_kafka = True
67
+ print("local kafka connection success!")
68
+ except BaseException:
69
+ print("local Kafka connection error")
70
+
71
+
72
+ def normalize_batch(data, window=200):
73
+ """
74
+ data: nbn, nf, nt, 2
75
+ """
76
+ assert len(data.shape) == 4
77
+ shift = window // 2
78
+ nbt, nf, nt, nimg = data.shape
79
+
80
+ ## std in slide windows
81
+ data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
82
+ t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200
83
+ # print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
84
+ std = np.zeros([nbt, len(t)])
85
+ mean = np.zeros([nbt, len(t)])
86
+ for i in range(std.shape[1]):
87
+ std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
88
+ mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
89
+
90
+ std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
91
+ std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
92
+
93
+ ## normalize data with interplated std
94
+ t_interp = np.arange(nt, dtype="int")
95
+ std_interp = interp1d(t, std, kind="slinear")(t_interp)
96
+ std_interp[std_interp == 0] = 1.0
97
+ mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
98
+
99
+ data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]
100
+
101
+ if len(t) > 3: ##need to address this normalization issue in training
102
+ data /= 2.0
103
+
104
+ return data
105
+
106
+
107
+ def get_prediction(meta):
108
+
109
+ FS = 100
110
+ NPERSEG = 30
111
+ NFFT = 60
112
+
113
+ vec = np.array(meta.vec) # [batch, nt, chn]
114
+ nbt, nt, nch = vec.shape
115
+ vec = np.transpose(vec, [0, 2, 1]) # [batch, chn, nt]
116
+ vec = np.reshape(vec, [nbt * nch, nt]) ## [batch * chn, nt]
117
+
118
+ if np.mod(vec.shape[-1], 3000) == 1: # 3001=>3000
119
+ vec = vec[..., :-1]
120
+
121
+ if meta.dt != 0.01:
122
+ t = np.linspace(0, 1, len(vec))
123
+ t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS)))
124
+ vec = interp1d(t, vec, kind="slinear")(t_interp)
125
+
126
+ # sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence
127
+ # vec = scipy.signal.sosfilt(sos, vec)
128
+ f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary='zeros')
129
+ noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # [batch * chn, nf, nt, 2]
130
+ noisy_signal[np.isnan(noisy_signal)] = 0
131
+ noisy_signal[np.isinf(noisy_signal)] = 0
132
+ X_input = normalize_batch(noisy_signal)
133
+
134
+ feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False}
135
+ preds = sess.run(model.preds, feed_dict=feed)
136
+
137
+ _, denoised_signal = scipy.signal.istft(
138
+ (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0],
139
+ fs=FS,
140
+ nperseg=NPERSEG,
141
+ nfft=NFFT,
142
+ boundary='zeros',
143
+ )
144
+ # _, denoised_noise = scipy.signal.istft(
145
+ # (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1],
146
+ # fs=FS,
147
+ # nperseg=NPERSEG,
148
+ # nfft=NFFT,
149
+ # boundary='zeros',
150
+ # )
151
+
152
+ denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt])
153
+ denoised_signal = np.transpose(denoised_signal, [0, 2, 1])
154
+
155
+ result = meta.copy()
156
+ result.vec = denoised_signal.tolist()
157
+ return result
158
+
159
+
160
+ class Data(BaseModel):
161
+ # id: Union[List[str], str]
162
+ # timestamp: Union[List[str], str]
163
+ # vec: Union[List[List[List[float]]], List[List[float]]]
164
+ id: List[str]
165
+ timestamp: List[str]
166
+ vec: List[List[List[float]]]
167
+ dt: float = 0.01
168
+
169
+
170
+ @app.post("/predict")
171
+ def predict(data: Data):
172
+
173
+ denoised = get_prediction(data)
174
+
175
+ return denoised
176
+
177
+
178
+ @app.get("/healthz")
179
+ def healthz():
180
+ return {"status": "ok"}
deepdenoiser/data_reader.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import scipy.signal
4
+ import tensorflow as tf
5
+
6
+ pd.options.mode.chained_assignment = None
7
+ import logging
8
+ import os
9
+ import threading
10
+
11
+ import obspy
12
+ from scipy.interpolate import interp1d
13
+
14
+ tf.compat.v1.disable_eager_execution()
15
+ # from tensorflow.python.ops.linalg_ops import norm
16
+ # from tensorflow.python.util import nest
17
+
18
+
19
+ class Config:
20
+ seed = 100
21
+ n_class = 2
22
+ fs = 100
23
+ dt = 1.0 / fs
24
+ freq_range = [0, fs / 2]
25
+ time_range = [0, 30]
26
+ nperseg = 30
27
+ nfft = 60
28
+ plot = False
29
+ nt = 3000
30
+ X_shape = [31, 201, 2]
31
+ Y_shape = [31, 201, n_class]
32
+ signal_shape = [31, 201]
33
+ noise_shape = signal_shape
34
+ use_seed = False
35
+ queue_size = 10
36
+ noise_mean = 2
37
+ noise_std = 1
38
+ # noise_low = 1
39
+ # noise_high = 5
40
+ use_buffer = True
41
+ snr_threshold = 10
42
+
43
+
44
+ # %%
45
+ # def normalize(data, window=3000):
46
+ # """
47
+ # data: nsta, chn, nt
48
+ # """
49
+ # shift = window//2
50
+ # nt = len(data)
51
+
52
+ # ## std in slide windows
53
+ # data_pad = np.pad(data, ((window//2, window//2)), mode="reflect")
54
+ # t = np.arange(0, nt, shift, dtype="int")
55
+ # # print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
56
+ # std = np.zeros(len(t))
57
+ # mean = np.zeros(len(t))
58
+ # for i in range(len(std)):
59
+ # std[i] = np.std(data_pad[i*shift:i*shift+window])
60
+ # mean[i] = np.mean(data_pad[i*shift:i*shift+window])
61
+
62
+ # t = np.append(t, nt)
63
+ # std = np.append(std, [np.std(data_pad[-window:])])
64
+ # mean = np.append(mean, [np.mean(data_pad[-window:])])
65
+
66
+ # # print(t)
67
+ # ## normalize data with interplated std
68
+ # t_interp = np.arange(nt, dtype="int")
69
+ # std_interp = interp1d(t, std, kind="slinear")(t_interp)
70
+ # mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
71
+ # data = (data - mean_interp)/(std_interp)
72
+ # return data, std_interp
73
+
74
+ # %%
75
+ def normalize(data, window=200):
76
+ """
77
+ data: nsta, chn, nt
78
+ """
79
+ shift = window // 2
80
+ nt = data.shape[1]
81
+
82
+ ## std in slide windows
83
+ data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
84
+ t = np.arange(0, nt, shift, dtype="int")
85
+ # print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
86
+ std = np.zeros(len(t))
87
+ mean = np.zeros(len(t))
88
+ for i in range(len(std)):
89
+ std[i] = np.std(data_pad[:, i * shift : i * shift + window, :])
90
+ mean[i] = np.mean(data_pad[:, i * shift : i * shift + window, :])
91
+
92
+ t = np.append(t, nt)
93
+ std = np.append(std, [np.std(data_pad[:, -window:, :])])
94
+ mean = np.append(mean, [np.mean(data_pad[:, -window:, :])])
95
+ # print(t)
96
+ ## normalize data with interplated std
97
+ t_interp = np.arange(nt, dtype="int")
98
+ std_interp = interp1d(t, std, kind="slinear")(t_interp)
99
+ std_interp[std_interp == 0] = 1.0
100
+ mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
101
+ data = (data - mean_interp[np.newaxis, :, np.newaxis]) / std_interp[np.newaxis, :, np.newaxis]
102
+ return data, std_interp
103
+
104
+
105
+ def normalize_batch(data, window=200):
106
+ """
107
+ data: nbn, nf, nt, 2
108
+ """
109
+ assert len(data.shape) == 4
110
+ shift = window // 2
111
+ nbt, nf, nt, nimg = data.shape
112
+
113
+ ## std in slide windows
114
+ data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
115
+ t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200
116
+ std = np.zeros([nbt, len(t)])
117
+ mean = np.zeros([nbt, len(t)])
118
+ for i in range(std.shape[1]):
119
+ std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
120
+ mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
121
+
122
+ std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
123
+ std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
124
+
125
+ ## normalize data with interplated std
126
+ t_interp = np.arange(nt, dtype="int")
127
+ std_interp = interp1d(t, std, kind="slinear")(t_interp) ##nbt, nt
128
+ std_interp[std_interp == 0] = 1.0
129
+ mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
130
+
131
+ data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]
132
+
133
+ if len(t) > 3: ##need to address this normalization issue in training
134
+ data /= 2.0
135
+
136
+ return data
137
+
138
+
139
+ # %%
140
+ def py_func_decorator(output_types=None, output_shapes=None, name=None):
141
+ def decorator(func):
142
+ def call(*args, **kwargs):
143
+ nonlocal output_shapes
144
+ # flat_output_types = nest.flatten(output_types)
145
+ flat_output_types = tf.nest.flatten(output_types)
146
+ # flat_values = tf.py_func(
147
+ flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
148
+ if output_shapes is not None:
149
+ for v, s in zip(flat_values, output_shapes):
150
+ v.set_shape(s)
151
+ # return nest.pack_sequence_as(output_types, flat_values)
152
+ return tf.nest.pack_sequence_as(output_types, flat_values)
153
+
154
+ return call
155
+
156
+ return decorator
157
+
158
+
159
+ def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None):
160
+ dataset = tf.data.Dataset.range(len(iterator))
161
+
162
+ @py_func_decorator(output_types, output_shapes, name=name)
163
+ def index_to_entry(idx):
164
+ return iterator[idx]
165
+
166
+ return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
167
+
168
+
169
+ class DataReader(object):
170
+ def __init__(
171
+ self,
172
+ signal_dir=None,
173
+ signal_list=None,
174
+ noise_dir=None,
175
+ noise_list=None,
176
+ queue_size=None,
177
+ coord=None,
178
+ config=Config(),
179
+ ):
180
+
181
+ self.config = config
182
+
183
+ signal_list = pd.read_csv(signal_list, header=0)
184
+ noise_list = pd.read_csv(noise_list, header=0)
185
+
186
+ self.signal = signal_list
187
+ self.noise = noise_list
188
+ self.n_signal = len(self.signal)
189
+
190
+ self.signal_dir = signal_dir
191
+ self.noise_dir = noise_dir
192
+
193
+ self.X_shape = config.X_shape
194
+ self.Y_shape = config.Y_shape
195
+ self.n_class = config.n_class
196
+
197
+ self.coord = coord
198
+ self.threads = []
199
+ self.queue_size = queue_size
200
+
201
+ self.add_queue()
202
+ self.buffer_signal = {}
203
+ self.buffer_noise = {}
204
+ self.buffer_channels_signal = {}
205
+ self.buffer_channels_noise = {}
206
+
207
+ def add_queue(self):
208
+ with tf.device('/cpu:0'):
209
+ self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
210
+ self.target_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
211
+ self.queue = tf.queue.PaddingFIFOQueue(
212
+ self.queue_size, ['float32', 'float32'], shapes=[self.config.X_shape, self.config.Y_shape]
213
+ )
214
+ self.enqueue = self.queue.enqueue([self.sample_placeholder, self.target_placeholder])
215
+ return 0
216
+
217
+ def dequeue(self, num_elements):
218
+ output = self.queue.dequeue_many(num_elements)
219
+ return output
220
+
221
+ def get_snr(self, data, itp, dit=300):
222
+ tmp_std = np.std(data[itp - dit : itp])
223
+ if tmp_std > 0:
224
+ return np.std(data[itp : itp + dit]) / tmp_std
225
+ else:
226
+ return 0
227
+
228
+ def add_event(self, sample, channels, j):
229
+ while np.random.uniform(0, 1) < 0.2:
230
+ shift = None
231
+ if channels not in self.buffer_channels_signal:
232
+ self.buffer_channels_signal[channels] = self.signal[self.signal['channels'] == channels]
233
+ fname = os.path.join(self.signal_dir, self.buffer_channels_signal[channels].sample(n=1).iloc[0]['fname'])
234
+ try:
235
+ if fname not in self.buffer_signal:
236
+ meta = np.load(fname)
237
+ data_FT = []
238
+ snr = []
239
+ for i in range(3):
240
+ tmp_data = meta['data'][:, i]
241
+ tmp_itp = meta['itp']
242
+ snr.append(self.get_snr(tmp_data, tmp_itp))
243
+ tmp_data -= np.mean(tmp_data)
244
+ f, t, tmp_FT = scipy.signal.stft(
245
+ tmp_data,
246
+ fs=self.config.fs,
247
+ nperseg=self.config.nperseg,
248
+ nfft=self.config.nfft,
249
+ boundary='zeros',
250
+ )
251
+ data_FT.append(tmp_FT)
252
+ data_FT = np.stack(data_FT, axis=-1)
253
+ self.buffer_signal[fname] = {
254
+ 'data_FT': data_FT,
255
+ 'itp': tmp_itp,
256
+ 'channels': meta['channels'],
257
+ 'snr': snr,
258
+ }
259
+ meta_signal = self.buffer_signal[fname]
260
+ except:
261
+ logging.error("Failed reading signal: {}".format(fname))
262
+ continue
263
+ if meta_signal['snr'][j] > self.config.snr_threshold:
264
+ tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
265
+ shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
266
+ tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
267
+ if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
268
+ continue
269
+ tmp_signal = tmp_signal / np.std(tmp_signal)
270
+ sample += tmp_signal / np.random.uniform(1, 5)
271
+ return sample
272
+
273
+ def thread_main(self, sess, n_threads=1, start=0):
274
+ stop = False
275
+ while not stop:
276
+ index = list(range(start, self.n_signal, n_threads))
277
+ np.random.shuffle(index)
278
+ for i in index:
279
+ fname_signal = os.path.join(self.signal_dir, self.signal.iloc[i]['fname'])
280
+ try:
281
+ if fname_signal not in self.buffer_signal:
282
+ meta = np.load(fname_signal)
283
+ data_FT = []
284
+ snr = []
285
+ for j in range(3):
286
+ tmp_data = meta['data'][..., j]
287
+ tmp_itp = meta['itp']
288
+ snr.append(self.get_snr(tmp_data, tmp_itp))
289
+ tmp_data -= np.mean(tmp_data)
290
+ f, t, tmp_FT = scipy.signal.stft(
291
+ tmp_data,
292
+ fs=self.config.fs,
293
+ nperseg=self.config.nperseg,
294
+ nfft=self.config.nfft,
295
+ boundary='zeros',
296
+ )
297
+ data_FT.append(tmp_FT)
298
+ data_FT = np.stack(data_FT, axis=-1)
299
+ self.buffer_signal[fname_signal] = {
300
+ 'data_FT': data_FT,
301
+ 'itp': tmp_itp,
302
+ 'channels': meta['channels'],
303
+ 'snr': snr,
304
+ }
305
+ meta_signal = self.buffer_signal[fname_signal]
306
+ except:
307
+ logging.error("Failed reading signal: {}".format(fname_signal))
308
+ continue
309
+ channels = meta_signal['channels'].tolist()
310
+ start_tp = meta_signal['itp'].tolist()
311
+
312
+ if channels not in self.buffer_channels_noise:
313
+ self.buffer_channels_noise[channels] = self.noise[self.noise['channels'] == channels]
314
+ fname_noise = os.path.join(
315
+ self.noise_dir, self.buffer_channels_noise[channels].sample(n=1).iloc[0]['fname']
316
+ )
317
+ try:
318
+ if fname_noise not in self.buffer_noise:
319
+ meta = np.load(fname_noise)
320
+ data_FT = []
321
+ for i in range(3):
322
+ tmp_data = meta['data'][: self.config.nt, i]
323
+ tmp_data -= np.mean(tmp_data)
324
+ f, t, tmp_FT = scipy.signal.stft(
325
+ tmp_data,
326
+ fs=self.config.fs,
327
+ nperseg=self.config.nperseg,
328
+ nfft=self.config.nfft,
329
+ boundary='zeros',
330
+ )
331
+ data_FT.append(tmp_FT)
332
+ data_FT = np.stack(data_FT, axis=-1)
333
+ self.buffer_noise[fname_noise] = {'data_FT': data_FT, 'channels': meta['channels']}
334
+ meta_noise = self.buffer_noise[fname_noise]
335
+ except:
336
+ logging.error("Failed reading noise: {}".format(fname_noise))
337
+ continue
338
+
339
+ if self.coord.should_stop():
340
+ stop = True
341
+ break
342
+
343
+ j = np.random.choice([0, 1, 2])
344
+ if meta_signal['snr'][j] <= self.config.snr_threshold:
345
+ continue
346
+
347
+ tmp_noise = meta_noise['data_FT'][..., j]
348
+ if np.isinf(tmp_noise).any() or np.isnan(tmp_noise).any() or (not np.any(tmp_noise)):
349
+ continue
350
+ tmp_noise = tmp_noise / np.std(tmp_noise)
351
+
352
+ tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
353
+ if np.random.random() < 0.9:
354
+ shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
355
+ tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
356
+ if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
357
+ continue
358
+ tmp_signal = tmp_signal / np.std(tmp_signal)
359
+ tmp_signal = self.add_event(tmp_signal, channels, j)
360
+
361
+ if np.random.random() < 0.2:
362
+ tmp_signal = np.fliplr(tmp_signal)
363
+
364
+ ratio = 0
365
+ while ratio <= 0:
366
+ ratio = self.config.noise_mean + np.random.randn() * self.config.noise_std
367
+ # ratio = np.random.uniform(self.config.noise_low, self.config.noise_high)
368
+ tmp_noisy_signal = tmp_signal + ratio * tmp_noise
369
+ noisy_signal = np.stack([tmp_noisy_signal.real, tmp_noisy_signal.imag], axis=-1)
370
+ if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any():
371
+ continue
372
+ noisy_signal = noisy_signal / np.std(noisy_signal)
373
+ tmp_mask = np.abs(tmp_signal) / (np.abs(tmp_signal) + np.abs(ratio * tmp_noise) + 1e-4)
374
+ tmp_mask[tmp_mask >= 1] = 1
375
+ tmp_mask[tmp_mask <= 0] = 0
376
+ mask = np.zeros([tmp_mask.shape[0], tmp_mask.shape[1], self.n_class])
377
+ mask[:, :, 0] = tmp_mask
378
+ mask[:, :, 1] = 1 - tmp_mask
379
+ sess.run(self.enqueue, feed_dict={self.sample_placeholder: noisy_signal, self.target_placeholder: mask})
380
+
381
+ def start_threads(self, sess, n_threads=8):
382
+ for i in range(n_threads):
383
+ thread = threading.Thread(target=self.thread_main, args=(sess, n_threads, i))
384
+ thread.daemon = True
385
+ thread.start()
386
+ self.threads.append(thread)
387
+ return self.threads
388
+
389
+
390
+ class DataReader_test(DataReader):
391
+ def __init__(
392
+ self,
393
+ signal_dir=None,
394
+ signal_list=None,
395
+ noise_dir=None,
396
+ noise_list=None,
397
+ queue_size=None,
398
+ coord=None,
399
+ config=Config(),
400
+ ):
401
+ self.config = config
402
+
403
+ signal_list = pd.read_csv(signal_list, header=0)
404
+ noise_list = pd.read_csv(noise_list, header=0)
405
+ self.signal = signal_list
406
+ self.noise = noise_list
407
+ self.n_signal = len(self.signal)
408
+
409
+ self.signal_dir = signal_dir
410
+ self.noise_dir = noise_dir
411
+
412
+ self.X_shape = config.X_shape
413
+ self.Y_shape = config.Y_shape
414
+ self.n_class = config.n_class
415
+
416
+ self.coord = coord
417
+ self.threads = []
418
+ self.queue_size = queue_size
419
+
420
+ self.add_queue()
421
+ self.buffer_signal = {}
422
+ self.buffer_noise = {}
423
+ self.buffer_channels_signal = {}
424
+ self.buffer_channels_noise = {}
425
+
426
+ def add_queue(self):
427
+ self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
428
+ self.target_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
429
+ self.ratio_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
430
+ self.signal_placeholder = tf.compat.v1.placeholder(dtype=tf.complex64, shape=None)
431
+ self.noise_placeholder = tf.compat.v1.placeholder(dtype=tf.complex64, shape=None)
432
+ self.fname_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=None)
433
+ self.queue = tf.queue.PaddingFIFOQueue(
434
+ self.queue_size,
435
+ ['float32', 'float32', 'float32', 'complex64', 'complex64', 'string'],
436
+ shapes=[
437
+ self.config.X_shape,
438
+ self.config.Y_shape,
439
+ [],
440
+ self.config.signal_shape,
441
+ self.config.noise_shape,
442
+ [],
443
+ ],
444
+ )
445
+ self.enqueue = self.queue.enqueue(
446
+ [
447
+ self.sample_placeholder,
448
+ self.target_placeholder,
449
+ self.ratio_placeholder,
450
+ self.signal_placeholder,
451
+ self.noise_placeholder,
452
+ self.fname_placeholder,
453
+ ]
454
+ )
455
+ return 0
456
+
457
+ def dequeue(self, num_elements):
458
+ output = self.queue.dequeue_up_to(num_elements)
459
+ return output
460
+
461
+ def thread_main(self, sess, n_threads=1, start=0):
462
+ index = list(range(start, self.n_signal, n_threads))
463
+ for i in index:
464
+ np.random.seed(i)
465
+
466
+ fname = self.signal.iloc[i]['fname']
467
+ fname_signal = os.path.join(self.signal_dir, fname)
468
+ meta = np.load(fname_signal)
469
+ data_FT = []
470
+ snr = []
471
+ for j in range(3):
472
+ tmp_data = meta['data'][..., j]
473
+ tmp_itp = meta['itp']
474
+ snr.append(self.get_snr(tmp_data, tmp_itp))
475
+ tmp_data -= np.mean(tmp_data)
476
+ f, t, tmp_FT = scipy.signal.stft(
477
+ tmp_data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
478
+ )
479
+ data_FT.append(tmp_FT)
480
+ data_FT = np.stack(data_FT, axis=-1)
481
+ meta_signal = {'data_FT': data_FT, 'itp': tmp_itp, 'channels': meta['channels'], 'snr': snr}
482
+ channels = meta['channels'].tolist()
483
+ start_tp = meta['itp'].tolist()
484
+
485
+ if channels not in self.buffer_channels_noise:
486
+ self.buffer_channels_noise[channels] = self.noise[self.noise['channels'] == channels]
487
+ fname_noise = os.path.join(
488
+ self.noise_dir, self.buffer_channels_noise[channels].sample(n=1, random_state=i).iloc[0]['fname']
489
+ )
490
+ meta = np.load(fname_noise)
491
+ data_FT = []
492
+ for i in range(3):
493
+ tmp_data = meta['data'][: self.config.nt, i]
494
+ tmp_data -= np.mean(tmp_data)
495
+ f, t, tmp_FT = scipy.signal.stft(
496
+ tmp_data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
497
+ )
498
+ data_FT.append(tmp_FT)
499
+ data_FT = np.stack(data_FT, axis=-1)
500
+ meta_noise = {'data_FT': data_FT, 'channels': meta['channels']}
501
+
502
+ if self.coord.should_stop():
503
+ stop = True
504
+ break
505
+
506
+ j = np.random.choice([0, 1, 2])
507
+ tmp_noise = meta_noise['data_FT'][..., j]
508
+ if np.isinf(tmp_noise).any() or np.isnan(tmp_noise).any() or (not np.any(tmp_noise)):
509
+ continue
510
+ tmp_noise = tmp_noise / np.std(tmp_noise)
511
+
512
+ tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
513
+ if np.random.random() < 0.9:
514
+ shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
515
+ tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
516
+ if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
517
+ continue
518
+ tmp_signal = tmp_signal / np.std(tmp_signal)
519
+ # tmp_signal = self.add_event(tmp_signal, channels, j)
520
+ # if np.random.random() < 0.2:
521
+ # tmp_signal = np.fliplr(tmp_signal)
522
+
523
+ ratio = 0
524
+ while ratio <= 0:
525
+ ratio = self.config.noise_mean + np.random.randn() * self.config.noise_std
526
+ tmp_noisy_signal = tmp_signal + ratio * tmp_noise
527
+ noisy_signal = np.stack([tmp_noisy_signal.real, tmp_noisy_signal.imag], axis=-1)
528
+ if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any():
529
+ continue
530
+ std_noisy_signal = np.std(noisy_signal)
531
+ noisy_signal = noisy_signal / std_noisy_signal
532
+ tmp_mask = np.abs(tmp_signal) / (np.abs(tmp_signal) + np.abs(ratio * tmp_noise) + 1e-4)
533
+ tmp_mask[tmp_mask >= 1] = 1
534
+ tmp_mask[tmp_mask <= 0] = 0
535
+ mask = np.zeros([tmp_mask.shape[0], tmp_mask.shape[1], self.n_class])
536
+ mask[:, :, 0] = tmp_mask
537
+ mask[:, :, 1] = 1 - tmp_mask
538
+
539
+ sess.run(
540
+ self.enqueue,
541
+ feed_dict={
542
+ self.sample_placeholder: noisy_signal,
543
+ self.target_placeholder: mask,
544
+ self.ratio_placeholder: std_noisy_signal,
545
+ self.signal_placeholder: tmp_signal,
546
+ self.noise_placeholder: ratio * tmp_noise,
547
+ self.fname_placeholder: fname,
548
+ },
549
+ )
550
+
551
+
552
+ class DataReader_pred_queue(DataReader):
553
+ def __init__(self, signal_dir, signal_list, queue_size, coord, config=Config()):
554
+ self.config = config
555
+ signal_list = pd.read_csv(signal_list)
556
+ self.signal = signal_list
557
+ self.n_signal = len(self.signal)
558
+ self.n_class = config.n_class
559
+ self.X_shape = config.X_shape
560
+ self.Y_shape = config.Y_shape
561
+ self.signal_dir = signal_dir
562
+
563
+ self.coord = coord
564
+ self.threads = []
565
+ self.queue_size = queue_size
566
+ self.add_placeholder()
567
+
568
+ def add_placeholder(self):
569
+ self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
570
+ self.ratio_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
571
+ self.fname_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=None)
572
+ self.queue = tf.queue.PaddingFIFOQueue(
573
+ self.queue_size, ['float32', 'float32', 'string'], shapes=[self.config.X_shape, [], []]
574
+ )
575
+ self.enqueue = self.queue.enqueue([self.sample_placeholder, self.ratio_placeholder, self.fname_placeholder])
576
+
577
+ def dequeue(self, num_elements):
578
+ output = self.queue.dequeue_up_to(num_elements)
579
+ return output
580
+
581
+ def thread_main(self, sess, n_threads=1, start=0):
582
+ index = list(range(start, self.n_signal, n_threads))
583
+ shift = 0
584
+ for i in index:
585
+ fname = self.signal.iloc[i]['fname']
586
+ data_signal = np.load(os.path.join(self.signal_dir, fname))
587
+ f, t, tmp_signal = scipy.signal.stft(
588
+ scipy.signal.detrend(np.squeeze(data_signal['data'][shift : self.config.nt + shift])),
589
+ fs=self.config.fs,
590
+ nperseg=self.config.nperseg,
591
+ nfft=self.config.nfft,
592
+ boundary='zeros',
593
+ )
594
+ noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1)
595
+ if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any() or (not np.any(noisy_signal)):
596
+ continue
597
+ std_noisy_signal = np.std(noisy_signal)
598
+ if std_noisy_signal == 0:
599
+ continue
600
+ noisy_signal = noisy_signal / std_noisy_signal
601
+ sess.run(
602
+ self.enqueue,
603
+ feed_dict={
604
+ self.sample_placeholder: noisy_signal,
605
+ self.ratio_placeholder: std_noisy_signal,
606
+ self.fname_placeholder: fname,
607
+ },
608
+ )
609
+
610
+
611
+ class DataReader_pred:
612
+ def __init__(self, signal_dir, signal_list, format="numpy", sampling_rate=100, config=Config()):
613
+ self.buffer = {}
614
+ self.config = config
615
+ self.format = format
616
+ self.dtype = "float32"
617
+ try:
618
+ signal_list = pd.read_csv(signal_list, sep="\t")["fname"]
619
+ except:
620
+ signal_list = pd.read_csv(signal_list)["fname"]
621
+ self.signal_list = signal_list
622
+ self.n_signal = len(self.signal_list)
623
+ self.signal_dir = signal_dir
624
+ self.sampling_rate = sampling_rate
625
+ self.n_class = config.n_class
626
+ FT_shape = self.get_data_shape()
627
+ self.X_shape = [*FT_shape, 2]
628
+
629
+ def get_data_shape(self):
630
+ # fname = self.signal_list.iloc[0]['fname']
631
+ # data = np.load(os.path.join(self.signal_dir, fname), allow_pickle=True)["data"]
632
+ # data = np.squeeze(data)
633
+ base_name = self.signal_list[0]
634
+ if self.format == "numpy":
635
+ meta = self.read_numpy(os.path.join(self.signal_dir, base_name))
636
+ elif self.format == "mseed":
637
+ meta = self.read_mseed(os.path.join(self.signal_dir, base_name))
638
+ elif self.format == "hdf5":
639
+ meta = self.read_hdf5(base_name)
640
+
641
+ data = meta["data"]
642
+ data = np.transpose(data, [2, 1, 0])
643
+
644
+ if self.sampling_rate != 100:
645
+ t = np.linspace(0, 1, data.shape[-1])
646
+ t_interp = np.linspace(0, 1, np.int(np.around(data.shape[-1] * 100.0 / self.sampling_rate)))
647
+ data = interp1d(t, data, kind="slinear")(t_interp)
648
+ f, t, tmp_signal = scipy.signal.stft(
649
+ data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
650
+ )
651
+ logging.info(f"Input data shape: {tmp_signal.shape} measured on file {base_name}")
652
+ return tmp_signal.shape
653
+
654
+ def __len__(self):
655
+ return self.n_signal
656
+
657
+ def read_numpy(self, fname):
658
+ # try:
659
+ if fname not in self.buffer:
660
+ npz = np.load(fname)
661
+ meta = {}
662
+ if len(npz['data'].shape) == 1:
663
+ meta["data"] = npz['data'][:, np.newaxis, np.newaxis]
664
+ elif len(npz['data'].shape) == 2:
665
+ meta["data"] = npz['data'][:, np.newaxis, :]
666
+ else:
667
+ meta["data"] = npz['data']
668
+ if "p_idx" in npz.files:
669
+ if len(npz["p_idx"].shape) == 0:
670
+ meta["itp"] = [[npz["p_idx"]]]
671
+ else:
672
+ meta["itp"] = npz["p_idx"]
673
+ if "s_idx" in npz.files:
674
+ if len(npz["s_idx"].shape) == 0:
675
+ meta["its"] = [[npz["s_idx"]]]
676
+ else:
677
+ meta["its"] = npz["s_idx"]
678
+ if "t0" in npz.files:
679
+ meta["t0"] = npz["t0"]
680
+ self.buffer[fname] = meta
681
+ else:
682
+ meta = self.buffer[fname]
683
+ return meta
684
+ # except:
685
+ # logging.error("Failed reading {}".format(fname))
686
+ # return None
687
+
688
+ def read_hdf5(self, fname):
689
+ data = self.h5_data[fname][()]
690
+ attrs = self.h5_data[fname].attrs
691
+ meta = {}
692
+ if len(data.shape) == 2:
693
+ meta["data"] = data[:, np.newaxis, :]
694
+ else:
695
+ meta["data"] = data
696
+ if "p_idx" in attrs:
697
+ if len(attrs["p_idx"].shape) == 0:
698
+ meta["itp"] = [[attrs["p_idx"]]]
699
+ else:
700
+ meta["itp"] = attrs["p_idx"]
701
+ if "s_idx" in attrs:
702
+ if len(attrs["s_idx"].shape) == 0:
703
+ meta["its"] = [[attrs["s_idx"]]]
704
+ else:
705
+ meta["its"] = attrs["s_idx"]
706
+ if "t0" in attrs:
707
+ meta["t0"] = attrs["t0"]
708
+ return meta
709
+
710
+ def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
711
+ with self.s3fs.open(bucket + "/" + fname, 'rb') as fp:
712
+ if format == "numpy":
713
+ meta = self.read_numpy(fp)
714
+ elif format == "mseed":
715
+ meta = self.read_mseed(fp)
716
+ else:
717
+ raise (f"Format {format} not supported")
718
+ return meta
719
+
720
+ def read_mseed(self, fname):
721
+
722
+ mseed = obspy.read(fname)
723
+ mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
724
+ mseed = mseed.merge(fill_value=0)
725
+ starttime = min([st.stats.starttime for st in mseed])
726
+ endtime = max([st.stats.endtime for st in mseed])
727
+ mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
728
+ if mseed[0].stats.sampling_rate != self.sampling_rate:
729
+ logging.warning(f"Sampling rate {mseed[0].stats.sampling_rate} != {self.sampling_rate} Hz")
730
+
731
+ order = ['3', '2', '1', 'E', 'N', 'Z']
732
+ order = {key: i for i, key in enumerate(order)}
733
+ comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
734
+
735
+ t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
736
+ nt = len(mseed[0].data)
737
+ data = np.zeros([nt, 3], dtype=self.dtype)
738
+ ids = [x.get_id() for x in mseed]
739
+ if len(ids) == 3:
740
+ for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
741
+ data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
742
+ else:
743
+ if len(ids) > 3:
744
+ logging.warning(f"More than 3 channels {ids}!")
745
+ for jj, id in enumerate(ids):
746
+ j = comp2idx[id[-1]]
747
+ data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
748
+
749
+ data = data[:, np.newaxis, :]
750
+ meta = {"data": data, "t0": t0}
751
+ return meta
752
+
753
+ def __getitem__(self, i):
754
+ # fname = self.signal.iloc[i]['fname']
755
+ # data = np.load(os.path.join(self.signal_dir, fname), allow_pickle=True)["data"]
756
+ # data = np.squeeze(data)
757
+ base_name = self.signal_list[i]
758
+
759
+ if self.format == "numpy":
760
+ meta = self.read_numpy(os.path.join(self.signal_dir, base_name))
761
+ elif self.format == "mseed":
762
+ meta = self.read_mseed(os.path.join(self.signal_dir, base_name))
763
+ elif self.format == "hdf5":
764
+ meta = self.read_hdf5(base_name)
765
+
766
+ data = meta["data"] # nt, 1, nch
767
+ data = np.transpose(data, [2, 1, 0]) # nch, 1, nt
768
+ if np.mod(data.shape[-1], 3000) == 1: # 3001=>3000
769
+ data = data[..., :-1]
770
+ if "t0" in meta:
771
+ t0 = meta["t0"]
772
+ else:
773
+ t0 = "1970-01-01T00:00:00.000"
774
+
775
+ if self.sampling_rate != 100:
776
+ logging.warning(f"Resample from {self.sampling_rate} to 100!")
777
+ t = np.linspace(0, 1, data.shape[-1])
778
+ t_interp = np.linspace(0, 1, np.int(np.around(data.shape[-1] * 100.0 / self.sampling_rate)))
779
+ data = interp1d(t, data, kind="slinear")(t_interp)
780
+ # sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence
781
+ # data = scipy.signal.sosfilt(sos, data)
782
+ f, t, tmp_signal = scipy.signal.stft(
783
+ data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
784
+ ) # nch, 1, nf, nt
785
+ noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # nch, 1, nf, nt, 2
786
+ noisy_signal[np.isnan(noisy_signal)] = 0
787
+ noisy_signal[np.isinf(noisy_signal)] = 0
788
+ # noisy_signal, std_noisy_signal = normalize(noisy_signal)
789
+ # return noisy_signal.astype(self.dtype), std_noisy_signal.astype(self.dtype), fname
790
+
791
+ return noisy_signal.astype(self.dtype), base_name, t0
792
+
793
+ def dataset(self, batch_size, num_parallel_calls=4):
794
+ dataset = dataset_map(
795
+ self,
796
+ output_types=(self.dtype, "string", "string"),
797
+ output_shapes=(self.X_shape, None, None),
798
+ num_parallel_calls=num_parallel_calls,
799
+ )
800
+ dataset = tf.compat.v1.data.make_one_shot_iterator(
801
+ dataset.batch(batch_size).prefetch(batch_size * 3)
802
+ ).get_next()
803
+ return dataset
804
+
805
+
806
+ if __name__ == "__main__":
807
+
808
+ # %%
809
+ data_reader = DataReader_pred(signal_dir="./Dataset/yixiao/", signal_list="./Dataset/yixiao.csv")
810
+ noisy_signal, std_noisy_signal, fname = data_reader[0]
811
+ print(noisy_signal.shape, std_noisy_signal.shape, fname)
812
+ batch = data_reader.dataset(10)
813
+ init = tf.compat.v1.initialize_all_variables()
814
+ sess = tf.compat.v1.Session()
815
+ sess.run(init)
816
+ print(sess.run(batch))
deepdenoiser/model.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+
6
+ from util import *
7
+
8
+ tf.compat.v1.disable_eager_execution()
9
+
10
+
11
+ class ModelConfig:
12
+
13
+ batch_size = 20
14
+ depths = 6
15
+ filters_root = 8
16
+ kernel_size = [3, 3]
17
+ pool_size = [2, 2]
18
+ dilation_rate = [1, 1]
19
+ class_weights = [1.0, 1.0, 1.0]
20
+ loss_type = "cross_entropy"
21
+ weight_decay = 0.0
22
+ optimizer = "adam"
23
+ momentum = 0.9
24
+ learning_rate = 0.01
25
+ decay_step = 1e9
26
+ decay_rate = 0.9
27
+ drop_rate = 0.0
28
+ summary = True
29
+
30
+ X_shape = [31, 201, 2]
31
+ n_channel = X_shape[-1]
32
+ Y_shape = [31, 201, 2]
33
+ n_class = Y_shape[-1]
34
+
35
+ def __init__(self, **kwargs):
36
+ for k, v in kwargs.items():
37
+ setattr(self, k, v)
38
+
39
+ def update_args(self, args):
40
+ for k, v in vars(args).items():
41
+ setattr(self, k, v)
42
+
43
+
44
+ def crop_and_concat(net1, net2):
45
+ """
46
+ the size(net1) <= size(net2)
47
+ """
48
+ # net1_shape = net1.get_shape().as_list()
49
+ # net2_shape = net2.get_shape().as_list()
50
+ # # print(net1_shape)
51
+ # # print(net2_shape)
52
+ # # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
53
+ # offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
54
+ # size = [-1, net1_shape[1], net1_shape[2], -1]
55
+ # net2_resize = tf.slice(net2, offsets, size)
56
+ # return tf.concat([net1, net2_resize], 3)
57
+ # # else:
58
+ # # offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
59
+ # # size = [-1, net2_shape[1], net2_shape[2], -1]
60
+ # # net1_resize = tf.slice(net1, offsets, size)
61
+ # # return tf.concat([net1_resize, net2], 3)
62
+
63
+ ## dynamic shape
64
+ chn1 = net1.get_shape().as_list()[-1]
65
+ chn2 = net2.get_shape().as_list()[-1]
66
+ net1_shape = tf.shape(net1)
67
+ net2_shape = tf.shape(net2)
68
+ # print(net1_shape)
69
+ # print(net2_shape)
70
+ # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
71
+ offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
72
+ size = [-1, net1_shape[1], net1_shape[2], -1]
73
+ net2_resize = tf.slice(net2, offsets, size)
74
+
75
+ out = tf.concat([net1, net2_resize], 3)
76
+ out.set_shape([None, None, None, chn1 + chn2])
77
+ return out
78
+
79
+
80
+ def crop_only(net1, net2):
81
+ """
82
+ the size(net1) <= size(net2)
83
+ """
84
+ net1_shape = net1.get_shape().as_list()
85
+ net2_shape = net2.get_shape().as_list()
86
+ # print(net1_shape)
87
+ # print(net2_shape)
88
+ # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
89
+ offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
90
+ size = [-1, net1_shape[1], net1_shape[2], -1]
91
+ net2_resize = tf.slice(net2, offsets, size)
92
+ # return tf.concat([net1, net2_resize], 3)
93
+ return net2_resize
94
+
95
+
96
+ class UNet:
97
+ def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
98
+ self.depths = config.depths
99
+ self.filters_root = config.filters_root
100
+ self.kernel_size = config.kernel_size
101
+ self.dilation_rate = config.dilation_rate
102
+ self.pool_size = config.pool_size
103
+ self.X_shape = config.X_shape
104
+ self.Y_shape = config.Y_shape
105
+ self.n_channel = config.n_channel
106
+ self.n_class = config.n_class
107
+ self.class_weights = config.class_weights
108
+ self.batch_size = config.batch_size
109
+ self.loss_type = config.loss_type
110
+ self.weight_decay = config.weight_decay
111
+ self.optimizer = config.optimizer
112
+ self.decay_step = config.decay_step
113
+ self.decay_rate = config.decay_rate
114
+ self.momentum = config.momentum
115
+ self.learning_rate = config.learning_rate
116
+ self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
117
+ self.summary_train = []
118
+ self.summary_valid = []
119
+
120
+ self.build(input_batch, mode=mode)
121
+
122
+ def add_placeholders(self, input_batch=None, mode='train'):
123
+ if input_batch is None:
124
+ self.X = tf.compat.v1.placeholder(
125
+ dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X'
126
+ )
127
+ self.Y = tf.compat.v1.placeholder(
128
+ dtype=tf.float32, shape=[None, None, None, self.n_class], name='y'
129
+ )
130
+ else:
131
+ self.X = input_batch[0]
132
+ if mode in ["train", "valid", "test"]:
133
+ self.Y = input_batch[1]
134
+ self.input_batch = input_batch
135
+
136
+ self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
137
+ # self.keep_prob = tf.placeholder(dtype=tf.float32, name="keep_prob")
138
+ self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
139
+ # self.learning_rate = tf.placeholder_with_default(tf.constant(0.01, dtype=tf.float32), shape=[], name="learning_rate")
140
+ # self.global_step = tf.placeholder_with_default(tf.constant(0, dtype=tf.int32), shape=[], name="global_step")
141
+
142
+ def add_prediction_op(self):
143
+ logging.info(
144
+ "Model: depths {depths}, filters {filters}, "
145
+ "filter size {kernel_size[0]}x{kernel_size[1]}, "
146
+ "pool size: {pool_size[0]}x{pool_size[1]}, "
147
+ "dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
148
+ depths=self.depths,
149
+ filters=self.filters_root,
150
+ kernel_size=self.kernel_size,
151
+ dilation_rate=self.dilation_rate,
152
+ pool_size=self.pool_size,
153
+ )
154
+ )
155
+
156
+ if self.weight_decay > 0:
157
+ weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
158
+ self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
159
+ else:
160
+ self.regularizer = None
161
+
162
+ self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(
163
+ scale=1.0, mode="fan_avg", distribution="uniform"
164
+ )
165
+
166
+ # down sample layers
167
+ convs = [None] * self.depths # store output of each depth
168
+
169
+ with tf.compat.v1.variable_scope("Input"):
170
+ net = self.X
171
+ net = tf.compat.v1.layers.conv2d(
172
+ net,
173
+ filters=self.filters_root,
174
+ kernel_size=self.kernel_size,
175
+ activation=None,
176
+ use_bias=False,
177
+ padding='same',
178
+ dilation_rate=self.dilation_rate,
179
+ kernel_initializer=self.initializer,
180
+ kernel_regularizer=self.regularizer,
181
+ # bias_regularizer=self.regularizer,
182
+ name="input_conv",
183
+ )
184
+ net = tf.compat.v1.layers.batch_normalization(net, training=self.is_training, name="input_bn")
185
+ net = tf.nn.relu(net, name="input_relu")
186
+ # net = tf.nn.dropout(net, self.keep_prob)
187
+ net = tf.compat.v1.layers.dropout(net, rate=self.drop_rate, training=self.is_training, name="input_dropout")
188
+
189
+ for depth in range(0, self.depths):
190
+ with tf.compat.v1.variable_scope("DownConv_%d" % depth):
191
+ filters = int(2 ** (depth) * self.filters_root)
192
+
193
+ net = tf.compat.v1.layers.conv2d(
194
+ net,
195
+ filters=filters,
196
+ kernel_size=self.kernel_size,
197
+ activation=None,
198
+ use_bias=False,
199
+ padding='same',
200
+ dilation_rate=self.dilation_rate,
201
+ kernel_initializer=self.initializer,
202
+ kernel_regularizer=self.regularizer,
203
+ # bias_regularizer=self.regularizer,
204
+ name="down_conv1_{}".format(depth + 1),
205
+ )
206
+ net = tf.compat.v1.layers.batch_normalization(
207
+ net, training=self.is_training, name="down_bn1_{}".format(depth + 1)
208
+ )
209
+ net = tf.nn.relu(net, name="down_relu1_{}".format(depth + 1))
210
+ net = tf.compat.v1.layers.dropout(
211
+ net, rate=self.drop_rate, training=self.is_training, name="down_dropout1_{}".format(depth + 1)
212
+ )
213
+
214
+ convs[depth] = net
215
+
216
+ if depth < self.depths - 1:
217
+ net = tf.compat.v1.layers.conv2d(
218
+ net,
219
+ filters=filters,
220
+ kernel_size=self.kernel_size,
221
+ strides=self.pool_size,
222
+ activation=None,
223
+ use_bias=False,
224
+ padding='same',
225
+ # dilation_rate=self.dilation_rate,
226
+ kernel_initializer=self.initializer,
227
+ kernel_regularizer=self.regularizer,
228
+ # bias_regularizer=self.regularizer,
229
+ name="down_conv3_{}".format(depth + 1),
230
+ )
231
+ net = tf.compat.v1.layers.batch_normalization(
232
+ net, training=self.is_training, name="down_bn3_{}".format(depth + 1)
233
+ )
234
+ net = tf.nn.relu(net, name="down_relu3_{}".format(depth + 1))
235
+ net = tf.compat.v1.layers.dropout(
236
+ net, rate=self.drop_rate, training=self.is_training, name="down_dropout3_{}".format(depth + 1)
237
+ )
238
+
239
+ # up layers
240
+ for depth in range(self.depths - 2, -1, -1):
241
+ with tf.compat.v1.variable_scope("UpConv_%d" % depth):
242
+ filters = int(2 ** (depth) * self.filters_root)
243
+ net = tf.compat.v1.layers.conv2d_transpose(
244
+ net,
245
+ filters=filters,
246
+ kernel_size=self.kernel_size,
247
+ strides=self.pool_size,
248
+ activation=None,
249
+ use_bias=False,
250
+ padding="same",
251
+ kernel_initializer=self.initializer,
252
+ kernel_regularizer=self.regularizer,
253
+ # bias_regularizer=self.regularizer,
254
+ name="up_conv0_{}".format(depth + 1),
255
+ )
256
+ net = tf.compat.v1.layers.batch_normalization(
257
+ net, training=self.is_training, name="up_bn0_{}".format(depth + 1)
258
+ )
259
+ net = tf.nn.relu(net, name="up_relu0_{}".format(depth + 1))
260
+ net = tf.compat.v1.layers.dropout(
261
+ net, rate=self.drop_rate, training=self.is_training, name="up_dropout0_{}".format(depth + 1)
262
+ )
263
+
264
+ # skip connection
265
+ net = crop_and_concat(convs[depth], net)
266
+ # net = crop_only(convs[depth], net)
267
+
268
+ net = tf.compat.v1.layers.conv2d(
269
+ net,
270
+ filters=filters,
271
+ kernel_size=self.kernel_size,
272
+ activation=None,
273
+ use_bias=False,
274
+ padding='same',
275
+ dilation_rate=self.dilation_rate,
276
+ kernel_initializer=self.initializer,
277
+ kernel_regularizer=self.regularizer,
278
+ # bias_regularizer=self.regularizer,
279
+ name="up_conv1_{}".format(depth + 1),
280
+ )
281
+ net = tf.compat.v1.layers.batch_normalization(
282
+ net, training=self.is_training, name="up_bn1_{}".format(depth + 1)
283
+ )
284
+ net = tf.nn.relu(net, name="up_relu1_{}".format(depth + 1))
285
+ net = tf.compat.v1.layers.dropout(
286
+ net, rate=self.drop_rate, training=self.is_training, name="up_dropout1_{}".format(depth + 1)
287
+ )
288
+
289
+ # Output Map
290
+ with tf.compat.v1.variable_scope("Output"):
291
+ net = tf.compat.v1.layers.conv2d(
292
+ net,
293
+ filters=self.n_class,
294
+ kernel_size=(1, 1),
295
+ activation=None,
296
+ use_bias=True,
297
+ padding='same',
298
+ # dilation_rate=self.dilation_rate,
299
+ kernel_initializer=self.initializer,
300
+ kernel_regularizer=self.regularizer,
301
+ # bias_regularizer=self.regularizer,
302
+ name="output_conv",
303
+ )
304
+ # net = tf.nn.relu(net,
305
+ # name="output_relu")
306
+ # net = tf.layers.dropout(net,
307
+ # rate=self.drop_rate,
308
+ # training=self.is_training,
309
+ # name="output_dropout")
310
+ # net = tf.layers.batch_normalization(net,
311
+ # training=self.is_training,
312
+ # name="output_bn")
313
+ output = net
314
+
315
+ with tf.compat.v1.variable_scope("representation"):
316
+ self.representation = convs[-1]
317
+
318
+ with tf.compat.v1.variable_scope("logits"):
319
+ self.logits = output
320
+ tmp = tf.compat.v1.summary.histogram("logits", self.logits)
321
+ self.summary_train.append(tmp)
322
+
323
+ with tf.compat.v1.variable_scope("preds"):
324
+ self.preds = tf.nn.softmax(output)
325
+ tmp = tf.compat.v1.summary.histogram("preds", self.preds)
326
+ self.summary_train.append(tmp)
327
+
328
+ def add_loss_op(self):
329
+ if self.loss_type == "cross_entropy":
330
+ with tf.compat.v1.variable_scope("cross_entropy"):
331
+ flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
332
+ flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
333
+ if (np.array(self.class_weights) != 1).any():
334
+ class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
335
+ weight_map = tf.multiply(flat_labels, class_weights)
336
+ weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
337
+ loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
338
+ # loss_map = tf.nn.sigmoid_cross_entropy_with_logits(logits=flat_logits,
339
+ # labels=flat_labels)
340
+ weighted_loss = tf.multiply(loss_map, weight_map)
341
+ loss = tf.reduce_mean(input_tensor=weighted_loss)
342
+ else:
343
+ loss = tf.reduce_mean(
344
+ input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
345
+ )
346
+ # loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=flat_logits,
347
+ # labels=flat_labels))
348
+ elif self.loss_type == "IOU":
349
+ with tf.compat.v1.variable_scope("IOU"):
350
+ eps = 1e-7
351
+ loss = 0
352
+ for i in range(1, self.n_class):
353
+ intersection = eps + tf.reduce_sum(
354
+ input_tensor=self.preds[:, :, :, i] * self.Y[:, :, :, i], axis=[1, 2]
355
+ )
356
+ union = (
357
+ eps
358
+ + tf.reduce_sum(input_tensor=self.preds[:, :, :, i], axis=[1, 2])
359
+ + tf.reduce_sum(input_tensor=self.Y[:, :, :, i], axis=[1, 2])
360
+ )
361
+ loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
362
+ elif self.loss_type == "mean_squared":
363
+ with tf.compat.v1.variable_scope("mean_squared"):
364
+ flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
365
+ flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
366
+ with tf.compat.v1.variable_scope("mean_squared"):
367
+ loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
368
+ else:
369
+ raise ValueError("Unknown loss function: " % self.loss_type)
370
+
371
+ tmp = tf.compat.v1.summary.scalar("train_loss", loss)
372
+ self.summary_train.append(tmp)
373
+ tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
374
+ self.summary_valid.append(tmp)
375
+
376
+ if self.weight_decay > 0:
377
+ with tf.compat.v1.name_scope('weight_loss'):
378
+ tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
379
+ weight_loss = tf.add_n(tmp, name="weight_loss")
380
+ self.loss = loss + weight_loss
381
+ else:
382
+ self.loss = loss
383
+
384
+ def add_training_op(self):
385
+ if self.optimizer == "momentum":
386
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(
387
+ learning_rate=self.learning_rate,
388
+ global_step=self.global_step,
389
+ decay_steps=self.decay_step,
390
+ decay_rate=self.decay_rate,
391
+ staircase=True,
392
+ )
393
+ optimizer = tf.compat.v1.train.MomentumOptimizer(
394
+ learning_rate=self.learning_rate_node, momentum=self.momentum
395
+ )
396
+ elif self.optimizer == "adam":
397
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(
398
+ learning_rate=self.learning_rate,
399
+ global_step=self.global_step,
400
+ decay_steps=self.decay_step,
401
+ decay_rate=self.decay_rate,
402
+ staircase=True,
403
+ )
404
+
405
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
406
+ update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
407
+ with tf.control_dependencies(update_ops):
408
+ self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
409
+ tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
410
+ self.summary_train.append(tmp)
411
+
412
+ def reset_learning_rate(self, sess, learning_rate, global_step):
413
+ self.learning_rate = learning_rate
414
+ assign_op = self.global_step.assign(global_step)
415
+ sess.run(assign_op)
416
+ if self.optimizer == "momentum":
417
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(
418
+ learning_rate=learning_rate,
419
+ global_step=self.global_step,
420
+ decay_steps=self.decay_step,
421
+ decay_rate=self.decay_rate,
422
+ staircase=True,
423
+ )
424
+ optimizer = tf.compat.v1.train.MomentumOptimizer(
425
+ learning_rate=self.learning_rate_node, momentum=self.momentum
426
+ )
427
+ elif self.optimizer == "adam":
428
+ self.learning_rate_node = tf.compat.v1.train.exponential_decay(
429
+ learning_rate=self.learning_rate,
430
+ global_step=self.global_step,
431
+ decay_steps=self.decay_step,
432
+ decay_rate=self.decay_rate,
433
+ staircase=True,
434
+ )
435
+
436
+ optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
437
+
438
+ def train_on_batch(self, sess, X_batch, Y_batch, summary_writer, drop_rate=0.0):
439
+ feed = {self.drop_rate: drop_rate, self.is_training: True, self.X: X_batch, self.Y: Y_batch}
440
+ _, step_summary, step, loss = sess.run(
441
+ [self.train_op, self.summary_train, self.global_step, self.loss], feed_dict=feed
442
+ )
443
+ summary_writer.add_summary(step_summary, step)
444
+ return loss
445
+
446
+ def valid_on_batch(self, sess, X_batch, Y_batch, summary_writer, drop_rate=0.0):
447
+ feed = {self.drop_rate: drop_rate, self.is_training: False, self.X: X_batch, self.Y: Y_batch}
448
+ step_summary, step, loss, preds = sess.run(
449
+ [self.summary_valid, self.global_step, self.loss, self.preds], feed_dict=feed
450
+ )
451
+ summary_writer.add_summary(step_summary, step)
452
+ return loss, preds
453
+
454
+ def test_on_batch(self, sess, summary_writer):
455
+ feed = {self.drop_rate: 0, self.is_training: False}
456
+ (
457
+ step_summary,
458
+ step,
459
+ loss,
460
+ preds,
461
+ X_batch,
462
+ Y_batch,
463
+ ratio_batch,
464
+ signal_batch,
465
+ noise_batch,
466
+ fname_batch,
467
+ ) = sess.run(
468
+ [
469
+ self.summary_valid,
470
+ self.global_step,
471
+ self.loss,
472
+ self.preds,
473
+ self.X,
474
+ self.Y,
475
+ self.input_batch[2],
476
+ self.input_batch[3],
477
+ self.input_batch[4],
478
+ self.input_batch[5],
479
+ ],
480
+ feed_dict=feed,
481
+ )
482
+ summary_writer.add_summary(step_summary, step)
483
+
484
+ return loss, preds, X_batch, Y_batch, ratio_batch, signal_batch, noise_batch, fname_batch
485
+
486
+ def build(self, input_batch=None, mode='train'):
487
+ self.add_placeholders(input_batch, mode)
488
+ self.add_prediction_op()
489
+ if mode in ["train", "valid", "test"]:
490
+ self.add_loss_op()
491
+ self.add_training_op()
492
+ # self.add_metrics_op()
493
+ self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
494
+ self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
495
+ return 0
deepdenoiser/predict.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import multiprocessing
4
+ import os
5
+ import time
6
+ from functools import partial
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tqdm import tqdm
11
+
12
+ from data_reader import DataReader_pred, normalize_batch
13
+ from model import UNet
14
+ from util import *
15
+
16
+ tf.compat.v1.disable_eager_execution()
17
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
18
+
19
+
20
+ def read_args():
21
+ """Returns args"""
22
+
23
+ parser = argparse.ArgumentParser()
24
+
25
+ parser.add_argument("--format", default="numpy", type=str, help="Input data format: numpy or mseed")
26
+ parser.add_argument("--batch_size", default=20, type=int, help="Batch size")
27
+ parser.add_argument("--output_dir", default="output", help="Output directory (default: output)")
28
+ parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
29
+ parser.add_argument("--sampling_rate", default=100, type=int, help="sampling rate of pred data")
30
+ parser.add_argument("--data_dir", default="./Dataset/pred/", help="Input file directory")
31
+ parser.add_argument("--data_list", default="./Dataset/pred.csv", help="Input csv file")
32
+ parser.add_argument("--plot_figure", action="store_true", help="If plot figure")
33
+ parser.add_argument("--save_signal", action="store_true", help="If save denoised signal")
34
+ parser.add_argument("--save_noise", action="store_true", help="If save denoised noise")
35
+
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None):
41
+ current_time = time.strftime("%y%m%d-%H%M%S")
42
+ if log_dir is None:
43
+ log_dir = os.path.join(args.log_dir, "pred", current_time)
44
+ logging.info("Pred log: %s" % log_dir)
45
+ # logging.info("Dataset size: {}".format(data_reader.num_data))
46
+ if not os.path.exists(log_dir):
47
+ os.makedirs(log_dir)
48
+ if args.plot_figure:
49
+ figure_dir = os.path.join(log_dir, 'figures')
50
+ os.makedirs(figure_dir, exist_ok=True)
51
+ if args.save_signal or args.save_noise:
52
+ result_dir = os.path.join(log_dir, 'results')
53
+ os.makedirs(result_dir, exist_ok=True)
54
+
55
+ with tf.compat.v1.name_scope('Input_Batch'):
56
+ data_batch = data_reader.dataset(args.batch_size)
57
+
58
+ # model = UNet(input_batch=data_batch, mode='pred')
59
+ model = UNet(mode='pred')
60
+ sess_config = tf.compat.v1.ConfigProto()
61
+ sess_config.gpu_options.allow_growth = True
62
+ # sess_config.log_device_placement = False
63
+
64
+ with tf.compat.v1.Session(config=sess_config) as sess:
65
+
66
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
67
+ init = tf.compat.v1.global_variables_initializer()
68
+ sess.run(init)
69
+
70
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
71
+ logging.info(f"restoring models: {latest_check_point}")
72
+ saver.restore(sess, latest_check_point)
73
+
74
+ if args.plot_figure:
75
+ num_pool = multiprocessing.cpu_count()
76
+ else:
77
+ num_pool = 2
78
+ multiprocessing.set_start_method('spawn')
79
+ pool = multiprocessing.Pool(num_pool)
80
+ for _ in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"):
81
+ X_batch, fname_batch, t0_batch = sess.run(data_batch)
82
+ nbt, nch, nst, nf, nt, nimg = X_batch.shape
83
+ X_batch_ = np.reshape(X_batch, [nbt * nch * nst, nf, nt, nimg])
84
+ X_batch_ = normalize_batch(X_batch_)
85
+ preds_batch = sess.run(
86
+ model.preds,
87
+ feed_dict={model.X: X_batch_, model.drop_rate: 0, model.is_training: False},
88
+ )
89
+ preds_batch = np.reshape(preds_batch, [nbt, nch, nst, nf, nt, preds_batch.shape[-1]])
90
+ # preds_batch, X_batch, ratio_batch, fname_batch = sess.run(
91
+ # [model.preds, data_batch[0], data_batch[1], data_batch[2]],
92
+ # feed_dict={model.drop_rate: 0, model.is_training: False},
93
+ # )
94
+
95
+ if args.save_signal or args.save_noise:
96
+ save_results(
97
+ preds_batch,
98
+ X_batch,
99
+ fname=[x.decode() for x in fname_batch],
100
+ t0=[x.decode() for x in t0_batch],
101
+ save_signal=args.save_signal,
102
+ save_noise=args.save_noise,
103
+ result_dir=result_dir,
104
+ )
105
+
106
+ if args.plot_figure:
107
+ pool.starmap(
108
+ partial(
109
+ plot_figures,
110
+ figure_dir=figure_dir,
111
+ ),
112
+ zip(preds_batch, X_batch, [x.decode() for x in fname_batch]),
113
+ )
114
+
115
+ pool.close()
116
+
117
+ return 0
118
+
119
+
120
+ def main(args):
121
+
122
+ logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
123
+
124
+ with tf.compat.v1.name_scope('create_inputs'):
125
+ data_reader = DataReader_pred(
126
+ format=args.format, signal_dir=args.data_dir, signal_list=args.data_list, sampling_rate=args.sampling_rate
127
+ )
128
+ logging.info("Dataset Size: {}".format(data_reader.n_signal))
129
+ pred_fn(args, data_reader, log_dir=args.output_dir)
130
+
131
+ return 0
132
+
133
+
134
+ if __name__ == '__main__':
135
+ args = read_args()
136
+ main(args)
deepdenoiser/train.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import warnings
2
+ #warnings.filterwarnings('ignore', category=FutureWarning)
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ tf.compat.v1.disable_eager_execution()
6
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
7
+ import argparse
8
+ import os
9
+ import time
10
+ import logging
11
+ from model import UNet
12
+ from data_reader import *
13
+ from util import *
14
+ from tqdm import tqdm
15
+ import multiprocessing
16
+ from functools import partial
17
+
18
+
19
+ def read_args():
20
+ """Returns args"""
21
+
22
+ parser = argparse.ArgumentParser()
23
+
24
+ parser.add_argument("--mode",
25
+ default="train",
26
+ help="train/valid/test/debug (default: train)")
27
+
28
+ parser.add_argument("--epochs",
29
+ default=10,
30
+ type=int,
31
+ help="Number of epochs (default: 10)")
32
+
33
+ parser.add_argument("--batch_size",
34
+ default=20,
35
+ type=int,
36
+ help="Batch size (default: 20)")
37
+
38
+ parser.add_argument("--learning_rate",
39
+ default=0.001,
40
+ type=float,
41
+ help="learning rate (default: 0.001)")
42
+
43
+ parser.add_argument("--decay_step",
44
+ default=-1,
45
+ type=int,
46
+ help="decay step (default: -1)")
47
+
48
+ parser.add_argument("--decay_rate",
49
+ default=0.9,
50
+ type=float,
51
+ help="decay rate (default: 0.9)")
52
+
53
+ parser.add_argument("--momentum",
54
+ default=0.9,
55
+ type=float,
56
+ help="momentum (default: 0.9)")
57
+
58
+ parser.add_argument("--filters_root",
59
+ default=8,
60
+ type=int,
61
+ help="filters root (default: 8)")
62
+
63
+ parser.add_argument("--depth",
64
+ default=6,
65
+ type=int,
66
+ help="depth (default: 6)")
67
+
68
+ parser.add_argument("--kernel_size",
69
+ nargs="+",
70
+ type=int,
71
+ default=[3, 3],
72
+ help="kernel size (default: [3, 3]")
73
+
74
+ parser.add_argument("--pool_size",
75
+ nargs="+",
76
+ type=int,
77
+ default=[2, 2],
78
+ help="pool size (default: [2, 2]")
79
+
80
+ parser.add_argument("--drop_rate",
81
+ default=0,
82
+ type=float,
83
+ help="drop out rate (default: 0)")
84
+
85
+ parser.add_argument("--dilation_rate",
86
+ nargs="+",
87
+ type=int,
88
+ default=[1, 1],
89
+ help="dilation_rate (default: [1, 1]")
90
+
91
+ parser.add_argument("--loss_type",
92
+ default="cross_entropy",
93
+ help="loss type: cross_entropy, IOU, mean_squared (default: cross_entropy)")
94
+
95
+ parser.add_argument("--weight_decay",
96
+ default=0,
97
+ type=float,
98
+ help="weight decay (default: 0)")
99
+
100
+ parser.add_argument("--optimizer",
101
+ default="adam",
102
+ help="optimizer: adam, momentum (default: adam)")
103
+
104
+ parser.add_argument("--summary",
105
+ default=True,
106
+ type=bool,
107
+ help="summary (default: True)")
108
+
109
+ parser.add_argument("--class_weights",
110
+ nargs="+",
111
+ default=[1, 1],
112
+ type=float,
113
+ help="class weights (default: [1, 1]")
114
+
115
+ parser.add_argument("--log_dir",
116
+ default="log",
117
+ help="Tensorboard log directory (default: log)")
118
+
119
+ parser.add_argument("--model_dir",
120
+ default=None,
121
+ help="Checkpoint directory")
122
+
123
+ parser.add_argument("--num_plots",
124
+ default=10,
125
+ type=int,
126
+ help="plotting trainning result (default: 10)")
127
+
128
+ parser.add_argument("--input_length",
129
+ default=None,
130
+ type=int,
131
+ help="input length")
132
+ parser.add_argument("--sampling_rate",
133
+ default=100,
134
+ type=int,
135
+ help="sampling rate of pred data in Hz (default: 100)")
136
+
137
+ parser.add_argument("--train_signal_dir",
138
+ default="./Dataset/train/",
139
+ help="Input file directory (default: ./Dataset/train/)")
140
+ parser.add_argument("--train_signal_list",
141
+ default="./Dataset/train.csv",
142
+ help="Input csv file (default: ./Dataset/train.csv)")
143
+ parser.add_argument("--train_noise_dir",
144
+ default="./Dataset/train/",
145
+ help="Input file directory (default: ./Dataset/train/)")
146
+ parser.add_argument("--train_noise_list",
147
+ default="./Dataset/train.csv",
148
+ help="Input csv file (default: ./Dataset/train.csv)")
149
+
150
+ parser.add_argument("--valid_signal_dir",
151
+ default="./Dataset/",
152
+ help="Input file directory (default: ./Dataset/)")
153
+ parser.add_argument("--valid_signal_list",
154
+ default=None,
155
+ help="Input csv file")
156
+ parser.add_argument("--valid_noise_dir",
157
+ default="./Dataset/",
158
+ help="Input file directory (default: ./Dataset/)")
159
+ parser.add_argument("--valid_noise_list",
160
+ default=None,
161
+ help="Input csv file")
162
+
163
+ parser.add_argument("--data_dir",
164
+ default="./Dataset/pred/",
165
+ help="Input file directory (default: ./Dataset/pred/)")
166
+ parser.add_argument("--data_list",
167
+ default="./Dataset/pred.csv",
168
+ help="Input csv file (default: ./Dataset/pred.csv)")
169
+
170
+ parser.add_argument("--output_dir",
171
+ default=None,
172
+ help="Output directory")
173
+
174
+ parser.add_argument("--fpred",
175
+ default="preds.npz",
176
+ help="ouput file name of test data")
177
+ parser.add_argument("--plot_figure",
178
+ action="store_true",
179
+ help="If plot figure for test")
180
+ parser.add_argument("--save_result",
181
+ action="store_true",
182
+ help="If save result for test")
183
+
184
+ args = parser.parse_args()
185
+ return args
186
+
187
+
188
+ def set_config(args, data_reader):
189
+ config = Config()
190
+
191
+ config.X_shape = data_reader.X_shape
192
+ config.n_channel = config.X_shape[-1]
193
+ config.Y_shape = data_reader.Y_shape
194
+ config.n_class = config.Y_shape[-1]
195
+
196
+ config.depths = args.depth
197
+ config.filters_root = args.filters_root
198
+ config.kernel_size = args.kernel_size
199
+ config.pool_size = args.pool_size
200
+ config.dilation_rate = args.dilation_rate
201
+ config.batch_size = args.batch_size
202
+ config.class_weights = args.class_weights
203
+ config.loss_type = args.loss_type
204
+ config.weight_decay = args.weight_decay
205
+ config.optimizer = args.optimizer
206
+
207
+ config.learning_rate = args.learning_rate
208
+ if (args.decay_step == -1) and (args.mode == 'train'):
209
+ config.decay_step = data_reader.n_signal // args.batch_size
210
+ else:
211
+ config.decay_step = args.decay_step
212
+ config.decay_rate = args.decay_rate
213
+ config.momentum = args.momentum
214
+
215
+ config.summary = args.summary
216
+ config.drop_rate = args.drop_rate
217
+ config.class_weights = args.class_weights
218
+
219
+ return config
220
+
221
+
222
+ def train_fn(args, data_reader, data_reader_valid=None):
223
+ current_time = time.strftime("%y%m%d-%H%M%S")
224
+ log_dir = os.path.join(args.log_dir, current_time)
225
+ logging.info("Training log: {}".format(log_dir))
226
+ if not os.path.exists(log_dir):
227
+ os.makedirs(log_dir)
228
+ figure_dir = os.path.join(log_dir, 'figures')
229
+ if not os.path.exists(figure_dir):
230
+ os.makedirs(figure_dir)
231
+
232
+ config = set_config(args, data_reader)
233
+ with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
234
+ fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
235
+
236
+ with tf.compat.v1.name_scope('Input_Batch'):
237
+ batch = data_reader.dequeue(args.batch_size)
238
+ if data_reader_valid is not None:
239
+ batch_valid = data_reader_valid.dequeue(args.batch_size)
240
+
241
+ model = UNet(config)
242
+ sess_config = tf.compat.v1.ConfigProto()
243
+ sess_config.gpu_options.allow_growth = True
244
+ sess_config.log_device_placement = False
245
+
246
+ with tf.compat.v1.Session(config=sess_config) as sess:
247
+
248
+ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
249
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
250
+ init = tf.compat.v1.global_variables_initializer()
251
+ sess.run(init)
252
+
253
+ if args.model_dir is not None:
254
+ logging.info("restoring models...")
255
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
256
+ saver.restore(sess, latest_check_point)
257
+ model.reset_learning_rate(sess, learning_rate=0.01, global_step=0)
258
+
259
+
260
+ threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
261
+ if data_reader_valid is not None:
262
+ threads_valid = data_reader_valid.start_threads(sess, n_threads=multiprocessing.cpu_count())
263
+ flog = open(os.path.join(log_dir, 'loss.log'), 'w')
264
+
265
+ total_step = 0
266
+ mean_loss = 0
267
+ pool = multiprocessing.Pool(2)
268
+ for epoch in range(args.epochs):
269
+ progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc="{}: ".format(log_dir.split("/")[-1]))
270
+ for step in progressbar:
271
+ X_batch, Y_batch = sess.run(batch)
272
+ loss_batch = model.train_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate)
273
+ if epoch < 1:
274
+ mean_loss = loss_batch
275
+ else:
276
+ total_step += 1
277
+ mean_loss += (loss_batch-mean_loss)/total_step
278
+ progressbar.set_description("{}: epoch={}, loss={:.6f}, mean loss={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, mean_loss))
279
+ flog.write("Epoch: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss))
280
+ saver.save(sess, os.path.join(log_dir, "model_{}.ckpt".format(epoch)))
281
+
282
+ ## valid
283
+ if data_reader_valid is not None:
284
+ mean_loss_valid = 0
285
+ total_step_valid = 0
286
+ progressbar = tqdm(range(0, data_reader_valid.n_signal, args.batch_size), desc="Valid: ")
287
+ for step in progressbar:
288
+ X_batch, Y_batch = sess.run(batch_valid)
289
+ loss_batch, preds_batch = model.valid_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate)
290
+ total_step_valid += 1
291
+ mean_loss_valid += (loss_batch-mean_loss_valid)/total_step_valid
292
+ progressbar.set_description("Valid: loss={:.6f}, mean loss={:.6f}".format(loss_batch, mean_loss_valid))
293
+ flog.write("Valid: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss_valid))
294
+
295
+ # plot_result(epoch, args.num_plots, figure_dir, preds_batch, X_batch, Y_batch)
296
+ pool.map(partial(plot_result_thread,
297
+ epoch = epoch,
298
+ preds = preds_batch,
299
+ X = X_batch,
300
+ Y = Y_batch,
301
+ figure_dir = figure_dir),
302
+ range(args.num_plots))
303
+
304
+ flog.close()
305
+ pool.close()
306
+ data_reader.coord.request_stop()
307
+ if data_reader_valid is not None:
308
+ data_reader_valid.coord.request_stop()
309
+ try:
310
+ data_reader.coord.join(threads, stop_grace_period_secs=10, ignore_live_threads=True)
311
+ if data_reader_valid is not None:
312
+ data_reader_valid.coord.join(threads_valid, stop_grace_period_secs=10, ignore_live_threads=True)
313
+ except:
314
+ pass
315
+ sess.run(data_reader.queue.close(cancel_pending_enqueues=True))
316
+ if data_reader_valid is not None:
317
+ sess.run(data_reader_valid.queue.close(cancel_pending_enqueues=True))
318
+ return 0
319
+
320
+
321
+ def test_fn(args, data_reader, figure_dir=None, result_dir=None):
322
+ current_time = time.strftime("%y%m%d-%H%M%S")
323
+ log_dir = os.path.join(args.log_dir, args.mode, current_time)
324
+ logging.info("{} log: {}".format(args.mode, log_dir))
325
+ if not os.path.exists(log_dir):
326
+ os.makedirs(log_dir)
327
+ if (args.plot_figure == True) and (figure_dir is None):
328
+ figure_dir = os.path.join(log_dir, 'figures')
329
+ if not os.path.exists(figure_dir):
330
+ os.makedirs(figure_dir)
331
+ if (args.save_result == True) and (result_dir is None):
332
+ result_dir = os.path.join(log_dir, 'results')
333
+ if not os.path.exists(result_dir):
334
+ os.makedirs(result_dir)
335
+
336
+ config = set_config(args, data_reader)
337
+ with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
338
+ fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
339
+
340
+ with tf.compat.v1.name_scope('Input_Batch'):
341
+ batch = data_reader.dequeue(args.batch_size)
342
+
343
+ model = UNet(config, input_batch=batch, mode='test')
344
+ sess_config = tf.compat.v1.ConfigProto()
345
+ sess_config.gpu_options.allow_growth = True
346
+ sess_config.log_device_placement = False
347
+
348
+ with tf.compat.v1.Session(config=sess_config) as sess:
349
+
350
+ summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
351
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
352
+ init = tf.compat.v1.global_variables_initializer()
353
+ sess.run(init)
354
+
355
+ logging.info("restoring models...")
356
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
357
+ saver.restore(sess, latest_check_point)
358
+
359
+ threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
360
+
361
+ flog = open(os.path.join(log_dir, 'loss.log'), 'w')
362
+ total_step = 0
363
+ mean_loss = 0
364
+ progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc=args.mode)
365
+ if args.plot_figure:
366
+ num_pool = multiprocessing.cpu_count()*2
367
+ elif args.save_result:
368
+ num_pool = multiprocessing.cpu_count()
369
+ else:
370
+ num_pool = 2
371
+ pool = multiprocessing.Pool(num_pool)
372
+ for step in progressbar:
373
+
374
+ if step + args.batch_size >= data_reader.n_signal:
375
+ for t in threads:
376
+ t.join()
377
+ sess.run(data_reader.queue.close())
378
+
379
+ loss_batch, preds_batch, X_batch, Y_batch, ratio_batch, \
380
+ signal_batch, noise_batch, fname_batch = model.test_on_batch(sess, summary_writer)
381
+ total_step += 1
382
+ mean_loss += (loss_batch-mean_loss)/total_step
383
+ progressbar.set_description("{}: loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, mean_loss))
384
+ flog.write("step: {}, loss: {}\n".format(step, loss_batch))
385
+ flog.flush()
386
+
387
+ pool.map(partial(postprocessing_test,
388
+ preds=preds_batch,
389
+ X=X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis],
390
+ fname=fname_batch,
391
+ figure_dir=figure_dir,
392
+ result_dir=result_dir,
393
+ signal_FT=signal_batch,
394
+ noise_FT=noise_batch),
395
+ range(len(X_batch)))
396
+
397
+ flog.close()
398
+ pool.close()
399
+
400
+ return 0
401
+
402
+ def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None):
403
+ current_time = time.strftime("%y%m%d-%H%M%S")
404
+ if log_dir is None:
405
+ log_dir = os.path.join(args.log_dir, "pred", current_time)
406
+ logging.info("Pred log: %s" % log_dir)
407
+ # logging.info("Dataset size: {}".format(data_reader.num_data))
408
+ if not os.path.exists(log_dir):
409
+ os.makedirs(log_dir)
410
+ if args.plot_figure:
411
+ figure_dir = os.path.join(log_dir, 'figures')
412
+ os.makedirs(figure_dir, exist_ok=True)
413
+ if args.save_result:
414
+ result_dir = os.path.join(log_dir, 'results')
415
+ os.makedirs(result_dir, exist_ok=True)
416
+
417
+ config = set_config(args, data_reader)
418
+ with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
419
+ fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
420
+
421
+ with tf.compat.v1.name_scope('Input_Batch'):
422
+ data_batch = data_reader.dataset(args.batch_size)
423
+
424
+ # model = UNet(config, input_batch=batch, mode='pred')
425
+ model = UNet(config, mode='pred')
426
+ sess_config = tf.compat.v1.ConfigProto()
427
+ sess_config.gpu_options.allow_growth = True
428
+ #sess_config.log_device_placement = False
429
+
430
+ with tf.compat.v1.Session(config=sess_config) as sess:
431
+
432
+ saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
433
+ init = tf.compat.v1.global_variables_initializer()
434
+ sess.run(init)
435
+
436
+ logging.info("restoring models...")
437
+ latest_check_point = tf.train.latest_checkpoint(args.model_dir)
438
+ saver.restore(sess, latest_check_point)
439
+
440
+ # threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
441
+
442
+ if args.plot_figure:
443
+ num_pool = multiprocessing.cpu_count()
444
+ elif args.save_result:
445
+ num_pool = multiprocessing.cpu_count()
446
+ else:
447
+ num_pool = 2
448
+ multiprocessing.set_start_method('spawn')
449
+ pool = multiprocessing.Pool(num_pool)
450
+ for step in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"):
451
+ #if step + args.batch_size >= data_reader.n_signal:
452
+ # for t in threads:
453
+ # t.join()
454
+ # sess.run(data_reader.queue.close())
455
+ # X_batch = []
456
+ # ratio_batch = []
457
+ # fname_batch = []
458
+ # for i in range(step, min(step+args.batch_size, data_reader.n_signal)):
459
+ # X, ratio, fname = data_reader[i]
460
+ # if np.std(X) == 0:
461
+ # continue
462
+ # X_batch.append(X)
463
+ # ratio_batch.append(ratio)
464
+ # fname_batch.append(fname)
465
+ # X_batch = np.stack(X_batch, axis=0)
466
+ # ratio_batch = np.array(ratio_batch)
467
+ X_batch, ratio_batch, fname_batch = sess.run(data_batch)
468
+ preds_batch = sess.run(model.preds, feed_dict={model.X: X_batch,
469
+ model.drop_rate: 0,
470
+ model.is_training: False})
471
+ #preds_batch, X_batch, ratio_batch, fname_batch = sess.run([model.preds,
472
+ # batch[0],
473
+ # batch[1],
474
+ # batch[2]],
475
+ # feed_dict={model.drop_rate: 0,
476
+ # model.is_training: False})
477
+
478
+ pool.map(partial(postprocessing_pred,
479
+ preds = preds_batch,
480
+ X = X_batch*ratio_batch[:,np.newaxis,:,np.newaxis],
481
+ fname = [x.decode() for x in fname_batch],
482
+ figure_dir = figure_dir,
483
+ result_dir = result_dir),
484
+ range(len(X_batch)))
485
+
486
+ # for i in range(len(X_batch)):
487
+ # postprocessing_thread(i,
488
+ # preds = preds_batch,
489
+ # X = X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis],
490
+ # fname = fname_batch,
491
+ # figure_dir = figure_dir,
492
+ # result_dir = result_dir)
493
+
494
+ pool.close()
495
+
496
+ return 0
497
+
498
+ def main(args):
499
+
500
+ logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
501
+
502
+ coord = tf.train.Coordinator()
503
+
504
+ if args.mode == "train":
505
+ with tf.compat.v1.name_scope('create_inputs'):
506
+ data_reader = DataReader(
507
+ signal_dir=args.train_signal_dir,
508
+ signal_list=args.train_signal_list,
509
+ noise_dir=args.train_noise_dir,
510
+ noise_list=args.train_noise_list,
511
+ queue_size=args.batch_size*2,
512
+ coord=coord)
513
+ if (args.valid_signal_list is not None) and (args.valid_noise_list is not None):
514
+ data_reader_valid = DataReader(
515
+ signal_dir=args.valid_signal_dir,
516
+ signal_list=args.valid_signal_list,
517
+ noise_dir=args.valid_noise_dir,
518
+ noise_list=args.valid_noise_list,
519
+ queue_size=args.batch_size*2,
520
+ coord=coord)
521
+ logging.info("Dataset size: training %d, validation %d" % (data_reader.n_signal, data_reader_valid.n_signal))
522
+ else:
523
+ data_reader_valid = None
524
+ logging.info("Dataset size: training %d, validation 0" % (data_reader.n_signal))
525
+ train_fn(args, data_reader, data_reader_valid)
526
+
527
+ elif args.mode == "valid" or args.mode == "test":
528
+ with tf.compat.v1.name_scope('create_inputs'):
529
+ data_reader = DataReader_test(
530
+ signal_dir=args.valid_signal_dir,
531
+ signal_list=args.valid_signal_list,
532
+ noise_dir=args.valid_noise_dir,
533
+ noise_list=args.valid_noise_list,
534
+ queue_size=args.batch_size*2,
535
+ coord=coord)
536
+ logging.info("Dataset Size: {}".format(data_reader.n_signal))
537
+ test_fn(args, data_reader)
538
+
539
+ elif args.mode == "pred":
540
+ with tf.compat.v1.name_scope('create_inputs'):
541
+ data_reader = DataReader_pred(
542
+ signal_dir=args.data_dir,
543
+ signal_list=args.data_list,
544
+ sampling_rate=args.sampling_rate)
545
+ logging.info("Dataset Size: {}".format(data_reader.n_signal))
546
+ pred_fn(args, data_reader, log_dir=args.output_dir)
547
+
548
+ else:
549
+ print("mode should be: train, valid, test, debug or pred")
550
+
551
+ coord.request_stop()
552
+ coord.join()
553
+ return 0
554
+
555
+ if __name__ == '__main__':
556
+ args = read_args()
557
+ main(args)
deepdenoiser/util.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import scipy
7
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
8
+ from scipy import signal
9
+ from tqdm import tqdm
10
+
11
+ from data_reader import Config
12
+
13
+ matplotlib.use('agg')
14
+
15
+
16
+ def plot_result(epoch, num, figure_dir, preds, X, Y, mode="valid"):
17
+ config = Config()
18
+ for i in range(min(num, len(X))):
19
+
20
+ t, noisy_signal = scipy.signal.istft(
21
+ X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
22
+ )
23
+ t, ideal_denoised_signal = scipy.signal.istft(
24
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0],
25
+ fs=config.fs,
26
+ nperseg=config.nperseg,
27
+ nfft=config.nfft,
28
+ boundary='zeros',
29
+ )
30
+ t, denoised_signal = scipy.signal.istft(
31
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
32
+ fs=config.fs,
33
+ nperseg=config.nperseg,
34
+ nfft=config.nfft,
35
+ boundary='zeros',
36
+ )
37
+
38
+ plt.figure(i)
39
+ fig_size = plt.gcf().get_size_inches()
40
+ plt.gcf().set_size_inches(fig_size * [1.5, 1.5])
41
+ plt.subplot(4, 2, 1)
42
+ plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2)
43
+ plt.title("Noisy signal")
44
+ plt.gca().set_xticklabels([])
45
+ plt.subplot(4, 2, 2)
46
+ plt.plot(t, noisy_signal, label='Noisy signal', linewidth=0.1)
47
+ signal_ylim = plt.gca().get_ylim()
48
+ plt.gca().set_xticklabels([])
49
+ plt.legend(loc='lower left')
50
+ plt.margins(x=0)
51
+
52
+ plt.subplot(4, 2, 3)
53
+ plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1)
54
+ plt.gca().set_xticklabels([])
55
+ plt.title("Y")
56
+ plt.subplot(4, 2, 4)
57
+ plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1)
58
+ plt.title("$\hat{Y}$")
59
+ plt.gca().set_xticklabels([])
60
+
61
+ plt.subplot(4, 2, 5)
62
+ plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2)
63
+ plt.title("Ideal denoised signal")
64
+ plt.gca().set_xticklabels([])
65
+ plt.subplot(4, 2, 6)
66
+ plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2)
67
+ plt.title("Denoised signal")
68
+ plt.gca().set_xticklabels([])
69
+
70
+ plt.subplot(4, 2, 7)
71
+ plt.plot(t, ideal_denoised_signal, label='Ideal denoised signal', linewidth=0.1)
72
+ plt.ylim(signal_ylim)
73
+ plt.xlabel("Time (s)")
74
+ plt.legend(loc='lower left')
75
+ plt.margins(x=0)
76
+ plt.subplot(4, 2, 8)
77
+ plt.plot(t, denoised_signal, label='Denoised signal', linewidth=0.1)
78
+ plt.ylim(signal_ylim)
79
+ plt.xlabel("Time (s)")
80
+ plt.legend(loc='lower left')
81
+ plt.margins(x=0)
82
+
83
+ plt.tight_layout()
84
+ plt.gcf().align_labels()
85
+ plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight')
86
+ # plt.savefig(os.path.join(figure_dir, "epoch%03d_%03d.pdf" % (epoch, i)), bbox_inches='tight')
87
+ plt.close(i)
88
+ return 0
89
+
90
+
91
+ def plot_result_thread(i, epoch, preds, X, Y, figure_dir, mode="valid"):
92
+ config = Config()
93
+ t, noisy_signal = scipy.signal.istft(
94
+ X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
95
+ )
96
+ t, ideal_denoised_signal = scipy.signal.istft(
97
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0],
98
+ fs=config.fs,
99
+ nperseg=config.nperseg,
100
+ nfft=config.nfft,
101
+ boundary='zeros',
102
+ )
103
+ t, denoised_signal = scipy.signal.istft(
104
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
105
+ fs=config.fs,
106
+ nperseg=config.nperseg,
107
+ nfft=config.nfft,
108
+ boundary='zeros',
109
+ )
110
+
111
+ plt.figure(i)
112
+ fig_size = plt.gcf().get_size_inches()
113
+ plt.gcf().set_size_inches(fig_size * [1.5, 1.5])
114
+ plt.subplot(4, 2, 1)
115
+ plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2)
116
+ plt.title("Noisy signal")
117
+ plt.gca().set_xticklabels([])
118
+ plt.subplot(4, 2, 2)
119
+ plt.plot(t, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
120
+ signal_ylim = plt.gca().get_ylim()
121
+ plt.gca().set_xticklabels([])
122
+ plt.legend(loc='lower left')
123
+ plt.margins(x=0)
124
+
125
+ plt.subplot(4, 2, 3)
126
+ plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1)
127
+ plt.gca().set_xticklabels([])
128
+ plt.title("Y")
129
+ plt.subplot(4, 2, 4)
130
+ plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1)
131
+ plt.title("$\hat{Y}$")
132
+ plt.gca().set_xticklabels([])
133
+
134
+ plt.subplot(4, 2, 5)
135
+ plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2)
136
+ plt.title("Ideal denoised signal")
137
+ plt.gca().set_xticklabels([])
138
+ plt.subplot(4, 2, 6)
139
+ plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2)
140
+ plt.title("Denoised signal")
141
+ plt.gca().set_xticklabels([])
142
+
143
+ plt.subplot(4, 2, 7)
144
+ plt.plot(t, ideal_denoised_signal, 'k', label='Ideal denoised signal', linewidth=0.5)
145
+ plt.ylim(signal_ylim)
146
+ plt.xlabel("Time (s)")
147
+ plt.legend(loc='lower left')
148
+ plt.margins(x=0)
149
+ plt.subplot(4, 2, 8)
150
+ plt.plot(t, denoised_signal, 'k', label='Denoised signal', linewidth=0.5)
151
+ plt.ylim(signal_ylim)
152
+ plt.xlabel("Time (s)")
153
+ plt.legend(loc='lower left')
154
+ plt.margins(x=0)
155
+
156
+ plt.tight_layout()
157
+ plt.gcf().align_labels()
158
+ plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight')
159
+ plt.close(i)
160
+ return 0
161
+
162
+
163
+ def postprocessing_test(
164
+ i, preds, X, fname, figure_dir=None, result_dir=None, signal_FT=None, noise_FT=None, data_dir=None
165
+ ):
166
+ if (figure_dir is not None) or (result_dir is not None):
167
+ config = Config()
168
+ t1, noisy_signal = scipy.signal.istft(
169
+ X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
170
+ )
171
+ t1, denoised_signal = scipy.signal.istft(
172
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
173
+ fs=config.fs,
174
+ nperseg=config.nperseg,
175
+ nfft=config.nfft,
176
+ boundary='zeros',
177
+ )
178
+ t1, denoised_noise = scipy.signal.istft(
179
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * (1 - preds[i, :, :, 0]),
180
+ fs=config.fs,
181
+ nperseg=config.nperseg,
182
+ nfft=config.nfft,
183
+ boundary='zeros',
184
+ )
185
+ t1, signal = scipy.signal.istft(
186
+ signal_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
187
+ )
188
+ t1, noise = scipy.signal.istft(
189
+ noise_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
190
+ )
191
+
192
+ if result_dir is not None:
193
+ try:
194
+ np.savez(
195
+ os.path.join(result_dir, fname[i].decode()),
196
+ preds=preds[i],
197
+ X=X[i],
198
+ signal_FT=signal_FT[i],
199
+ noise_FT=noise_FT[i],
200
+ noisy_signal=noisy_signal,
201
+ denoised_signal=denoised_signal,
202
+ denoised_noise=denoised_noise,
203
+ signal=signal,
204
+ noise=noise,
205
+ )
206
+ except FileNotFoundError:
207
+ os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
208
+ np.savez(
209
+ os.path.join(result_dir, fname[i].decode()),
210
+ preds=preds[i],
211
+ X=X[i],
212
+ signal_FT=signal_FT[i],
213
+ noise_FT=noise_FT[i],
214
+ noisy_signal=noisy_signal,
215
+ denoised_signal=denoised_signal,
216
+ denoised_noise=denoised_noise,
217
+ signal=signal,
218
+ noise=noise,
219
+ )
220
+
221
+ if figure_dir is not None:
222
+ t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
223
+ f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])
224
+
225
+ raw_data = None
226
+ if data_dir is not None:
227
+ raw_data = np.load(os.path.join(data_dir, fname[i].decode().split('/')[-1]))
228
+ itp = raw_data['itp']
229
+ its = raw_data['its']
230
+ ix1 = (750 - 50) / 100
231
+ ix2 = (750 + (its - itp) + 50) / 100
232
+ if ix2 - ix1 > 3:
233
+ ix2 = ix1 + 3
234
+
235
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
236
+
237
+ text_loc = [0.05, 0.8]
238
+ plt.figure(i)
239
+ fig_size = plt.gcf().get_size_inches()
240
+ plt.gcf().set_size_inches(fig_size * [1, 2])
241
+ plt.subplot(511)
242
+ plt.pcolormesh(t_FT, f_FT, np.abs(signal_FT[i, :, :]), vmin=0, vmax=1)
243
+ plt.gca().set_xticklabels([])
244
+ plt.text(
245
+ text_loc[0],
246
+ text_loc[1],
247
+ '(i)',
248
+ horizontalalignment='center',
249
+ transform=plt.gca().transAxes,
250
+ fontsize="medium",
251
+ fontweight="bold",
252
+ bbox=box,
253
+ )
254
+ plt.subplot(512)
255
+ plt.pcolormesh(t_FT, f_FT, np.abs(noise_FT[i, :, :]), vmin=0, vmax=1)
256
+ plt.gca().set_xticklabels([])
257
+ plt.text(
258
+ text_loc[0],
259
+ text_loc[1],
260
+ '(ii)',
261
+ horizontalalignment='center',
262
+ transform=plt.gca().transAxes,
263
+ fontsize="medium",
264
+ fontweight="bold",
265
+ bbox=box,
266
+ )
267
+ plt.subplot(513)
268
+ plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=1)
269
+ plt.ylabel("Frequency (Hz)", fontsize='large')
270
+ plt.gca().set_xticklabels([])
271
+ plt.text(
272
+ text_loc[0],
273
+ text_loc[1],
274
+ '(iii)',
275
+ horizontalalignment='center',
276
+ transform=plt.gca().transAxes,
277
+ fontsize="medium",
278
+ fontweight="bold",
279
+ bbox=box,
280
+ )
281
+ plt.subplot(514)
282
+ plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=1)
283
+ plt.gca().set_xticklabels([])
284
+ plt.text(
285
+ text_loc[0],
286
+ text_loc[1],
287
+ '(iv)',
288
+ horizontalalignment='center',
289
+ transform=plt.gca().transAxes,
290
+ fontsize="medium",
291
+ fontweight="bold",
292
+ bbox=box,
293
+ )
294
+ plt.subplot(515)
295
+ plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], vmin=0, vmax=1)
296
+ plt.xlabel("Time (s)", fontsize='large')
297
+ plt.text(
298
+ text_loc[0],
299
+ text_loc[1],
300
+ '(v)',
301
+ horizontalalignment='center',
302
+ transform=plt.gca().transAxes,
303
+ fontsize="medium",
304
+ fontweight="bold",
305
+ bbox=box,
306
+ )
307
+
308
+ try:
309
+ plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
310
+ # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
311
+ except FileNotFoundError:
312
+ os.makedirs(
313
+ os.path.dirname(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png')), exist_ok=True
314
+ )
315
+ plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
316
+ # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
317
+ plt.close(i)
318
+
319
+ text_loc = [0.05, 0.8]
320
+ plt.figure(i)
321
+ fig_size = plt.gcf().get_size_inches()
322
+ plt.gcf().set_size_inches(fig_size * [1, 2])
323
+
324
+ ax3 = plt.subplot(513)
325
+ plt.plot(t1, noisy_signal, 'k', linewidth=0.5, label='Noisy signal')
326
+ plt.legend(loc='lower left', fontsize='medium')
327
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
328
+ plt.ylim([-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))])
329
+ signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
330
+ plt.ylim(signal_ylim)
331
+ plt.ylabel("Amplitude", fontsize='large')
332
+ plt.gca().set_xticklabels([])
333
+ plt.text(
334
+ text_loc[0],
335
+ text_loc[1],
336
+ '(iii)',
337
+ horizontalalignment='center',
338
+ transform=plt.gca().transAxes,
339
+ fontsize="medium",
340
+ fontweight="bold",
341
+ bbox=box,
342
+ )
343
+
344
+ ax1 = plt.subplot(511)
345
+ plt.plot(t1, signal, 'k', linewidth=0.5, label='Signal')
346
+ plt.legend(loc='lower left', fontsize='medium')
347
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
348
+ plt.ylim(signal_ylim)
349
+ plt.gca().set_xticklabels([])
350
+ plt.text(
351
+ text_loc[0],
352
+ text_loc[1],
353
+ '(i)',
354
+ horizontalalignment='center',
355
+ transform=plt.gca().transAxes,
356
+ fontsize="medium",
357
+ fontweight="bold",
358
+ bbox=box,
359
+ )
360
+
361
+ plt.subplot(512)
362
+ plt.plot(t1, noise, 'k', linewidth=0.5, label='Noise')
363
+ plt.legend(loc='lower left', fontsize='medium')
364
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
365
+ plt.ylim([-np.max(np.abs(noise)), np.max(np.abs(noise))])
366
+ noise_ylim = [-np.max(np.abs(noise[100:-100])), np.max(np.abs(noise[100:-100]))]
367
+ plt.ylim(noise_ylim)
368
+ plt.gca().set_xticklabels([])
369
+ plt.text(
370
+ text_loc[0],
371
+ text_loc[1],
372
+ '(ii)',
373
+ horizontalalignment='center',
374
+ transform=plt.gca().transAxes,
375
+ fontsize="medium",
376
+ fontweight="bold",
377
+ bbox=box,
378
+ )
379
+
380
+ ax4 = plt.subplot(514)
381
+ plt.plot(t1, denoised_signal, 'k', linewidth=0.5, label='Recovered signal')
382
+ plt.legend(loc='lower left', fontsize='medium')
383
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
384
+ plt.ylim(signal_ylim)
385
+ plt.gca().set_xticklabels([])
386
+ plt.text(
387
+ text_loc[0],
388
+ text_loc[1],
389
+ '(iv)',
390
+ horizontalalignment='center',
391
+ transform=plt.gca().transAxes,
392
+ fontsize="medium",
393
+ fontweight="bold",
394
+ bbox=box,
395
+ )
396
+
397
+ plt.subplot(515)
398
+ plt.plot(t1, denoised_noise, 'k', linewidth=0.5, label='Recovered noise')
399
+ plt.legend(loc='lower left', fontsize='medium')
400
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
401
+ plt.xlabel("Time (s)", fontsize='large')
402
+ plt.ylim(noise_ylim)
403
+ plt.text(
404
+ text_loc[0],
405
+ text_loc[1],
406
+ '(v)',
407
+ horizontalalignment='center',
408
+ transform=plt.gca().transAxes,
409
+ fontsize="medium",
410
+ fontweight="bold",
411
+ bbox=box,
412
+ )
413
+
414
+ if data_dir is not None:
415
+ axins = inset_axes(
416
+ ax1, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax1.transAxes
417
+ )
418
+ axins.plot(t1, signal, 'k', linewidth=0.5)
419
+ x1, x2 = ix1, ix2
420
+ y1 = -np.max(np.abs(signal[(t1 > ix1) & (t1 < ix2)]))
421
+ y2 = -y1
422
+ axins.set_xlim(x1, x2)
423
+ axins.set_ylim(y1, y2)
424
+ plt.xticks(visible=False)
425
+ plt.yticks(visible=False)
426
+ mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.5")
427
+
428
+ axins = inset_axes(
429
+ ax3, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.3), bbox_transform=ax3.transAxes
430
+ )
431
+ axins.plot(t1, noisy_signal, 'k', linewidth=0.5)
432
+ x1, x2 = ix1, ix2
433
+ axins.set_xlim(x1, x2)
434
+ axins.set_ylim(y1, y2)
435
+ plt.xticks(visible=False)
436
+ plt.yticks(visible=False)
437
+ mark_inset(ax3, axins, loc1=1, loc2=3, fc="none", ec="0.5")
438
+
439
+ axins = inset_axes(
440
+ ax4, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax4.transAxes
441
+ )
442
+ axins.plot(t1, denoised_signal, 'k', linewidth=0.5)
443
+ x1, x2 = ix1, ix2
444
+ axins.set_xlim(x1, x2)
445
+ axins.set_ylim(y1, y2)
446
+ plt.xticks(visible=False)
447
+ plt.yticks(visible=False)
448
+ mark_inset(ax4, axins, loc1=1, loc2=3, fc="none", ec="0.5")
449
+
450
+ plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_wave.png'), bbox_inches='tight')
451
+ # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
452
+ plt.close(i)
453
+
454
+ return
455
+
456
+
457
+ def postprocessing_pred(i, preds, X, fname, figure_dir=None, result_dir=None):
458
+
459
+ if (result_dir is not None) or (figure_dir is not None):
460
+ config = Config()
461
+
462
+ t1, noisy_signal = scipy.signal.istft(
463
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j),
464
+ fs=config.fs,
465
+ nperseg=config.nperseg,
466
+ nfft=config.nfft,
467
+ boundary='zeros',
468
+ )
469
+ t1, denoised_signal = scipy.signal.istft(
470
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
471
+ fs=config.fs,
472
+ nperseg=config.nperseg,
473
+ nfft=config.nfft,
474
+ boundary='zeros',
475
+ )
476
+ t1, denoised_noise = scipy.signal.istft(
477
+ (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
478
+ fs=config.fs,
479
+ nperseg=config.nperseg,
480
+ nfft=config.nfft,
481
+ boundary='zeros',
482
+ )
483
+
484
+ if result_dir is not None:
485
+ try:
486
+ np.savez(
487
+ os.path.join(result_dir, fname[i]),
488
+ noisy_signal=noisy_signal,
489
+ denoised_signal=denoised_signal,
490
+ denoised_noise=denoised_noise,
491
+ t=t1,
492
+ )
493
+ except FileNotFoundError:
494
+ os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i])))
495
+ np.savez(
496
+ os.path.join(result_dir, fname[i]),
497
+ noisy_signal=noisy_signal,
498
+ denoised_signal=denoised_signal,
499
+ denoised_noise=denoised_noise,
500
+ t=t1,
501
+ )
502
+
503
+ if figure_dir is not None:
504
+
505
+ t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
506
+ f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])
507
+
508
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
509
+ text_loc = [0.05, 0.77]
510
+
511
+ plt.figure(i)
512
+ fig_size = plt.gcf().get_size_inches()
513
+ plt.gcf().set_size_inches(fig_size * [1, 1.2])
514
+ vmax = np.std(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j)) * 1.8
515
+
516
+ plt.subplot(311)
517
+ plt.pcolormesh(
518
+ t_FT,
519
+ f_FT,
520
+ np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j),
521
+ vmin=0,
522
+ vmax=vmax,
523
+ shading='auto',
524
+ label='Noisy signal',
525
+ )
526
+ plt.gca().set_xticklabels([])
527
+ plt.text(
528
+ text_loc[0],
529
+ text_loc[1],
530
+ '(i)',
531
+ horizontalalignment='center',
532
+ transform=plt.gca().transAxes,
533
+ fontsize="medium",
534
+ fontweight="bold",
535
+ bbox=box,
536
+ )
537
+ plt.subplot(312)
538
+ plt.pcolormesh(
539
+ t_FT,
540
+ f_FT,
541
+ np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
542
+ vmin=0,
543
+ vmax=vmax,
544
+ shading='auto',
545
+ label='Recovered signal',
546
+ )
547
+ plt.gca().set_xticklabels([])
548
+ plt.ylabel("Frequency (Hz)", fontsize='large')
549
+ plt.text(
550
+ text_loc[0],
551
+ text_loc[1],
552
+ '(ii)',
553
+ horizontalalignment='center',
554
+ transform=plt.gca().transAxes,
555
+ fontsize="medium",
556
+ fontweight="bold",
557
+ bbox=box,
558
+ )
559
+ plt.subplot(313)
560
+ plt.pcolormesh(
561
+ t_FT,
562
+ f_FT,
563
+ np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
564
+ vmin=0,
565
+ vmax=vmax,
566
+ shading='auto',
567
+ label='Recovered noise',
568
+ )
569
+ plt.xlabel("Time (s)", fontsize='large')
570
+ plt.text(
571
+ text_loc[0],
572
+ text_loc[1],
573
+ '(iii)',
574
+ horizontalalignment='center',
575
+ transform=plt.gca().transAxes,
576
+ fontsize="medium",
577
+ fontweight="bold",
578
+ bbox=box,
579
+ )
580
+
581
+ try:
582
+ plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
583
+ # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
584
+ except FileNotFoundError:
585
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png')), exist_ok=True)
586
+ plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
587
+ # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
588
+ plt.close(i)
589
+
590
+ plt.figure(i)
591
+ fig_size = plt.gcf().get_size_inches()
592
+ plt.gcf().set_size_inches(fig_size * [1, 1.2])
593
+
594
+ ax4 = plt.subplot(311)
595
+ plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
596
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
597
+ signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
598
+ plt.ylim(signal_ylim)
599
+ plt.gca().set_xticklabels([])
600
+ plt.legend(loc='lower left', fontsize='medium')
601
+ plt.text(
602
+ text_loc[0],
603
+ text_loc[1],
604
+ '(i)',
605
+ horizontalalignment='center',
606
+ transform=plt.gca().transAxes,
607
+ fontsize="medium",
608
+ fontweight="bold",
609
+ bbox=box,
610
+ )
611
+
612
+ ax5 = plt.subplot(312)
613
+ plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
614
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
615
+ plt.ylim(signal_ylim)
616
+ plt.gca().set_xticklabels([])
617
+ plt.ylabel("Amplitude", fontsize='large')
618
+ plt.legend(loc='lower left', fontsize='medium')
619
+ plt.text(
620
+ text_loc[0],
621
+ text_loc[1],
622
+ '(ii)',
623
+ horizontalalignment='center',
624
+ transform=plt.gca().transAxes,
625
+ fontsize="medium",
626
+ fontweight="bold",
627
+ bbox=box,
628
+ )
629
+
630
+ plt.subplot(313)
631
+ plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
632
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
633
+ plt.ylim(signal_ylim)
634
+ plt.xlabel("Time (s)", fontsize='large')
635
+ plt.legend(loc='lower left', fontsize='medium')
636
+ plt.text(
637
+ text_loc[0],
638
+ text_loc[1],
639
+ '(iii)',
640
+ horizontalalignment='center',
641
+ transform=plt.gca().transAxes,
642
+ fontsize="medium",
643
+ fontweight="bold",
644
+ bbox=box,
645
+ )
646
+
647
+ plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_wave.png'), bbox_inches='tight')
648
+ # plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
649
+ plt.close(i)
650
+
651
+ return
652
+
653
+
654
+ def save_results(mask, X, fname, t0, save_signal=True, save_noise=True, result_dir="results"):
655
+
656
+ config = Config()
657
+
658
+ if save_signal:
659
+ _, denoised_signal = scipy.signal.istft(
660
+ (X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
661
+ fs=config.fs,
662
+ nperseg=config.nperseg,
663
+ nfft=config.nfft,
664
+ boundary='zeros',
665
+ ) # nbt, nch, nst, nt
666
+ denoised_signal = np.transpose(denoised_signal, [0, 3, 2, 1]) # nbt, nt, nst, nch,
667
+ if save_noise:
668
+ _, denoised_noise = scipy.signal.istft(
669
+ (X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
670
+ fs=config.fs,
671
+ nperseg=config.nperseg,
672
+ nfft=config.nfft,
673
+ boundary='zeros',
674
+ )
675
+ denoised_noise = np.transpose(denoised_noise, [0, 3, 2, 1])
676
+
677
+ if not os.path.exists(result_dir):
678
+ os.makedirs(result_dir)
679
+
680
+ for i in range(len(X)):
681
+ np.savez(
682
+ os.path.join(result_dir, fname[i]),
683
+ data=denoised_signal[i] if save_signal else None,
684
+ noise=denoised_noise[i] if save_noise else None,
685
+ t0=t0[i],
686
+ )
687
+
688
+
689
+ def plot_figures(mask, X, fname, figure_dir="figures"):
690
+
691
+ config = Config()
692
+
693
+ # plot the last channel
694
+ mask = mask[-1, -1, ...] # nch, nst, nf, nt, 2 => nf, nt, 2
695
+ X = X[-1, -1, ...]
696
+
697
+ t1, noisy_signal = scipy.signal.istft(
698
+ (X[..., 0] + X[..., 1] * 1j),
699
+ fs=config.fs,
700
+ nperseg=config.nperseg,
701
+ nfft=config.nfft,
702
+ boundary='zeros',
703
+ )
704
+ t1, denoised_signal = scipy.signal.istft(
705
+ (X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
706
+ fs=config.fs,
707
+ nperseg=config.nperseg,
708
+ nfft=config.nfft,
709
+ boundary='zeros',
710
+ )
711
+ t1, denoised_noise = scipy.signal.istft(
712
+ (X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
713
+ fs=config.fs,
714
+ nperseg=config.nperseg,
715
+ nfft=config.nfft,
716
+ boundary='zeros',
717
+ )
718
+
719
+ if not os.path.exists(figure_dir):
720
+ os.makedirs(figure_dir)
721
+
722
+ t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[1])
723
+ f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[0])
724
+
725
+ box = dict(boxstyle='round', facecolor='white', alpha=1)
726
+ text_loc = [0.05, 0.77]
727
+
728
+ plt.figure()
729
+ fig_size = plt.gcf().get_size_inches()
730
+ plt.gcf().set_size_inches(fig_size * [1, 1.2])
731
+ vmax = np.std(np.abs(X[:, :, 0] + X[:, :, 1] * 1j)) * 1.8
732
+
733
+ plt.subplot(311)
734
+ plt.pcolormesh(
735
+ t_FT,
736
+ f_FT,
737
+ np.abs(X[:, :, 0] + X[:, :, 1] * 1j),
738
+ vmin=0,
739
+ vmax=vmax,
740
+ shading='auto',
741
+ label='Noisy signal',
742
+ )
743
+ plt.gca().set_xticklabels([])
744
+ plt.text(
745
+ text_loc[0],
746
+ text_loc[1],
747
+ '(i)',
748
+ horizontalalignment='center',
749
+ transform=plt.gca().transAxes,
750
+ fontsize="medium",
751
+ fontweight="bold",
752
+ bbox=box,
753
+ )
754
+ plt.subplot(312)
755
+ plt.pcolormesh(
756
+ t_FT,
757
+ f_FT,
758
+ np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 0],
759
+ vmin=0,
760
+ vmax=vmax,
761
+ shading='auto',
762
+ label='Recovered signal',
763
+ )
764
+ plt.gca().set_xticklabels([])
765
+ plt.ylabel("Frequency (Hz)", fontsize='large')
766
+ plt.text(
767
+ text_loc[0],
768
+ text_loc[1],
769
+ '(ii)',
770
+ horizontalalignment='center',
771
+ transform=plt.gca().transAxes,
772
+ fontsize="medium",
773
+ fontweight="bold",
774
+ bbox=box,
775
+ )
776
+ plt.subplot(313)
777
+ plt.pcolormesh(
778
+ t_FT,
779
+ f_FT,
780
+ np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 1],
781
+ vmin=0,
782
+ vmax=vmax,
783
+ shading='auto',
784
+ label='Recovered noise',
785
+ )
786
+ plt.xlabel("Time (s)", fontsize='large')
787
+ plt.text(
788
+ text_loc[0],
789
+ text_loc[1],
790
+ '(iii)',
791
+ horizontalalignment='center',
792
+ transform=plt.gca().transAxes,
793
+ fontsize="medium",
794
+ fontweight="bold",
795
+ bbox=box,
796
+ )
797
+
798
+ try:
799
+ plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
800
+ # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
801
+ except FileNotFoundError:
802
+ os.makedirs(os.path.dirname(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png')), exist_ok=True)
803
+ plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
804
+ # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
805
+ plt.close()
806
+
807
+ plt.figure()
808
+ fig_size = plt.gcf().get_size_inches()
809
+ plt.gcf().set_size_inches(fig_size * [1, 1.2])
810
+
811
+ ax4 = plt.subplot(311)
812
+ plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
813
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
814
+ signal_ylim = [-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))]
815
+ if signal_ylim[0] != signal_ylim[1]:
816
+ plt.ylim(signal_ylim)
817
+ plt.gca().set_xticklabels([])
818
+ plt.legend(loc='lower left', fontsize='medium')
819
+ plt.text(
820
+ text_loc[0],
821
+ text_loc[1],
822
+ '(i)',
823
+ horizontalalignment='center',
824
+ transform=plt.gca().transAxes,
825
+ fontsize="medium",
826
+ fontweight="bold",
827
+ bbox=box,
828
+ )
829
+
830
+ ax5 = plt.subplot(312)
831
+ plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
832
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
833
+ if signal_ylim[0] != signal_ylim[1]:
834
+ plt.ylim(signal_ylim)
835
+ plt.gca().set_xticklabels([])
836
+ plt.ylabel("Amplitude", fontsize='large')
837
+ plt.legend(loc='lower left', fontsize='medium')
838
+ plt.text(
839
+ text_loc[0],
840
+ text_loc[1],
841
+ '(ii)',
842
+ horizontalalignment='center',
843
+ transform=plt.gca().transAxes,
844
+ fontsize="medium",
845
+ fontweight="bold",
846
+ bbox=box,
847
+ )
848
+
849
+ plt.subplot(313)
850
+ plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
851
+ plt.xlim([np.around(t1[0]), np.around(t1[-1])])
852
+ if signal_ylim[0] != signal_ylim[1]:
853
+ plt.ylim(signal_ylim)
854
+ plt.xlabel("Time (s)", fontsize='large')
855
+ plt.legend(loc='lower left', fontsize='medium')
856
+ plt.text(
857
+ text_loc[0],
858
+ text_loc[1],
859
+ '(iii)',
860
+ horizontalalignment='center',
861
+ transform=plt.gca().transAxes,
862
+ fontsize="medium",
863
+ fontweight="bold",
864
+ bbox=box,
865
+ )
866
+
867
+ plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_wave.png'), bbox_inches='tight')
868
+ # plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
869
+ plt.close()
870
+
871
+ return
872
+
873
+
874
+ if __name__ == "__main__":
875
+ pass
docs/README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DeepDenoiser
3
+ emoji: 🌊
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # DeepDenoiser: Seismic Signal Denoising and Decomposition Using Deep Neural Networks
11
+
12
+ [![](https://github.com/AI4EPS/DeepDenoiser/workflows/documentation/badge.svg)](https://ai4eps.github.io/DeepDenoiser)
13
+ ## 1. Install [miniconda](https://docs.conda.io/en/latest/miniconda.html) and requirements
14
+ - Download DeepDenoiser repository
15
+ ```bash
16
+ git clone https://github.com/wayneweiqiang/DeeoDenoiser.git
17
+ cd DeepDenoiser
18
+ ```
19
+ - Install to default environment
20
+ ```bash
21
+ conda env update -f=env.yml -n base
22
+ ```
23
+ - Install to "deepdenoiser" virtual envirionment
24
+ ```bash
25
+ conda env create -f env.yml
26
+ conda activate deepdenoiser
27
+ ```
28
+
29
+ ## 2. Pre-trained model
30
+ Located in directory: **model/190614-104802**
31
+
32
+ ## 3. Related papers
33
+ - Zhu, Weiqiang, S. Mostafa Mousavi, and Gregory C. Beroza. "Seismic Signal Denoising and Decomposition Using Deep Neural Networks." arXiv preprint arXiv:1811.02695 (2018).
34
+
35
+ ## 4. Interactive example
36
+ See details in the [notebook](https://github.com/wayneweiqiang/DeepDenoiser/blob/master/docs/example_interactive.ipynb): [example_interactive.ipynb](example_interactive.ipynb)
37
+
38
+
39
+ ## 5. Batch prediction
40
+ See details in the [notebook](https://github.com/wayneweiqiang/DeepDenoiser/blob/master/docs/example_batch_prediction.ipynb): [example_batch_prediction.ipynb](example_batch_prediction.ipynb)
41
+ ## 6. Train
42
+ ### Data format
43
+
44
+ Required: two csv files for signal and noise, corresponding directories of the npz files.
45
+
46
+ The csv file contains four columns: "fname", "itp", "channels"
47
+
48
+ The npz file contains four variable: "data", "itp", "channels"
49
+
50
+ The shape of "data" variables has a shape of 9001 x 3
51
+
52
+ The variables "itp" is the data points of first P arrival times.
53
+
54
+ Note: In the demo data, for simplicity we use the waveform before itp as noise samples, so the train_noise_list is same as train_signal_list here.
55
+
56
+ ~~~bash
57
+ python deepdenoiser/train.py --mode=train --train_signal_dir=./Dataset/train --train_signal_list=./Dataset/train.csv --train_noise_dir=./Dataset/train --train_noise_list=./Dataset/train.csv --batch_size=20
58
+ ~~~
59
+
60
+ Please let us know of any bugs found in the code. Suggestions and collaborations are welcomed
docs/example_batch_prediction.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
docs/example_interactive.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
env.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: deepdenoiser
2
+ channels:
3
+ - defaults
4
+ - conda-forge
5
+ dependencies:
6
+ - python=3.7
7
+ - numpy
8
+ - scipy
9
+ - matplotlib
10
+ - pandas
11
+ - scikit-learn
12
+ - tqdm
13
+ - obspy
14
+ - uvicorn
15
+ - fastapi
16
+ - kafka-python
17
+ - tensorflow
18
+
19
+
mkdocs.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ site_name: "DeepDenoiser"
2
+ site_description: 'DeepDenoiser: Seismic Signal Denoising and Decomposition Using Deep Neural Networks'
3
+ site_author: 'Weiqiang Zhu'
4
+ docs_dir: docs/
5
+ repo_name: 'wayneweiqiang/DeepDenoiser'
6
+ repo_url: 'https://github.com/wayneweiqiang/DeepDenoiser'
7
+ nav:
8
+ - Overview: README.md
9
+ - Interactive Example: example_interactive.ipynb
10
+ - Batch Prediction: example_batch_prediction.ipynb
11
+ theme:
12
+ name: 'material'
13
+ plugins:
14
+ - mkdocs-jupyter
15
+ extra:
16
+ analytics:
17
+ provider: google
18
+ property: G-FMMP8CQRDZ
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tensorflow
2
+ matplotlib
3
+ scipy
4
+ pandas
5
+ tqdm