zhuwq0's picture
update
127cbe6
import os
from collections import defaultdict, namedtuple
from datetime import datetime, timedelta
from json import dumps
from typing import Any, AnyStr, Dict, List, NamedTuple, Union
import numpy as np
import requests
import scipy
import tensorflow as tf
from fastapi import FastAPI
from kafka import KafkaProducer
from model import UNet
from pydantic import BaseModel
from scipy.interpolate import interp1d
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
JSONObject = Dict[AnyStr, Any]
JSONArray = List[Any]
JSONStructure = Union[JSONArray, JSONObject]
app = FastAPI()
X_SHAPE = [3000, 1, 3]
SAMPLING_RATE = 100
# load model
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802")
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
# Kafak producer
use_kafka = False
# BROKER_URL = 'localhost:9092'
# BROKER_URL = 'my-kafka-headless:9092'
try:
print("Connecting to k8s kafka")
BROKER_URL = "quakeflow-kafka-headless:9092"
producer = KafkaProducer(
bootstrap_servers=[BROKER_URL],
key_serializer=lambda x: dumps(x).encode("utf-8"),
value_serializer=lambda x: dumps(x).encode("utf-8"),
)
use_kafka = True
print("k8s kafka connection success!")
except BaseException:
print("k8s Kafka connection error")
try:
print("Connecting to local kafka")
producer = KafkaProducer(
bootstrap_servers=["localhost:9092"],
key_serializer=lambda x: dumps(x).encode("utf-8"),
value_serializer=lambda x: dumps(x).encode("utf-8"),
)
use_kafka = True
print("local kafka connection success!")
except BaseException:
print("local Kafka connection error")
def normalize_batch(data, window=200):
"""
data: nbn, nf, nt, 2
"""
assert len(data.shape) == 4
shift = window // 2
nbt, nf, nt, nimg = data.shape
## std in slide windows
data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200
# print(f"nt = {nt}, nt+window//2 = {nt+window//2}")
std = np.zeros([nbt, len(t)])
mean = np.zeros([nbt, len(t)])
for i in range(std.shape[1]):
std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3))
std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2]
std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1]
## normalize data with interplated std
t_interp = np.arange(nt, dtype="int")
std_interp = interp1d(t, std, kind="slinear")(t_interp)
std_interp[std_interp == 0] = 1.0
mean_interp = interp1d(t, mean, kind="slinear")(t_interp)
data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis]
if len(t) > 3: ##need to address this normalization issue in training
data /= 2.0
return data
def get_prediction(meta):
FS = 100
NPERSEG = 30
NFFT = 60
vec = np.array(meta.vec) # [batch, nt, chn]
nbt, nt, nch = vec.shape
vec = np.transpose(vec, [0, 2, 1]) # [batch, chn, nt]
vec = np.reshape(vec, [nbt * nch, nt]) ## [batch * chn, nt]
if np.mod(vec.shape[-1], 3000) == 1: # 3001=>3000
vec = vec[..., :-1]
if meta.dt != 0.01:
t = np.linspace(0, 1, len(vec))
t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS)))
vec = interp1d(t, vec, kind="slinear")(t_interp)
# sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence
# vec = scipy.signal.sosfilt(sos, vec)
f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary="zeros")
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # [batch * chn, nf, nt, 2]
noisy_signal[np.isnan(noisy_signal)] = 0
noisy_signal[np.isinf(noisy_signal)] = 0
X_input = normalize_batch(noisy_signal)
feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False}
preds = sess.run(model.preds, feed_dict=feed)
_, denoised_signal = scipy.signal.istft(
(noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0],
fs=FS,
nperseg=NPERSEG,
nfft=NFFT,
boundary="zeros",
)
# _, denoised_noise = scipy.signal.istft(
# (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1],
# fs=FS,
# nperseg=NPERSEG,
# nfft=NFFT,
# boundary='zeros',
# )
denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt])
denoised_signal = np.transpose(denoised_signal, [0, 2, 1])
result = meta.copy()
result.vec = denoised_signal.tolist()
return result
class Data(BaseModel):
# id: Union[List[str], str]
# timestamp: Union[List[str], str]
# vec: Union[List[List[List[float]]], List[List[float]]]
id: List[str]
timestamp: List[str]
vec: List[List[List[float]]]
dt: float = 0.01
@app.post("/predict")
def predict(data: Data):
denoised = get_prediction(data)
return denoised
@app.get("/healthz")
def healthz():
return {"status": "ok"}
@app.get("/")
def read_root():
return {"Hello": "DeepDenoiser!"}