pain commited on
Commit
5953c71
·
1 Parent(s): cd62611

update the code files, adding flicker 8k dataset

Browse files
Files changed (2) hide show
  1. app.py +18 -13
  2. utils.py +46 -34
app.py CHANGED
@@ -2,13 +2,18 @@ import gradio as gr
2
  import utils
3
 
4
 
 
5
  # Araclip demo
6
  with gr.Blocks() as demo_araclip:
7
 
8
- gr.Markdown("## Input parameters")
9
 
 
 
 
 
10
  txt = gr.Textbox(label="Text Query (Caption)")
11
- num = gr.Slider(label="Number of retrieved image", value=1, minimum=1, maximum=1000)
12
 
13
 
14
  with gr.Row():
@@ -32,20 +37,18 @@ with gr.Blocks() as demo_araclip:
32
  json_output = gr.JSON()
33
 
34
  with gr.Column(scale=1):
35
- # gr.Markdown("### Data Retrieved based on Text similarity")
36
- # gr.Markdown("<div style='text-align: center;'> Data Retrieved based on Text similarity </div>")
37
  gr.Markdown("<div style='text-align: center; font-size: 24px; font-weight: bold;'>Data Retrieved based on Text similarity</div>")
38
  json_text = gr.JSON()
39
 
40
 
41
- btn.click(utils.predict, inputs=[txt, num], outputs=[gallery,lables, json_output, json_text])
42
 
43
 
