philschmid HF staff commited on
Commit
7c28c52
1 Parent(s): bd0c184

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +4 -3
pipeline.py CHANGED
@@ -1,12 +1,12 @@
1
  from typing import Dict, List, Any
2
- from optimum.onnxruntime import ORTModelForSequenceClassification
3
  from transformers import pipeline, AutoTokenizer
4
 
5
 
6
  class PreTrainedPipeline():
7
  def __init__(self, path=""):
8
  # load the optimized model
9
- model = ORTModelForSequenceClassification.from_pretrained(path)
10
  tokenizer = AutoTokenizer.from_pretrained(path, model_max_length=128)
11
  # create inference pipeline
12
  self.pipeline = pipeline("feature-extraction", model=model, tokenizer=tokenizer)
@@ -30,4 +30,5 @@ class PreTrainedPipeline():
30
  return [_h[0] for _h in pipeline_output]
31
 
32
  embeddings = cls_pooling(self.pipeline(inputs))
33
- return embeddings
 
 
1
  from typing import Dict, List, Any
2
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
3
  from transformers import pipeline, AutoTokenizer
4
 
5
 
6
  class PreTrainedPipeline():
7
  def __init__(self, path=""):
8
  # load the optimized model
9
+ model = ORTModelForFeatureExtraction.from_pretrained(path)
10
  tokenizer = AutoTokenizer.from_pretrained(path, model_max_length=128)
11
  # create inference pipeline
12
  self.pipeline = pipeline("feature-extraction", model=model, tokenizer=tokenizer)
 
30
  return [_h[0] for _h in pipeline_output]
31
 
32
  embeddings = cls_pooling(self.pipeline(inputs))
33
+ return {"vectors": embeddings}
34
+