|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
|
|
class Model: |
|
def __init__(self): |
|
|
|
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
def __call__(self, payload): |
|
|
|
inputs = payload.get("inputs", {}) |
|
source_sentence = inputs.get("source_sentence", "") |
|
sentences = inputs.get("sentences", []) |
|
|
|
|
|
chunks = [source_sentence] + sentences |
|
|
|
embeddings = self.embedding_model.encode(chunks, convert_to_tensor=True) |
|
|
|
|
|
response = { |
|
"embeddings": embeddings.tolist(), |
|
"shape": list(embeddings.shape) |
|
} |
|
return response |
|
|