Spaces:
Running
Running
import os | |
from collections import defaultdict, namedtuple | |
from datetime import datetime, timedelta | |
from json import dumps | |
from typing import Any, AnyStr, Dict, List, NamedTuple, Union | |
import numpy as np | |
import requests | |
import scipy | |
import tensorflow as tf | |
from fastapi import FastAPI | |
from kafka import KafkaProducer | |
from model import UNet | |
from pydantic import BaseModel | |
from scipy.interpolate import interp1d | |
tf.compat.v1.disable_eager_execution() | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) | |
JSONObject = Dict[AnyStr, Any] | |
JSONArray = List[Any] | |
JSONStructure = Union[JSONArray, JSONObject] | |
app = FastAPI() | |
X_SHAPE = [3000, 1, 3] | |
SAMPLING_RATE = 100 | |
# load model | |
model = UNet(mode="pred") | |
sess_config = tf.compat.v1.ConfigProto() | |
sess_config.gpu_options.allow_growth = True | |
sess = tf.compat.v1.Session(config=sess_config) | |
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) | |
init = tf.compat.v1.global_variables_initializer() | |
sess.run(init) | |
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802") | |
print(f"restoring model {latest_check_point}") | |
saver.restore(sess, latest_check_point) | |
# Kafak producer | |
use_kafka = False | |
# BROKER_URL = 'localhost:9092' | |
# BROKER_URL = 'my-kafka-headless:9092' | |
try: | |
print("Connecting to k8s kafka") | |
BROKER_URL = "quakeflow-kafka-headless:9092" | |
producer = KafkaProducer( | |
bootstrap_servers=[BROKER_URL], | |
key_serializer=lambda x: dumps(x).encode("utf-8"), | |
value_serializer=lambda x: dumps(x).encode("utf-8"), | |
) | |
use_kafka = True | |
print("k8s kafka connection success!") | |
except BaseException: | |
print("k8s Kafka connection error") | |
try: | |
print("Connecting to local kafka") | |
producer = KafkaProducer( | |
bootstrap_servers=["localhost:9092"], | |
key_serializer=lambda x: dumps(x).encode("utf-8"), | |
value_serializer=lambda x: dumps(x).encode("utf-8"), | |
) | |
use_kafka = True | |
print("local kafka connection success!") | |
except BaseException: | |
print("local Kafka connection error") | |
def normalize_batch(data, window=200): | |
""" | |
data: nbn, nf, nt, 2 | |
""" | |
assert len(data.shape) == 4 | |
shift = window // 2 | |
nbt, nf, nt, nimg = data.shape | |
## std in slide windows | |
data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect") | |
t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200 | |
# print(f"nt = {nt}, nt+window//2 = {nt+window//2}") | |
std = np.zeros([nbt, len(t)]) | |
mean = np.zeros([nbt, len(t)]) | |
for i in range(std.shape[1]): | |
std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)) | |
mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)) | |
std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2] | |
std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1] | |
## normalize data with interplated std | |
t_interp = np.arange(nt, dtype="int") | |
std_interp = interp1d(t, std, kind="slinear")(t_interp) | |
std_interp[std_interp == 0] = 1.0 | |
mean_interp = interp1d(t, mean, kind="slinear")(t_interp) | |
data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis] | |
if len(t) > 3: ##need to address this normalization issue in training | |
data /= 2.0 | |
return data | |
def get_prediction(meta): | |
FS = 100 | |
NPERSEG = 30 | |
NFFT = 60 | |
vec = np.array(meta.vec) # [batch, nt, chn] | |
nbt, nt, nch = vec.shape | |
vec = np.transpose(vec, [0, 2, 1]) # [batch, chn, nt] | |
vec = np.reshape(vec, [nbt * nch, nt]) ## [batch * chn, nt] | |
if np.mod(vec.shape[-1], 3000) == 1: # 3001=>3000 | |
vec = vec[..., :-1] | |
if meta.dt != 0.01: | |
t = np.linspace(0, 1, len(vec)) | |
t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS))) | |
vec = interp1d(t, vec, kind="slinear")(t_interp) | |
# sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence | |
# vec = scipy.signal.sosfilt(sos, vec) | |
f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary="zeros") | |
noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # [batch * chn, nf, nt, 2] | |
noisy_signal[np.isnan(noisy_signal)] = 0 | |
noisy_signal[np.isinf(noisy_signal)] = 0 | |
X_input = normalize_batch(noisy_signal) | |
feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False} | |
preds = sess.run(model.preds, feed_dict=feed) | |
_, denoised_signal = scipy.signal.istft( | |
(noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0], | |
fs=FS, | |
nperseg=NPERSEG, | |
nfft=NFFT, | |
boundary="zeros", | |
) | |
# _, denoised_noise = scipy.signal.istft( | |
# (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1], | |
# fs=FS, | |
# nperseg=NPERSEG, | |
# nfft=NFFT, | |
# boundary='zeros', | |
# ) | |
denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt]) | |
denoised_signal = np.transpose(denoised_signal, [0, 2, 1]) | |
result = meta.copy() | |
result.vec = denoised_signal.tolist() | |
return result | |
class Data(BaseModel): | |
# id: Union[List[str], str] | |
# timestamp: Union[List[str], str] | |
# vec: Union[List[List[List[float]]], List[List[float]]] | |
id: List[str] | |
timestamp: List[str] | |
vec: List[List[List[float]]] | |
dt: float = 0.01 | |
def predict(data: Data): | |
denoised = get_prediction(data) | |
return denoised | |
def healthz(): | |
return {"status": "ok"} | |
def read_root(): | |
return {"Hello": "DeepDenoiser!"} | |