import os import torch from transformers import AutoModel, AutoTokenizer from sentence_transformers import SentenceTransformer from sagemaker_inference import content_types, decoder, default_inference_handler, encoder def model_fn(model_dir): model = SentenceTransformer(model_dir) return model def input_fn(request_body, request_content_type): if request_content_type == content_types.JSON: input_data = decoder.decode(request_body, content_types.JSON) return input_data else: raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}") def predict_fn(input_data, model): embeddings = model.encode(input_data) return embeddings def output_fn(prediction, accept): if accept == content_types.JSON: output = encoder.encode(prediction, content_types.JSON) return output else: raise ValueError(f"Requested unsupported ContentType in Accept: {accept}")