File size: 968 Bytes
649aa1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import torch
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
from sagemaker_inference import content_types, decoder, default_inference_handler, encoder

def model_fn(model_dir):
    model = SentenceTransformer(model_dir)
    return model

def input_fn(request_body, request_content_type):
    if request_content_type == content_types.JSON:
        input_data = decoder.decode(request_body, content_types.JSON)
        return input_data
    else:
        raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}")

def predict_fn(input_data, model):
    embeddings = model.encode(input_data)
    return embeddings

def output_fn(prediction, accept):
    if accept == content_types.JSON:
        output = encoder.encode(prediction, content_types.JSON)
        return output
    else:
        raise ValueError(f"Requested unsupported ContentType in Accept: {accept}")