foodDetectionDemo / clip_component.py
cheng
update cpu version
baab402
raw
history blame
1.27 kB
import cv2
import torch
import os
from PIL import Image
import clip
def get_token_from_clip(image):
text_inputs = ["Bacon", "Bread", "Fruit", "Beans and Rice", "fries", "Lasagna"]
text_tokens = clip.tokenize(text_inputs)
device = "cpu"
model, preprocess = clip.load("ViT-B/32")
print("device: ", device)
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
image_input = preprocess(image).unsqueeze(0).to(device) # Add batch dimension
with torch.no_grad():
image_feature = model.encode_image(image_input)
image_feature /= image_feature.norm(dim=-1, keepdim=True)
with torch.no_grad():
similarity = text_features.cpu().numpy() @ image_feature.cpu().numpy().T
best_similarity = 0
best_text_input = ""
for i in range(similarity.shape[0]):
similarity_num = (100.0 * similarity[i][0])
if similarity_num > best_similarity:
best_similarity = similarity_num
best_text_input = text_inputs[i]
# Print the caption for the image
print("Best caption for the image: ", best_text_input)
return best_text_input