44
  gr.Examples(
45
  examples=[["تخطي لاعب فريق بيتسبرج بايرتس منطقة اللوحة الرئيسية في مباراة بدوري البيسبول", 5],
46
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
47
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
48
- inputs=[txt, num],
49
  outputs=[gallery,lables, json_output, json_text],
50
  fn=utils.predict,
51
  cache_examples=False,
@@ -54,10 +57,15 @@ with gr.Blocks() as demo_araclip:
54
  # mclip demo
55
  with gr.Blocks() as demo_mclip:
56
 
 
 
 
 
 
57
  gr.Markdown("## Input parameters")
58
 
59
  txt = gr.Textbox(label="Text Query (Caption)")
60
- num = gr.Slider(label="Number of retrieved image", value=1, minimum=1, maximum=1000)
61
 
62
  with gr.Row():
63
  btn = gr.Button("Retrieve images", scale=1)
@@ -81,16 +89,13 @@ with gr.Blocks() as demo_mclip:
81
  gr.Markdown("## Text Retrieved")
82
  json_text = gr.JSON()
83
 
84
- btn.click(utils.predict_mclip, inputs=[txt, num], outputs=[gallery,lables, json_output, json_text])
85
-
86
-
87
-
88
 
89
  gr.Examples(
90
  examples=[["تخطي لاعب فريق بيتسبرج بايرتس منطقة اللوحة الرئيسية في مباراة بدوري البيسبول", 5],
91
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
92
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
93
- inputs=[txt, num],
94
  outputs=[gallery,lables, json_output, json_text],
95
  fn=utils.predict_mclip,
96
  cache_examples=False,
@@ -101,8 +106,8 @@ with gr.Blocks() as demo_mclip:
101
  with gr.Blocks() as demo:
102
 
103
  gr.Markdown("<font color=red size=10><center>AraClip: Arabic Image Retrieval Application</center></font>")
104
- gr.TabbedInterface([demo_araclip, demo_mclip], ["Our Model", "Mclip model"])
105
 
 
106
 
107
 
108
  if __name__ == "__main__":
 
2
  import utils
3
 
4
 
5
+
6
  # Araclip demo
7
  with gr.Blocks() as demo_araclip:
8
 
9
+ gr.Markdown("## Choose the dataset")
10
 
11
+ dadtaset_select = gr.Radio(["XTD dataset", "Flicker 8k dataset"], value="XTD dataset", label="Dataset", info="Which dataset you would like to search in?")
12
+
13
+ gr.Markdown("## Input parameters")
14
+
15
  txt = gr.Textbox(label="Text Query (Caption)")
16
+ num = gr.Slider(label="Number of retrieved image", value=1, minimum=1)
17
 
18
 
19
  with gr.Row():
 
37
  json_output = gr.JSON()
38
 
39
  with gr.Column(scale=1):
 
 
40
  gr.Markdown("<div style='text-align: center; font-size: 24px; font-weight: bold;'>Data Retrieved based on Text similarity</div>")
41
  json_text = gr.JSON()
42
 
43
 
44
+ btn.click(utils.predict, inputs=[txt, num, dadtaset_select], outputs=[gallery,lables, json_output, json_text])
45
 
46
 
47
  gr.Examples(
48
  examples=[["تخطي لاعب فريق بيتسبرج بايرتس منطقة اللوحة الرئيسية في مباراة بدوري البيسبول", 5],
49
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
50
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
51
+ inputs=[txt, num, dadtaset_select],
52
  outputs=[gallery,lables, json_output, json_text],
53
  fn=utils.predict,
54
  cache_examples=False,
 
57
  # mclip demo
58
  with gr.Blocks() as demo_mclip:
59
 
60
+ gr.Markdown("## Choose the dataset")
61
+
62
+ dadtaset_select = gr.Radio(["XTD dataset", "Flicker 8k dataset"], value="XTD dataset", label="Dataset", info="Which dataset you would like to search in?")
63
+
64
+
65
  gr.Markdown("## Input parameters")
66
 
67
  txt = gr.Textbox(label="Text Query (Caption)")
68
+ num = gr.Slider(label="Number of retrieved image", value=1, minimum=1)
69
 
70
  with gr.Row():
71
  btn = gr.Button("Retrieve images", scale=1)
 
89
  gr.Markdown("## Text Retrieved")
90
  json_text = gr.JSON()
91
 
92
+ btn.click(utils.predict_mclip, inputs=[txt, num, dadtaset_select], outputs=[gallery,lables, json_output, json_text])
 
 
 
93
 
94
  gr.Examples(
95
  examples=[["تخطي لاعب فريق بيتسبرج بايرتس منطقة اللوحة الرئيسية في مباراة بدوري البيسبول", 5],
96
  ["وقوف قطة بمخالبها على فأرة حاسوب على المكتب", 10],
97
  ["صحن به شوربة صينية بالخضار، وإلى جانبه بطاطس مقلية وزجاجة ماء", 7]],
98
+ inputs=[txt, num, dadtaset_select],
99
  outputs=[gallery,lables, json_output, json_text],
100
  fn=utils.predict_mclip,
101
  cache_examples=False,
 
106
  with gr.Blocks() as demo:
107
 
108
  gr.Markdown("<font color=red size=10><center>AraClip: Arabic Image Retrieval Application</center></font>")
 
109
 
110
+ gr.TabbedInterface([demo_araclip, demo_mclip], ["Our Model", "Mclip model"])
111
 
112
 
113
  if __name__ == "__main__":
utils.py CHANGED
@@ -43,10 +43,9 @@ def features_pickle(file_path=None):
43
  return features_pickle
44
 
45
 
46
- def dataset_loading():
47
-
48
- with open("photos/en_ar_XTD10_edited_v2.jsonl") as filino:
49
 
 
50
 
51
  data = [json.loads(file_i) for file_i in filino]
52
 
@@ -89,7 +88,7 @@ def compare_embeddings_text(full_text_embds, txt_embs):
89
 
90
 
91
 
92
- def find_image(language_model,clip_model, text_query, dataset, image_features, text_features_new,sorted_data, num=1):
93
 
94
  embedding, _ = text_encoder(language_model, text_query)
95
 
@@ -111,7 +110,7 @@ def find_image(language_model,clip_model, text_query, dataset, image_features, t
111
 
112
  for i in range(1, num+1):
113
  idx = np.argsort(probs, axis=0)[-i, 0]
114
- path = 'photos/XTD10_dataset/' + dataset.get_image_name(idx)
115
 
116
  path_l = (path,f"{sorted_data[idx]['caption_ar']}")
117
 
@@ -142,27 +141,32 @@ class AraClip():
142
  self.text_model = load_model('bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M', in_features= 768, out_features=768)
143
  self.language_model = lambda queries: np.asarray(self.text_model(queries).detach().to('cpu'))
144
  self.clip_model, self.compose = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-512')
145
- self.sorted_data, self.image_name_list = dataset_loading()
146
 
147
- def load_images(self):
148
- # Return the features of the text and images
149
- image_features_new = features_pickle('cashed_pickles/image_features_XTD_1000_images_arabert_siglib_best_model.pickle')
150
- return image_features_new
151
-
152
- def load_text(self):
153
- text_features_new = features_pickle('cashed_pickles/text_features_XTD_1000_images_arabert_siglib_best_model.pickle')
154
- return text_features_new
155
-
156
- def load_dataset(self):
157
- dataset = CustomDataSet("photos/XTD10_dataset", self.compose, self.image_name_list)
158
  return dataset
159
 
 
 
 
160
 
161
  araclip = AraClip()
162
 
163
- def predict(text, num):
 
 
 
164
 
165
- image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_dataset(), araclip.load_images() , araclip.load_text(), araclip.sorted_data, num=int(num))
 
166
 
167
  return image_paths, labels, json_data, json_text
168
 
@@ -175,26 +179,34 @@ class Mclip():
175
  self.text_model_mclip = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
176
  self.language_model_mclip = lambda queries: np.asarray(self.text_model_mclip.forward(queries, self.tokenizer_mclip).detach().to('cpu'))
177
  self.clip_model_mclip, _, self.compose_mclip = create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
178
- self.sorted_data, self.image_name_list = dataset_loading()
 
179
 
180
- def load_images(self):
181
- # Return the features of the text and images
182
- image_features_mclip = features_pickle('cashed_pickles/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
183
- return image_features_mclip
184
-
185
- def load_text(self):
186
- text_features_new_mclip = features_pickle('cashed_pickles/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
187
- return text_features_new_mclip
188
-
189
- def load_dataset(self):
190
- dataset_mclip = CustomDataSet("photos/XTD10_dataset", self.compose_mclip, self.image_name_list)
191
- return dataset_mclip
192
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  mclip = Mclip()
195
 
196
- def predict_mclip(text, num):
 
197
 
198
- image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_dataset() , mclip.load_text() , mclip.load_text() , mclip.sorted_data , num=int(num))
 
199
 
 
 
 
 
200
  return image_paths, labels, json_data, json_text
 
43
  return features_pickle
44
 
45
 
46
+ def dataset_loading(file_name):
 
 
47
 
48
+ with open(file_name) as filino:
49
 
50
  data = [json.loads(file_i) for file_i in filino]
51
 
 
88
 
89
 
90
 
91
+ def find_image(language_model,clip_model, text_query, dataset, image_features, text_features_new,sorted_data, images_path,num=1):
92
 
93
  embedding, _ = text_encoder(language_model, text_query)
94
 
 
110
 
111
  for i in range(1, num+1):
112
  idx = np.argsort(probs, axis=0)[-i, 0]
113
+ path = images_path + dataset.get_image_name(idx)
114
 
115
  path_l = (path,f"{sorted_data[idx]['caption_ar']}")
116
 
 
141
  self.text_model = load_model('bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M', in_features= 768, out_features=768)
142
  self.language_model = lambda queries: np.asarray(self.text_model(queries).detach().to('cpu'))
143
  self.clip_model, self.compose = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-512')
 
144
 
145
+ self.sorted_data_xtd, self.image_name_list_xtd = dataset_loading("photos/en_ar_XTD10_edited_v2.jsonl")
146
+ self.sorted_data_flicker8k, self.image_name_list_flicker8k = dataset_loading("photos/flicker_8k.jsonl")
147
+
148
+ def load_pickle_file(self, file_name):
149
+
150
+ return features_pickle(file_name)
151
+
152
+ def load_xtd_dataset(self):
153
+ dataset = CustomDataSet("photos/XTD10_dataset", self.compose, self.image_name_list_xtd)
154
+
 
155
  return dataset
156
 
157
+ def load_flicker8k_dataset(self):
158
+ dataset = CustomDataSet("photos/Flicker8k_Dataset", self.compose, self.image_name_list_flicker8k)
159
+ return dataset
160
 
161
  araclip = AraClip()
162
 
163
+ def predict(text, num, dadtaset_select):
164
+
165
+ if dadtaset_select == "XTD dataset":
166
+ image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_xtd_dataset(), araclip.load_pickle_file("cashed_pickles/XTD_pickles/image_features_XTD_1000_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/XTD_pickles/image_features_XTD_1000_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_xtd, 'photos/XTD10_dataset/', num=int(num))
167
 
168
+ else:
169
+ image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_flicker8k_dataset(), araclip.load_pickle_file("cashed_pickles/flicker_8k/image_features_flicker_8k_images_arabert_siglib_best_model.pickle") , araclip.load_pickle_file("cashed_pickles/flicker_8k/text_features_flicker_8k_images_arabert_siglib_best_model.pickle"), araclip.sorted_data_flicker8k, "photos/Flicker8k_Dataset/", num=int(num))
170
 
171
  return image_paths, labels, json_data, json_text
172
 
 
179
  self.text_model_mclip = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
180
  self.language_model_mclip = lambda queries: np.asarray(self.text_model_mclip.forward(queries, self.tokenizer_mclip).detach().to('cpu'))
181
  self.clip_model_mclip, _, self.compose_mclip = create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
182
+ self.sorted_data_xtd, self.image_name_list_xtd = dataset_loading("photos/en_ar_XTD10_edited_v2.jsonl")
183
+ self.sorted_data_flicker8k, self.image_name_list_flicker8k = dataset_loading("photos/flicker_8k.jsonl")
184
 
185
+ def load_pickle_file(self, file_name):
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ return features_pickle(file_name)
188
+
189
+
190
+ def load_xtd_dataset(self):
191
+ dataset = CustomDataSet("photos/XTD10_dataset", self.compose_mclip, self.image_name_list_xtd)
192
+
193
+ return dataset
194
+
195
+ def load_flicker8k_dataset(self):
196
+ dataset = CustomDataSet("photos/Flicker8k_Dataset", self.compose_mclip, self.image_name_list_flicker8k)
197
+ return dataset
198
+
199
 
200
  mclip = Mclip()
201
 
202
+ def predict_mclip(text, num, dadtaset_select):
203
+
204
 
205
+ if dadtaset_select == "XTD dataset":
206
+ image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_xtd_dataset() , mclip.load_pickle_file("cashed_pickles/XTD_pickles/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/XTD_pickles/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_xtd , 'photos/XTD10_dataset/', num=int(num))
207
 
208
+ else:
209
+ image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_flicker8k_dataset() , mclip.load_pickle_file("cashed_pickles/flicker_8k/image_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.load_pickle_file("cashed_pickles/flicker_8k/text_features_flicker_8k_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle") , mclip.sorted_data_flicker8k , 'photos/Flicker8k_Dataset/', num=int(num))
210
+
211
+
212
  return image_paths, labels, json_data, json_text