File size: 2,780 Bytes
68ba412 1e0743b 68ba412 7b07ad9 68ba412 7b07ad9 8c6a8b2 7b07ad9 1e0743b 7b07ad9 68ba412 7b07ad9 a27c77a 7b07ad9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from typing import Dict, List
import numpy as np
import tensorflow as tf
import os
from phasenet.model import ModelConfig, UNet
from phasenet.postprocess import extract_picks
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
# raise NotImplementedError(
# "Please implement PreTrainedPipeline __init__ function"
# )
## load model
tf.compat.v1.reset_default_graph()
model = UNet(mode="pred")
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=sess_config)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
latest_check_point = tf.train.latest_checkpoint(os.path.join(path, "model/190703-214543"))
print(f"restoring model {latest_check_point}")
saver.restore(sess, latest_check_point)
##
self.sess = sess
self.model = model
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
# IMPLEMENT_THIS
# raise NotImplementedError(
# "Please implement PreTrainedPipeline __call__ function"
# )
vec = np.array(inputs)[np.newaxis, :, np.newaxis, :]
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
preds = self.sess.run(self.model.preds, feed_dict=feed)
picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
# picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
# return picks
return [[{"label": "debug", "score": 0.1}]]
if __name__ == "__main__":
pipeline = PreTrainedPipeline()
inputs = np.random.rand(1000, 3).tolist()
picks = pipeline(inputs)
|