import os from collections import defaultdict, namedtuple from datetime import datetime, timedelta from json import dumps from typing import Any, AnyStr, Dict, List, NamedTuple, Union import numpy as np import requests import scipy import tensorflow as tf from fastapi import FastAPI from kafka import KafkaProducer from model import UNet from pydantic import BaseModel from scipy.interpolate import interp1d tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) JSONObject = Dict[AnyStr, Any] JSONArray = List[Any] JSONStructure = Union[JSONArray, JSONObject] app = FastAPI() X_SHAPE = [3000, 1, 3] SAMPLING_RATE = 100 # load model model = UNet(mode="pred") sess_config = tf.compat.v1.ConfigProto() sess_config.gpu_options.allow_growth = True sess = tf.compat.v1.Session(config=sess_config) saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) init = tf.compat.v1.global_variables_initializer() sess.run(init) latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190614-104802") print(f"restoring model {latest_check_point}") saver.restore(sess, latest_check_point) # Kafak producer use_kafka = False # BROKER_URL = 'localhost:9092' # BROKER_URL = 'my-kafka-headless:9092' try: print("Connecting to k8s kafka") BROKER_URL = "quakeflow-kafka-headless:9092" producer = KafkaProducer( bootstrap_servers=[BROKER_URL], key_serializer=lambda x: dumps(x).encode("utf-8"), value_serializer=lambda x: dumps(x).encode("utf-8"), ) use_kafka = True print("k8s kafka connection success!") except BaseException: print("k8s Kafka connection error") try: print("Connecting to local kafka") producer = KafkaProducer( bootstrap_servers=["localhost:9092"], key_serializer=lambda x: dumps(x).encode("utf-8"), value_serializer=lambda x: dumps(x).encode("utf-8"), ) use_kafka = True print("local kafka connection success!") except BaseException: print("local Kafka connection error") def normalize_batch(data, window=200): """ data: nbn, nf, nt, 2 """ assert len(data.shape) == 4 shift = window // 2 nbt, nf, nt, nimg = data.shape ## std in slide windows data_pad = np.pad(data, ((0, 0), (0, 0), (window // 2, window // 2), (0, 0)), mode="reflect") t = np.arange(0, nt + shift - 1, shift, dtype="int") # 201 => 0, 100, 200 # print(f"nt = {nt}, nt+window//2 = {nt+window//2}") std = np.zeros([nbt, len(t)]) mean = np.zeros([nbt, len(t)]) for i in range(std.shape[1]): std[:, i] = np.std(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)) mean[:, i] = np.mean(data_pad[:, :, i * shift : i * shift + window, :], axis=(1, 2, 3)) std[:, -1], mean[:, -1] = std[:, -2], mean[:, -2] std[:, 0], mean[:, 0] = std[:, 1], mean[:, 1] ## normalize data with interplated std t_interp = np.arange(nt, dtype="int") std_interp = interp1d(t, std, kind="slinear")(t_interp) std_interp[std_interp == 0] = 1.0 mean_interp = interp1d(t, mean, kind="slinear")(t_interp) data = (data - mean_interp[:, np.newaxis, :, np.newaxis]) / std_interp[:, np.newaxis, :, np.newaxis] if len(t) > 3: ##need to address this normalization issue in training data /= 2.0 return data def get_prediction(meta): FS = 100 NPERSEG = 30 NFFT = 60 vec = np.array(meta.vec) # [batch, nt, chn] nbt, nt, nch = vec.shape vec = np.transpose(vec, [0, 2, 1]) # [batch, chn, nt] vec = np.reshape(vec, [nbt * nch, nt]) ## [batch * chn, nt] if np.mod(vec.shape[-1], 3000) == 1: # 3001=>3000 vec = vec[..., :-1] if meta.dt != 0.01: t = np.linspace(0, 1, len(vec)) t_interp = np.linspace(0, 1, np.int(np.around(len(vec) * meta.dt * FS))) vec = interp1d(t, vec, kind="slinear")(t_interp) # sos = scipy.signal.butter(4, 0.1, 'high', fs=100, output='sos') ## for stability of long sequence # vec = scipy.signal.sosfilt(sos, vec) f, t, tmp_signal = scipy.signal.stft(vec, fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary="zeros") noisy_signal = np.stack([tmp_signal.real, tmp_signal.imag], axis=-1) # [batch * chn, nf, nt, 2] noisy_signal[np.isnan(noisy_signal)] = 0 noisy_signal[np.isinf(noisy_signal)] = 0 X_input = normalize_batch(noisy_signal) feed = {model.X: X_input, model.drop_rate: 0, model.is_training: False} preds = sess.run(model.preds, feed_dict=feed) _, denoised_signal = scipy.signal.istft( (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 0], fs=FS, nperseg=NPERSEG, nfft=NFFT, boundary="zeros", ) # _, denoised_noise = scipy.signal.istft( # (noisy_signal[..., 0] + noisy_signal[..., 1] * 1j) * preds[..., 1], # fs=FS, # nperseg=NPERSEG, # nfft=NFFT, # boundary='zeros', # ) denoised_signal = np.reshape(denoised_signal, [nbt, nch, nt]) denoised_signal = np.transpose(denoised_signal, [0, 2, 1]) result = meta.copy() result.vec = denoised_signal.tolist() return result class Data(BaseModel): # id: Union[List[str], str] # timestamp: Union[List[str], str] # vec: Union[List[List[List[float]]], List[List[float]]] id: List[str] timestamp: List[str] vec: List[List[List[float]]] dt: float = 0.01 @app.post("/predict") def predict(data: Data): denoised = get_prediction(data) return denoised @app.get("/healthz") def healthz(): return {"status": "ok"} @app.get("/") def read_root(): return {"Hello": "DeepDenoiser!"}