Spaces:
Running
Running
Commit
·
81c99dc
0
Parent(s):
init
Browse files- Dockerfile +26 -0
- LICENSE +21 -0
- deepdenoiser/__init__.py +0 -0
- deepdenoiser/app.py +180 -0
- deepdenoiser/data_reader.py +816 -0
- deepdenoiser/model.py +495 -0
- deepdenoiser/predict.py +136 -0
- deepdenoiser/train.py +557 -0
- deepdenoiser/util.py +875 -0
- docs/README.md +60 -0
- docs/example_batch_prediction.ipynb +0 -0
- docs/example_interactive.ipynb +0 -0
- env.yml +19 -0
- mkdocs.yml +18 -0
- requirements.txt +5 -0
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM tensorflow/tensorflow
|
2 |
+
|
3 |
+
# Create the environment:
|
4 |
+
# COPY env.yml /app
|
5 |
+
# RUN conda env create --name cs329s --file=env.yml
|
6 |
+
# Make RUN commands use the new environment:
|
7 |
+
# SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"]
|
8 |
+
|
9 |
+
RUN pip install tqdm obspy pandas
|
10 |
+
RUN pip install uvicorn fastapi kafka-python
|
11 |
+
|
12 |
+
WORKDIR /opt
|
13 |
+
|
14 |
+
# Copy files
|
15 |
+
COPY deepdenoiser /opt/deepdenoiser
|
16 |
+
# COPY model /opt/model
|
17 |
+
RUN wget https://github.com/AI4EPS/models/releases/download/DeepDenoiser/model.tar && tar -xvf model.tar && rm model.tar
|
18 |
+
|
19 |
+
# Expose API port
|
20 |
+
EXPOSE 8000
|
21 |
+
|
22 |
+
ENV PYTHONUNBUFFERED=1
|
23 |
+
|
24 |
+
# Start API server
|
25 |
+
#ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "cs329s", "uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
|
26 |
+
ENTRYPOINT ["uvicorn", "--app-dir", "deepdenoiser", "app:app", "--reload", "--port", "7860", "--host", "0.0.0.0"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Weiqiang Zhu
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
deepdenoiser/__init__.py
ADDED
File without changes
|
deepdenoiser/app.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import defaultdict, namedtuple
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
from json import dumps
|
5 |
+
from typing import Any, AnyStr, Dict, List, NamedTuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import requests
|
9 |
+
import tensorflow as tf
|
10 |
+
from fastapi import FastAPI
|
11 |
+
from kafka import KafkaProducer
|
12 |
+
from pydantic import BaseModel
|
13 |
+
import scipy
|
14 |
+
from scipy.interpolate import interp1d
|
15 |
+
|
16 |
+
from model import UNet
|
17 |
+
|
18 |
+
tf.compat.v1.disable_eager_execution()
|
19 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
20 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
21 |
+
JSONObject = Dict[AnyStr, Any]
|
22 |
+
JSONArray = List[Any]
|
23 |
+
JSONStructure = Union[JSONArray, JSONObject]
|
24 |
+
|
25 |
+
app = FastAPI()
|
26 |
+
X_SHAPE = [3000, 1, 3]
|
27 |
+
SAMPLING_RATE = 100
|
28 |
+
|
29 |
+
# load model
|
30 |
+
model = UNet(mode="pred")
|
31 |
+
sess_config = tf.compat.v1.ConfigProto()
|
32 |
+
sess_config.gpu_options.allow_growth = True
|
33 |
+
|
34 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
35 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
36 |
+
init = tf.compat.v1.global_variables_initializer()
|
37 |
+
sess.run(init)
|
38 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802")
|
39 |
+
print(f"restoring model {latest_check_point}")
|
40 |
+
saver.restore(sess, latest_check_point)
|
41 |
+
|
42 |
+
# Kafak producer
|
43 |
+
use_kafka = False
|
44 |
+
# BROKER_URL = 'localhost:9092'
|
45 |
+
# BROKER_URL = 'my-kafka-headless:9092'
|
46 |
+
|
47 |
+
try:
|
48 |
+
print("Connecting to k8s kafka")
|
49 |
+
BROKER_URL = "quakeflow-kafka-headless:9092"
|
50 |
+
producer = KafkaProducer(
|
51 |
+
bootstrap_servers=[BROKER_URL],
|
52 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
53 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
54 |
+
)
|
55 |
+
use_kafka = True
|
56 |
+
print("k8s kafka connection success!")
|
57 |
+
except BaseException:
|
58 |
+
print("k8s Kafka connection error")
|
59 |
+
try:
|
60 |
+
print("Connecting to local kafka")
|
61 |
+
producer = KafkaProducer(
|
62 |
+
bootstrap_servers=["localhost:9092"],
|
63 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
64 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
65 |
+
)
|
66 |
+
use_kafka = True
|
67 |
+
print("local kafka connection success!")
|
68 |
+
except BaseException:
|
69 |
+
print("local Kafka connection error")
|
70 |
+
|
71 |
+
|
72 |
+
def normalize_batch(data, window=200):
|
73 |
+
"""
|
74 |
+
data: nbn, nf, nt, 2
|
75 |
+
"""
|
76 |
+
assert len(data.shape) == 4
|
77 |
+
shift = window // 2
|
78 |
+
nbt, nf, nt, nimg = data.shape
|
79 |
+
|
80 |
+
## std in slide windows
|
81 |
+
data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
82 |
+
t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200
|
83 |
+
# print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
|
84 |
+
std = np.zeros([nbt, len(t)])
|
85 |
+
mean = np.zeros([nbt, len(t)])
|
86 |
+
for i in range(std.shape[1]):
|
87 |
+
std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
88 |
+
mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
89 |
+
|
90 |
+
std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
|
91 |
+
std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
|
92 |
+
|
93 |
+
## normalize data with interplated std
|
94 |
+
t_interp = np.arange(nt, dtype="int")
|
95 |
+
std_interp = interp1d(t, std, kind="slinear")(t_interp)
|
96 |
+
std_interp[std_interp == 0] = 1.0
|
97 |
+
mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
98 |
+
|
99 |
+
data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]
|
100 |
+
|
101 |
+
if len(t) > 3: ##need to address this normalization issue in training
|
102 |
+
data /= 2.0
|
103 |
+
|
104 |
+
return data
|
105 |
+
|
106 |
+
|
107 |
+
def get_prediction(meta):
|
108 |
+
|
109 |
+
FS = 100
|
110 |
+
NPERSEG = 30
|
111 |
+
NFFT = 60
|
112 |
+
|
113 |
+
vec = np.array(meta.vec) # [batch, nt, chn]
|
114 |
+
nbt, nt, nch = vec.shape
|
115 |
+
vec = np.transpose(vec, [0, 2, 1]) # [batch, chn, nt]
|
116 |
+
vec = np.reshape(vec, [nbt * nch, nt]) ## [batch * chn, nt]
|
117 |
+
|
118 |
+
if np.mod(vec.shape[-1], 3000) == 1: # 3001=>3000
|
119 |
+
vec = vec[..., :-1]
|
120 |
+
|
121 |
+
if meta.dt != 0.01:
|
122 |
+
t = np.linspace(0, 1, len(vec))
|
123 |
+
t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS)))
|
124 |
+
vec = interp1d(t, vec, kind="slinear")(t_interp)
|
125 |
+
|
126 |
+
# sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence
|
127 |
+
# vec = scipy.signal.sosfilt(sos, vec)
|
128 |
+
f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary='zeros')
|
129 |
+
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # [batch * chn, nf, nt, 2]
|
130 |
+
noisy_signal[np.isnan(noisy_signal)] = 0
|
131 |
+
noisy_signal[np.isinf(noisy_signal)] = 0
|
132 |
+
X_input = normalize_batch(noisy_signal)
|
133 |
+
|
134 |
+
feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False}
|
135 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
136 |
+
|
137 |
+
_, denoised_signal = scipy.signal.istft(
|
138 |
+
(noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0],
|
139 |
+
fs=FS,
|
140 |
+
nperseg=NPERSEG,
|
141 |
+
nfft=NFFT,
|
142 |
+
boundary='zeros',
|
143 |
+
)
|
144 |
+
# _, denoised_noise = scipy.signal.istft(
|
145 |
+
# (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1],
|
146 |
+
# fs=FS,
|
147 |
+
# nperseg=NPERSEG,
|
148 |
+
# nfft=NFFT,
|
149 |
+
# boundary='zeros',
|
150 |
+
# )
|
151 |
+
|
152 |
+
denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt])
|
153 |
+
denoised_signal = np.transpose(denoised_signal, [0, 2, 1])
|
154 |
+
|
155 |
+
result = meta.copy()
|
156 |
+
result.vec = denoised_signal.tolist()
|
157 |
+
return result
|
158 |
+
|
159 |
+
|
160 |
+
class Data(BaseModel):
|
161 |
+
# id: Union[List[str], str]
|
162 |
+
# timestamp: Union[List[str], str]
|
163 |
+
# vec: Union[List[List[List[float]]], List[List[float]]]
|
164 |
+
id: List[str]
|
165 |
+
timestamp: List[str]
|
166 |
+
vec: List[List[List[float]]]
|
167 |
+
dt: float = 0.01
|
168 |
+
|
169 |
+
|
170 |
+
@app.post("/predict")
|
171 |
+
def predict(data: Data):
|
172 |
+
|
173 |
+
denoised = get_prediction(data)
|
174 |
+
|
175 |
+
return denoised
|
176 |
+
|
177 |
+
|
178 |
+
@app.get("/healthz")
|
179 |
+
def healthz():
|
180 |
+
return {"status": "ok"}
|
deepdenoiser/data_reader.py
ADDED
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import scipy.signal
|
4 |
+
import tensorflow as tf
|
5 |
+
|
6 |
+
pd.options.mode.chained_assignment = None
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import threading
|
10 |
+
|
11 |
+
import obspy
|
12 |
+
from scipy.interpolate import interp1d
|
13 |
+
|
14 |
+
tf.compat.v1.disable_eager_execution()
|
15 |
+
# from tensorflow.python.ops.linalg_ops import norm
|
16 |
+
# from tensorflow.python.util import nest
|
17 |
+
|
18 |
+
|
19 |
+
class Config:
|
20 |
+
seed = 100
|
21 |
+
n_class = 2
|
22 |
+
fs = 100
|
23 |
+
dt = 1.0 / fs
|
24 |
+
freq_range = [0, fs / 2]
|
25 |
+
time_range = [0, 30]
|
26 |
+
nperseg = 30
|
27 |
+
nfft = 60
|
28 |
+
plot = False
|
29 |
+
nt = 3000
|
30 |
+
X_shape = [31, 201, 2]
|
31 |
+
Y_shape = [31, 201, n_class]
|
32 |
+
signal_shape = [31, 201]
|
33 |
+
noise_shape = signal_shape
|
34 |
+
use_seed = False
|
35 |
+
queue_size = 10
|
36 |
+
noise_mean = 2
|
37 |
+
noise_std = 1
|
38 |
+
# noise_low = 1
|
39 |
+
# noise_high = 5
|
40 |
+
use_buffer = True
|
41 |
+
snr_threshold = 10
|
42 |
+
|
43 |
+
|
44 |
+
# %%
|
45 |
+
# def normalize(data, window=3000):
|
46 |
+
# """
|
47 |
+
# data: nsta, chn, nt
|
48 |
+
# """
|
49 |
+
# shift = window//2
|
50 |
+
# nt = len(data)
|
51 |
+
|
52 |
+
# ## std in slide windows
|
53 |
+
# data_pad = np.pad(data, ((window//2, window//2)), mode="reflect")
|
54 |
+
# t = np.arange(0, nt, shift, dtype="int")
|
55 |
+
# # print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
|
56 |
+
# std = np.zeros(len(t))
|
57 |
+
# mean = np.zeros(len(t))
|
58 |
+
# for i in range(len(std)):
|
59 |
+
# std[i] = np.std(data_pad[i*shift:i*shift+window])
|
60 |
+
# mean[i] = np.mean(data_pad[i*shift:i*shift+window])
|
61 |
+
|
62 |
+
# t = np.append(t, nt)
|
63 |
+
# std = np.append(std, [np.std(data_pad[-window:])])
|
64 |
+
# mean = np.append(mean, [np.mean(data_pad[-window:])])
|
65 |
+
|
66 |
+
# # print(t)
|
67 |
+
# ## normalize data with interplated std
|
68 |
+
# t_interp = np.arange(nt, dtype="int")
|
69 |
+
# std_interp = interp1d(t, std, kind="slinear")(t_interp)
|
70 |
+
# mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
71 |
+
# data = (data - mean_interp)/(std_interp)
|
72 |
+
# return data, std_interp
|
73 |
+
|
74 |
+
# %%
|
75 |
+
def normalize(data, window=200):
|
76 |
+
"""
|
77 |
+
data: nsta, chn, nt
|
78 |
+
"""
|
79 |
+
shift = window // 2
|
80 |
+
nt = data.shape[1]
|
81 |
+
|
82 |
+
## std in slide windows
|
83 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
84 |
+
t = np.arange(0, nt, shift, dtype="int")
|
85 |
+
# print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
|
86 |
+
std = np.zeros(len(t))
|
87 |
+
mean = np.zeros(len(t))
|
88 |
+
for i in range(len(std)):
|
89 |
+
std[i] = np.std(data_pad[:, i * shift : i * shift + window, :])
|
90 |
+
mean[i] = np.mean(data_pad[:, i * shift : i * shift + window, :])
|
91 |
+
|
92 |
+
t = np.append(t, nt)
|
93 |
+
std = np.append(std, [np.std(data_pad[:, -window:, :])])
|
94 |
+
mean = np.append(mean, [np.mean(data_pad[:, -window:, :])])
|
95 |
+
# print(t)
|
96 |
+
## normalize data with interplated std
|
97 |
+
t_interp = np.arange(nt, dtype="int")
|
98 |
+
std_interp = interp1d(t, std, kind="slinear")(t_interp)
|
99 |
+
std_interp[std_interp == 0] = 1.0
|
100 |
+
mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
101 |
+
data = (data - mean_interp[np.newaxis, :, np.newaxis]) / std_interp[np.newaxis, :, np.newaxis]
|
102 |
+
return data, std_interp
|
103 |
+
|
104 |
+
|
105 |
+
def normalize_batch(data, window=200):
|
106 |
+
"""
|
107 |
+
data: nbn, nf, nt, 2
|
108 |
+
"""
|
109 |
+
assert len(data.shape) == 4
|
110 |
+
shift = window // 2
|
111 |
+
nbt, nf, nt, nimg = data.shape
|
112 |
+
|
113 |
+
## std in slide windows
|
114 |
+
data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
115 |
+
t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200
|
116 |
+
std = np.zeros([nbt, len(t)])
|
117 |
+
mean = np.zeros([nbt, len(t)])
|
118 |
+
for i in range(std.shape[1]):
|
119 |
+
std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
120 |
+
mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
|
121 |
+
|
122 |
+
std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
|
123 |
+
std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
|
124 |
+
|
125 |
+
## normalize data with interplated std
|
126 |
+
t_interp = np.arange(nt, dtype="int")
|
127 |
+
std_interp = interp1d(t, std, kind="slinear")(t_interp) ##nbt, nt
|
128 |
+
std_interp[std_interp == 0] = 1.0
|
129 |
+
mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
|
130 |
+
|
131 |
+
data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]
|
132 |
+
|
133 |
+
if len(t) > 3: ##need to address this normalization issue in training
|
134 |
+
data /= 2.0
|
135 |
+
|
136 |
+
return data
|
137 |
+
|
138 |
+
|
139 |
+
# %%
|
140 |
+
def py_func_decorator(output_types=None, output_shapes=None, name=None):
|
141 |
+
def decorator(func):
|
142 |
+
def call(*args, **kwargs):
|
143 |
+
nonlocal output_shapes
|
144 |
+
# flat_output_types = nest.flatten(output_types)
|
145 |
+
flat_output_types = tf.nest.flatten(output_types)
|
146 |
+
# flat_values = tf.py_func(
|
147 |
+
flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
|
148 |
+
if output_shapes is not None:
|
149 |
+
for v, s in zip(flat_values, output_shapes):
|
150 |
+
v.set_shape(s)
|
151 |
+
# return nest.pack_sequence_as(output_types, flat_values)
|
152 |
+
return tf.nest.pack_sequence_as(output_types, flat_values)
|
153 |
+
|
154 |
+
return call
|
155 |
+
|
156 |
+
return decorator
|
157 |
+
|
158 |
+
|
159 |
+
def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None):
|
160 |
+
dataset = tf.data.Dataset.range(len(iterator))
|
161 |
+
|
162 |
+
@py_func_decorator(output_types, output_shapes, name=name)
|
163 |
+
def index_to_entry(idx):
|
164 |
+
return iterator[idx]
|
165 |
+
|
166 |
+
return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
|
167 |
+
|
168 |
+
|
169 |
+
class DataReader(object):
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
signal_dir=None,
|
173 |
+
signal_list=None,
|
174 |
+
noise_dir=None,
|
175 |
+
noise_list=None,
|
176 |
+
queue_size=None,
|
177 |
+
coord=None,
|
178 |
+
config=Config(),
|
179 |
+
):
|
180 |
+
|
181 |
+
self.config = config
|
182 |
+
|
183 |
+
signal_list = pd.read_csv(signal_list, header=0)
|
184 |
+
noise_list = pd.read_csv(noise_list, header=0)
|
185 |
+
|
186 |
+
self.signal = signal_list
|
187 |
+
self.noise = noise_list
|
188 |
+
self.n_signal = len(self.signal)
|
189 |
+
|
190 |
+
self.signal_dir = signal_dir
|
191 |
+
self.noise_dir = noise_dir
|
192 |
+
|
193 |
+
self.X_shape = config.X_shape
|
194 |
+
self.Y_shape = config.Y_shape
|
195 |
+
self.n_class = config.n_class
|
196 |
+
|
197 |
+
self.coord = coord
|
198 |
+
self.threads = []
|
199 |
+
self.queue_size = queue_size
|
200 |
+
|
201 |
+
self.add_queue()
|
202 |
+
self.buffer_signal = {}
|
203 |
+
self.buffer_noise = {}
|
204 |
+
self.buffer_channels_signal = {}
|
205 |
+
self.buffer_channels_noise = {}
|
206 |
+
|
207 |
+
def add_queue(self):
|
208 |
+
with tf.device('/cpu:0'):
|
209 |
+
self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
210 |
+
self.target_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
211 |
+
self.queue = tf.queue.PaddingFIFOQueue(
|
212 |
+
self.queue_size, ['float32', 'float32'], shapes=[self.config.X_shape, self.config.Y_shape]
|
213 |
+
)
|
214 |
+
self.enqueue = self.queue.enqueue([self.sample_placeholder, self.target_placeholder])
|
215 |
+
return 0
|
216 |
+
|
217 |
+
def dequeue(self, num_elements):
|
218 |
+
output = self.queue.dequeue_many(num_elements)
|
219 |
+
return output
|
220 |
+
|
221 |
+
def get_snr(self, data, itp, dit=300):
|
222 |
+
tmp_std = np.std(data[itp - dit : itp])
|
223 |
+
if tmp_std > 0:
|
224 |
+
return np.std(data[itp : itp + dit]) / tmp_std
|
225 |
+
else:
|
226 |
+
return 0
|
227 |
+
|
228 |
+
def add_event(self, sample, channels, j):
|
229 |
+
while np.random.uniform(0, 1) < 0.2:
|
230 |
+
shift = None
|
231 |
+
if channels not in self.buffer_channels_signal:
|
232 |
+
self.buffer_channels_signal[channels] = self.signal[self.signal['channels'] == channels]
|
233 |
+
fname = os.path.join(self.signal_dir, self.buffer_channels_signal[channels].sample(n=1).iloc[0]['fname'])
|
234 |
+
try:
|
235 |
+
if fname not in self.buffer_signal:
|
236 |
+
meta = np.load(fname)
|
237 |
+
data_FT = []
|
238 |
+
snr = []
|
239 |
+
for i in range(3):
|
240 |
+
tmp_data = meta['data'][:, i]
|
241 |
+
tmp_itp = meta['itp']
|
242 |
+
snr.append(self.get_snr(tmp_data, tmp_itp))
|
243 |
+
tmp_data -= np.mean(tmp_data)
|
244 |
+
f, t, tmp_FT = scipy.signal.stft(
|
245 |
+
tmp_data,
|
246 |
+
fs=self.config.fs,
|
247 |
+
nperseg=self.config.nperseg,
|
248 |
+
nfft=self.config.nfft,
|
249 |
+
boundary='zeros',
|
250 |
+
)
|
251 |
+
data_FT.append(tmp_FT)
|
252 |
+
data_FT = np.stack(data_FT, axis=-1)
|
253 |
+
self.buffer_signal[fname] = {
|
254 |
+
'data_FT': data_FT,
|
255 |
+
'itp': tmp_itp,
|
256 |
+
'channels': meta['channels'],
|
257 |
+
'snr': snr,
|
258 |
+
}
|
259 |
+
meta_signal = self.buffer_signal[fname]
|
260 |
+
except:
|
261 |
+
logging.error("Failed reading signal: {}".format(fname))
|
262 |
+
continue
|
263 |
+
if meta_signal['snr'][j] > self.config.snr_threshold:
|
264 |
+
tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
|
265 |
+
shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
|
266 |
+
tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
|
267 |
+
if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
|
268 |
+
continue
|
269 |
+
tmp_signal = tmp_signal / np.std(tmp_signal)
|
270 |
+
sample += tmp_signal / np.random.uniform(1, 5)
|
271 |
+
return sample
|
272 |
+
|
273 |
+
def thread_main(self, sess, n_threads=1, start=0):
|
274 |
+
stop = False
|
275 |
+
while not stop:
|
276 |
+
index = list(range(start, self.n_signal, n_threads))
|
277 |
+
np.random.shuffle(index)
|
278 |
+
for i in index:
|
279 |
+
fname_signal = os.path.join(self.signal_dir, self.signal.iloc[i]['fname'])
|
280 |
+
try:
|
281 |
+
if fname_signal not in self.buffer_signal:
|
282 |
+
meta = np.load(fname_signal)
|
283 |
+
data_FT = []
|
284 |
+
snr = []
|
285 |
+
for j in range(3):
|
286 |
+
tmp_data = meta['data'][..., j]
|
287 |
+
tmp_itp = meta['itp']
|
288 |
+
snr.append(self.get_snr(tmp_data, tmp_itp))
|
289 |
+
tmp_data -= np.mean(tmp_data)
|
290 |
+
f, t, tmp_FT = scipy.signal.stft(
|
291 |
+
tmp_data,
|
292 |
+
fs=self.config.fs,
|
293 |
+
nperseg=self.config.nperseg,
|
294 |
+
nfft=self.config.nfft,
|
295 |
+
boundary='zeros',
|
296 |
+
)
|
297 |
+
data_FT.append(tmp_FT)
|
298 |
+
data_FT = np.stack(data_FT, axis=-1)
|
299 |
+
self.buffer_signal[fname_signal] = {
|
300 |
+
'data_FT': data_FT,
|
301 |
+
'itp': tmp_itp,
|
302 |
+
'channels': meta['channels'],
|
303 |
+
'snr': snr,
|
304 |
+
}
|
305 |
+
meta_signal = self.buffer_signal[fname_signal]
|
306 |
+
except:
|
307 |
+
logging.error("Failed reading signal: {}".format(fname_signal))
|
308 |
+
continue
|
309 |
+
channels = meta_signal['channels'].tolist()
|
310 |
+
start_tp = meta_signal['itp'].tolist()
|
311 |
+
|
312 |
+
if channels not in self.buffer_channels_noise:
|
313 |
+
self.buffer_channels_noise[channels] = self.noise[self.noise['channels'] == channels]
|
314 |
+
fname_noise = os.path.join(
|
315 |
+
self.noise_dir, self.buffer_channels_noise[channels].sample(n=1).iloc[0]['fname']
|
316 |
+
)
|
317 |
+
try:
|
318 |
+
if fname_noise not in self.buffer_noise:
|
319 |
+
meta = np.load(fname_noise)
|
320 |
+
data_FT = []
|
321 |
+
for i in range(3):
|
322 |
+
tmp_data = meta['data'][: self.config.nt, i]
|
323 |
+
tmp_data -= np.mean(tmp_data)
|
324 |
+
f, t, tmp_FT = scipy.signal.stft(
|
325 |
+
tmp_data,
|
326 |
+
fs=self.config.fs,
|
327 |
+
nperseg=self.config.nperseg,
|
328 |
+
nfft=self.config.nfft,
|
329 |
+
boundary='zeros',
|
330 |
+
)
|
331 |
+
data_FT.append(tmp_FT)
|
332 |
+
data_FT = np.stack(data_FT, axis=-1)
|
333 |
+
self.buffer_noise[fname_noise] = {'data_FT': data_FT, 'channels': meta['channels']}
|
334 |
+
meta_noise = self.buffer_noise[fname_noise]
|
335 |
+
except:
|
336 |
+
logging.error("Failed reading noise: {}".format(fname_noise))
|
337 |
+
continue
|
338 |
+
|
339 |
+
if self.coord.should_stop():
|
340 |
+
stop = True
|
341 |
+
break
|
342 |
+
|
343 |
+
j = np.random.choice([0, 1, 2])
|
344 |
+
if meta_signal['snr'][j] <= self.config.snr_threshold:
|
345 |
+
continue
|
346 |
+
|
347 |
+
tmp_noise = meta_noise['data_FT'][..., j]
|
348 |
+
if np.isinf(tmp_noise).any() or np.isnan(tmp_noise).any() or (not np.any(tmp_noise)):
|
349 |
+
continue
|
350 |
+
tmp_noise = tmp_noise / np.std(tmp_noise)
|
351 |
+
|
352 |
+
tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
|
353 |
+
if np.random.random() < 0.9:
|
354 |
+
shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
|
355 |
+
tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
|
356 |
+
if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
|
357 |
+
continue
|
358 |
+
tmp_signal = tmp_signal / np.std(tmp_signal)
|
359 |
+
tmp_signal = self.add_event(tmp_signal, channels, j)
|
360 |
+
|
361 |
+
if np.random.random() < 0.2:
|
362 |
+
tmp_signal = np.fliplr(tmp_signal)
|
363 |
+
|
364 |
+
ratio = 0
|
365 |
+
while ratio <= 0:
|
366 |
+
ratio = self.config.noise_mean + np.random.randn() * self.config.noise_std
|
367 |
+
# ratio = np.random.uniform(self.config.noise_low, self.config.noise_high)
|
368 |
+
tmp_noisy_signal = tmp_signal + ratio * tmp_noise
|
369 |
+
noisy_signal = np.stack([tmp_noisy_signal.real, tmp_noisy_signal.imag], axis=-1)
|
370 |
+
if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any():
|
371 |
+
continue
|
372 |
+
noisy_signal = noisy_signal / np.std(noisy_signal)
|
373 |
+
tmp_mask = np.abs(tmp_signal) / (np.abs(tmp_signal) + np.abs(ratio * tmp_noise) + 1e-4)
|
374 |
+
tmp_mask[tmp_mask >= 1] = 1
|
375 |
+
tmp_mask[tmp_mask <= 0] = 0
|
376 |
+
mask = np.zeros([tmp_mask.shape[0], tmp_mask.shape[1], self.n_class])
|
377 |
+
mask[:, :, 0] = tmp_mask
|
378 |
+
mask[:, :, 1] = 1 - tmp_mask
|
379 |
+
sess.run(self.enqueue, feed_dict={self.sample_placeholder: noisy_signal, self.target_placeholder: mask})
|
380 |
+
|
381 |
+
def start_threads(self, sess, n_threads=8):
|
382 |
+
for i in range(n_threads):
|
383 |
+
thread = threading.Thread(target=self.thread_main, args=(sess, n_threads, i))
|
384 |
+
thread.daemon = True
|
385 |
+
thread.start()
|
386 |
+
self.threads.append(thread)
|
387 |
+
return self.threads
|
388 |
+
|
389 |
+
|
390 |
+
class DataReader_test(DataReader):
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
signal_dir=None,
|
394 |
+
signal_list=None,
|
395 |
+
noise_dir=None,
|
396 |
+
noise_list=None,
|
397 |
+
queue_size=None,
|
398 |
+
coord=None,
|
399 |
+
config=Config(),
|
400 |
+
):
|
401 |
+
self.config = config
|
402 |
+
|
403 |
+
signal_list = pd.read_csv(signal_list, header=0)
|
404 |
+
noise_list = pd.read_csv(noise_list, header=0)
|
405 |
+
self.signal = signal_list
|
406 |
+
self.noise = noise_list
|
407 |
+
self.n_signal = len(self.signal)
|
408 |
+
|
409 |
+
self.signal_dir = signal_dir
|
410 |
+
self.noise_dir = noise_dir
|
411 |
+
|
412 |
+
self.X_shape = config.X_shape
|
413 |
+
self.Y_shape = config.Y_shape
|
414 |
+
self.n_class = config.n_class
|
415 |
+
|
416 |
+
self.coord = coord
|
417 |
+
self.threads = []
|
418 |
+
self.queue_size = queue_size
|
419 |
+
|
420 |
+
self.add_queue()
|
421 |
+
self.buffer_signal = {}
|
422 |
+
self.buffer_noise = {}
|
423 |
+
self.buffer_channels_signal = {}
|
424 |
+
self.buffer_channels_noise = {}
|
425 |
+
|
426 |
+
def add_queue(self):
|
427 |
+
self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
428 |
+
self.target_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
429 |
+
self.ratio_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
430 |
+
self.signal_placeholder = tf.compat.v1.placeholder(dtype=tf.complex64, shape=None)
|
431 |
+
self.noise_placeholder = tf.compat.v1.placeholder(dtype=tf.complex64, shape=None)
|
432 |
+
self.fname_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=None)
|
433 |
+
self.queue = tf.queue.PaddingFIFOQueue(
|
434 |
+
self.queue_size,
|
435 |
+
['float32', 'float32', 'float32', 'complex64', 'complex64', 'string'],
|
436 |
+
shapes=[
|
437 |
+
self.config.X_shape,
|
438 |
+
self.config.Y_shape,
|
439 |
+
[],
|
440 |
+
self.config.signal_shape,
|
441 |
+
self.config.noise_shape,
|
442 |
+
[],
|
443 |
+
],
|
444 |
+
)
|
445 |
+
self.enqueue = self.queue.enqueue(
|
446 |
+
[
|
447 |
+
self.sample_placeholder,
|
448 |
+
self.target_placeholder,
|
449 |
+
self.ratio_placeholder,
|
450 |
+
self.signal_placeholder,
|
451 |
+
self.noise_placeholder,
|
452 |
+
self.fname_placeholder,
|
453 |
+
]
|
454 |
+
)
|
455 |
+
return 0
|
456 |
+
|
457 |
+
def dequeue(self, num_elements):
|
458 |
+
output = self.queue.dequeue_up_to(num_elements)
|
459 |
+
return output
|
460 |
+
|
461 |
+
def thread_main(self, sess, n_threads=1, start=0):
|
462 |
+
index = list(range(start, self.n_signal, n_threads))
|
463 |
+
for i in index:
|
464 |
+
np.random.seed(i)
|
465 |
+
|
466 |
+
fname = self.signal.iloc[i]['fname']
|
467 |
+
fname_signal = os.path.join(self.signal_dir, fname)
|
468 |
+
meta = np.load(fname_signal)
|
469 |
+
data_FT = []
|
470 |
+
snr = []
|
471 |
+
for j in range(3):
|
472 |
+
tmp_data = meta['data'][..., j]
|
473 |
+
tmp_itp = meta['itp']
|
474 |
+
snr.append(self.get_snr(tmp_data, tmp_itp))
|
475 |
+
tmp_data -= np.mean(tmp_data)
|
476 |
+
f, t, tmp_FT = scipy.signal.stft(
|
477 |
+
tmp_data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
478 |
+
)
|
479 |
+
data_FT.append(tmp_FT)
|
480 |
+
data_FT = np.stack(data_FT, axis=-1)
|
481 |
+
meta_signal = {'data_FT': data_FT, 'itp': tmp_itp, 'channels': meta['channels'], 'snr': snr}
|
482 |
+
channels = meta['channels'].tolist()
|
483 |
+
start_tp = meta['itp'].tolist()
|
484 |
+
|
485 |
+
if channels not in self.buffer_channels_noise:
|
486 |
+
self.buffer_channels_noise[channels] = self.noise[self.noise['channels'] == channels]
|
487 |
+
fname_noise = os.path.join(
|
488 |
+
self.noise_dir, self.buffer_channels_noise[channels].sample(n=1, random_state=i).iloc[0]['fname']
|
489 |
+
)
|
490 |
+
meta = np.load(fname_noise)
|
491 |
+
data_FT = []
|
492 |
+
for i in range(3):
|
493 |
+
tmp_data = meta['data'][: self.config.nt, i]
|
494 |
+
tmp_data -= np.mean(tmp_data)
|
495 |
+
f, t, tmp_FT = scipy.signal.stft(
|
496 |
+
tmp_data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
497 |
+
)
|
498 |
+
data_FT.append(tmp_FT)
|
499 |
+
data_FT = np.stack(data_FT, axis=-1)
|
500 |
+
meta_noise = {'data_FT': data_FT, 'channels': meta['channels']}
|
501 |
+
|
502 |
+
if self.coord.should_stop():
|
503 |
+
stop = True
|
504 |
+
break
|
505 |
+
|
506 |
+
j = np.random.choice([0, 1, 2])
|
507 |
+
tmp_noise = meta_noise['data_FT'][..., j]
|
508 |
+
if np.isinf(tmp_noise).any() or np.isnan(tmp_noise).any() or (not np.any(tmp_noise)):
|
509 |
+
continue
|
510 |
+
tmp_noise = tmp_noise / np.std(tmp_noise)
|
511 |
+
|
512 |
+
tmp_signal = np.zeros([self.X_shape[0], self.X_shape[1]], dtype=np.complex_)
|
513 |
+
if np.random.random() < 0.9:
|
514 |
+
shift = np.random.randint(-self.X_shape[1], 1, None, 'int')
|
515 |
+
tmp_signal[:, -shift:] = meta_signal['data_FT'][:, self.X_shape[1] : 2 * self.X_shape[1] + shift, j]
|
516 |
+
if np.isinf(tmp_signal).any() or np.isnan(tmp_signal).any() or (not np.any(tmp_signal)):
|
517 |
+
continue
|
518 |
+
tmp_signal = tmp_signal / np.std(tmp_signal)
|
519 |
+
# tmp_signal = self.add_event(tmp_signal, channels, j)
|
520 |
+
# if np.random.random() < 0.2:
|
521 |
+
# tmp_signal = np.fliplr(tmp_signal)
|
522 |
+
|
523 |
+
ratio = 0
|
524 |
+
while ratio <= 0:
|
525 |
+
ratio = self.config.noise_mean + np.random.randn() * self.config.noise_std
|
526 |
+
tmp_noisy_signal = tmp_signal + ratio * tmp_noise
|
527 |
+
noisy_signal = np.stack([tmp_noisy_signal.real, tmp_noisy_signal.imag], axis=-1)
|
528 |
+
if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any():
|
529 |
+
continue
|
530 |
+
std_noisy_signal = np.std(noisy_signal)
|
531 |
+
noisy_signal = noisy_signal / std_noisy_signal
|
532 |
+
tmp_mask = np.abs(tmp_signal) / (np.abs(tmp_signal) + np.abs(ratio * tmp_noise) + 1e-4)
|
533 |
+
tmp_mask[tmp_mask >= 1] = 1
|
534 |
+
tmp_mask[tmp_mask <= 0] = 0
|
535 |
+
mask = np.zeros([tmp_mask.shape[0], tmp_mask.shape[1], self.n_class])
|
536 |
+
mask[:, :, 0] = tmp_mask
|
537 |
+
mask[:, :, 1] = 1 - tmp_mask
|
538 |
+
|
539 |
+
sess.run(
|
540 |
+
self.enqueue,
|
541 |
+
feed_dict={
|
542 |
+
self.sample_placeholder: noisy_signal,
|
543 |
+
self.target_placeholder: mask,
|
544 |
+
self.ratio_placeholder: std_noisy_signal,
|
545 |
+
self.signal_placeholder: tmp_signal,
|
546 |
+
self.noise_placeholder: ratio * tmp_noise,
|
547 |
+
self.fname_placeholder: fname,
|
548 |
+
},
|
549 |
+
)
|
550 |
+
|
551 |
+
|
552 |
+
class DataReader_pred_queue(DataReader):
|
553 |
+
def __init__(self, signal_dir, signal_list, queue_size, coord, config=Config()):
|
554 |
+
self.config = config
|
555 |
+
signal_list = pd.read_csv(signal_list)
|
556 |
+
self.signal = signal_list
|
557 |
+
self.n_signal = len(self.signal)
|
558 |
+
self.n_class = config.n_class
|
559 |
+
self.X_shape = config.X_shape
|
560 |
+
self.Y_shape = config.Y_shape
|
561 |
+
self.signal_dir = signal_dir
|
562 |
+
|
563 |
+
self.coord = coord
|
564 |
+
self.threads = []
|
565 |
+
self.queue_size = queue_size
|
566 |
+
self.add_placeholder()
|
567 |
+
|
568 |
+
def add_placeholder(self):
|
569 |
+
self.sample_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
570 |
+
self.ratio_placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=None)
|
571 |
+
self.fname_placeholder = tf.compat.v1.placeholder(dtype=tf.string, shape=None)
|
572 |
+
self.queue = tf.queue.PaddingFIFOQueue(
|
573 |
+
self.queue_size, ['float32', 'float32', 'string'], shapes=[self.config.X_shape, [], []]
|
574 |
+
)
|
575 |
+
self.enqueue = self.queue.enqueue([self.sample_placeholder, self.ratio_placeholder, self.fname_placeholder])
|
576 |
+
|
577 |
+
def dequeue(self, num_elements):
|
578 |
+
output = self.queue.dequeue_up_to(num_elements)
|
579 |
+
return output
|
580 |
+
|
581 |
+
def thread_main(self, sess, n_threads=1, start=0):
|
582 |
+
index = list(range(start, self.n_signal, n_threads))
|
583 |
+
shift = 0
|
584 |
+
for i in index:
|
585 |
+
fname = self.signal.iloc[i]['fname']
|
586 |
+
data_signal = np.load(os.path.join(self.signal_dir, fname))
|
587 |
+
f, t, tmp_signal = scipy.signal.stft(
|
588 |
+
scipy.signal.detrend(np.squeeze(data_signal['data'][shift : self.config.nt + shift])),
|
589 |
+
fs=self.config.fs,
|
590 |
+
nperseg=self.config.nperseg,
|
591 |
+
nfft=self.config.nfft,
|
592 |
+
boundary='zeros',
|
593 |
+
)
|
594 |
+
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1)
|
595 |
+
if np.isnan(noisy_signal).any() or np.isinf(noisy_signal).any() or (not np.any(noisy_signal)):
|
596 |
+
continue
|
597 |
+
std_noisy_signal = np.std(noisy_signal)
|
598 |
+
if std_noisy_signal == 0:
|
599 |
+
continue
|
600 |
+
noisy_signal = noisy_signal / std_noisy_signal
|
601 |
+
sess.run(
|
602 |
+
self.enqueue,
|
603 |
+
feed_dict={
|
604 |
+
self.sample_placeholder: noisy_signal,
|
605 |
+
self.ratio_placeholder: std_noisy_signal,
|
606 |
+
self.fname_placeholder: fname,
|
607 |
+
},
|
608 |
+
)
|
609 |
+
|
610 |
+
|
611 |
+
class DataReader_pred:
|
612 |
+
def __init__(self, signal_dir, signal_list, format="numpy", sampling_rate=100, config=Config()):
|
613 |
+
self.buffer = {}
|
614 |
+
self.config = config
|
615 |
+
self.format = format
|
616 |
+
self.dtype = "float32"
|
617 |
+
try:
|
618 |
+
signal_list = pd.read_csv(signal_list, sep="\t")["fname"]
|
619 |
+
except:
|
620 |
+
signal_list = pd.read_csv(signal_list)["fname"]
|
621 |
+
self.signal_list = signal_list
|
622 |
+
self.n_signal = len(self.signal_list)
|
623 |
+
self.signal_dir = signal_dir
|
624 |
+
self.sampling_rate = sampling_rate
|
625 |
+
self.n_class = config.n_class
|
626 |
+
FT_shape = self.get_data_shape()
|
627 |
+
self.X_shape = [*FT_shape, 2]
|
628 |
+
|
629 |
+
def get_data_shape(self):
|
630 |
+
# fname = self.signal_list.iloc[0]['fname']
|
631 |
+
# data = np.load(os.path.join(self.signal_dir, fname), allow_pickle=True)["data"]
|
632 |
+
# data = np.squeeze(data)
|
633 |
+
base_name = self.signal_list[0]
|
634 |
+
if self.format == "numpy":
|
635 |
+
meta = self.read_numpy(os.path.join(self.signal_dir, base_name))
|
636 |
+
elif self.format == "mseed":
|
637 |
+
meta = self.read_mseed(os.path.join(self.signal_dir, base_name))
|
638 |
+
elif self.format == "hdf5":
|
639 |
+
meta = self.read_hdf5(base_name)
|
640 |
+
|
641 |
+
data = meta["data"]
|
642 |
+
data = np.transpose(data, [2, 1, 0])
|
643 |
+
|
644 |
+
if self.sampling_rate != 100:
|
645 |
+
t = np.linspace(0, 1, data.shape[-1])
|
646 |
+
t_interp = np.linspace(0, 1, np.int(np.around(data.shape[-1] * 100.0 / self.sampling_rate)))
|
647 |
+
data = interp1d(t, data, kind="slinear")(t_interp)
|
648 |
+
f, t, tmp_signal = scipy.signal.stft(
|
649 |
+
data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
650 |
+
)
|
651 |
+
logging.info(f"Input data shape: {tmp_signal.shape} measured on file {base_name}")
|
652 |
+
return tmp_signal.shape
|
653 |
+
|
654 |
+
def __len__(self):
|
655 |
+
return self.n_signal
|
656 |
+
|
657 |
+
def read_numpy(self, fname):
|
658 |
+
# try:
|
659 |
+
if fname not in self.buffer:
|
660 |
+
npz = np.load(fname)
|
661 |
+
meta = {}
|
662 |
+
if len(npz['data'].shape) == 1:
|
663 |
+
meta["data"] = npz['data'][:, np.newaxis, np.newaxis]
|
664 |
+
elif len(npz['data'].shape) == 2:
|
665 |
+
meta["data"] = npz['data'][:, np.newaxis, :]
|
666 |
+
else:
|
667 |
+
meta["data"] = npz['data']
|
668 |
+
if "p_idx" in npz.files:
|
669 |
+
if len(npz["p_idx"].shape) == 0:
|
670 |
+
meta["itp"] = [[npz["p_idx"]]]
|
671 |
+
else:
|
672 |
+
meta["itp"] = npz["p_idx"]
|
673 |
+
if "s_idx" in npz.files:
|
674 |
+
if len(npz["s_idx"].shape) == 0:
|
675 |
+
meta["its"] = [[npz["s_idx"]]]
|
676 |
+
else:
|
677 |
+
meta["its"] = npz["s_idx"]
|
678 |
+
if "t0" in npz.files:
|
679 |
+
meta["t0"] = npz["t0"]
|
680 |
+
self.buffer[fname] = meta
|
681 |
+
else:
|
682 |
+
meta = self.buffer[fname]
|
683 |
+
return meta
|
684 |
+
# except:
|
685 |
+
# logging.error("Failed reading {}".format(fname))
|
686 |
+
# return None
|
687 |
+
|
688 |
+
def read_hdf5(self, fname):
|
689 |
+
data = self.h5_data[fname][()]
|
690 |
+
attrs = self.h5_data[fname].attrs
|
691 |
+
meta = {}
|
692 |
+
if len(data.shape) == 2:
|
693 |
+
meta["data"] = data[:, np.newaxis, :]
|
694 |
+
else:
|
695 |
+
meta["data"] = data
|
696 |
+
if "p_idx" in attrs:
|
697 |
+
if len(attrs["p_idx"].shape) == 0:
|
698 |
+
meta["itp"] = [[attrs["p_idx"]]]
|
699 |
+
else:
|
700 |
+
meta["itp"] = attrs["p_idx"]
|
701 |
+
if "s_idx" in attrs:
|
702 |
+
if len(attrs["s_idx"].shape) == 0:
|
703 |
+
meta["its"] = [[attrs["s_idx"]]]
|
704 |
+
else:
|
705 |
+
meta["its"] = attrs["s_idx"]
|
706 |
+
if "t0" in attrs:
|
707 |
+
meta["t0"] = attrs["t0"]
|
708 |
+
return meta
|
709 |
+
|
710 |
+
def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
|
711 |
+
with self.s3fs.open(bucket + "/" + fname, 'rb') as fp:
|
712 |
+
if format == "numpy":
|
713 |
+
meta = self.read_numpy(fp)
|
714 |
+
elif format == "mseed":
|
715 |
+
meta = self.read_mseed(fp)
|
716 |
+
else:
|
717 |
+
raise (f"Format {format} not supported")
|
718 |
+
return meta
|
719 |
+
|
720 |
+
def read_mseed(self, fname):
|
721 |
+
|
722 |
+
mseed = obspy.read(fname)
|
723 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
724 |
+
mseed = mseed.merge(fill_value=0)
|
725 |
+
starttime = min([st.stats.starttime for st in mseed])
|
726 |
+
endtime = max([st.stats.endtime for st in mseed])
|
727 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
728 |
+
if mseed[0].stats.sampling_rate != self.sampling_rate:
|
729 |
+
logging.warning(f"Sampling rate {mseed[0].stats.sampling_rate} != {self.sampling_rate} Hz")
|
730 |
+
|
731 |
+
order = ['3', '2', '1', 'E', 'N', 'Z']
|
732 |
+
order = {key: i for i, key in enumerate(order)}
|
733 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
734 |
+
|
735 |
+
t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
736 |
+
nt = len(mseed[0].data)
|
737 |
+
data = np.zeros([nt, 3], dtype=self.dtype)
|
738 |
+
ids = [x.get_id() for x in mseed]
|
739 |
+
if len(ids) == 3:
|
740 |
+
for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
|
741 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
742 |
+
else:
|
743 |
+
if len(ids) > 3:
|
744 |
+
logging.warning(f"More than 3 channels {ids}!")
|
745 |
+
for jj, id in enumerate(ids):
|
746 |
+
j = comp2idx[id[-1]]
|
747 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
748 |
+
|
749 |
+
data = data[:, np.newaxis, :]
|
750 |
+
meta = {"data": data, "t0": t0}
|
751 |
+
return meta
|
752 |
+
|
753 |
+
def __getitem__(self, i):
|
754 |
+
# fname = self.signal.iloc[i]['fname']
|
755 |
+
# data = np.load(os.path.join(self.signal_dir, fname), allow_pickle=True)["data"]
|
756 |
+
# data = np.squeeze(data)
|
757 |
+
base_name = self.signal_list[i]
|
758 |
+
|
759 |
+
if self.format == "numpy":
|
760 |
+
meta = self.read_numpy(os.path.join(self.signal_dir, base_name))
|
761 |
+
elif self.format == "mseed":
|
762 |
+
meta = self.read_mseed(os.path.join(self.signal_dir, base_name))
|
763 |
+
elif self.format == "hdf5":
|
764 |
+
meta = self.read_hdf5(base_name)
|
765 |
+
|
766 |
+
data = meta["data"] # nt, 1, nch
|
767 |
+
data = np.transpose(data, [2, 1, 0]) # nch, 1, nt
|
768 |
+
if np.mod(data.shape[-1], 3000) == 1: # 3001=>3000
|
769 |
+
data = data[..., :-1]
|
770 |
+
if "t0" in meta:
|
771 |
+
t0 = meta["t0"]
|
772 |
+
else:
|
773 |
+
t0 = "1970-01-01T00:00:00.000"
|
774 |
+
|
775 |
+
if self.sampling_rate != 100:
|
776 |
+
logging.warning(f"Resample from {self.sampling_rate} to 100!")
|
777 |
+
t = np.linspace(0, 1, data.shape[-1])
|
778 |
+
t_interp = np.linspace(0, 1, np.int(np.around(data.shape[-1] * 100.0 / self.sampling_rate)))
|
779 |
+
data = interp1d(t, data, kind="slinear")(t_interp)
|
780 |
+
# sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence
|
781 |
+
# data = scipy.signal.sosfilt(sos, data)
|
782 |
+
f, t, tmp_signal = scipy.signal.stft(
|
783 |
+
data, fs=self.config.fs, nperseg=self.config.nperseg, nfft=self.config.nfft, boundary='zeros'
|
784 |
+
) # nch, 1, nf, nt
|
785 |
+
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # nch, 1, nf, nt, 2
|
786 |
+
noisy_signal[np.isnan(noisy_signal)] = 0
|
787 |
+
noisy_signal[np.isinf(noisy_signal)] = 0
|
788 |
+
# noisy_signal, std_noisy_signal = normalize(noisy_signal)
|
789 |
+
# return noisy_signal.astype(self.dtype), std_noisy_signal.astype(self.dtype), fname
|
790 |
+
|
791 |
+
return noisy_signal.astype(self.dtype), base_name, t0
|
792 |
+
|
793 |
+
def dataset(self, batch_size, num_parallel_calls=4):
|
794 |
+
dataset = dataset_map(
|
795 |
+
self,
|
796 |
+
output_types=(self.dtype, "string", "string"),
|
797 |
+
output_shapes=(self.X_shape, None, None),
|
798 |
+
num_parallel_calls=num_parallel_calls,
|
799 |
+
)
|
800 |
+
dataset = tf.compat.v1.data.make_one_shot_iterator(
|
801 |
+
dataset.batch(batch_size).prefetch(batch_size * 3)
|
802 |
+
).get_next()
|
803 |
+
return dataset
|
804 |
+
|
805 |
+
|
806 |
+
if __name__ == "__main__":
|
807 |
+
|
808 |
+
# %%
|
809 |
+
data_reader = DataReader_pred(signal_dir="./Dataset/yixiao/", signal_list="./Dataset/yixiao.csv")
|
810 |
+
noisy_signal, std_noisy_signal, fname = data_reader[0]
|
811 |
+
print(noisy_signal.shape, std_noisy_signal.shape, fname)
|
812 |
+
batch = data_reader.dataset(10)
|
813 |
+
init = tf.compat.v1.initialize_all_variables()
|
814 |
+
sess = tf.compat.v1.Session()
|
815 |
+
sess.run(init)
|
816 |
+
print(sess.run(batch))
|
deepdenoiser/model.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
|
6 |
+
from util import *
|
7 |
+
|
8 |
+
tf.compat.v1.disable_eager_execution()
|
9 |
+
|
10 |
+
|
11 |
+
class ModelConfig:
|
12 |
+
|
13 |
+
batch_size = 20
|
14 |
+
depths = 6
|
15 |
+
filters_root = 8
|
16 |
+
kernel_size = [3, 3]
|
17 |
+
pool_size = [2, 2]
|
18 |
+
dilation_rate = [1, 1]
|
19 |
+
class_weights = [1.0, 1.0, 1.0]
|
20 |
+
loss_type = "cross_entropy"
|
21 |
+
weight_decay = 0.0
|
22 |
+
optimizer = "adam"
|
23 |
+
momentum = 0.9
|
24 |
+
learning_rate = 0.01
|
25 |
+
decay_step = 1e9
|
26 |
+
decay_rate = 0.9
|
27 |
+
drop_rate = 0.0
|
28 |
+
summary = True
|
29 |
+
|
30 |
+
X_shape = [31, 201, 2]
|
31 |
+
n_channel = X_shape[-1]
|
32 |
+
Y_shape = [31, 201, 2]
|
33 |
+
n_class = Y_shape[-1]
|
34 |
+
|
35 |
+
def __init__(self, **kwargs):
|
36 |
+
for k, v in kwargs.items():
|
37 |
+
setattr(self, k, v)
|
38 |
+
|
39 |
+
def update_args(self, args):
|
40 |
+
for k, v in vars(args).items():
|
41 |
+
setattr(self, k, v)
|
42 |
+
|
43 |
+
|
44 |
+
def crop_and_concat(net1, net2):
|
45 |
+
"""
|
46 |
+
the size(net1) <= size(net2)
|
47 |
+
"""
|
48 |
+
# net1_shape = net1.get_shape().as_list()
|
49 |
+
# net2_shape = net2.get_shape().as_list()
|
50 |
+
# # print(net1_shape)
|
51 |
+
# # print(net2_shape)
|
52 |
+
# # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
53 |
+
# offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
54 |
+
# size = [-1, net1_shape[1], net1_shape[2], -1]
|
55 |
+
# net2_resize = tf.slice(net2, offsets, size)
|
56 |
+
# return tf.concat([net1, net2_resize], 3)
|
57 |
+
# # else:
|
58 |
+
# # offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
|
59 |
+
# # size = [-1, net2_shape[1], net2_shape[2], -1]
|
60 |
+
# # net1_resize = tf.slice(net1, offsets, size)
|
61 |
+
# # return tf.concat([net1_resize, net2], 3)
|
62 |
+
|
63 |
+
## dynamic shape
|
64 |
+
chn1 = net1.get_shape().as_list()[-1]
|
65 |
+
chn2 = net2.get_shape().as_list()[-1]
|
66 |
+
net1_shape = tf.shape(net1)
|
67 |
+
net2_shape = tf.shape(net2)
|
68 |
+
# print(net1_shape)
|
69 |
+
# print(net2_shape)
|
70 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
71 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
72 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
73 |
+
net2_resize = tf.slice(net2, offsets, size)
|
74 |
+
|
75 |
+
out = tf.concat([net1, net2_resize], 3)
|
76 |
+
out.set_shape([None, None, None, chn1 + chn2])
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
def crop_only(net1, net2):
|
81 |
+
"""
|
82 |
+
the size(net1) <= size(net2)
|
83 |
+
"""
|
84 |
+
net1_shape = net1.get_shape().as_list()
|
85 |
+
net2_shape = net2.get_shape().as_list()
|
86 |
+
# print(net1_shape)
|
87 |
+
# print(net2_shape)
|
88 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
89 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
90 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
91 |
+
net2_resize = tf.slice(net2, offsets, size)
|
92 |
+
# return tf.concat([net1, net2_resize], 3)
|
93 |
+
return net2_resize
|
94 |
+
|
95 |
+
|
96 |
+
class UNet:
|
97 |
+
def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
|
98 |
+
self.depths = config.depths
|
99 |
+
self.filters_root = config.filters_root
|
100 |
+
self.kernel_size = config.kernel_size
|
101 |
+
self.dilation_rate = config.dilation_rate
|
102 |
+
self.pool_size = config.pool_size
|
103 |
+
self.X_shape = config.X_shape
|
104 |
+
self.Y_shape = config.Y_shape
|
105 |
+
self.n_channel = config.n_channel
|
106 |
+
self.n_class = config.n_class
|
107 |
+
self.class_weights = config.class_weights
|
108 |
+
self.batch_size = config.batch_size
|
109 |
+
self.loss_type = config.loss_type
|
110 |
+
self.weight_decay = config.weight_decay
|
111 |
+
self.optimizer = config.optimizer
|
112 |
+
self.decay_step = config.decay_step
|
113 |
+
self.decay_rate = config.decay_rate
|
114 |
+
self.momentum = config.momentum
|
115 |
+
self.learning_rate = config.learning_rate
|
116 |
+
self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
|
117 |
+
self.summary_train = []
|
118 |
+
self.summary_valid = []
|
119 |
+
|
120 |
+
self.build(input_batch, mode=mode)
|
121 |
+
|
122 |
+
def add_placeholders(self, input_batch=None, mode='train'):
|
123 |
+
if input_batch is None:
|
124 |
+
self.X = tf.compat.v1.placeholder(
|
125 |
+
dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X'
|
126 |
+
)
|
127 |
+
self.Y = tf.compat.v1.placeholder(
|
128 |
+
dtype=tf.float32, shape=[None, None, None, self.n_class], name='y'
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
self.X = input_batch[0]
|
132 |
+
if mode in ["train", "valid", "test"]:
|
133 |
+
self.Y = input_batch[1]
|
134 |
+
self.input_batch = input_batch
|
135 |
+
|
136 |
+
self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
|
137 |
+
# self.keep_prob = tf.placeholder(dtype=tf.float32, name="keep_prob")
|
138 |
+
self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
|
139 |
+
# self.learning_rate = tf.placeholder_with_default(tf.constant(0.01, dtype=tf.float32), shape=[], name="learning_rate")
|
140 |
+
# self.global_step = tf.placeholder_with_default(tf.constant(0, dtype=tf.int32), shape=[], name="global_step")
|
141 |
+
|
142 |
+
def add_prediction_op(self):
|
143 |
+
logging.info(
|
144 |
+
"Model: depths {depths}, filters {filters}, "
|
145 |
+
"filter size {kernel_size[0]}x{kernel_size[1]}, "
|
146 |
+
"pool size: {pool_size[0]}x{pool_size[1]}, "
|
147 |
+
"dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
|
148 |
+
depths=self.depths,
|
149 |
+
filters=self.filters_root,
|
150 |
+
kernel_size=self.kernel_size,
|
151 |
+
dilation_rate=self.dilation_rate,
|
152 |
+
pool_size=self.pool_size,
|
153 |
+
)
|
154 |
+
)
|
155 |
+
|
156 |
+
if self.weight_decay > 0:
|
157 |
+
weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
|
158 |
+
self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
|
159 |
+
else:
|
160 |
+
self.regularizer = None
|
161 |
+
|
162 |
+
self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(
|
163 |
+
scale=1.0, mode="fan_avg", distribution="uniform"
|
164 |
+
)
|
165 |
+
|
166 |
+
# down sample layers
|
167 |
+
convs = [None] * self.depths # store output of each depth
|
168 |
+
|
169 |
+
with tf.compat.v1.variable_scope("Input"):
|
170 |
+
net = self.X
|
171 |
+
net = tf.compat.v1.layers.conv2d(
|
172 |
+
net,
|
173 |
+
filters=self.filters_root,
|
174 |
+
kernel_size=self.kernel_size,
|
175 |
+
activation=None,
|
176 |
+
use_bias=False,
|
177 |
+
padding='same',
|
178 |
+
dilation_rate=self.dilation_rate,
|
179 |
+
kernel_initializer=self.initializer,
|
180 |
+
kernel_regularizer=self.regularizer,
|
181 |
+
# bias_regularizer=self.regularizer,
|
182 |
+
name="input_conv",
|
183 |
+
)
|
184 |
+
net = tf.compat.v1.layers.batch_normalization(net, training=self.is_training, name="input_bn")
|
185 |
+
net = tf.nn.relu(net, name="input_relu")
|
186 |
+
# net = tf.nn.dropout(net, self.keep_prob)
|
187 |
+
net = tf.compat.v1.layers.dropout(net, rate=self.drop_rate, training=self.is_training, name="input_dropout")
|
188 |
+
|
189 |
+
for depth in range(0, self.depths):
|
190 |
+
with tf.compat.v1.variable_scope("DownConv_%d" % depth):
|
191 |
+
filters = int(2 ** (depth) * self.filters_root)
|
192 |
+
|
193 |
+
net = tf.compat.v1.layers.conv2d(
|
194 |
+
net,
|
195 |
+
filters=filters,
|
196 |
+
kernel_size=self.kernel_size,
|
197 |
+
activation=None,
|
198 |
+
use_bias=False,
|
199 |
+
padding='same',
|
200 |
+
dilation_rate=self.dilation_rate,
|
201 |
+
kernel_initializer=self.initializer,
|
202 |
+
kernel_regularizer=self.regularizer,
|
203 |
+
# bias_regularizer=self.regularizer,
|
204 |
+
name="down_conv1_{}".format(depth + 1),
|
205 |
+
)
|
206 |
+
net = tf.compat.v1.layers.batch_normalization(
|
207 |
+
net, training=self.is_training, name="down_bn1_{}".format(depth + 1)
|
208 |
+
)
|
209 |
+
net = tf.nn.relu(net, name="down_relu1_{}".format(depth + 1))
|
210 |
+
net = tf.compat.v1.layers.dropout(
|
211 |
+
net, rate=self.drop_rate, training=self.is_training, name="down_dropout1_{}".format(depth + 1)
|
212 |
+
)
|
213 |
+
|
214 |
+
convs[depth] = net
|
215 |
+
|
216 |
+
if depth < self.depths - 1:
|
217 |
+
net = tf.compat.v1.layers.conv2d(
|
218 |
+
net,
|
219 |
+
filters=filters,
|
220 |
+
kernel_size=self.kernel_size,
|
221 |
+
strides=self.pool_size,
|
222 |
+
activation=None,
|
223 |
+
use_bias=False,
|
224 |
+
padding='same',
|
225 |
+
# dilation_rate=self.dilation_rate,
|
226 |
+
kernel_initializer=self.initializer,
|
227 |
+
kernel_regularizer=self.regularizer,
|
228 |
+
# bias_regularizer=self.regularizer,
|
229 |
+
name="down_conv3_{}".format(depth + 1),
|
230 |
+
)
|
231 |
+
net = tf.compat.v1.layers.batch_normalization(
|
232 |
+
net, training=self.is_training, name="down_bn3_{}".format(depth + 1)
|
233 |
+
)
|
234 |
+
net = tf.nn.relu(net, name="down_relu3_{}".format(depth + 1))
|
235 |
+
net = tf.compat.v1.layers.dropout(
|
236 |
+
net, rate=self.drop_rate, training=self.is_training, name="down_dropout3_{}".format(depth + 1)
|
237 |
+
)
|
238 |
+
|
239 |
+
# up layers
|
240 |
+
for depth in range(self.depths - 2, -1, -1):
|
241 |
+
with tf.compat.v1.variable_scope("UpConv_%d" % depth):
|
242 |
+
filters = int(2 ** (depth) * self.filters_root)
|
243 |
+
net = tf.compat.v1.layers.conv2d_transpose(
|
244 |
+
net,
|
245 |
+
filters=filters,
|
246 |
+
kernel_size=self.kernel_size,
|
247 |
+
strides=self.pool_size,
|
248 |
+
activation=None,
|
249 |
+
use_bias=False,
|
250 |
+
padding="same",
|
251 |
+
kernel_initializer=self.initializer,
|
252 |
+
kernel_regularizer=self.regularizer,
|
253 |
+
# bias_regularizer=self.regularizer,
|
254 |
+
name="up_conv0_{}".format(depth + 1),
|
255 |
+
)
|
256 |
+
net = tf.compat.v1.layers.batch_normalization(
|
257 |
+
net, training=self.is_training, name="up_bn0_{}".format(depth + 1)
|
258 |
+
)
|
259 |
+
net = tf.nn.relu(net, name="up_relu0_{}".format(depth + 1))
|
260 |
+
net = tf.compat.v1.layers.dropout(
|
261 |
+
net, rate=self.drop_rate, training=self.is_training, name="up_dropout0_{}".format(depth + 1)
|
262 |
+
)
|
263 |
+
|
264 |
+
# skip connection
|
265 |
+
net = crop_and_concat(convs[depth], net)
|
266 |
+
# net = crop_only(convs[depth], net)
|
267 |
+
|
268 |
+
net = tf.compat.v1.layers.conv2d(
|
269 |
+
net,
|
270 |
+
filters=filters,
|
271 |
+
kernel_size=self.kernel_size,
|
272 |
+
activation=None,
|
273 |
+
use_bias=False,
|
274 |
+
padding='same',
|
275 |
+
dilation_rate=self.dilation_rate,
|
276 |
+
kernel_initializer=self.initializer,
|
277 |
+
kernel_regularizer=self.regularizer,
|
278 |
+
# bias_regularizer=self.regularizer,
|
279 |
+
name="up_conv1_{}".format(depth + 1),
|
280 |
+
)
|
281 |
+
net = tf.compat.v1.layers.batch_normalization(
|
282 |
+
net, training=self.is_training, name="up_bn1_{}".format(depth + 1)
|
283 |
+
)
|
284 |
+
net = tf.nn.relu(net, name="up_relu1_{}".format(depth + 1))
|
285 |
+
net = tf.compat.v1.layers.dropout(
|
286 |
+
net, rate=self.drop_rate, training=self.is_training, name="up_dropout1_{}".format(depth + 1)
|
287 |
+
)
|
288 |
+
|
289 |
+
# Output Map
|
290 |
+
with tf.compat.v1.variable_scope("Output"):
|
291 |
+
net = tf.compat.v1.layers.conv2d(
|
292 |
+
net,
|
293 |
+
filters=self.n_class,
|
294 |
+
kernel_size=(1, 1),
|
295 |
+
activation=None,
|
296 |
+
use_bias=True,
|
297 |
+
padding='same',
|
298 |
+
# dilation_rate=self.dilation_rate,
|
299 |
+
kernel_initializer=self.initializer,
|
300 |
+
kernel_regularizer=self.regularizer,
|
301 |
+
# bias_regularizer=self.regularizer,
|
302 |
+
name="output_conv",
|
303 |
+
)
|
304 |
+
# net = tf.nn.relu(net,
|
305 |
+
# name="output_relu")
|
306 |
+
# net = tf.layers.dropout(net,
|
307 |
+
# rate=self.drop_rate,
|
308 |
+
# training=self.is_training,
|
309 |
+
# name="output_dropout")
|
310 |
+
# net = tf.layers.batch_normalization(net,
|
311 |
+
# training=self.is_training,
|
312 |
+
# name="output_bn")
|
313 |
+
output = net
|
314 |
+
|
315 |
+
with tf.compat.v1.variable_scope("representation"):
|
316 |
+
self.representation = convs[-1]
|
317 |
+
|
318 |
+
with tf.compat.v1.variable_scope("logits"):
|
319 |
+
self.logits = output
|
320 |
+
tmp = tf.compat.v1.summary.histogram("logits", self.logits)
|
321 |
+
self.summary_train.append(tmp)
|
322 |
+
|
323 |
+
with tf.compat.v1.variable_scope("preds"):
|
324 |
+
self.preds = tf.nn.softmax(output)
|
325 |
+
tmp = tf.compat.v1.summary.histogram("preds", self.preds)
|
326 |
+
self.summary_train.append(tmp)
|
327 |
+
|
328 |
+
def add_loss_op(self):
|
329 |
+
if self.loss_type == "cross_entropy":
|
330 |
+
with tf.compat.v1.variable_scope("cross_entropy"):
|
331 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
332 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
333 |
+
if (np.array(self.class_weights) != 1).any():
|
334 |
+
class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
|
335 |
+
weight_map = tf.multiply(flat_labels, class_weights)
|
336 |
+
weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
|
337 |
+
loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
|
338 |
+
# loss_map = tf.nn.sigmoid_cross_entropy_with_logits(logits=flat_logits,
|
339 |
+
# labels=flat_labels)
|
340 |
+
weighted_loss = tf.multiply(loss_map, weight_map)
|
341 |
+
loss = tf.reduce_mean(input_tensor=weighted_loss)
|
342 |
+
else:
|
343 |
+
loss = tf.reduce_mean(
|
344 |
+
input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels)
|
345 |
+
)
|
346 |
+
# loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=flat_logits,
|
347 |
+
# labels=flat_labels))
|
348 |
+
elif self.loss_type == "IOU":
|
349 |
+
with tf.compat.v1.variable_scope("IOU"):
|
350 |
+
eps = 1e-7
|
351 |
+
loss = 0
|
352 |
+
for i in range(1, self.n_class):
|
353 |
+
intersection = eps + tf.reduce_sum(
|
354 |
+
input_tensor=self.preds[:, :, :, i] * self.Y[:, :, :, i], axis=[1, 2]
|
355 |
+
)
|
356 |
+
union = (
|
357 |
+
eps
|
358 |
+
+ tf.reduce_sum(input_tensor=self.preds[:, :, :, i], axis=[1, 2])
|
359 |
+
+ tf.reduce_sum(input_tensor=self.Y[:, :, :, i], axis=[1, 2])
|
360 |
+
)
|
361 |
+
loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
|
362 |
+
elif self.loss_type == "mean_squared":
|
363 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
364 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
365 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
366 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
367 |
+
loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
|
368 |
+
else:
|
369 |
+
raise ValueError("Unknown loss function: " % self.loss_type)
|
370 |
+
|
371 |
+
tmp = tf.compat.v1.summary.scalar("train_loss", loss)
|
372 |
+
self.summary_train.append(tmp)
|
373 |
+
tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
|
374 |
+
self.summary_valid.append(tmp)
|
375 |
+
|
376 |
+
if self.weight_decay > 0:
|
377 |
+
with tf.compat.v1.name_scope('weight_loss'):
|
378 |
+
tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
379 |
+
weight_loss = tf.add_n(tmp, name="weight_loss")
|
380 |
+
self.loss = loss + weight_loss
|
381 |
+
else:
|
382 |
+
self.loss = loss
|
383 |
+
|
384 |
+
def add_training_op(self):
|
385 |
+
if self.optimizer == "momentum":
|
386 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
387 |
+
learning_rate=self.learning_rate,
|
388 |
+
global_step=self.global_step,
|
389 |
+
decay_steps=self.decay_step,
|
390 |
+
decay_rate=self.decay_rate,
|
391 |
+
staircase=True,
|
392 |
+
)
|
393 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(
|
394 |
+
learning_rate=self.learning_rate_node, momentum=self.momentum
|
395 |
+
)
|
396 |
+
elif self.optimizer == "adam":
|
397 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
398 |
+
learning_rate=self.learning_rate,
|
399 |
+
global_step=self.global_step,
|
400 |
+
decay_steps=self.decay_step,
|
401 |
+
decay_rate=self.decay_rate,
|
402 |
+
staircase=True,
|
403 |
+
)
|
404 |
+
|
405 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
|
406 |
+
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
|
407 |
+
with tf.control_dependencies(update_ops):
|
408 |
+
self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
|
409 |
+
tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
|
410 |
+
self.summary_train.append(tmp)
|
411 |
+
|
412 |
+
def reset_learning_rate(self, sess, learning_rate, global_step):
|
413 |
+
self.learning_rate = learning_rate
|
414 |
+
assign_op = self.global_step.assign(global_step)
|
415 |
+
sess.run(assign_op)
|
416 |
+
if self.optimizer == "momentum":
|
417 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
418 |
+
learning_rate=learning_rate,
|
419 |
+
global_step=self.global_step,
|
420 |
+
decay_steps=self.decay_step,
|
421 |
+
decay_rate=self.decay_rate,
|
422 |
+
staircase=True,
|
423 |
+
)
|
424 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(
|
425 |
+
learning_rate=self.learning_rate_node, momentum=self.momentum
|
426 |
+
)
|
427 |
+
elif self.optimizer == "adam":
|
428 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(
|
429 |
+
learning_rate=self.learning_rate,
|
430 |
+
global_step=self.global_step,
|
431 |
+
decay_steps=self.decay_step,
|
432 |
+
decay_rate=self.decay_rate,
|
433 |
+
staircase=True,
|
434 |
+
)
|
435 |
+
|
436 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
|
437 |
+
|
438 |
+
def train_on_batch(self, sess, X_batch, Y_batch, summary_writer, drop_rate=0.0):
|
439 |
+
feed = {self.drop_rate: drop_rate, self.is_training: True, self.X: X_batch, self.Y: Y_batch}
|
440 |
+
_, step_summary, step, loss = sess.run(
|
441 |
+
[self.train_op, self.summary_train, self.global_step, self.loss], feed_dict=feed
|
442 |
+
)
|
443 |
+
summary_writer.add_summary(step_summary, step)
|
444 |
+
return loss
|
445 |
+
|
446 |
+
def valid_on_batch(self, sess, X_batch, Y_batch, summary_writer, drop_rate=0.0):
|
447 |
+
feed = {self.drop_rate: drop_rate, self.is_training: False, self.X: X_batch, self.Y: Y_batch}
|
448 |
+
step_summary, step, loss, preds = sess.run(
|
449 |
+
[self.summary_valid, self.global_step, self.loss, self.preds], feed_dict=feed
|
450 |
+
)
|
451 |
+
summary_writer.add_summary(step_summary, step)
|
452 |
+
return loss, preds
|
453 |
+
|
454 |
+
def test_on_batch(self, sess, summary_writer):
|
455 |
+
feed = {self.drop_rate: 0, self.is_training: False}
|
456 |
+
(
|
457 |
+
step_summary,
|
458 |
+
step,
|
459 |
+
loss,
|
460 |
+
preds,
|
461 |
+
X_batch,
|
462 |
+
Y_batch,
|
463 |
+
ratio_batch,
|
464 |
+
signal_batch,
|
465 |
+
noise_batch,
|
466 |
+
fname_batch,
|
467 |
+
) = sess.run(
|
468 |
+
[
|
469 |
+
self.summary_valid,
|
470 |
+
self.global_step,
|
471 |
+
self.loss,
|
472 |
+
self.preds,
|
473 |
+
self.X,
|
474 |
+
self.Y,
|
475 |
+
self.input_batch[2],
|
476 |
+
self.input_batch[3],
|
477 |
+
self.input_batch[4],
|
478 |
+
self.input_batch[5],
|
479 |
+
],
|
480 |
+
feed_dict=feed,
|
481 |
+
)
|
482 |
+
summary_writer.add_summary(step_summary, step)
|
483 |
+
|
484 |
+
return loss, preds, X_batch, Y_batch, ratio_batch, signal_batch, noise_batch, fname_batch
|
485 |
+
|
486 |
+
def build(self, input_batch=None, mode='train'):
|
487 |
+
self.add_placeholders(input_batch, mode)
|
488 |
+
self.add_prediction_op()
|
489 |
+
if mode in ["train", "valid", "test"]:
|
490 |
+
self.add_loss_op()
|
491 |
+
self.add_training_op()
|
492 |
+
# self.add_metrics_op()
|
493 |
+
self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
|
494 |
+
self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
|
495 |
+
return 0
|
deepdenoiser/predict.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import multiprocessing
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from data_reader import DataReader_pred, normalize_batch
|
13 |
+
from model import UNet
|
14 |
+
from util import *
|
15 |
+
|
16 |
+
tf.compat.v1.disable_eager_execution()
|
17 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
18 |
+
|
19 |
+
|
20 |
+
def read_args():
|
21 |
+
"""Returns args"""
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
|
25 |
+
parser.add_argument("--format", default="numpy", type=str, help="Input data format: numpy or mseed")
|
26 |
+
parser.add_argument("--batch_size", default=20, type=int, help="Batch size")
|
27 |
+
parser.add_argument("--output_dir", default="output", help="Output directory (default: output)")
|
28 |
+
parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
|
29 |
+
parser.add_argument("--sampling_rate", default=100, type=int, help="sampling rate of pred data")
|
30 |
+
parser.add_argument("--data_dir", default="./Dataset/pred/", help="Input file directory")
|
31 |
+
parser.add_argument("--data_list", default="./Dataset/pred.csv", help="Input csv file")
|
32 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure")
|
33 |
+
parser.add_argument("--save_signal", action="store_true", help="If save denoised signal")
|
34 |
+
parser.add_argument("--save_noise", action="store_true", help="If save denoised noise")
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None):
|
41 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
42 |
+
if log_dir is None:
|
43 |
+
log_dir = os.path.join(args.log_dir, "pred", current_time)
|
44 |
+
logging.info("Pred log: %s" % log_dir)
|
45 |
+
# logging.info("Dataset size: {}".format(data_reader.num_data))
|
46 |
+
if not os.path.exists(log_dir):
|
47 |
+
os.makedirs(log_dir)
|
48 |
+
if args.plot_figure:
|
49 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
50 |
+
os.makedirs(figure_dir, exist_ok=True)
|
51 |
+
if args.save_signal or args.save_noise:
|
52 |
+
result_dir = os.path.join(log_dir, 'results')
|
53 |
+
os.makedirs(result_dir, exist_ok=True)
|
54 |
+
|
55 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
56 |
+
data_batch = data_reader.dataset(args.batch_size)
|
57 |
+
|
58 |
+
# model = UNet(input_batch=data_batch, mode='pred')
|
59 |
+
model = UNet(mode='pred')
|
60 |
+
sess_config = tf.compat.v1.ConfigProto()
|
61 |
+
sess_config.gpu_options.allow_growth = True
|
62 |
+
# sess_config.log_device_placement = False
|
63 |
+
|
64 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
65 |
+
|
66 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
67 |
+
init = tf.compat.v1.global_variables_initializer()
|
68 |
+
sess.run(init)
|
69 |
+
|
70 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
71 |
+
logging.info(f"restoring models: {latest_check_point}")
|
72 |
+
saver.restore(sess, latest_check_point)
|
73 |
+
|
74 |
+
if args.plot_figure:
|
75 |
+
num_pool = multiprocessing.cpu_count()
|
76 |
+
else:
|
77 |
+
num_pool = 2
|
78 |
+
multiprocessing.set_start_method('spawn')
|
79 |
+
pool = multiprocessing.Pool(num_pool)
|
80 |
+
for _ in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"):
|
81 |
+
X_batch, fname_batch, t0_batch = sess.run(data_batch)
|
82 |
+
nbt, nch, nst, nf, nt, nimg = X_batch.shape
|
83 |
+
X_batch_ = np.reshape(X_batch, [nbt * nch * nst, nf, nt, nimg])
|
84 |
+
X_batch_ = normalize_batch(X_batch_)
|
85 |
+
preds_batch = sess.run(
|
86 |
+
model.preds,
|
87 |
+
feed_dict={model.X: X_batch_, model.drop_rate: 0, model.is_training: False},
|
88 |
+
)
|
89 |
+
preds_batch = np.reshape(preds_batch, [nbt, nch, nst, nf, nt, preds_batch.shape[-1]])
|
90 |
+
# preds_batch, X_batch, ratio_batch, fname_batch = sess.run(
|
91 |
+
# [model.preds, data_batch[0], data_batch[1], data_batch[2]],
|
92 |
+
# feed_dict={model.drop_rate: 0, model.is_training: False},
|
93 |
+
# )
|
94 |
+
|
95 |
+
if args.save_signal or args.save_noise:
|
96 |
+
save_results(
|
97 |
+
preds_batch,
|
98 |
+
X_batch,
|
99 |
+
fname=[x.decode() for x in fname_batch],
|
100 |
+
t0=[x.decode() for x in t0_batch],
|
101 |
+
save_signal=args.save_signal,
|
102 |
+
save_noise=args.save_noise,
|
103 |
+
result_dir=result_dir,
|
104 |
+
)
|
105 |
+
|
106 |
+
if args.plot_figure:
|
107 |
+
pool.starmap(
|
108 |
+
partial(
|
109 |
+
plot_figures,
|
110 |
+
figure_dir=figure_dir,
|
111 |
+
),
|
112 |
+
zip(preds_batch, X_batch, [x.decode() for x in fname_batch]),
|
113 |
+
)
|
114 |
+
|
115 |
+
pool.close()
|
116 |
+
|
117 |
+
return 0
|
118 |
+
|
119 |
+
|
120 |
+
def main(args):
|
121 |
+
|
122 |
+
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
123 |
+
|
124 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
125 |
+
data_reader = DataReader_pred(
|
126 |
+
format=args.format, signal_dir=args.data_dir, signal_list=args.data_list, sampling_rate=args.sampling_rate
|
127 |
+
)
|
128 |
+
logging.info("Dataset Size: {}".format(data_reader.n_signal))
|
129 |
+
pred_fn(args, data_reader, log_dir=args.output_dir)
|
130 |
+
|
131 |
+
return 0
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == '__main__':
|
135 |
+
args = read_args()
|
136 |
+
main(args)
|
deepdenoiser/train.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#import warnings
|
2 |
+
#warnings.filterwarnings('ignore', category=FutureWarning)
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
tf.compat.v1.disable_eager_execution()
|
6 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
import logging
|
11 |
+
from model import UNet
|
12 |
+
from data_reader import *
|
13 |
+
from util import *
|
14 |
+
from tqdm import tqdm
|
15 |
+
import multiprocessing
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
|
19 |
+
def read_args():
|
20 |
+
"""Returns args"""
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
|
24 |
+
parser.add_argument("--mode",
|
25 |
+
default="train",
|
26 |
+
help="train/valid/test/debug (default: train)")
|
27 |
+
|
28 |
+
parser.add_argument("--epochs",
|
29 |
+
default=10,
|
30 |
+
type=int,
|
31 |
+
help="Number of epochs (default: 10)")
|
32 |
+
|
33 |
+
parser.add_argument("--batch_size",
|
34 |
+
default=20,
|
35 |
+
type=int,
|
36 |
+
help="Batch size (default: 20)")
|
37 |
+
|
38 |
+
parser.add_argument("--learning_rate",
|
39 |
+
default=0.001,
|
40 |
+
type=float,
|
41 |
+
help="learning rate (default: 0.001)")
|
42 |
+
|
43 |
+
parser.add_argument("--decay_step",
|
44 |
+
default=-1,
|
45 |
+
type=int,
|
46 |
+
help="decay step (default: -1)")
|
47 |
+
|
48 |
+
parser.add_argument("--decay_rate",
|
49 |
+
default=0.9,
|
50 |
+
type=float,
|
51 |
+
help="decay rate (default: 0.9)")
|
52 |
+
|
53 |
+
parser.add_argument("--momentum",
|
54 |
+
default=0.9,
|
55 |
+
type=float,
|
56 |
+
help="momentum (default: 0.9)")
|
57 |
+
|
58 |
+
parser.add_argument("--filters_root",
|
59 |
+
default=8,
|
60 |
+
type=int,
|
61 |
+
help="filters root (default: 8)")
|
62 |
+
|
63 |
+
parser.add_argument("--depth",
|
64 |
+
default=6,
|
65 |
+
type=int,
|
66 |
+
help="depth (default: 6)")
|
67 |
+
|
68 |
+
parser.add_argument("--kernel_size",
|
69 |
+
nargs="+",
|
70 |
+
type=int,
|
71 |
+
default=[3, 3],
|
72 |
+
help="kernel size (default: [3, 3]")
|
73 |
+
|
74 |
+
parser.add_argument("--pool_size",
|
75 |
+
nargs="+",
|
76 |
+
type=int,
|
77 |
+
default=[2, 2],
|
78 |
+
help="pool size (default: [2, 2]")
|
79 |
+
|
80 |
+
parser.add_argument("--drop_rate",
|
81 |
+
default=0,
|
82 |
+
type=float,
|
83 |
+
help="drop out rate (default: 0)")
|
84 |
+
|
85 |
+
parser.add_argument("--dilation_rate",
|
86 |
+
nargs="+",
|
87 |
+
type=int,
|
88 |
+
default=[1, 1],
|
89 |
+
help="dilation_rate (default: [1, 1]")
|
90 |
+
|
91 |
+
parser.add_argument("--loss_type",
|
92 |
+
default="cross_entropy",
|
93 |
+
help="loss type: cross_entropy, IOU, mean_squared (default: cross_entropy)")
|
94 |
+
|
95 |
+
parser.add_argument("--weight_decay",
|
96 |
+
default=0,
|
97 |
+
type=float,
|
98 |
+
help="weight decay (default: 0)")
|
99 |
+
|
100 |
+
parser.add_argument("--optimizer",
|
101 |
+
default="adam",
|
102 |
+
help="optimizer: adam, momentum (default: adam)")
|
103 |
+
|
104 |
+
parser.add_argument("--summary",
|
105 |
+
default=True,
|
106 |
+
type=bool,
|
107 |
+
help="summary (default: True)")
|
108 |
+
|
109 |
+
parser.add_argument("--class_weights",
|
110 |
+
nargs="+",
|
111 |
+
default=[1, 1],
|
112 |
+
type=float,
|
113 |
+
help="class weights (default: [1, 1]")
|
114 |
+
|
115 |
+
parser.add_argument("--log_dir",
|
116 |
+
default="log",
|
117 |
+
help="Tensorboard log directory (default: log)")
|
118 |
+
|
119 |
+
parser.add_argument("--model_dir",
|
120 |
+
default=None,
|
121 |
+
help="Checkpoint directory")
|
122 |
+
|
123 |
+
parser.add_argument("--num_plots",
|
124 |
+
default=10,
|
125 |
+
type=int,
|
126 |
+
help="plotting trainning result (default: 10)")
|
127 |
+
|
128 |
+
parser.add_argument("--input_length",
|
129 |
+
default=None,
|
130 |
+
type=int,
|
131 |
+
help="input length")
|
132 |
+
parser.add_argument("--sampling_rate",
|
133 |
+
default=100,
|
134 |
+
type=int,
|
135 |
+
help="sampling rate of pred data in Hz (default: 100)")
|
136 |
+
|
137 |
+
parser.add_argument("--train_signal_dir",
|
138 |
+
default="./Dataset/train/",
|
139 |
+
help="Input file directory (default: ./Dataset/train/)")
|
140 |
+
parser.add_argument("--train_signal_list",
|
141 |
+
default="./Dataset/train.csv",
|
142 |
+
help="Input csv file (default: ./Dataset/train.csv)")
|
143 |
+
parser.add_argument("--train_noise_dir",
|
144 |
+
default="./Dataset/train/",
|
145 |
+
help="Input file directory (default: ./Dataset/train/)")
|
146 |
+
parser.add_argument("--train_noise_list",
|
147 |
+
default="./Dataset/train.csv",
|
148 |
+
help="Input csv file (default: ./Dataset/train.csv)")
|
149 |
+
|
150 |
+
parser.add_argument("--valid_signal_dir",
|
151 |
+
default="./Dataset/",
|
152 |
+
help="Input file directory (default: ./Dataset/)")
|
153 |
+
parser.add_argument("--valid_signal_list",
|
154 |
+
default=None,
|
155 |
+
help="Input csv file")
|
156 |
+
parser.add_argument("--valid_noise_dir",
|
157 |
+
default="./Dataset/",
|
158 |
+
help="Input file directory (default: ./Dataset/)")
|
159 |
+
parser.add_argument("--valid_noise_list",
|
160 |
+
default=None,
|
161 |
+
help="Input csv file")
|
162 |
+
|
163 |
+
parser.add_argument("--data_dir",
|
164 |
+
default="./Dataset/pred/",
|
165 |
+
help="Input file directory (default: ./Dataset/pred/)")
|
166 |
+
parser.add_argument("--data_list",
|
167 |
+
default="./Dataset/pred.csv",
|
168 |
+
help="Input csv file (default: ./Dataset/pred.csv)")
|
169 |
+
|
170 |
+
parser.add_argument("--output_dir",
|
171 |
+
default=None,
|
172 |
+
help="Output directory")
|
173 |
+
|
174 |
+
parser.add_argument("--fpred",
|
175 |
+
default="preds.npz",
|
176 |
+
help="ouput file name of test data")
|
177 |
+
parser.add_argument("--plot_figure",
|
178 |
+
action="store_true",
|
179 |
+
help="If plot figure for test")
|
180 |
+
parser.add_argument("--save_result",
|
181 |
+
action="store_true",
|
182 |
+
help="If save result for test")
|
183 |
+
|
184 |
+
args = parser.parse_args()
|
185 |
+
return args
|
186 |
+
|
187 |
+
|
188 |
+
def set_config(args, data_reader):
|
189 |
+
config = Config()
|
190 |
+
|
191 |
+
config.X_shape = data_reader.X_shape
|
192 |
+
config.n_channel = config.X_shape[-1]
|
193 |
+
config.Y_shape = data_reader.Y_shape
|
194 |
+
config.n_class = config.Y_shape[-1]
|
195 |
+
|
196 |
+
config.depths = args.depth
|
197 |
+
config.filters_root = args.filters_root
|
198 |
+
config.kernel_size = args.kernel_size
|
199 |
+
config.pool_size = args.pool_size
|
200 |
+
config.dilation_rate = args.dilation_rate
|
201 |
+
config.batch_size = args.batch_size
|
202 |
+
config.class_weights = args.class_weights
|
203 |
+
config.loss_type = args.loss_type
|
204 |
+
config.weight_decay = args.weight_decay
|
205 |
+
config.optimizer = args.optimizer
|
206 |
+
|
207 |
+
config.learning_rate = args.learning_rate
|
208 |
+
if (args.decay_step == -1) and (args.mode == 'train'):
|
209 |
+
config.decay_step = data_reader.n_signal // args.batch_size
|
210 |
+
else:
|
211 |
+
config.decay_step = args.decay_step
|
212 |
+
config.decay_rate = args.decay_rate
|
213 |
+
config.momentum = args.momentum
|
214 |
+
|
215 |
+
config.summary = args.summary
|
216 |
+
config.drop_rate = args.drop_rate
|
217 |
+
config.class_weights = args.class_weights
|
218 |
+
|
219 |
+
return config
|
220 |
+
|
221 |
+
|
222 |
+
def train_fn(args, data_reader, data_reader_valid=None):
|
223 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
224 |
+
log_dir = os.path.join(args.log_dir, current_time)
|
225 |
+
logging.info("Training log: {}".format(log_dir))
|
226 |
+
if not os.path.exists(log_dir):
|
227 |
+
os.makedirs(log_dir)
|
228 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
229 |
+
if not os.path.exists(figure_dir):
|
230 |
+
os.makedirs(figure_dir)
|
231 |
+
|
232 |
+
config = set_config(args, data_reader)
|
233 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
234 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
235 |
+
|
236 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
237 |
+
batch = data_reader.dequeue(args.batch_size)
|
238 |
+
if data_reader_valid is not None:
|
239 |
+
batch_valid = data_reader_valid.dequeue(args.batch_size)
|
240 |
+
|
241 |
+
model = UNet(config)
|
242 |
+
sess_config = tf.compat.v1.ConfigProto()
|
243 |
+
sess_config.gpu_options.allow_growth = True
|
244 |
+
sess_config.log_device_placement = False
|
245 |
+
|
246 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
247 |
+
|
248 |
+
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
|
249 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
250 |
+
init = tf.compat.v1.global_variables_initializer()
|
251 |
+
sess.run(init)
|
252 |
+
|
253 |
+
if args.model_dir is not None:
|
254 |
+
logging.info("restoring models...")
|
255 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
256 |
+
saver.restore(sess, latest_check_point)
|
257 |
+
model.reset_learning_rate(sess, learning_rate=0.01, global_step=0)
|
258 |
+
|
259 |
+
|
260 |
+
threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
261 |
+
if data_reader_valid is not None:
|
262 |
+
threads_valid = data_reader_valid.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
263 |
+
flog = open(os.path.join(log_dir, 'loss.log'), 'w')
|
264 |
+
|
265 |
+
total_step = 0
|
266 |
+
mean_loss = 0
|
267 |
+
pool = multiprocessing.Pool(2)
|
268 |
+
for epoch in range(args.epochs):
|
269 |
+
progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc="{}: ".format(log_dir.split("/")[-1]))
|
270 |
+
for step in progressbar:
|
271 |
+
X_batch, Y_batch = sess.run(batch)
|
272 |
+
loss_batch = model.train_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate)
|
273 |
+
if epoch < 1:
|
274 |
+
mean_loss = loss_batch
|
275 |
+
else:
|
276 |
+
total_step += 1
|
277 |
+
mean_loss += (loss_batch-mean_loss)/total_step
|
278 |
+
progressbar.set_description("{}: epoch={}, loss={:.6f}, mean loss={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, mean_loss))
|
279 |
+
flog.write("Epoch: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss))
|
280 |
+
saver.save(sess, os.path.join(log_dir, "model_{}.ckpt".format(epoch)))
|
281 |
+
|
282 |
+
## valid
|
283 |
+
if data_reader_valid is not None:
|
284 |
+
mean_loss_valid = 0
|
285 |
+
total_step_valid = 0
|
286 |
+
progressbar = tqdm(range(0, data_reader_valid.n_signal, args.batch_size), desc="Valid: ")
|
287 |
+
for step in progressbar:
|
288 |
+
X_batch, Y_batch = sess.run(batch_valid)
|
289 |
+
loss_batch, preds_batch = model.valid_on_batch(sess, X_batch, Y_batch, summary_writer, args.drop_rate)
|
290 |
+
total_step_valid += 1
|
291 |
+
mean_loss_valid += (loss_batch-mean_loss_valid)/total_step_valid
|
292 |
+
progressbar.set_description("Valid: loss={:.6f}, mean loss={:.6f}".format(loss_batch, mean_loss_valid))
|
293 |
+
flog.write("Valid: {}, step: {}, loss: {}, mean loss: {}\n".format(epoch, step//args.batch_size, loss_batch, mean_loss_valid))
|
294 |
+
|
295 |
+
# plot_result(epoch, args.num_plots, figure_dir, preds_batch, X_batch, Y_batch)
|
296 |
+
pool.map(partial(plot_result_thread,
|
297 |
+
epoch = epoch,
|
298 |
+
preds = preds_batch,
|
299 |
+
X = X_batch,
|
300 |
+
Y = Y_batch,
|
301 |
+
figure_dir = figure_dir),
|
302 |
+
range(args.num_plots))
|
303 |
+
|
304 |
+
flog.close()
|
305 |
+
pool.close()
|
306 |
+
data_reader.coord.request_stop()
|
307 |
+
if data_reader_valid is not None:
|
308 |
+
data_reader_valid.coord.request_stop()
|
309 |
+
try:
|
310 |
+
data_reader.coord.join(threads, stop_grace_period_secs=10, ignore_live_threads=True)
|
311 |
+
if data_reader_valid is not None:
|
312 |
+
data_reader_valid.coord.join(threads_valid, stop_grace_period_secs=10, ignore_live_threads=True)
|
313 |
+
except:
|
314 |
+
pass
|
315 |
+
sess.run(data_reader.queue.close(cancel_pending_enqueues=True))
|
316 |
+
if data_reader_valid is not None:
|
317 |
+
sess.run(data_reader_valid.queue.close(cancel_pending_enqueues=True))
|
318 |
+
return 0
|
319 |
+
|
320 |
+
|
321 |
+
def test_fn(args, data_reader, figure_dir=None, result_dir=None):
|
322 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
323 |
+
log_dir = os.path.join(args.log_dir, args.mode, current_time)
|
324 |
+
logging.info("{} log: {}".format(args.mode, log_dir))
|
325 |
+
if not os.path.exists(log_dir):
|
326 |
+
os.makedirs(log_dir)
|
327 |
+
if (args.plot_figure == True) and (figure_dir is None):
|
328 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
329 |
+
if not os.path.exists(figure_dir):
|
330 |
+
os.makedirs(figure_dir)
|
331 |
+
if (args.save_result == True) and (result_dir is None):
|
332 |
+
result_dir = os.path.join(log_dir, 'results')
|
333 |
+
if not os.path.exists(result_dir):
|
334 |
+
os.makedirs(result_dir)
|
335 |
+
|
336 |
+
config = set_config(args, data_reader)
|
337 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
338 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
339 |
+
|
340 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
341 |
+
batch = data_reader.dequeue(args.batch_size)
|
342 |
+
|
343 |
+
model = UNet(config, input_batch=batch, mode='test')
|
344 |
+
sess_config = tf.compat.v1.ConfigProto()
|
345 |
+
sess_config.gpu_options.allow_growth = True
|
346 |
+
sess_config.log_device_placement = False
|
347 |
+
|
348 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
349 |
+
|
350 |
+
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
|
351 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
352 |
+
init = tf.compat.v1.global_variables_initializer()
|
353 |
+
sess.run(init)
|
354 |
+
|
355 |
+
logging.info("restoring models...")
|
356 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
357 |
+
saver.restore(sess, latest_check_point)
|
358 |
+
|
359 |
+
threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
360 |
+
|
361 |
+
flog = open(os.path.join(log_dir, 'loss.log'), 'w')
|
362 |
+
total_step = 0
|
363 |
+
mean_loss = 0
|
364 |
+
progressbar = tqdm(range(0, data_reader.n_signal, args.batch_size), desc=args.mode)
|
365 |
+
if args.plot_figure:
|
366 |
+
num_pool = multiprocessing.cpu_count()*2
|
367 |
+
elif args.save_result:
|
368 |
+
num_pool = multiprocessing.cpu_count()
|
369 |
+
else:
|
370 |
+
num_pool = 2
|
371 |
+
pool = multiprocessing.Pool(num_pool)
|
372 |
+
for step in progressbar:
|
373 |
+
|
374 |
+
if step + args.batch_size >= data_reader.n_signal:
|
375 |
+
for t in threads:
|
376 |
+
t.join()
|
377 |
+
sess.run(data_reader.queue.close())
|
378 |
+
|
379 |
+
loss_batch, preds_batch, X_batch, Y_batch, ratio_batch, \
|
380 |
+
signal_batch, noise_batch, fname_batch = model.test_on_batch(sess, summary_writer)
|
381 |
+
total_step += 1
|
382 |
+
mean_loss += (loss_batch-mean_loss)/total_step
|
383 |
+
progressbar.set_description("{}: loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, mean_loss))
|
384 |
+
flog.write("step: {}, loss: {}\n".format(step, loss_batch))
|
385 |
+
flog.flush()
|
386 |
+
|
387 |
+
pool.map(partial(postprocessing_test,
|
388 |
+
preds=preds_batch,
|
389 |
+
X=X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis],
|
390 |
+
fname=fname_batch,
|
391 |
+
figure_dir=figure_dir,
|
392 |
+
result_dir=result_dir,
|
393 |
+
signal_FT=signal_batch,
|
394 |
+
noise_FT=noise_batch),
|
395 |
+
range(len(X_batch)))
|
396 |
+
|
397 |
+
flog.close()
|
398 |
+
pool.close()
|
399 |
+
|
400 |
+
return 0
|
401 |
+
|
402 |
+
def pred_fn(args, data_reader, figure_dir=None, result_dir=None, log_dir=None):
|
403 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
404 |
+
if log_dir is None:
|
405 |
+
log_dir = os.path.join(args.log_dir, "pred", current_time)
|
406 |
+
logging.info("Pred log: %s" % log_dir)
|
407 |
+
# logging.info("Dataset size: {}".format(data_reader.num_data))
|
408 |
+
if not os.path.exists(log_dir):
|
409 |
+
os.makedirs(log_dir)
|
410 |
+
if args.plot_figure:
|
411 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
412 |
+
os.makedirs(figure_dir, exist_ok=True)
|
413 |
+
if args.save_result:
|
414 |
+
result_dir = os.path.join(log_dir, 'results')
|
415 |
+
os.makedirs(result_dir, exist_ok=True)
|
416 |
+
|
417 |
+
config = set_config(args, data_reader)
|
418 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
419 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
420 |
+
|
421 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
422 |
+
data_batch = data_reader.dataset(args.batch_size)
|
423 |
+
|
424 |
+
# model = UNet(config, input_batch=batch, mode='pred')
|
425 |
+
model = UNet(config, mode='pred')
|
426 |
+
sess_config = tf.compat.v1.ConfigProto()
|
427 |
+
sess_config.gpu_options.allow_growth = True
|
428 |
+
#sess_config.log_device_placement = False
|
429 |
+
|
430 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
431 |
+
|
432 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
433 |
+
init = tf.compat.v1.global_variables_initializer()
|
434 |
+
sess.run(init)
|
435 |
+
|
436 |
+
logging.info("restoring models...")
|
437 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
438 |
+
saver.restore(sess, latest_check_point)
|
439 |
+
|
440 |
+
# threads = data_reader.start_threads(sess, n_threads=multiprocessing.cpu_count())
|
441 |
+
|
442 |
+
if args.plot_figure:
|
443 |
+
num_pool = multiprocessing.cpu_count()
|
444 |
+
elif args.save_result:
|
445 |
+
num_pool = multiprocessing.cpu_count()
|
446 |
+
else:
|
447 |
+
num_pool = 2
|
448 |
+
multiprocessing.set_start_method('spawn')
|
449 |
+
pool = multiprocessing.Pool(num_pool)
|
450 |
+
for step in tqdm(range(0, data_reader.n_signal, args.batch_size), desc="Pred"):
|
451 |
+
#if step + args.batch_size >= data_reader.n_signal:
|
452 |
+
# for t in threads:
|
453 |
+
# t.join()
|
454 |
+
# sess.run(data_reader.queue.close())
|
455 |
+
# X_batch = []
|
456 |
+
# ratio_batch = []
|
457 |
+
# fname_batch = []
|
458 |
+
# for i in range(step, min(step+args.batch_size, data_reader.n_signal)):
|
459 |
+
# X, ratio, fname = data_reader[i]
|
460 |
+
# if np.std(X) == 0:
|
461 |
+
# continue
|
462 |
+
# X_batch.append(X)
|
463 |
+
# ratio_batch.append(ratio)
|
464 |
+
# fname_batch.append(fname)
|
465 |
+
# X_batch = np.stack(X_batch, axis=0)
|
466 |
+
# ratio_batch = np.array(ratio_batch)
|
467 |
+
X_batch, ratio_batch, fname_batch = sess.run(data_batch)
|
468 |
+
preds_batch = sess.run(model.preds, feed_dict={model.X: X_batch,
|
469 |
+
model.drop_rate: 0,
|
470 |
+
model.is_training: False})
|
471 |
+
#preds_batch, X_batch, ratio_batch, fname_batch = sess.run([model.preds,
|
472 |
+
# batch[0],
|
473 |
+
# batch[1],
|
474 |
+
# batch[2]],
|
475 |
+
# feed_dict={model.drop_rate: 0,
|
476 |
+
# model.is_training: False})
|
477 |
+
|
478 |
+
pool.map(partial(postprocessing_pred,
|
479 |
+
preds = preds_batch,
|
480 |
+
X = X_batch*ratio_batch[:,np.newaxis,:,np.newaxis],
|
481 |
+
fname = [x.decode() for x in fname_batch],
|
482 |
+
figure_dir = figure_dir,
|
483 |
+
result_dir = result_dir),
|
484 |
+
range(len(X_batch)))
|
485 |
+
|
486 |
+
# for i in range(len(X_batch)):
|
487 |
+
# postprocessing_thread(i,
|
488 |
+
# preds = preds_batch,
|
489 |
+
# X = X_batch*ratio_batch[:,np.newaxis,np.newaxis,np.newaxis],
|
490 |
+
# fname = fname_batch,
|
491 |
+
# figure_dir = figure_dir,
|
492 |
+
# result_dir = result_dir)
|
493 |
+
|
494 |
+
pool.close()
|
495 |
+
|
496 |
+
return 0
|
497 |
+
|
498 |
+
def main(args):
|
499 |
+
|
500 |
+
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
501 |
+
|
502 |
+
coord = tf.train.Coordinator()
|
503 |
+
|
504 |
+
if args.mode == "train":
|
505 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
506 |
+
data_reader = DataReader(
|
507 |
+
signal_dir=args.train_signal_dir,
|
508 |
+
signal_list=args.train_signal_list,
|
509 |
+
noise_dir=args.train_noise_dir,
|
510 |
+
noise_list=args.train_noise_list,
|
511 |
+
queue_size=args.batch_size*2,
|
512 |
+
coord=coord)
|
513 |
+
if (args.valid_signal_list is not None) and (args.valid_noise_list is not None):
|
514 |
+
data_reader_valid = DataReader(
|
515 |
+
signal_dir=args.valid_signal_dir,
|
516 |
+
signal_list=args.valid_signal_list,
|
517 |
+
noise_dir=args.valid_noise_dir,
|
518 |
+
noise_list=args.valid_noise_list,
|
519 |
+
queue_size=args.batch_size*2,
|
520 |
+
coord=coord)
|
521 |
+
logging.info("Dataset size: training %d, validation %d" % (data_reader.n_signal, data_reader_valid.n_signal))
|
522 |
+
else:
|
523 |
+
data_reader_valid = None
|
524 |
+
logging.info("Dataset size: training %d, validation 0" % (data_reader.n_signal))
|
525 |
+
train_fn(args, data_reader, data_reader_valid)
|
526 |
+
|
527 |
+
elif args.mode == "valid" or args.mode == "test":
|
528 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
529 |
+
data_reader = DataReader_test(
|
530 |
+
signal_dir=args.valid_signal_dir,
|
531 |
+
signal_list=args.valid_signal_list,
|
532 |
+
noise_dir=args.valid_noise_dir,
|
533 |
+
noise_list=args.valid_noise_list,
|
534 |
+
queue_size=args.batch_size*2,
|
535 |
+
coord=coord)
|
536 |
+
logging.info("Dataset Size: {}".format(data_reader.n_signal))
|
537 |
+
test_fn(args, data_reader)
|
538 |
+
|
539 |
+
elif args.mode == "pred":
|
540 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
541 |
+
data_reader = DataReader_pred(
|
542 |
+
signal_dir=args.data_dir,
|
543 |
+
signal_list=args.data_list,
|
544 |
+
sampling_rate=args.sampling_rate)
|
545 |
+
logging.info("Dataset Size: {}".format(data_reader.n_signal))
|
546 |
+
pred_fn(args, data_reader, log_dir=args.output_dir)
|
547 |
+
|
548 |
+
else:
|
549 |
+
print("mode should be: train, valid, test, debug or pred")
|
550 |
+
|
551 |
+
coord.request_stop()
|
552 |
+
coord.join()
|
553 |
+
return 0
|
554 |
+
|
555 |
+
if __name__ == '__main__':
|
556 |
+
args = read_args()
|
557 |
+
main(args)
|
deepdenoiser/util.py
ADDED
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import scipy
|
7 |
+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
|
8 |
+
from scipy import signal
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from data_reader import Config
|
12 |
+
|
13 |
+
matplotlib.use('agg')
|
14 |
+
|
15 |
+
|
16 |
+
def plot_result(epoch, num, figure_dir, preds, X, Y, mode="valid"):
|
17 |
+
config = Config()
|
18 |
+
for i in range(min(num, len(X))):
|
19 |
+
|
20 |
+
t, noisy_signal = scipy.signal.istft(
|
21 |
+
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
22 |
+
)
|
23 |
+
t, ideal_denoised_signal = scipy.signal.istft(
|
24 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0],
|
25 |
+
fs=config.fs,
|
26 |
+
nperseg=config.nperseg,
|
27 |
+
nfft=config.nfft,
|
28 |
+
boundary='zeros',
|
29 |
+
)
|
30 |
+
t, denoised_signal = scipy.signal.istft(
|
31 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
32 |
+
fs=config.fs,
|
33 |
+
nperseg=config.nperseg,
|
34 |
+
nfft=config.nfft,
|
35 |
+
boundary='zeros',
|
36 |
+
)
|
37 |
+
|
38 |
+
plt.figure(i)
|
39 |
+
fig_size = plt.gcf().get_size_inches()
|
40 |
+
plt.gcf().set_size_inches(fig_size * [1.5, 1.5])
|
41 |
+
plt.subplot(4, 2, 1)
|
42 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2)
|
43 |
+
plt.title("Noisy signal")
|
44 |
+
plt.gca().set_xticklabels([])
|
45 |
+
plt.subplot(4, 2, 2)
|
46 |
+
plt.plot(t, noisy_signal, label='Noisy signal', linewidth=0.1)
|
47 |
+
signal_ylim = plt.gca().get_ylim()
|
48 |
+
plt.gca().set_xticklabels([])
|
49 |
+
plt.legend(loc='lower left')
|
50 |
+
plt.margins(x=0)
|
51 |
+
|
52 |
+
plt.subplot(4, 2, 3)
|
53 |
+
plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1)
|
54 |
+
plt.gca().set_xticklabels([])
|
55 |
+
plt.title("Y")
|
56 |
+
plt.subplot(4, 2, 4)
|
57 |
+
plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1)
|
58 |
+
plt.title("$\hat{Y}$")
|
59 |
+
plt.gca().set_xticklabels([])
|
60 |
+
|
61 |
+
plt.subplot(4, 2, 5)
|
62 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2)
|
63 |
+
plt.title("Ideal denoised signal")
|
64 |
+
plt.gca().set_xticklabels([])
|
65 |
+
plt.subplot(4, 2, 6)
|
66 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2)
|
67 |
+
plt.title("Denoised signal")
|
68 |
+
plt.gca().set_xticklabels([])
|
69 |
+
|
70 |
+
plt.subplot(4, 2, 7)
|
71 |
+
plt.plot(t, ideal_denoised_signal, label='Ideal denoised signal', linewidth=0.1)
|
72 |
+
plt.ylim(signal_ylim)
|
73 |
+
plt.xlabel("Time (s)")
|
74 |
+
plt.legend(loc='lower left')
|
75 |
+
plt.margins(x=0)
|
76 |
+
plt.subplot(4, 2, 8)
|
77 |
+
plt.plot(t, denoised_signal, label='Denoised signal', linewidth=0.1)
|
78 |
+
plt.ylim(signal_ylim)
|
79 |
+
plt.xlabel("Time (s)")
|
80 |
+
plt.legend(loc='lower left')
|
81 |
+
plt.margins(x=0)
|
82 |
+
|
83 |
+
plt.tight_layout()
|
84 |
+
plt.gcf().align_labels()
|
85 |
+
plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight')
|
86 |
+
# plt.savefig(os.path.join(figure_dir, "epoch%03d_%03d.pdf" % (epoch, i)), bbox_inches='tight')
|
87 |
+
plt.close(i)
|
88 |
+
return 0
|
89 |
+
|
90 |
+
|
91 |
+
def plot_result_thread(i, epoch, preds, X, Y, figure_dir, mode="valid"):
|
92 |
+
config = Config()
|
93 |
+
t, noisy_signal = scipy.signal.istft(
|
94 |
+
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
95 |
+
)
|
96 |
+
t, ideal_denoised_signal = scipy.signal.istft(
|
97 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0],
|
98 |
+
fs=config.fs,
|
99 |
+
nperseg=config.nperseg,
|
100 |
+
nfft=config.nfft,
|
101 |
+
boundary='zeros',
|
102 |
+
)
|
103 |
+
t, denoised_signal = scipy.signal.istft(
|
104 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
105 |
+
fs=config.fs,
|
106 |
+
nperseg=config.nperseg,
|
107 |
+
nfft=config.nfft,
|
108 |
+
boundary='zeros',
|
109 |
+
)
|
110 |
+
|
111 |
+
plt.figure(i)
|
112 |
+
fig_size = plt.gcf().get_size_inches()
|
113 |
+
plt.gcf().set_size_inches(fig_size * [1.5, 1.5])
|
114 |
+
plt.subplot(4, 2, 1)
|
115 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2)
|
116 |
+
plt.title("Noisy signal")
|
117 |
+
plt.gca().set_xticklabels([])
|
118 |
+
plt.subplot(4, 2, 2)
|
119 |
+
plt.plot(t, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
|
120 |
+
signal_ylim = plt.gca().get_ylim()
|
121 |
+
plt.gca().set_xticklabels([])
|
122 |
+
plt.legend(loc='lower left')
|
123 |
+
plt.margins(x=0)
|
124 |
+
|
125 |
+
plt.subplot(4, 2, 3)
|
126 |
+
plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1)
|
127 |
+
plt.gca().set_xticklabels([])
|
128 |
+
plt.title("Y")
|
129 |
+
plt.subplot(4, 2, 4)
|
130 |
+
plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1)
|
131 |
+
plt.title("$\hat{Y}$")
|
132 |
+
plt.gca().set_xticklabels([])
|
133 |
+
|
134 |
+
plt.subplot(4, 2, 5)
|
135 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2)
|
136 |
+
plt.title("Ideal denoised signal")
|
137 |
+
plt.gca().set_xticklabels([])
|
138 |
+
plt.subplot(4, 2, 6)
|
139 |
+
plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2)
|
140 |
+
plt.title("Denoised signal")
|
141 |
+
plt.gca().set_xticklabels([])
|
142 |
+
|
143 |
+
plt.subplot(4, 2, 7)
|
144 |
+
plt.plot(t, ideal_denoised_signal, 'k', label='Ideal denoised signal', linewidth=0.5)
|
145 |
+
plt.ylim(signal_ylim)
|
146 |
+
plt.xlabel("Time (s)")
|
147 |
+
plt.legend(loc='lower left')
|
148 |
+
plt.margins(x=0)
|
149 |
+
plt.subplot(4, 2, 8)
|
150 |
+
plt.plot(t, denoised_signal, 'k', label='Denoised signal', linewidth=0.5)
|
151 |
+
plt.ylim(signal_ylim)
|
152 |
+
plt.xlabel("Time (s)")
|
153 |
+
plt.legend(loc='lower left')
|
154 |
+
plt.margins(x=0)
|
155 |
+
|
156 |
+
plt.tight_layout()
|
157 |
+
plt.gcf().align_labels()
|
158 |
+
plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight')
|
159 |
+
plt.close(i)
|
160 |
+
return 0
|
161 |
+
|
162 |
+
|
163 |
+
def postprocessing_test(
|
164 |
+
i, preds, X, fname, figure_dir=None, result_dir=None, signal_FT=None, noise_FT=None, data_dir=None
|
165 |
+
):
|
166 |
+
if (figure_dir is not None) or (result_dir is not None):
|
167 |
+
config = Config()
|
168 |
+
t1, noisy_signal = scipy.signal.istft(
|
169 |
+
X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
170 |
+
)
|
171 |
+
t1, denoised_signal = scipy.signal.istft(
|
172 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
173 |
+
fs=config.fs,
|
174 |
+
nperseg=config.nperseg,
|
175 |
+
nfft=config.nfft,
|
176 |
+
boundary='zeros',
|
177 |
+
)
|
178 |
+
t1, denoised_noise = scipy.signal.istft(
|
179 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * (1 - preds[i, :, :, 0]),
|
180 |
+
fs=config.fs,
|
181 |
+
nperseg=config.nperseg,
|
182 |
+
nfft=config.nfft,
|
183 |
+
boundary='zeros',
|
184 |
+
)
|
185 |
+
t1, signal = scipy.signal.istft(
|
186 |
+
signal_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
187 |
+
)
|
188 |
+
t1, noise = scipy.signal.istft(
|
189 |
+
noise_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
|
190 |
+
)
|
191 |
+
|
192 |
+
if result_dir is not None:
|
193 |
+
try:
|
194 |
+
np.savez(
|
195 |
+
os.path.join(result_dir, fname[i].decode()),
|
196 |
+
preds=preds[i],
|
197 |
+
X=X[i],
|
198 |
+
signal_FT=signal_FT[i],
|
199 |
+
noise_FT=noise_FT[i],
|
200 |
+
noisy_signal=noisy_signal,
|
201 |
+
denoised_signal=denoised_signal,
|
202 |
+
denoised_noise=denoised_noise,
|
203 |
+
signal=signal,
|
204 |
+
noise=noise,
|
205 |
+
)
|
206 |
+
except FileNotFoundError:
|
207 |
+
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
|
208 |
+
np.savez(
|
209 |
+
os.path.join(result_dir, fname[i].decode()),
|
210 |
+
preds=preds[i],
|
211 |
+
X=X[i],
|
212 |
+
signal_FT=signal_FT[i],
|
213 |
+
noise_FT=noise_FT[i],
|
214 |
+
noisy_signal=noisy_signal,
|
215 |
+
denoised_signal=denoised_signal,
|
216 |
+
denoised_noise=denoised_noise,
|
217 |
+
signal=signal,
|
218 |
+
noise=noise,
|
219 |
+
)
|
220 |
+
|
221 |
+
if figure_dir is not None:
|
222 |
+
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
|
223 |
+
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])
|
224 |
+
|
225 |
+
raw_data = None
|
226 |
+
if data_dir is not None:
|
227 |
+
raw_data = np.load(os.path.join(data_dir, fname[i].decode().split('/')[-1]))
|
228 |
+
itp = raw_data['itp']
|
229 |
+
its = raw_data['its']
|
230 |
+
ix1 = (750 - 50) / 100
|
231 |
+
ix2 = (750 + (its - itp) + 50) / 100
|
232 |
+
if ix2 - ix1 > 3:
|
233 |
+
ix2 = ix1 + 3
|
234 |
+
|
235 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
236 |
+
|
237 |
+
text_loc = [0.05, 0.8]
|
238 |
+
plt.figure(i)
|
239 |
+
fig_size = plt.gcf().get_size_inches()
|
240 |
+
plt.gcf().set_size_inches(fig_size * [1, 2])
|
241 |
+
plt.subplot(511)
|
242 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(signal_FT[i, :, :]), vmin=0, vmax=1)
|
243 |
+
plt.gca().set_xticklabels([])
|
244 |
+
plt.text(
|
245 |
+
text_loc[0],
|
246 |
+
text_loc[1],
|
247 |
+
'(i)',
|
248 |
+
horizontalalignment='center',
|
249 |
+
transform=plt.gca().transAxes,
|
250 |
+
fontsize="medium",
|
251 |
+
fontweight="bold",
|
252 |
+
bbox=box,
|
253 |
+
)
|
254 |
+
plt.subplot(512)
|
255 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(noise_FT[i, :, :]), vmin=0, vmax=1)
|
256 |
+
plt.gca().set_xticklabels([])
|
257 |
+
plt.text(
|
258 |
+
text_loc[0],
|
259 |
+
text_loc[1],
|
260 |
+
'(ii)',
|
261 |
+
horizontalalignment='center',
|
262 |
+
transform=plt.gca().transAxes,
|
263 |
+
fontsize="medium",
|
264 |
+
fontweight="bold",
|
265 |
+
bbox=box,
|
266 |
+
)
|
267 |
+
plt.subplot(513)
|
268 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=1)
|
269 |
+
plt.ylabel("Frequency (Hz)", fontsize='large')
|
270 |
+
plt.gca().set_xticklabels([])
|
271 |
+
plt.text(
|
272 |
+
text_loc[0],
|
273 |
+
text_loc[1],
|
274 |
+
'(iii)',
|
275 |
+
horizontalalignment='center',
|
276 |
+
transform=plt.gca().transAxes,
|
277 |
+
fontsize="medium",
|
278 |
+
fontweight="bold",
|
279 |
+
bbox=box,
|
280 |
+
)
|
281 |
+
plt.subplot(514)
|
282 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=1)
|
283 |
+
plt.gca().set_xticklabels([])
|
284 |
+
plt.text(
|
285 |
+
text_loc[0],
|
286 |
+
text_loc[1],
|
287 |
+
'(iv)',
|
288 |
+
horizontalalignment='center',
|
289 |
+
transform=plt.gca().transAxes,
|
290 |
+
fontsize="medium",
|
291 |
+
fontweight="bold",
|
292 |
+
bbox=box,
|
293 |
+
)
|
294 |
+
plt.subplot(515)
|
295 |
+
plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], vmin=0, vmax=1)
|
296 |
+
plt.xlabel("Time (s)", fontsize='large')
|
297 |
+
plt.text(
|
298 |
+
text_loc[0],
|
299 |
+
text_loc[1],
|
300 |
+
'(v)',
|
301 |
+
horizontalalignment='center',
|
302 |
+
transform=plt.gca().transAxes,
|
303 |
+
fontsize="medium",
|
304 |
+
fontweight="bold",
|
305 |
+
bbox=box,
|
306 |
+
)
|
307 |
+
|
308 |
+
try:
|
309 |
+
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
310 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
311 |
+
except FileNotFoundError:
|
312 |
+
os.makedirs(
|
313 |
+
os.path.dirname(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png')), exist_ok=True
|
314 |
+
)
|
315 |
+
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
316 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
317 |
+
plt.close(i)
|
318 |
+
|
319 |
+
text_loc = [0.05, 0.8]
|
320 |
+
plt.figure(i)
|
321 |
+
fig_size = plt.gcf().get_size_inches()
|
322 |
+
plt.gcf().set_size_inches(fig_size * [1, 2])
|
323 |
+
|
324 |
+
ax3 = plt.subplot(513)
|
325 |
+
plt.plot(t1, noisy_signal, 'k', linewidth=0.5, label='Noisy signal')
|
326 |
+
plt.legend(loc='lower left', fontsize='medium')
|
327 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
328 |
+
plt.ylim([-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))])
|
329 |
+
signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
|
330 |
+
plt.ylim(signal_ylim)
|
331 |
+
plt.ylabel("Amplitude", fontsize='large')
|
332 |
+
plt.gca().set_xticklabels([])
|
333 |
+
plt.text(
|
334 |
+
text_loc[0],
|
335 |
+
text_loc[1],
|
336 |
+
'(iii)',
|
337 |
+
horizontalalignment='center',
|
338 |
+
transform=plt.gca().transAxes,
|
339 |
+
fontsize="medium",
|
340 |
+
fontweight="bold",
|
341 |
+
bbox=box,
|
342 |
+
)
|
343 |
+
|
344 |
+
ax1 = plt.subplot(511)
|
345 |
+
plt.plot(t1, signal, 'k', linewidth=0.5, label='Signal')
|
346 |
+
plt.legend(loc='lower left', fontsize='medium')
|
347 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
348 |
+
plt.ylim(signal_ylim)
|
349 |
+
plt.gca().set_xticklabels([])
|
350 |
+
plt.text(
|
351 |
+
text_loc[0],
|
352 |
+
text_loc[1],
|
353 |
+
'(i)',
|
354 |
+
horizontalalignment='center',
|
355 |
+
transform=plt.gca().transAxes,
|
356 |
+
fontsize="medium",
|
357 |
+
fontweight="bold",
|
358 |
+
bbox=box,
|
359 |
+
)
|
360 |
+
|
361 |
+
plt.subplot(512)
|
362 |
+
plt.plot(t1, noise, 'k', linewidth=0.5, label='Noise')
|
363 |
+
plt.legend(loc='lower left', fontsize='medium')
|
364 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
365 |
+
plt.ylim([-np.max(np.abs(noise)), np.max(np.abs(noise))])
|
366 |
+
noise_ylim = [-np.max(np.abs(noise[100:-100])), np.max(np.abs(noise[100:-100]))]
|
367 |
+
plt.ylim(noise_ylim)
|
368 |
+
plt.gca().set_xticklabels([])
|
369 |
+
plt.text(
|
370 |
+
text_loc[0],
|
371 |
+
text_loc[1],
|
372 |
+
'(ii)',
|
373 |
+
horizontalalignment='center',
|
374 |
+
transform=plt.gca().transAxes,
|
375 |
+
fontsize="medium",
|
376 |
+
fontweight="bold",
|
377 |
+
bbox=box,
|
378 |
+
)
|
379 |
+
|
380 |
+
ax4 = plt.subplot(514)
|
381 |
+
plt.plot(t1, denoised_signal, 'k', linewidth=0.5, label='Recovered signal')
|
382 |
+
plt.legend(loc='lower left', fontsize='medium')
|
383 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
384 |
+
plt.ylim(signal_ylim)
|
385 |
+
plt.gca().set_xticklabels([])
|
386 |
+
plt.text(
|
387 |
+
text_loc[0],
|
388 |
+
text_loc[1],
|
389 |
+
'(iv)',
|
390 |
+
horizontalalignment='center',
|
391 |
+
transform=plt.gca().transAxes,
|
392 |
+
fontsize="medium",
|
393 |
+
fontweight="bold",
|
394 |
+
bbox=box,
|
395 |
+
)
|
396 |
+
|
397 |
+
plt.subplot(515)
|
398 |
+
plt.plot(t1, denoised_noise, 'k', linewidth=0.5, label='Recovered noise')
|
399 |
+
plt.legend(loc='lower left', fontsize='medium')
|
400 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
401 |
+
plt.xlabel("Time (s)", fontsize='large')
|
402 |
+
plt.ylim(noise_ylim)
|
403 |
+
plt.text(
|
404 |
+
text_loc[0],
|
405 |
+
text_loc[1],
|
406 |
+
'(v)',
|
407 |
+
horizontalalignment='center',
|
408 |
+
transform=plt.gca().transAxes,
|
409 |
+
fontsize="medium",
|
410 |
+
fontweight="bold",
|
411 |
+
bbox=box,
|
412 |
+
)
|
413 |
+
|
414 |
+
if data_dir is not None:
|
415 |
+
axins = inset_axes(
|
416 |
+
ax1, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax1.transAxes
|
417 |
+
)
|
418 |
+
axins.plot(t1, signal, 'k', linewidth=0.5)
|
419 |
+
x1, x2 = ix1, ix2
|
420 |
+
y1 = -np.max(np.abs(signal[(t1 > ix1) & (t1 < ix2)]))
|
421 |
+
y2 = -y1
|
422 |
+
axins.set_xlim(x1, x2)
|
423 |
+
axins.set_ylim(y1, y2)
|
424 |
+
plt.xticks(visible=False)
|
425 |
+
plt.yticks(visible=False)
|
426 |
+
mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.5")
|
427 |
+
|
428 |
+
axins = inset_axes(
|
429 |
+
ax3, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.3), bbox_transform=ax3.transAxes
|
430 |
+
)
|
431 |
+
axins.plot(t1, noisy_signal, 'k', linewidth=0.5)
|
432 |
+
x1, x2 = ix1, ix2
|
433 |
+
axins.set_xlim(x1, x2)
|
434 |
+
axins.set_ylim(y1, y2)
|
435 |
+
plt.xticks(visible=False)
|
436 |
+
plt.yticks(visible=False)
|
437 |
+
mark_inset(ax3, axins, loc1=1, loc2=3, fc="none", ec="0.5")
|
438 |
+
|
439 |
+
axins = inset_axes(
|
440 |
+
ax4, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax4.transAxes
|
441 |
+
)
|
442 |
+
axins.plot(t1, denoised_signal, 'k', linewidth=0.5)
|
443 |
+
x1, x2 = ix1, ix2
|
444 |
+
axins.set_xlim(x1, x2)
|
445 |
+
axins.set_ylim(y1, y2)
|
446 |
+
plt.xticks(visible=False)
|
447 |
+
plt.yticks(visible=False)
|
448 |
+
mark_inset(ax4, axins, loc1=1, loc2=3, fc="none", ec="0.5")
|
449 |
+
|
450 |
+
plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_wave.png'), bbox_inches='tight')
|
451 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
|
452 |
+
plt.close(i)
|
453 |
+
|
454 |
+
return
|
455 |
+
|
456 |
+
|
457 |
+
def postprocessing_pred(i, preds, X, fname, figure_dir=None, result_dir=None):
|
458 |
+
|
459 |
+
if (result_dir is not None) or (figure_dir is not None):
|
460 |
+
config = Config()
|
461 |
+
|
462 |
+
t1, noisy_signal = scipy.signal.istft(
|
463 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j),
|
464 |
+
fs=config.fs,
|
465 |
+
nperseg=config.nperseg,
|
466 |
+
nfft=config.nfft,
|
467 |
+
boundary='zeros',
|
468 |
+
)
|
469 |
+
t1, denoised_signal = scipy.signal.istft(
|
470 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
471 |
+
fs=config.fs,
|
472 |
+
nperseg=config.nperseg,
|
473 |
+
nfft=config.nfft,
|
474 |
+
boundary='zeros',
|
475 |
+
)
|
476 |
+
t1, denoised_noise = scipy.signal.istft(
|
477 |
+
(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
|
478 |
+
fs=config.fs,
|
479 |
+
nperseg=config.nperseg,
|
480 |
+
nfft=config.nfft,
|
481 |
+
boundary='zeros',
|
482 |
+
)
|
483 |
+
|
484 |
+
if result_dir is not None:
|
485 |
+
try:
|
486 |
+
np.savez(
|
487 |
+
os.path.join(result_dir, fname[i]),
|
488 |
+
noisy_signal=noisy_signal,
|
489 |
+
denoised_signal=denoised_signal,
|
490 |
+
denoised_noise=denoised_noise,
|
491 |
+
t=t1,
|
492 |
+
)
|
493 |
+
except FileNotFoundError:
|
494 |
+
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i])))
|
495 |
+
np.savez(
|
496 |
+
os.path.join(result_dir, fname[i]),
|
497 |
+
noisy_signal=noisy_signal,
|
498 |
+
denoised_signal=denoised_signal,
|
499 |
+
denoised_noise=denoised_noise,
|
500 |
+
t=t1,
|
501 |
+
)
|
502 |
+
|
503 |
+
if figure_dir is not None:
|
504 |
+
|
505 |
+
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
|
506 |
+
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])
|
507 |
+
|
508 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
509 |
+
text_loc = [0.05, 0.77]
|
510 |
+
|
511 |
+
plt.figure(i)
|
512 |
+
fig_size = plt.gcf().get_size_inches()
|
513 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
514 |
+
vmax = np.std(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j)) * 1.8
|
515 |
+
|
516 |
+
plt.subplot(311)
|
517 |
+
plt.pcolormesh(
|
518 |
+
t_FT,
|
519 |
+
f_FT,
|
520 |
+
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j),
|
521 |
+
vmin=0,
|
522 |
+
vmax=vmax,
|
523 |
+
shading='auto',
|
524 |
+
label='Noisy signal',
|
525 |
+
)
|
526 |
+
plt.gca().set_xticklabels([])
|
527 |
+
plt.text(
|
528 |
+
text_loc[0],
|
529 |
+
text_loc[1],
|
530 |
+
'(i)',
|
531 |
+
horizontalalignment='center',
|
532 |
+
transform=plt.gca().transAxes,
|
533 |
+
fontsize="medium",
|
534 |
+
fontweight="bold",
|
535 |
+
bbox=box,
|
536 |
+
)
|
537 |
+
plt.subplot(312)
|
538 |
+
plt.pcolormesh(
|
539 |
+
t_FT,
|
540 |
+
f_FT,
|
541 |
+
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
|
542 |
+
vmin=0,
|
543 |
+
vmax=vmax,
|
544 |
+
shading='auto',
|
545 |
+
label='Recovered signal',
|
546 |
+
)
|
547 |
+
plt.gca().set_xticklabels([])
|
548 |
+
plt.ylabel("Frequency (Hz)", fontsize='large')
|
549 |
+
plt.text(
|
550 |
+
text_loc[0],
|
551 |
+
text_loc[1],
|
552 |
+
'(ii)',
|
553 |
+
horizontalalignment='center',
|
554 |
+
transform=plt.gca().transAxes,
|
555 |
+
fontsize="medium",
|
556 |
+
fontweight="bold",
|
557 |
+
bbox=box,
|
558 |
+
)
|
559 |
+
plt.subplot(313)
|
560 |
+
plt.pcolormesh(
|
561 |
+
t_FT,
|
562 |
+
f_FT,
|
563 |
+
np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
|
564 |
+
vmin=0,
|
565 |
+
vmax=vmax,
|
566 |
+
shading='auto',
|
567 |
+
label='Recovered noise',
|
568 |
+
)
|
569 |
+
plt.xlabel("Time (s)", fontsize='large')
|
570 |
+
plt.text(
|
571 |
+
text_loc[0],
|
572 |
+
text_loc[1],
|
573 |
+
'(iii)',
|
574 |
+
horizontalalignment='center',
|
575 |
+
transform=plt.gca().transAxes,
|
576 |
+
fontsize="medium",
|
577 |
+
fontweight="bold",
|
578 |
+
bbox=box,
|
579 |
+
)
|
580 |
+
|
581 |
+
try:
|
582 |
+
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
583 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
584 |
+
except FileNotFoundError:
|
585 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png')), exist_ok=True)
|
586 |
+
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
587 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
588 |
+
plt.close(i)
|
589 |
+
|
590 |
+
plt.figure(i)
|
591 |
+
fig_size = plt.gcf().get_size_inches()
|
592 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
593 |
+
|
594 |
+
ax4 = plt.subplot(311)
|
595 |
+
plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
|
596 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
597 |
+
signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
|
598 |
+
plt.ylim(signal_ylim)
|
599 |
+
plt.gca().set_xticklabels([])
|
600 |
+
plt.legend(loc='lower left', fontsize='medium')
|
601 |
+
plt.text(
|
602 |
+
text_loc[0],
|
603 |
+
text_loc[1],
|
604 |
+
'(i)',
|
605 |
+
horizontalalignment='center',
|
606 |
+
transform=plt.gca().transAxes,
|
607 |
+
fontsize="medium",
|
608 |
+
fontweight="bold",
|
609 |
+
bbox=box,
|
610 |
+
)
|
611 |
+
|
612 |
+
ax5 = plt.subplot(312)
|
613 |
+
plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
|
614 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
615 |
+
plt.ylim(signal_ylim)
|
616 |
+
plt.gca().set_xticklabels([])
|
617 |
+
plt.ylabel("Amplitude", fontsize='large')
|
618 |
+
plt.legend(loc='lower left', fontsize='medium')
|
619 |
+
plt.text(
|
620 |
+
text_loc[0],
|
621 |
+
text_loc[1],
|
622 |
+
'(ii)',
|
623 |
+
horizontalalignment='center',
|
624 |
+
transform=plt.gca().transAxes,
|
625 |
+
fontsize="medium",
|
626 |
+
fontweight="bold",
|
627 |
+
bbox=box,
|
628 |
+
)
|
629 |
+
|
630 |
+
plt.subplot(313)
|
631 |
+
plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
|
632 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
633 |
+
plt.ylim(signal_ylim)
|
634 |
+
plt.xlabel("Time (s)", fontsize='large')
|
635 |
+
plt.legend(loc='lower left', fontsize='medium')
|
636 |
+
plt.text(
|
637 |
+
text_loc[0],
|
638 |
+
text_loc[1],
|
639 |
+
'(iii)',
|
640 |
+
horizontalalignment='center',
|
641 |
+
transform=plt.gca().transAxes,
|
642 |
+
fontsize="medium",
|
643 |
+
fontweight="bold",
|
644 |
+
bbox=box,
|
645 |
+
)
|
646 |
+
|
647 |
+
plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_wave.png'), bbox_inches='tight')
|
648 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
|
649 |
+
plt.close(i)
|
650 |
+
|
651 |
+
return
|
652 |
+
|
653 |
+
|
654 |
+
def save_results(mask, X, fname, t0, save_signal=True, save_noise=True, result_dir="results"):
|
655 |
+
|
656 |
+
config = Config()
|
657 |
+
|
658 |
+
if save_signal:
|
659 |
+
_, denoised_signal = scipy.signal.istft(
|
660 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
|
661 |
+
fs=config.fs,
|
662 |
+
nperseg=config.nperseg,
|
663 |
+
nfft=config.nfft,
|
664 |
+
boundary='zeros',
|
665 |
+
) # nbt, nch, nst, nt
|
666 |
+
denoised_signal = np.transpose(denoised_signal, [0, 3, 2, 1]) # nbt, nt, nst, nch,
|
667 |
+
if save_noise:
|
668 |
+
_, denoised_noise = scipy.signal.istft(
|
669 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
|
670 |
+
fs=config.fs,
|
671 |
+
nperseg=config.nperseg,
|
672 |
+
nfft=config.nfft,
|
673 |
+
boundary='zeros',
|
674 |
+
)
|
675 |
+
denoised_noise = np.transpose(denoised_noise, [0, 3, 2, 1])
|
676 |
+
|
677 |
+
if not os.path.exists(result_dir):
|
678 |
+
os.makedirs(result_dir)
|
679 |
+
|
680 |
+
for i in range(len(X)):
|
681 |
+
np.savez(
|
682 |
+
os.path.join(result_dir, fname[i]),
|
683 |
+
data=denoised_signal[i] if save_signal else None,
|
684 |
+
noise=denoised_noise[i] if save_noise else None,
|
685 |
+
t0=t0[i],
|
686 |
+
)
|
687 |
+
|
688 |
+
|
689 |
+
def plot_figures(mask, X, fname, figure_dir="figures"):
|
690 |
+
|
691 |
+
config = Config()
|
692 |
+
|
693 |
+
# plot the last channel
|
694 |
+
mask = mask[-1, -1, ...] # nch, nst, nf, nt, 2 => nf, nt, 2
|
695 |
+
X = X[-1, -1, ...]
|
696 |
+
|
697 |
+
t1, noisy_signal = scipy.signal.istft(
|
698 |
+
(X[..., 0] + X[..., 1] * 1j),
|
699 |
+
fs=config.fs,
|
700 |
+
nperseg=config.nperseg,
|
701 |
+
nfft=config.nfft,
|
702 |
+
boundary='zeros',
|
703 |
+
)
|
704 |
+
t1, denoised_signal = scipy.signal.istft(
|
705 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
|
706 |
+
fs=config.fs,
|
707 |
+
nperseg=config.nperseg,
|
708 |
+
nfft=config.nfft,
|
709 |
+
boundary='zeros',
|
710 |
+
)
|
711 |
+
t1, denoised_noise = scipy.signal.istft(
|
712 |
+
(X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
|
713 |
+
fs=config.fs,
|
714 |
+
nperseg=config.nperseg,
|
715 |
+
nfft=config.nfft,
|
716 |
+
boundary='zeros',
|
717 |
+
)
|
718 |
+
|
719 |
+
if not os.path.exists(figure_dir):
|
720 |
+
os.makedirs(figure_dir)
|
721 |
+
|
722 |
+
t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[1])
|
723 |
+
f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[0])
|
724 |
+
|
725 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
726 |
+
text_loc = [0.05, 0.77]
|
727 |
+
|
728 |
+
plt.figure()
|
729 |
+
fig_size = plt.gcf().get_size_inches()
|
730 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
731 |
+
vmax = np.std(np.abs(X[:, :, 0] + X[:, :, 1] * 1j)) * 1.8
|
732 |
+
|
733 |
+
plt.subplot(311)
|
734 |
+
plt.pcolormesh(
|
735 |
+
t_FT,
|
736 |
+
f_FT,
|
737 |
+
np.abs(X[:, :, 0] + X[:, :, 1] * 1j),
|
738 |
+
vmin=0,
|
739 |
+
vmax=vmax,
|
740 |
+
shading='auto',
|
741 |
+
label='Noisy signal',
|
742 |
+
)
|
743 |
+
plt.gca().set_xticklabels([])
|
744 |
+
plt.text(
|
745 |
+
text_loc[0],
|
746 |
+
text_loc[1],
|
747 |
+
'(i)',
|
748 |
+
horizontalalignment='center',
|
749 |
+
transform=plt.gca().transAxes,
|
750 |
+
fontsize="medium",
|
751 |
+
fontweight="bold",
|
752 |
+
bbox=box,
|
753 |
+
)
|
754 |
+
plt.subplot(312)
|
755 |
+
plt.pcolormesh(
|
756 |
+
t_FT,
|
757 |
+
f_FT,
|
758 |
+
np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 0],
|
759 |
+
vmin=0,
|
760 |
+
vmax=vmax,
|
761 |
+
shading='auto',
|
762 |
+
label='Recovered signal',
|
763 |
+
)
|
764 |
+
plt.gca().set_xticklabels([])
|
765 |
+
plt.ylabel("Frequency (Hz)", fontsize='large')
|
766 |
+
plt.text(
|
767 |
+
text_loc[0],
|
768 |
+
text_loc[1],
|
769 |
+
'(ii)',
|
770 |
+
horizontalalignment='center',
|
771 |
+
transform=plt.gca().transAxes,
|
772 |
+
fontsize="medium",
|
773 |
+
fontweight="bold",
|
774 |
+
bbox=box,
|
775 |
+
)
|
776 |
+
plt.subplot(313)
|
777 |
+
plt.pcolormesh(
|
778 |
+
t_FT,
|
779 |
+
f_FT,
|
780 |
+
np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 1],
|
781 |
+
vmin=0,
|
782 |
+
vmax=vmax,
|
783 |
+
shading='auto',
|
784 |
+
label='Recovered noise',
|
785 |
+
)
|
786 |
+
plt.xlabel("Time (s)", fontsize='large')
|
787 |
+
plt.text(
|
788 |
+
text_loc[0],
|
789 |
+
text_loc[1],
|
790 |
+
'(iii)',
|
791 |
+
horizontalalignment='center',
|
792 |
+
transform=plt.gca().transAxes,
|
793 |
+
fontsize="medium",
|
794 |
+
fontweight="bold",
|
795 |
+
bbox=box,
|
796 |
+
)
|
797 |
+
|
798 |
+
try:
|
799 |
+
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
800 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
801 |
+
except FileNotFoundError:
|
802 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png')), exist_ok=True)
|
803 |
+
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
|
804 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
|
805 |
+
plt.close()
|
806 |
+
|
807 |
+
plt.figure()
|
808 |
+
fig_size = plt.gcf().get_size_inches()
|
809 |
+
plt.gcf().set_size_inches(fig_size * [1, 1.2])
|
810 |
+
|
811 |
+
ax4 = plt.subplot(311)
|
812 |
+
plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
|
813 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
814 |
+
signal_ylim = [-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))]
|
815 |
+
if signal_ylim[0] != signal_ylim[1]:
|
816 |
+
plt.ylim(signal_ylim)
|
817 |
+
plt.gca().set_xticklabels([])
|
818 |
+
plt.legend(loc='lower left', fontsize='medium')
|
819 |
+
plt.text(
|
820 |
+
text_loc[0],
|
821 |
+
text_loc[1],
|
822 |
+
'(i)',
|
823 |
+
horizontalalignment='center',
|
824 |
+
transform=plt.gca().transAxes,
|
825 |
+
fontsize="medium",
|
826 |
+
fontweight="bold",
|
827 |
+
bbox=box,
|
828 |
+
)
|
829 |
+
|
830 |
+
ax5 = plt.subplot(312)
|
831 |
+
plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
|
832 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
833 |
+
if signal_ylim[0] != signal_ylim[1]:
|
834 |
+
plt.ylim(signal_ylim)
|
835 |
+
plt.gca().set_xticklabels([])
|
836 |
+
plt.ylabel("Amplitude", fontsize='large')
|
837 |
+
plt.legend(loc='lower left', fontsize='medium')
|
838 |
+
plt.text(
|
839 |
+
text_loc[0],
|
840 |
+
text_loc[1],
|
841 |
+
'(ii)',
|
842 |
+
horizontalalignment='center',
|
843 |
+
transform=plt.gca().transAxes,
|
844 |
+
fontsize="medium",
|
845 |
+
fontweight="bold",
|
846 |
+
bbox=box,
|
847 |
+
)
|
848 |
+
|
849 |
+
plt.subplot(313)
|
850 |
+
plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
|
851 |
+
plt.xlim([np.around(t1[0]), np.around(t1[-1])])
|
852 |
+
if signal_ylim[0] != signal_ylim[1]:
|
853 |
+
plt.ylim(signal_ylim)
|
854 |
+
plt.xlabel("Time (s)", fontsize='large')
|
855 |
+
plt.legend(loc='lower left', fontsize='medium')
|
856 |
+
plt.text(
|
857 |
+
text_loc[0],
|
858 |
+
text_loc[1],
|
859 |
+
'(iii)',
|
860 |
+
horizontalalignment='center',
|
861 |
+
transform=plt.gca().transAxes,
|
862 |
+
fontsize="medium",
|
863 |
+
fontweight="bold",
|
864 |
+
bbox=box,
|
865 |
+
)
|
866 |
+
|
867 |
+
plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_wave.png'), bbox_inches='tight')
|
868 |
+
# plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
|
869 |
+
plt.close()
|
870 |
+
|
871 |
+
return
|
872 |
+
|
873 |
+
|
874 |
+
if __name__ == "__main__":
|
875 |
+
pass
|
docs/README.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: DeepDenoiser
|
3 |
+
emoji: 🌊
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
# DeepDenoiser: Seismic Signal Denoising and Decomposition Using Deep Neural Networks
|
11 |
+
|
12 |
+
[](https://ai4eps.github.io/DeepDenoiser)
|
13 |
+
## 1. Install [miniconda](https://docs.conda.io/en/latest/miniconda.html) and requirements
|
14 |
+
- Download DeepDenoiser repository
|
15 |
+
```bash
|
16 |
+
git clone https://github.com/wayneweiqiang/DeeoDenoiser.git
|
17 |
+
cd DeepDenoiser
|
18 |
+
```
|
19 |
+
- Install to default environment
|
20 |
+
```bash
|
21 |
+
conda env update -f=env.yml -n base
|
22 |
+
```
|
23 |
+
- Install to "deepdenoiser" virtual envirionment
|
24 |
+
```bash
|
25 |
+
conda env create -f env.yml
|
26 |
+
conda activate deepdenoiser
|
27 |
+
```
|
28 |
+
|
29 |
+
## 2. Pre-trained model
|
30 |
+
Located in directory: **model/190614-104802**
|
31 |
+
|
32 |
+
## 3. Related papers
|
33 |
+
- Zhu, Weiqiang, S. Mostafa Mousavi, and Gregory C. Beroza. "Seismic Signal Denoising and Decomposition Using Deep Neural Networks." arXiv preprint arXiv:1811.02695 (2018).
|
34 |
+
|
35 |
+
## 4. Interactive example
|
36 |
+
See details in the [notebook](https://github.com/wayneweiqiang/DeepDenoiser/blob/master/docs/example_interactive.ipynb): [example_interactive.ipynb](example_interactive.ipynb)
|
37 |
+
|
38 |
+
|
39 |
+
## 5. Batch prediction
|
40 |
+
See details in the [notebook](https://github.com/wayneweiqiang/DeepDenoiser/blob/master/docs/example_batch_prediction.ipynb): [example_batch_prediction.ipynb](example_batch_prediction.ipynb)
|
41 |
+
## 6. Train
|
42 |
+
### Data format
|
43 |
+
|
44 |
+
Required: two csv files for signal and noise, corresponding directories of the npz files.
|
45 |
+
|
46 |
+
The csv file contains four columns: "fname", "itp", "channels"
|
47 |
+
|
48 |
+
The npz file contains four variable: "data", "itp", "channels"
|
49 |
+
|
50 |
+
The shape of "data" variables has a shape of 9001 x 3
|
51 |
+
|
52 |
+
The variables "itp" is the data points of first P arrival times.
|
53 |
+
|
54 |
+
Note: In the demo data, for simplicity we use the waveform before itp as noise samples, so the train_noise_list is same as train_signal_list here.
|
55 |
+
|
56 |
+
~~~bash
|
57 |
+
python deepdenoiser/train.py --mode=train --train_signal_dir=./Dataset/train --train_signal_list=./Dataset/train.csv --train_noise_dir=./Dataset/train --train_noise_list=./Dataset/train.csv --batch_size=20
|
58 |
+
~~~
|
59 |
+
|
60 |
+
Please let us know of any bugs found in the code. Suggestions and collaborations are welcomed
|
docs/example_batch_prediction.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/example_interactive.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
env.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: deepdenoiser
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
- conda-forge
|
5 |
+
dependencies:
|
6 |
+
- python=3.7
|
7 |
+
- numpy
|
8 |
+
- scipy
|
9 |
+
- matplotlib
|
10 |
+
- pandas
|
11 |
+
- scikit-learn
|
12 |
+
- tqdm
|
13 |
+
- obspy
|
14 |
+
- uvicorn
|
15 |
+
- fastapi
|
16 |
+
- kafka-python
|
17 |
+
- tensorflow
|
18 |
+
|
19 |
+
|
mkdocs.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
site_name: "DeepDenoiser"
|
2 |
+
site_description: 'DeepDenoiser: Seismic Signal Denoising and Decomposition Using Deep Neural Networks'
|
3 |
+
site_author: 'Weiqiang Zhu'
|
4 |
+
docs_dir: docs/
|
5 |
+
repo_name: 'wayneweiqiang/DeepDenoiser'
|
6 |
+
repo_url: 'https://github.com/wayneweiqiang/DeepDenoiser'
|
7 |
+
nav:
|
8 |
+
- Overview: README.md
|
9 |
+
- Interactive Example: example_interactive.ipynb
|
10 |
+
- Batch Prediction: example_batch_prediction.ipynb
|
11 |
+
theme:
|
12 |
+
name: 'material'
|
13 |
+
plugins:
|
14 |
+
- mkdocs-jupyter
|
15 |
+
extra:
|
16 |
+
analytics:
|
17 |
+
provider: google
|
18 |
+
property: G-FMMP8CQRDZ
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow
|
2 |
+
matplotlib
|
3 |
+
scipy
|
4 |
+
pandas
|
5 |
+
tqdm
|