Spaces:
Running
Running
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
import numpy as np | |
import torch | |
# Load CLIP model and processor | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Define a list of target words for the game | |
words = ["cat", "car", "tree", "house", "dog"] # Add more words as needed | |
# Precompute text embeddings for faster comparisons | |
text_inputs = processor(text=words, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
text_features = model.get_text_features(**text_inputs) | |
# Define the function to process drawing and make a prediction | |
def guess_drawing(drawing): | |
# Assuming `drawing` is provided as an RGB or grayscale array | |
image_array = np.array(drawing, dtype=np.uint8) # Directly convert it to a NumPy array | |
# Convert to PIL image | |
image = Image.fromarray(image_array) | |
# Prepare the image for the model | |
image_inputs = processor(images=image, return_tensors="pt") | |
# Get image features from the model | |
with torch.no_grad(): | |
image_features = model.get_image_features(**image_inputs) | |
# Calculate cosine similarity with each word | |
similarity = torch.nn.functional.cosine_similarity(image_features, text_features) | |
best_match = words[similarity.argmax().item()] | |
# Return the AI's best guess | |
return f"AI's guess: {best_match}" | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=guess_drawing, | |
inputs=gr.Sketchpad(), | |
outputs="text", | |
live=True, | |
description="Draw something and see if the AI can guess it!" | |
) | |
interface.launch() | |