Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ from decimal import Decimal # Import the Decimal module
|
|
9 |
|
10 |
# Load the CLIP model
|
11 |
model, preprocess = clip.load("ViT-B/32")
|
12 |
-
device = "cuda" if torch.cuda.
|
13 |
model.to(device).eval()
|
14 |
|
15 |
# Define a function to find similarity
|
@@ -18,7 +18,7 @@ def find_similarity(base64_image, text_input):
|
|
18 |
image_bytes = base64.b64decode(base64_image)
|
19 |
|
20 |
# Convert the bytes to a PIL image
|
21 |
-
image = Image.open(BytesIO(image_bytes)
|
22 |
|
23 |
# Preprocess the image
|
24 |
image = preprocess(image).unsqueeze(0).to(device)
|
|
|
9 |
|
10 |
# Load the CLIP model
|
11 |
model, preprocess = clip.load("ViT-B/32")
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
model.to(device).eval()
|
14 |
|
15 |
# Define a function to find similarity
|
|
|
18 |
image_bytes = base64.b64decode(base64_image)
|
19 |
|
20 |
# Convert the bytes to a PIL image
|
21 |
+
image = Image.open(BytesIO(image_bytes))
|
22 |
|
23 |
# Preprocess the image
|
24 |
image = preprocess(image).unsqueeze(0).to(device)
|