zhuwq0 commited on
Commit
d944358
·
1 Parent(s): 1e0743b

upload phasenet

Browse files
Files changed (1) hide show
  1. pipeline.py +3 -2
pipeline.py CHANGED
@@ -2,7 +2,7 @@ from typing import Dict, List
2
  import numpy as np
3
  import tensorflow as tf
4
  import os
5
-
6
  from phasenet.model import ModelConfig, UNet
7
  from phasenet.postprocess import extract_picks
8
 
@@ -51,7 +51,7 @@ class PreTrainedPipeline():
51
  # "Please implement PreTrainedPipeline __call__ function"
52
  # )
53
 
54
- vec = np.array(inputs)[np.newaxis, :, np.newaxis, :]
55
 
56
  feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
57
  preds = self.sess.run(self.model.preds, feed_dict=feed)
@@ -67,4 +67,5 @@ class PreTrainedPipeline():
67
  if __name__ == "__main__":
68
  pipeline = PreTrainedPipeline()
69
  inputs = np.random.rand(1000, 3).tolist()
 
70
  picks = pipeline(inputs)
 
2
  import numpy as np
3
  import tensorflow as tf
4
  import os
5
+ import json
6
  from phasenet.model import ModelConfig, UNet
7
  from phasenet.postprocess import extract_picks
8
 
 
51
  # "Please implement PreTrainedPipeline __call__ function"
52
  # )
53
 
54
+ vec = np.asarray(json.loads(inputs))[np.newaxis, :, np.newaxis, :]
55
 
56
  feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
57
  preds = self.sess.run(self.model.preds, feed_dict=feed)
 
67
  if __name__ == "__main__":
68
  pipeline = PreTrainedPipeline()
69
  inputs = np.random.rand(1000, 3).tolist()
70
+ inputs = json.dumps(inputs)
71
  picks = pipeline(inputs)