|
from typing import Dict, List |
|
import numpy as np |
|
import tensorflow as tf |
|
import os |
|
|
|
from phasenet.model import ModelConfig, UNet |
|
from phasenet.postprocess import extract_picks |
|
|
|
tf.compat.v1.disable_eager_execution() |
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tf.compat.v1.reset_default_graph() |
|
model = UNet(mode="pred") |
|
sess_config = tf.compat.v1.ConfigProto() |
|
sess_config.gpu_options.allow_growth = True |
|
sess = tf.compat.v1.Session(config=sess_config) |
|
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) |
|
init = tf.compat.v1.global_variables_initializer() |
|
sess.run(init) |
|
latest_check_point = tf.train.latest_checkpoint(os.path.join(path, "model/190703-214543")) |
|
print(f"restoring model {latest_check_point}") |
|
saver.restore(sess, latest_check_point) |
|
|
|
|
|
self.sess = sess |
|
self.model = model |
|
|
|
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]: |
|
""" |
|
Args: |
|
inputs (:obj:`str`): |
|
a string containing some text |
|
Return: |
|
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : |
|
- "label": A string representing what the label/class is. There can be multiple labels. |
|
- "score": A score between 0 and 1 describing how confident the model is for this label/class. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
vec = np.array(inputs)[np.newaxis, :, np.newaxis, :] |
|
|
|
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False} |
|
preds = self.sess.run(self.model.preds, feed_dict=feed) |
|
|
|
picks = extract_picks(preds) |
|
|
|
|
|
|
|
|
|
return [[{"label": "debug", "score": 0.1}]] |
|
|
|
|
|
if __name__ == "__main__": |
|
pipeline = PreTrainedPipeline() |
|
inputs = np.random.rand(1000, 3).tolist() |
|
picks = pipeline(inputs) |
|
|