Catherine ZHOU commited on
Commit
d7e5ae1
·
1 Parent(s): 7db0ed5

add flagging feature

Browse files
Files changed (1) hide show
  1. app.py +134 -52
app.py CHANGED
@@ -1,30 +1,114 @@
1
  import gradio as gr
 
2
  from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
3
- import sentence_transformers
4
- from sentence_transformers import SentenceTransformer, util
5
  import pickle
6
  from PIL import Image
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  ## Define model
11
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
12
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
13
  tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
14
 
15
- examples = [[("Dog in the beach"), 2],
16
- [("Paris during night."), 1],
17
- [("A cute kangaroo"), 5],
18
- [("Dois cachorros"), 2],
19
- [("un homme marchant sur le parc"), 3],
20
- [("et høyt fjell"), 2]]
 
 
21
 
22
  #Open the precomputed embeddings
23
  emb_filename = 'unsplash-25k-photos-embeddings.pkl'
24
  with open(emb_filename, 'rb') as fIn:
25
  img_names, img_emb = pickle.load(fIn)
26
  #print(f'img_emb: {print(img_emb)}')
27
- #print(f'img_names: {print(img_names)}')
28
 
29
  # helper functions
30
  def search_text(query, top_k=1):
@@ -38,6 +122,8 @@ def search_text(query, top_k=1):
38
  [list]: list of images that are related to the query.
39
  [list]: list of image embs that are related to the query.
40
  """
 
 
41
  # First, we encode the query.
42
  inputs = tokenizer([query], padding=True, return_tensors="pt")
43
  query_emb = model.get_text_features(**inputs)
@@ -53,49 +139,44 @@ def search_text(query, top_k=1):
53
  object = Image.open(os.path.join(
54
  "photos/", img_names[hit['corpus_id']]))
55
  image.append(object)
 
56
  #print(f'array length is: {len(image)}')
 
57
  return image
58
 
59
- def select_image(evt: gr.SelectData):
60
- """ Returns the index of the selected image
61
 
62
- Argrs:
63
- evt (SelectData): the event we are listening to
64
 
65
- Returns:
66
- int: index of the selected image
67
- """
68
- return evt.index
69
-
70
- def select_image_relevance(evt: gr.SelectData, gallery, selected_index, image_relevance_state):
71
- """ Returns the relevance of the selected image
72
-
73
- Args:
74
- evt (SelectData): the event we are listening to
75
- gallery (Gallery): the gallery of images
76
- selected_index (Number): the index of the selected image
77
- image_relevance_state (State): the current state of the image relevance
78
-
79
- Returns:
80
- state: the new state of the image relevance
81
- """
82
- image_relevance_state[gallery.value[selected_index.value]] = evt.value
83
- return image_relevance_state
84
 
85
 
86
- callback = gr.CSVLogger()
87
  with gr.Blocks() as demo:
88
  # create display
89
  gr.Markdown(
90
  """
91
  # Text to Image using CLIP Model 📸
92
 
93
- ---
 
 
 
 
 
 
94
 
95
- My version of the Gradio Demo fo CLIP model with the option to select relevance level of each image.
96
 
97
- This demo is based on assessment for the 🤗 Huggingface course 2. \n
98
- To use it, simply write which image you are looking for. Read more at the links below.
 
99
  """
100
  )
101
  with gr.Row():
@@ -103,40 +184,41 @@ with gr.Blocks() as demo:
103
  query = gr.Textbox(lines=4,
104
  label="Write what you are looking for in an image...",
105
  placeholder="Text Here...")
106
- top_k = gr.Slider(0, 5, step=1)
 
107
  with gr.Column():
108
  gallery = gr.Gallery(
109
  label="Generated images", show_label=False, elem_id="gallery"
110
  ).style(grid=[3], height="auto")
111
- relevance = gr.Dropdown(list(range(0, 6)), multiselect=False,
112
  label="How relevent is this image to your input text?")
113
  with gr.Row():
114
  with gr.Column():
115
  submit_btn = gr.Button("Submit")
116
  with gr.Column():
117
- save_btn = gr.Button("Save")
118
  gr.Markdown("## Here are some examples you can use:")
119
- gr.Examples(examples, [query, top_k])
 
 
120
 
121
  # when user input query and top_k
122
  submit_btn.click(search_text, [query, top_k], [gallery])
123
 
124
- image_relevance_state = gr.State({}, label="image_relevance_state")
125
- selected_index = gr.Number(value=0, visible=False)
126
-
127
- callback.setup([image_relevance_state], "flagged")
128
 
129
  # when user select an image in the gallery
130
- gallery.select(select_image, None, selected_index)
131
  # when user select the relevance of the image
132
- relevance.select(fn=select_image_relevance,
133
- inputs=[gallery, selected_index, image_relevance_state],
134
- outputs=image_relevance_state)
135
 
136
  # when user click save button
137
- # we will flag the current image_relevance_state
138
- save_btn.click(lambda *args: callback.flag(args), [image_relevance_state], None, preprocess=False)
139
- gallery_embs = []
140
 
141
  gr.Markdown(
142
  """
 
