pictionary / app.py
tasmiachow's picture
Update app.py
5bf9861 verified
raw
history blame
1.67 kB
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
import torch
# Load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Define a list of target words for the game
words = ["cat", "car", "tree", "house", "dog"] # Add more words as needed
# Precompute text embeddings for faster comparisons
text_inputs = processor(text=words, return_tensors="pt", padding=True)
with torch.no_grad():
text_features = model.get_text_features(**text_inputs)
# Define the function to process drawing and make a prediction
def guess_drawing(drawing):
# Assuming `drawing` is provided as an RGB or grayscale array
image_array = np.array(drawing, dtype=np.uint8) # Directly convert it to a NumPy array
# Convert to PIL image
image = Image.fromarray(image_array)
# Prepare the image for the model
image_inputs = processor(images=image, return_tensors="pt")
# Get image features from the model
with torch.no_grad():
image_features = model.get_image_features(**image_inputs)
# Calculate cosine similarity with each word
similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
best_match = words[similarity.argmax().item()]
# Return the AI's best guess
return f"AI's guess: {best_match}"
# Set up Gradio interface
interface = gr.Interface(
fn=guess_drawing,
inputs=gr.Sketchpad(),
outputs="text",
live=True,
description="Draw something and see if the AI can guess it!"
)
interface.launch()