tasmiachow commited on
Commit
5bf9861
·
verified ·
1 Parent(s): bc20f8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -9,26 +9,25 @@ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
9
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
 
11
  # Define a list of target words for the game
12
- words = ["cat", "car", "tree", "house", "dog"]
13
-
14
 
 
15
  text_inputs = processor(text=words, return_tensors="pt", padding=True)
16
  with torch.no_grad():
17
  text_features = model.get_text_features(**text_inputs)
18
 
19
-
20
  def guess_drawing(drawing):
21
-
22
- drawing_data = drawing['data']
23
- image_array = np.array(drawing_data, dtype=np.uint8)
24
-
25
-
26
  image = Image.fromarray(image_array)
27
 
28
-
29
  image_inputs = processor(images=image, return_tensors="pt")
30
 
31
-
32
  with torch.no_grad():
33
  image_features = model.get_image_features(**image_inputs)
34
 
@@ -49,4 +48,3 @@ interface = gr.Interface(
49
  )
50
 
51
  interface.launch()
52
-
 
9
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
 
11
  # Define a list of target words for the game
12
+ words = ["cat", "car", "tree", "house", "dog"] # Add more words as needed
 
13
 
14
+ # Precompute text embeddings for faster comparisons
15
  text_inputs = processor(text=words, return_tensors="pt", padding=True)
16
  with torch.no_grad():
17
  text_features = model.get_text_features(**text_inputs)
18
 
19
+ # Define the function to process drawing and make a prediction
20
  def guess_drawing(drawing):
21
+ # Assuming `drawing` is provided as an RGB or grayscale array
22
+ image_array = np.array(drawing, dtype=np.uint8) # Directly convert it to a NumPy array
23
+
24
+ # Convert to PIL image
 
25
  image = Image.fromarray(image_array)
26
 
27
+ # Prepare the image for the model
28
  image_inputs = processor(images=image, return_tensors="pt")
29
 
30
+ # Get image features from the model
31
  with torch.no_grad():
32
  image_features = model.get_image_features(**image_inputs)
33
 
 
48
  )
49
 
50
  interface.launch()