File size: 831 Bytes
4cfc73a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
from transformers import EsmForSequenceClassification, AutoTokenizer
import torch
def model_fn(model_dir):
model = EsmForSequenceClassification.from_pretrained(model_dir, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
model.eval()
inputs = data.pop("inputs", data)
encoding = tokenizer(inputs, return_tensors="pt")
encoding = {k: v.to(model.device) for k, v in encoding.items()}
results = model(**encoding)
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(results.logits)
probs = probs.cpu()
return {"membrane_probability": probs[0][1].item()}
|