scold / inference.py
enalis's picture
Update inference.py
bf999fb verified
raw
history blame contribute delete
948 Bytes
import torch
from model import LVL
from transformers import RobertaTokenizer
from PIL import Image
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = LVL()
model.load_state_dict(torch.load("scold.pth", map_location=device))
model.to(device)
model.eval()
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
# Image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def predict(image_path, text):
image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device)
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
img_feat, txt_feat = model(image, tokens["input_ids"], tokens["attention_mask"])
similarity = torch.matmul(img_feat, txt_feat.T).squeeze()
return similarity.item()