DSWLAC / app.py
wfdwed's picture
Update app.py
418bc05 verified
raw
history blame
1.26 kB
import gradio as gr
from transformers import CLIPModel, AutoTokenizer, RawImage
import torch
import torch.nn.functional as F
# Load the CLIP model and tokenizer
model = CLIPModel.from_pretrained("Xenova/mobileclip_blt")
tokenizer = AutoTokenizer.from_pretrained("Xenova/mobileclip_blt")
# Define the inference function
def compute_probability(image):
# Process the image
image = RawImage.read(image)
image_inputs = processor(image)
image_embeds = vision_model(image_inputs)
normalized_image_embeds = image_embeds.normalize().tolist()
# Compute the probability
text_inputs = tokenizer(["cats", "dogs", "birds"], padding="max_length", truncation=True)
text_embeds = model(text_inputs)
normalized_text_embeds = text_embeds.normalize().tolist()
probabilities = [F.softmax(torch.tensor([100 * torch.dot(torch.tensor(x), torch.tensor(y)) for y in normalized_text_embeds])).tolist()[0] for x in normalized_image_embeds]
return {"probability": probabilities[0]}
# Create the Gradio interface
iface = gr.Interface(
fn=compute_probability,
inputs="image",
outputs="text",
title="CLIP Probability",
description="Upload an image and get the probability scores!"
)
# Launch the interface
iface.launch()