1
  import gradio as gr
2
+ from gradio.flagging import FlaggingCallback
3
  from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
4
+ from sentence_transformers import util
 
5
  import pickle
6
  from PIL import Image
7
  import os
8
+ import logging
9
+ import csv
10
+ import datetime
11
+ from gradio_client import utils as client_utils
12
+ from pathlib import Path
13
+ from typing import List, Any
14
+ from gradio.components import IOComponent
15
+
16
+ class SaveRelevanceCallback(FlaggingCallback):
17
+ """ Callback to save the image relevance state to a csv file
18
+ """
19
+
20
+ def __init__(self):
21
+ pass
22
+
23
+ def setup(self, components: List[IOComponent], flagging_dir: str | Path):
24
+ """
25
+ This method gets called once at the beginning of the Interface.launch() method.
26
 
27
+ Args:
28
+ components ([IOComponent]): Set of components that will provide flagged data.
29
+ flagging_dir (string): typically containing the path to the directory where the flagging file should be storied
30
+ (provided as an argument to Interface.__init__()).
31
+ """
32
+ self.components = components
33
+ self.flagging_dir = flagging_dir
34
+ os.makedirs(flagging_dir, exist_ok=True)
35
+ logging.info(f"[SaveRelevance]: Flagging directory set to {flagging_dir}")
36
+
37
+ def flag(self,
38
+ flag_data: List[Any],
39
+ flag_option: str | None = None,
40
+ flag_index: int | None = None,
41
+ username: str | None = None,
42
+ ) -> int:
43
+ """
44
+ This gets called every time the <flag> button is pressed.
45
+
46
+ Args:
47
+ interface: The Interface object that is being used to launch the flagging interface.
48
+ flag_data: The data to be flagged.
49
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
50
+ flag_index (optional): The index of the sample that is being flagged.
51
+ username (optional): The username of the user that is flagging the data, if logged in.
52
+
53
+ Returns:
54
+ (int): The total number of samples that have been flagged.
55
+ """
56
+ logging.info("[SaveRelevance]: Flagging data...")
57
+ flagging_dir = self.flagging_dir
58
+ log_filepath = Path(flagging_dir) / "log.csv"
59
+ is_new = not Path(log_filepath).exists()
60
+ headers = ["query", "image directory", "relevance", "username", "timestamp"]
61
+
62
+ csv_data = []
63
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
64
+ save_dir = Path(
65
+ flagging_dir
66
+ ) / client_utils.strip_invalid_filename_characters(
67
+ getattr(component, "label", None) or f"component {idx}"
68
+ )
69
+ if gr.utils.is_update(sample):
70
+ csv_data.append(str(sample))
71
+ else:
72
+ new_data = component.deserialize(sample, save_dir=save_dir) if sample is not None else ""
73
+ if new_data and idx == 1:
74
+ # TO-DO: change this to a more robust way of getting the image name/identifier
75
+ # This doesn't work - the directory contains all the images in gallery
76
+ new_data = new_data.split('/')[-1]
77
+ csv_data.append(new_data)
78
+ csv_data.append(str(datetime.datetime.now()))
79
+
80
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
81
+ writer = csv.writer(csvfile)
82
+ if is_new:
83
+ writer.writerow(gr.utils.sanitize_list_for_csv(headers))
84
+ writer.writerow(gr.utils.sanitize_list_for_csv(csv_data))
85
+
86
+ with open(log_filepath, "r", encoding="utf-8") as csvfile:
87
+ line_count = len([None for _ in csv.reader(csvfile)]) - 1
88
+
89
+ logging.info(f"[SaveRelevance]: Saved a total of {line_count} samples to {log_filepath}")
90
+ return line_count
91
 
92
  ## Define model
93
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
94
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
95
  tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
96
 
97
+ examples = [[("Dog in the beach"), 2, 'ghost'],
98
+ [("Paris during night."), 1, 'ghost'],
99
+ [("A cute kangaroo"), 5, 'ghost'],
100
+ [("Dois cachorros"), 2, 'ghost'],
101
+ [("un homme marchant sur le parc"), 3, 'ghost'],
102
+ [("et høyt fjell"), 2, 'ghost']]
103
+
104
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
105
 
