embeddings / model.py
Pavithiran's picture
Create model.py
ab631a4 verified
raw
history blame
918 Bytes
from sentence_transformers import SentenceTransformer
import torch
class Model:
def __init__(self):
# Load the pre-trained model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, payload):
# Extract inputs from the payload
inputs = payload.get("inputs", {})
source_sentence = inputs.get("source_sentence", "")
sentences = inputs.get("sentences", [])
# Combine source_sentence with sentences
chunks = [source_sentence] + sentences
# Generate embeddings
embeddings = self.embedding_model.encode(chunks, convert_to_tensor=True)
# Prepare response
response = {
"embeddings": embeddings.tolist(), # Convert tensor to list for JSON serialization
"shape": list(embeddings.shape) # Return the shape of the embeddings tensor
}
return response