RobotJelly commited on
Commit
fb23d7b
·
1 Parent(s): 547056a
Files changed (1) hide show
  1. app.py +7 -24
app.py CHANGED
@@ -13,11 +13,12 @@ from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # Load the openAI's CLIP model
16
- #model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
17
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
18
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
19
  tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
20
 
 
 
21
  # taking photo IDs
22
  photo_ids = pd.read_csv("./photo_ids.csv")
23
  photo_ids = list(photo_ids['photo_id'])
@@ -32,32 +33,18 @@ def show_output_image(matched_images) :
32
  image=[]
33
  for photo_id in matched_images:
34
  photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
35
- #photo_image_url = f"https://unsplash.com/photos/{photo_id}?w=640"
36
- #photo_image_url = f"https://unsplash.com/photos/{photo_id}?ixid=2yJhcHBfaWQiOjEyMDd9&fm=jpg"
37
- #photo_found = photos[photos["photo_id"] == photo_id].iloc[0]
38
- #response = requests.get(photo_found["photo_image_url"] + "?w=640")
39
  response = requests.get(photo_image_url, stream=True)
40
  img = Image.open(BytesIO(response.content))
41
- #return img
42
- #photo_jpg = photo_id + '.jpg'
43
- #image_path = './photos/'
44
- #img = Image.open('./photos/'+photo_jpg)
45
  image.append(img)
46
  return image
47
 
48
  # Encode and normalize the search query using CLIP
49
- def encode_search_query(search_query, model, device):
50
  with torch.no_grad():
51
- inputs = tokenizer([search_query], padding=True, return_tensors="pt")
52
- #text_encoded = model.encode_text(clip.tokenize(search_query).to(device))
53
- #text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
54
- # Retrieve the feature vector from the GPU and convert it to a numpy array
55
- #text_features = model.get_text_features(**inputs).detach().numpy()
56
- #text_features = model.get_text_features(**inputs).cpu().numpy()
57
  text_features = model.get_text_features(**inputs).detach().numpy()
58
- return np.array(text_features)
59
- #return text_features
60
- #return text_encoded.cpu().numpy()
61
 
62
  # Find all matched photos
63
  def find_matches(text_features, photo_features, photo_ids, results_count=4):
@@ -70,14 +57,12 @@ def find_matches(text_features, photo_features, photo_ids, results_count=4):
70
 
71
  def image_search(search_text, search_image, option):
72
 
73
- #model = model.to(device)
74
-
75
  # Input Text Query
76
  #search_query = "The feeling when your program finally works"
77
 
78
  if option == "Text-To-Image" :
79
  # Extracting text features
80
- text_features = encode_search_query(search_text, model, device)
81
 
82
  # Find the matched Images
83
  matched_images = find_matches(text_features, photo_features, photo_ids, 4)
@@ -89,11 +74,9 @@ def image_search(search_text, search_image, option):
89
  processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"]
90
  image_feature = model.get_image_features(processed_image.to(device))
91
  image_feature /= image_feature.norm(dim=-1, keepdim=True)
92
- #image_feature = image_feature.cpu().numpy()
93
  image_feature = image_feature.detach().numpy()
94
  # Find the matched Images
95
  matched_images = find_matches(image_feature, photo_features, photo_ids, 4)
96
- #is_input_image = True
97
  return show_output_image(matched_images)
98
 
99
  gr.Interface(fn=image_search,
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # Load the openAI's CLIP model
 
16
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
17
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
18
  tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
19
 
20
+ model = model.to(device)
21
+
22
  # taking photo IDs
23
  photo_ids = pd.read_csv("./photo_ids.csv")
24
  photo_ids = list(photo_ids['photo_id'])
 
33
  image=[]
34
  for photo_id in matched_images:
35
  photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280"
 
 
 
 
36
  response = requests.get(photo_image_url, stream=True)
37
  img = Image.open(BytesIO(response.content))
 
 
 
 
38
  image.append(img)
39
  return image
40
 
41
  # Encode and normalize the search query using CLIP
42
+ def encode_search_query(search_query, model):
43
  with torch.no_grad():
44
+ #inputs = tokenizer([search_query], padding=True, return_tensors="pt")
45
+ inputs = processor(text=[search_query], images=None, return_tensors="pt", padding=True)
 
 
 
 
46
  text_features = model.get_text_features(**inputs).detach().numpy()
47
+ return text_features
 
 
48
 
49
  # Find all matched photos
50
  def find_matches(text_features, photo_features, photo_ids, results_count=4):
 
57
 
58
  def image_search(search_text, search_image, option):
59
 
 
 
60
  # Input Text Query
61
  #search_query = "The feeling when your program finally works"
62
 
63
  if option == "Text-To-Image" :
64
  # Extracting text features
65
+ text_features = encode_search_query(search_text, model)
66
 
67
  # Find the matched Images
68
  matched_images = find_matches(text_features, photo_features, photo_ids, 4)
 
74
  processed_image = processor(text=None, images=search_image, return_tensors="pt", padding=True)["pixel_values"]
75
  image_feature = model.get_image_features(processed_image.to(device))
76
  image_feature /= image_feature.norm(dim=-1, keepdim=True)
 
77
  image_feature = image_feature.detach().numpy()
78
  # Find the matched Images
79
  matched_images = find_matches(image_feature, photo_features, photo_ids, 4)
 
80
  return show_output_image(matched_images)
81
 
82
  gr.Interface(fn=image_search,