File size: 5,794 Bytes
81c99dc
 
 
 
 
 
 
 
127cbe6
81c99dc
 
 
127cbe6
81c99dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127cbe6
81c99dc
 
 
 
 
 
 
 
 
 
 
 
 
127cbe6
81c99dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127cbe6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
from collections import defaultdict, namedtuple
from datetime import datetime, timedelta
from json import dumps
from typing import Any, AnyStr, Dict, List, NamedTuple, Union

import numpy as np
import requests
import scipy
import tensorflow as tf
from fastapi import FastAPI
from kafka import KafkaProducer
from model import UNet
from pydantic import BaseModel
from scipy.interpolate import interp1d

tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
JSONObject = Dict[AnyStr, Any]
JSONArray = List[Any]
JSONStructure = Union[JSONArray, JSONObject]

app = FastAPI()
X_SHAPE = [3000, 1, 3]
SAMPLING_RATE = 100

# load model
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True

sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)

# Kafak producer
use_kafka = False
# BROKER_URL = 'localhost:9092'
# BROKER_URL = 'my-kafka-headless:9092'

try:
    print("Connecting to k8s kafka")
    BROKER_URL = "quakeflow-kafka-headless:9092"
    producer = KafkaProducer(
        bootstrap_servers=[BROKER_URL],
        key_serializer=lambda x: dumps(x).encode("utf-8"),
        value_serializer=lambda x: dumps(x).encode("utf-8"),
    )
    use_kafka = True
    print("k8s kafka connection success!")
except BaseException:
    print("k8s Kafka connection error")
    try:
        print("Connecting to local kafka")
        producer = KafkaProducer(
            bootstrap_servers=["localhost:9092"],
            key_serializer=lambda x: dumps(x).encode("utf-8"),
            value_serializer=lambda x: dumps(x).encode("utf-8"),
        )
        use_kafka = True
        print("local kafka connection success!")
    except BaseException:
        print("local Kafka connection error")


def normalize_batch(data, window=200):
    """
    data: nbn, nf, nt, 2
    """
    assert len(data.shape) == 4
    shift = window // 2
    nbt, nf, nt, nimg = data.shape

    ## std in slide windows
    data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
    t = np.arange(0, nt + shift - 1, shift, dtype="int")  # 201 => 0, 100, 200
    # print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
    std = np.zeros([nbt, len(t)])
    mean = np.zeros([nbt, len(t)])
    for i in range(std.shape[1]):
        std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
        mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))

    std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
    std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]

    ## normalize data with interplated std
    t_interp = np.arange(nt, dtype="int")
    std_interp = interp1d(t, std, kind="slinear")(t_interp)
    std_interp[std_interp == 0] = 1.0
    mean_interp = interp1d(t, mean, kind="slinear")(t_interp)

    data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]

    if len(t) > 3:  ##need to address this normalization issue in training
        data /= 2.0

    return data


def get_prediction(meta):
    FS = 100
    NPERSEG = 30
    NFFT = 60

    vec = np.array(meta.vec)  # [batch, nt, chn]
    nbt, nt, nch = vec.shape
    vec = np.transpose(vec, [0, 2, 1])  # [batch, chn, nt]
    vec = np.reshape(vec, [nbt * nch, nt])  ## [batch * chn, nt]

    if np.mod(vec.shape[-1], 3000) == 1:  # 3001=>3000
        vec = vec[..., :-1]

    if meta.dt != 0.01:
        t = np.linspace(0, 1, len(vec))
        t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS)))
        vec = interp1d(t, vec, kind="slinear")(t_interp)

    # sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos')  ## for stability of long sequence
    # vec = scipy.signal.sosfilt(sos, vec)
    f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary="zeros")
    noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1)  # [batch * chn, nf, nt, 2]
    noisy_signal[np.isnan(noisy_signal)] = 0
    noisy_signal[np.isinf(noisy_signal)] = 0
    X_input = normalize_batch(noisy_signal)

    feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False}
    preds = sess.run(model.preds, feed_dict=feed)

    _, denoised_signal = scipy.signal.istft(
        (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0],
        fs=FS,
        nperseg=NPERSEG,
        nfft=NFFT,
        boundary="zeros",
    )
    # _, denoised_noise = scipy.signal.istft(
    #     (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1],
    #     fs=FS,
    #     nperseg=NPERSEG,
    #     nfft=NFFT,
    #     boundary='zeros',
    # )

    denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt])
    denoised_signal = np.transpose(denoised_signal, [0, 2, 1])

    result = meta.copy()
    result.vec = denoised_signal.tolist()
    return result


class Data(BaseModel):
    # id: Union[List[str], str]
    # timestamp: Union[List[str], str]
    # vec: Union[List[List[List[float]]], List[List[float]]]
    id: List[str]
    timestamp: List[str]
    vec: List[List[List[float]]]
    dt: float = 0.01


@app.post("/predict")
def predict(data: Data):
    denoised = get_prediction(data)

    return denoised


@app.get("/healthz")
def healthz():
    return {"status": "ok"}


@app.get("/")
def read_root():
    return {"Hello": "DeepDenoiser!"}