ibrim commited on
Commit
1878377
·
verified ·
1 Parent(s): 3b6db54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -18
app.py CHANGED
@@ -1,3 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import gc
3
  import cv2
@@ -11,30 +47,68 @@ from implement import *
11
  # from main import build_loaders
12
  # from CLIP import CLIPModel
13
  import os
14
- import zipfile
15
 
16
- # Define the filename
17
- zip_filename = 'Images.zip'
18
 
19
- import os
20
- import zipfile
21
 
 
 
22
  with gr.Blocks(css="style.css") as demo:
23
- # Define the filename
24
- zip_filename = 'Images.zip'
 
25
 
26
- # Check if the file exists
27
- if os.path.isfile(zip_filename):
28
- # Open the zip file
29
- with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
30
- # Extract all contents of the zip file to the current directory
31
- zip_ref.extractall()
32
- print(f"'{zip_filename}' has been successfully unzipped.")
33
- else:
34
- print(f"'{zip_filename}' not found in the current directory.")
35
-
 
 
 
 
36
 
37
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # Create Gradio interface
40
  demo.launch(share=True)
 
1
+ import os
2
+ import zipfile
3
+
4
+ # Define the filename
5
+ zip_filename = 'Images.zip'
6
+
7
+ # Get the current directory path
8
+ current_directory = os.getcwd()
9
+ print(f"Current directory: {current_directory}")
10
+
11
+ # Append a custom string to the current directory path (for demonstration)
12
+ custom_directory = os.path.join(current_directory, 'UnzippedContent')
13
+
14
+ # Ensure the custom directory exists
15
+ os.makedirs(custom_directory, exist_ok=True)
16
+
17
+ # Print the contents of the current directory before unzipping
18
+ print(f"Contents of current directory before unzipping: {os.listdir(current_directory)}")
19
+
20
+ # Check if the zip file exists in the current directory
21
+ zip_file_path = os.path.join(current_directory, zip_filename)
22
+ if os.path.isfile(zip_file_path):
23
+ # Open the zip file
24
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
25
+ # Extract all contents of the zip file to the custom directory
26
+ zip_ref.extractall(custom_directory)
27
+ print(f"'{zip_filename}' has been successfully unzipped to '{custom_directory}'.")
28
+
29
+ # Print the contents of the custom directory after unzipping
30
+ print(f"Contents of '{custom_directory}': {os.listdir(custom_directory)}")
31
+ else:
32
+ print(f"'{zip_filename}' not found in the current directory.")
33
+
34
+ # Print the contents of the current directory after unzipping
35
+ print(f"Contents of current directory after unzipping: {os.listdir(current_directory)}")
36
+
37
  import gradio as gr
38
  import gc
39
  import cv2
 
47
  # from main import build_loaders
48
  # from CLIP import CLIPModel
49
  import os
 
50
 
 
 
51
 
 
 
52
 
53
+
54
+
55
  with gr.Blocks(css="style.css") as demo:
56
+ def get_image_embeddings(valid_df, model_path):
57
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
58
+ valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
59
 
60
+ model = CLIPModel().to(CFG.device)
61
+ model.load_state_dict(torch.load(model_path, map_location=CFG.device))
62
+ model.eval()
63
+
64
+ valid_image_embeddings = []
65
+ with torch.no_grad():
66
+ for batch in tqdm(valid_loader):
67
+ image_features = model.image_encoder(batch["image"].to(CFG.device))
68
+ image_embeddings = model.image_projection(image_features)
69
+ valid_image_embeddings.append(image_embeddings)
70
+ return model, torch.cat(valid_image_embeddings)
71
+
72
+ _, valid_df = make_train_valid_dfs()
73
+ model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
74
 
75
+ def find_matches(query, n=9):
76
+ tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
77
+ encoded_query = tokenizer([query])
78
+ batch = {
79
+ key: torch.tensor(values).to(CFG.device)
80
+ for key, values in encoded_query.items()
81
+ }
82
+ with torch.no_grad():
83
+ text_features = model.text_encoder(
84
+ input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
85
+ )
86
+ text_embeddings = model.text_projection(text_features)
87
+
88
+ image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
89
+ text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
90
+ dot_similarity = text_embeddings_n @ image_embeddings_n.T
91
+
92
+ _, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
93
+ matches = [valid_df['image'].values[idx] for idx in indices[::5]]
94
+
95
+ images = []
96
+ for match in matches:
97
+ image = cv2.imread(f"{CFG.image_path}/{match}")
98
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99
+ # images.append(image)
100
+
101
+ return image
102
+ with gr.Row():
103
+ textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.")
104
+ image = gr.Image(type="numpy")
105
+
106
+ button = gr.Button("Press")
107
+ button.click(
108
+ fn = find_matches,
109
+ inputs=textbox,
110
+ outputs=image
111
+ )
112
 
113
  # Create Gradio interface
114
  demo.launch(share=True)