Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,26 +9,25 @@ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
9 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
10 |
|
11 |
# Define a list of target words for the game
|
12 |
-
words = ["cat", "car", "tree", "house", "dog"]
|
13 |
-
|
14 |
|
|
|
15 |
text_inputs = processor(text=words, return_tensors="pt", padding=True)
|
16 |
with torch.no_grad():
|
17 |
text_features = model.get_text_features(**text_inputs)
|
18 |
|
19 |
-
|
20 |
def guess_drawing(drawing):
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
image = Image.fromarray(image_array)
|
27 |
|
28 |
-
|
29 |
image_inputs = processor(images=image, return_tensors="pt")
|
30 |
|
31 |
-
|
32 |
with torch.no_grad():
|
33 |
image_features = model.get_image_features(**image_inputs)
|
34 |
|
@@ -49,4 +48,3 @@ interface = gr.Interface(
|
|
49 |
)
|
50 |
|
51 |
interface.launch()
|
52 |
-
|
|
|
9 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
10 |
|
11 |
# Define a list of target words for the game
|
12 |
+
words = ["cat", "car", "tree", "house", "dog"] # Add more words as needed
|
|
|
13 |
|
14 |
+
# Precompute text embeddings for faster comparisons
|
15 |
text_inputs = processor(text=words, return_tensors="pt", padding=True)
|
16 |
with torch.no_grad():
|
17 |
text_features = model.get_text_features(**text_inputs)
|
18 |
|
19 |
+
# Define the function to process drawing and make a prediction
|
20 |
def guess_drawing(drawing):
|
21 |
+
# Assuming `drawing` is provided as an RGB or grayscale array
|
22 |
+
image_array = np.array(drawing, dtype=np.uint8) # Directly convert it to a NumPy array
|
23 |
+
|
24 |
+
# Convert to PIL image
|
|
|
25 |
image = Image.fromarray(image_array)
|
26 |
|
27 |
+
# Prepare the image for the model
|
28 |
image_inputs = processor(images=image, return_tensors="pt")
|
29 |
|
30 |
+
# Get image features from the model
|
31 |
with torch.no_grad():
|
32 |
image_features = model.get_image_features(**image_inputs)
|
33 |
|
|
|
48 |
)
|
49 |
|
50 |
interface.launch()
|
|