File size: 3,924 Bytes
087fe06
 
 
 
 
 
 
 
 
06ffe20
e59258f
ffaf3d7
b564a70
90348c5
 
 
 
e59258f
 
 
90348c5
b564a70
 
 
 
 
 
 
 
 
 
37a1406
 
c36f360
087fe06
 
74888a6
ffaf3d7
 
 
 
60b18f5
37a1406
60b18f5
087fe06
 
e59258f
087fe06
87c1220
087fe06
fde2555
 
147b3ce
3c7ac16
e59258f
087fe06
147b3ce
087fe06
147b3ce
 
087fe06
 
 
e9213d0
ea1e3b9
087fe06
b564a70
087fe06
 
 
 
 
 
87c1220
087fe06
 
 
547056a
c36f360
087fe06
27b8fd9
ffaf3d7
3c7ac16
27b8fd9
 
 
 
147b3ce
087fe06
147b3ce
27b8fd9
087fe06
 
 
 
 
 
 
 
a8854b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Import Libraries
from pathlib import Path
import pandas as pd
import numpy as np
import torch
from PIL import Image
from io import BytesIO
import requests
import gradio as gr
import os
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import urllib.request

# check if CUDA available 
device = "cuda" if torch.cuda.is_available() else "cpu"
  
# Load the openAI's CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

# taking photo IDs
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
  
# Photo dataset
photos = pd.read_csv("./photos.tsv000", sep="\t", header=0)
  
# taking features vectors
photo_features = np.load("./features.npy")

IMAGES_DIR = './photos'

def show_output_image(matched_images) :
   image=[]
   for photo_id in matched_images:
     photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
     #response = requests.get(photo_image_url, stream=True)
     #img = Image.open(BytesIO(response.content))
     response = requests.get(photo_image_url, stream=True).raw
     img = Image.open(response)
     #photo = photo_id + '.jpg'
     #img = Image.open(response).convert("RGB")
     #img = Image.open(os.path.join(IMAGES_DIR, photo))
     image.append(img)
   return image
   
# Encode and normalize the search query using CLIP
def encode_search_query(search_query, model, device):
    with torch.no_grad():
        inputs = tokenizer([search_query],  padding=True, return_tensors="pt")
        #inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True)
    text_features =  model.get_text_features(**inputs).cpu().numpy()
    return text_features
        
# Find all matched photos
def find_matches(features, photo_ids, results_count=4):
  # Compute the similarity between the search query and each photo using the Cosine similarity
  #text_features = np.array(text_features)
  similarities = (photo_features @ features.T).squeeze(1)
  # Sort the photos by their similarity score
  best_photo_idx = (-similarities).argsort()
  # Return the photo IDs of the best matches
  matches = [photo_ids[i] for i in best_photo_idx[:results_count]]
  return matches
  
def image_search(search_text, search_image, option):
  
  # Input Text Query
  #search_query = "The feeling when your program finally works"
  
  if option == "Text-To-Image" :
    # Extracting text features
    text_features = encode_search_query(search_text, model, device)
  
    # Find the matched Images
    matched_images = find_matches(text_features, photo_features, photo_ids, 4)
    
    return show_output_image(matched_images)
  elif option == "Image-To-Image":
      # Input Image for Search
      #search_image = Image.fromarray(search_image.astype('uint8'), 'RGB')
      
      with torch.no_grad():
         processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"]
         image_feature = model.get_image_features(processed_image.to(device))
         image_feature /= image_feature.norm(dim=-1, keepdim=True)
      image_feature = image_feature.cpu().numpy()
      # Find the matched Images
      matched_images = find_matches(image_feature, photo_ids, 4)
      return show_output_image(matched_images)
  
gr.Interface(fn=image_search, 
            inputs=[gr.inputs.Textbox(lines=7, label="Input Text"),
                    gr.inputs.Image(type="pil", optional=True),
                    gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"])
                    ], 
            outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]),
            enable_queue=True
             ).launch(debug=True,share=True)