ktllc commited on
Commit
625973c
·
1 Parent(s): 15470aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -5,10 +5,11 @@ import torch
5
  from PIL import Image
6
  import base64
7
  from io import BytesIO
 
8
 
9
  # Load the CLIP model
10
  model, preprocess = clip.load("ViT-B/32")
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device).eval()
13
 
14
  # Define a function to find similarity
@@ -17,7 +18,7 @@ def find_similarity(base64_image, text_input):
17
  image_bytes = base64.b64decode(base64_image)
18
 
19
  # Convert the bytes to a PIL image
20
- image = Image.open(BytesIO(image_bytes))
21
 
22
  # Preprocess the image
23
  image = preprocess(image).unsqueeze(0).to(device)
@@ -33,7 +34,10 @@ def find_similarity(base64_image, text_input):
33
  # Calculate cosine similarity
34
  similarity = (image_features @ text_features.T).squeeze(0).cpu().numpy()
35
 
36
- return similarity[0,0]
 
 
 
37
 
38
  # Create a Gradio interface
39
  iface = gr.Interface(
@@ -42,7 +46,7 @@ iface = gr.Interface(
42
  gr.inputs.Textbox(label="Base64 Image", lines=8),
43
  "text"
44
  ],
45
- outputs="number",
46
  live=True,
47
  interpretation="default",
48
  title="CLIP Model Image-Text Cosine Similarity",
 
5
  from PIL import Image
6
  import base64
7
  from io import BytesIO
8
+ from decimal import Decimal # Import the Decimal module
9
 
10
  # Load the CLIP model
11
  model, preprocess = clip.load("ViT-B/32")
12
+ device = "cuda" if torch.cuda.isavailable() else "cpu"
13
  model.to(device).eval()
14
 
15
  # Define a function to find similarity
 
18
  image_bytes = base64.b64decode(base64_image)
19
 
20
  # Convert the bytes to a PIL image
21
+ image = Image.open(BytesIO(image_bytes)
22
 
23
  # Preprocess the image
24
  image = preprocess(image).unsqueeze(0).to(device)
 
34
  # Calculate cosine similarity
35
  similarity = (image_features @ text_features.T).squeeze(0).cpu().numpy()
36
 
37
+ # Convert the similarity score to a Decimal
38
+ similarity_decimal = Decimal(similarity)
39
+
40
+ return similarity_decimal # Return the similarity score as a Decimal
41
 
42
  # Create a Gradio interface
43
  iface = gr.Interface(
 
46
  gr.inputs.Textbox(label="Base64 Image", lines=8),
47
  "text"
48
  ],
49
+ outputs="text", # Set the output type to "text"
50
  live=True,
51
  interpretation="default",
52
  title="CLIP Model Image-Text Cosine Similarity",