106
  #Open the precomputed embeddings
107
  emb_filename = 'unsplash-25k-photos-embeddings.pkl'
108
  with open(emb_filename, 'rb') as fIn:
109
  img_names, img_emb = pickle.load(fIn)
110
  #print(f'img_emb: {print(img_emb)}')
111
+ #print(f'img_names: {print(img_names)}')
112
 
113
  # helper functions
114
  def search_text(query, top_k=1):
 
122
  [list]: list of images that are related to the query.
123
  [list]: list of image embs that are related to the query.
124
  """
125
+ logging.info(f"[SearchText]: Searching for {query} with top_k={top_k}...")
126
+
127
  # First, we encode the query.
128
  inputs = tokenizer([query], padding=True, return_tensors="pt")
129
  query_emb = model.get_text_features(**inputs)
 
139
  object = Image.open(os.path.join(
140
  "photos/", img_names[hit['corpus_id']]))
141
  image.append(object)
142
+ # selected_image_embs.append(img_emb[hit['corpus_id']])
143
  #print(f'array length is: {len(image)}')
144
+ logging.info(f"[SearchText]: Found {len(image)} images.")
145
  return image
146
 
147
+ # def select_image(evt: gr.SelectData):
148
+ # """ Returns the index of the selected image
149
 
150
+ # Argrs:
151
+ # evt (SelectData): the event we are listening to
152
 
153
+ # Returns:
154
+ # int: index of the selected image
155
+ # """
156
+ # logging.info(f"[SelectImage]: Selected image {evt.index}.")
157
+ # return evt.index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
+ callback = SaveRelevanceCallback()
161
  with gr.Blocks() as demo:
162
  # create display
163
  gr.Markdown(
164
  """
165
  # Text to Image using CLIP Model 📸
166
 
167
+ My version of the Gradio Demo fo CLIP model with the option to select relevance level of each image. \n
168
+ This demo is based on assessment for the 🤗 Huggingface course 2.
169
+
170
+
171
+ - To use it, simply write which image you are looking for. See the examples section below for more details.
172
+ - After you submit your query, you will see a gallery of images that are related to your query.
173
+ - You can select the relevance of each image by using the dropdown menu.
174
 
175
+ ---
176
 
177
+ To-do:
178
+ - [ ] Add a way to save multiple image-relevance pairs at once.
179
+ - [ ] Improve image identification in the csv file.
180
  """
181
  )
182
  with gr.Row():
 
184
  query = gr.Textbox(lines=4,
185
  label="Write what you are looking for in an image...",
186
  placeholder="Text Here...")
187
+ top_k = gr.Slider(0, 5, step=1, label="Top K relevant images to show")
188
+ username = gr.Textbox(lines=1, label="Input your unique username 👻 ", placeholder="Text username here...")
189
  with gr.Column():
190
  gallery = gr.Gallery(
191
  label="Generated images", show_label=False, elem_id="gallery"
192
  ).style(grid=[3], height="auto")
193
+ relevance = gr.Dropdown([str(i) for i in range(6)], multiselect=False,
194
  label="How relevent is this image to your input text?")
195
  with gr.Row():
196
  with gr.Column():
197
  submit_btn = gr.Button("Submit")
198
  with gr.Column():
199
+ save_btn = gr.Button("Save after you select the relevance of each image")
200
  gr.Markdown("## Here are some examples you can use:")
201
+ gr.Examples(examples, [query, top_k, username])
202
+
203
+ callback.setup([query, gallery, relevance, username], "flagged")
204
 
205
  # when user input query and top_k
206
  submit_btn.click(search_text, [query, top_k], [gallery])
207
 
208
+ # image_relevance_state = gr.State(value={})
209
+ # selected_index = gr.Number(value=0, visible=False, precision=0)
 
 
210
 
211
  # when user select an image in the gallery
212
+ # gallery.select(select_image, None, selected_index)
213
  # when user select the relevance of the image
214
+ # relevance.select(fn=select_image_relevance,
215
+ # inputs=[gallery, selected_index, image_relevance_state],
216
+ # outputs=image_relevance_state)
217
 
218
  # when user click save button
219
+ # we will flag the current query, selected image, relevance, and username
220
+ save_btn.click(lambda *args: callback.flag(args), [query, gallery, relevance, username], preprocess=False)
221
+ # gallery_embs = []
222
 
223
  gr.Markdown(
224
  """