Spaces:
Build error
Build error
Commit
·
087fe06
1
Parent(s):
05170c1
app.py
Browse files
app.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import Libraries
|
2 |
+
from pathlib import Path
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import clip
|
7 |
+
from PIL import Image
|
8 |
+
from io import BytesIO
|
9 |
+
import requests
|
10 |
+
import gradio as gr
|
11 |
+
# Load the openAI's CLIP model
|
12 |
+
model, preprocess = clip.load("ViT-B/32", jit=False)
|
13 |
+
#display output photo
|
14 |
+
def show_output_image(matched_images) :
|
15 |
+
image=[]
|
16 |
+
for photo_id in matched_images:
|
17 |
+
photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
|
18 |
+
#photo_image_url = f"https://unsplash.com/photos/{photo_id}?w=280"
|
19 |
+
response = requests.get(photo_image_url)
|
20 |
+
img = Image.open(BytesIO(response.content))
|
21 |
+
#return img
|
22 |
+
image.append(img)
|
23 |
+
return image
|
24 |
+
# Encode and normalize the search query using CLIP
|
25 |
+
def encode_search_query(search_query, model, device):
|
26 |
+
with torch.no_grad():
|
27 |
+
text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
|
28 |
+
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
29 |
+
# Retrieve the feature vector from the GPU and convert it to a numpy array
|
30 |
+
return text_encoded.cpu().numpy()
|
31 |
+
# Find all matched photos
|
32 |
+
def find_matches(text_features, photo_features, photo_ids, results_count=4):
|
33 |
+
# Compute the similarity between the search query and each photo using the Cosine similarity
|
34 |
+
similarities = (photo_features @ text_features.T).squeeze(1)
|
35 |
+
# Sort the photos by their similarity score
|
36 |
+
best_photo_idx = (-similarities).argsort()
|
37 |
+
# Return the photo IDs of the best matches
|
38 |
+
return [photo_ids[i] for i in best_photo_idx[:results_count]]
|
39 |
+
def image_search(search_text, search_image, option):
|
40 |
+
# taking photo IDs
|
41 |
+
photo_ids = pd.read_csv("./photo_ids.csv")
|
42 |
+
photo_ids = list(photo_ids['photo_id'])
|
43 |
+
|
44 |
+
# taking features vectors
|
45 |
+
photo_features = np.load("./features.npy")
|
46 |
+
|
47 |
+
# check if CUDA available
|
48 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
|
50 |
+
# Load the openAI's CLIP model
|
51 |
+
#model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
52 |
+
model = model.to(device)
|
53 |
+
|
54 |
+
# Input Text Query
|
55 |
+
#search_query = "The feeling when your program finally works"
|
56 |
+
|
57 |
+
if option == "Text-To-Image" :
|
58 |
+
# Extracting text features
|
59 |
+
text_features = encode_search_query(search_text, model, device)
|
60 |
+
|
61 |
+
# Find the matched Images
|
62 |
+
matched_images = find_matches(text_features, photo_features, photo_ids, 4)
|
63 |
+
# ---- debug purpose ------#
|
64 |
+
print(matched_images[0])
|
65 |
+
id = matched_images[0]
|
66 |
+
photo_image_url = f"https://unsplash.com/photos/{id}/download?w=280"
|
67 |
+
print(photo_image_url)
|
68 |
+
#--------------------------#
|
69 |
+
|
70 |
+
return show_output_image(matched_images)
|
71 |
+
elif option == "Image-To-Image":
|
72 |
+
# Input Image for Search
|
73 |
+
with torch.no_grad():
|
74 |
+
image_feature = model.encode_image(preprocess(search_image).unsqueeze(0).to(device))
|
75 |
+
image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy()
|
76 |
+
# Find the matched Images
|
77 |
+
matched_images = find_matches(image_feature, photo_features, photo_ids, 4)
|
78 |
+
#is_input_image = True
|
79 |
+
images = show_output_image(matched_images)
|
80 |
+
return images
|
81 |
+
|
82 |
+
gr.Interface(fn=image_search,
|
83 |
+
inputs=[gr.inputs.Textbox(lines=7, label="Input Text"),
|
84 |
+
gr.inputs.Image(type="pil", optional=True),
|
85 |
+
gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"])
|
86 |
+
],
|
87 |
+
outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]),
|
88 |
+
enable_queue=True
|
89 |
+
).launch(debug=True)
|