ibrim commited on
Commit
3b6db54
·
verified ·
1 Parent(s): df02a37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -59
app.py CHANGED
@@ -19,12 +19,8 @@ zip_filename = 'Images.zip'
19
  import os
20
  import zipfile
21
 
22
-
23
-
24
-
25
  with gr.Blocks(css="style.css") as demo:
26
- def get_image_embeddings(valid_df, model_path):
27
- # Define the filename
28
  zip_filename = 'Images.zip'
29
 
30
  # Check if the file exists
@@ -36,61 +32,9 @@ with gr.Blocks(css="style.css") as demo:
36
  print(f"'{zip_filename}' has been successfully unzipped.")
37
  else:
38
  print(f"'{zip_filename}' not found in the current directory.")
39
- tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
40
- valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
41
-
42
- model = CLIPModel().to(CFG.device)
43
- model.load_state_dict(torch.load(model_path, map_location=CFG.device))
44
- model.eval()
45
-
46
- valid_image_embeddings = []
47
- with torch.no_grad():
48
- for batch in tqdm(valid_loader):
49
- image_features = model.image_encoder(batch["image"].to(CFG.device))
50
- image_embeddings = model.image_projection(image_features)
51
- valid_image_embeddings.append(image_embeddings)
52
- return model, torch.cat(valid_image_embeddings)
53
-
54
- _, valid_df = make_train_valid_dfs()
55
- model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
56
-
57
- def find_matches(query, n=9):
58
- tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
59
- encoded_query = tokenizer([query])
60
- batch = {
61
- key: torch.tensor(values).to(CFG.device)
62
- for key, values in encoded_query.items()
63
- }
64
- with torch.no_grad():
65
- text_features = model.text_encoder(
66
- input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
67
- )
68
- text_embeddings = model.text_projection(text_features)
69
-
70
- image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
71
- text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
72
- dot_similarity = text_embeddings_n @ image_embeddings_n.T
73
-
74
- _, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
75
- matches = [valid_df['image'].values[idx] for idx in indices[::5]]
76
-
77
- images = []
78
- for match in matches:
79
- image = cv2.imread(f"{CFG.image_path}/{match}")
80
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
81
- # images.append(image)
82
-
83
- return image
84
- with gr.Row():
85
- textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.")
86
- image = gr.Image(type="numpy")
87
 
88
- button = gr.Button("Press")
89
- button.click(
90
- fn = find_matches,
91
- inputs=textbox,
92
- outputs=image
93
- )
94
 
95
  # Create Gradio interface
96
  demo.launch(share=True)
 
19
  import os
20
  import zipfile
21
 
 
 
 
22
  with gr.Blocks(css="style.css") as demo:
23
+ # Define the filename
 
24
  zip_filename = 'Images.zip'
25
 
26
  # Check if the file exists
 
32
  print(f"'{zip_filename}' has been successfully unzipped.")
33
  else:
34
  print(f"'{zip_filename}' not found in the current directory.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+
37
+
 
 
 
 
38
 
39
  # Create Gradio interface
40
  demo.launch(share=True)