Spaces:
Runtime error
Runtime error
Catherine ZHOU
commited on
Commit
·
d7e5ae1
1
Parent(s):
7db0ed5
add flagging feature
Browse files
app.py
CHANGED
@@ -1,30 +1,114 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
|
3 |
-
import
|
4 |
-
from sentence_transformers import SentenceTransformer, util
|
5 |
import pickle
|
6 |
from PIL import Image
|
7 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
## Define model
|
11 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
12 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
13 |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
14 |
|
15 |
-
examples = [[("Dog in the beach"), 2],
|
16 |
-
[("Paris during night."), 1],
|
17 |
-
[("A cute kangaroo"), 5],
|
18 |
-
[("Dois cachorros"), 2],
|
19 |
-
[("un homme marchant sur le parc"), 3],
|
20 |
-
[("et høyt fjell"), 2]]
|
|
|
|
|
21 |
|
22 |
#Open the precomputed embeddings
|
23 |
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
|
24 |
with open(emb_filename, 'rb') as fIn:
|
25 |
img_names, img_emb = pickle.load(fIn)
|
26 |
#print(f'img_emb: {print(img_emb)}')
|
27 |
-
#print(f'img_names: {print(img_names)}')
|
28 |
|
29 |
# helper functions
|
30 |
def search_text(query, top_k=1):
|
@@ -38,6 +122,8 @@ def search_text(query, top_k=1):
|
|
38 |
[list]: list of images that are related to the query.
|
39 |
[list]: list of image embs that are related to the query.
|
40 |
"""
|
|
|
|
|
41 |
# First, we encode the query.
|
42 |
inputs = tokenizer([query], padding=True, return_tensors="pt")
|
43 |
query_emb = model.get_text_features(**inputs)
|
@@ -53,49 +139,44 @@ def search_text(query, top_k=1):
|
|
53 |
object = Image.open(os.path.join(
|
54 |
"photos/", img_names[hit['corpus_id']]))
|
55 |
image.append(object)
|
|
|
56 |
#print(f'array length is: {len(image)}')
|
|
|
57 |
return image
|
58 |
|
59 |
-
def select_image(evt: gr.SelectData):
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
def select_image_relevance(evt: gr.SelectData, gallery, selected_index, image_relevance_state):
|
71 |
-
""" Returns the relevance of the selected image
|
72 |
-
|
73 |
-
Args:
|
74 |
-
evt (SelectData): the event we are listening to
|
75 |
-
gallery (Gallery): the gallery of images
|
76 |
-
selected_index (Number): the index of the selected image
|
77 |
-
image_relevance_state (State): the current state of the image relevance
|
78 |
-
|
79 |
-
Returns:
|
80 |
-
state: the new state of the image relevance
|
81 |
-
"""
|
82 |
-
image_relevance_state[gallery.value[selected_index.value]] = evt.value
|
83 |
-
return image_relevance_state
|
84 |
|
85 |
|
86 |
-
callback =
|
87 |
with gr.Blocks() as demo:
|
88 |
# create display
|
89 |
gr.Markdown(
|
90 |
"""
|
91 |
# Text to Image using CLIP Model 📸
|
92 |
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
|
96 |
|
97 |
-
|
98 |
-
|
|
|
99 |
"""
|
100 |
)
|
101 |
with gr.Row():
|
@@ -103,40 +184,41 @@ with gr.Blocks() as demo:
|
|
103 |
query = gr.Textbox(lines=4,
|
104 |
label="Write what you are looking for in an image...",
|
105 |
placeholder="Text Here...")
|
106 |
-
top_k = gr.Slider(0, 5, step=1)
|
|
|
107 |
with gr.Column():
|
108 |
gallery = gr.Gallery(
|
109 |
label="Generated images", show_label=False, elem_id="gallery"
|
110 |
).style(grid=[3], height="auto")
|
111 |
-
relevance = gr.Dropdown(
|
112 |
label="How relevent is this image to your input text?")
|
113 |
with gr.Row():
|
114 |
with gr.Column():
|
115 |
submit_btn = gr.Button("Submit")
|
116 |
with gr.Column():
|
117 |
-
save_btn = gr.Button("Save")
|
118 |
gr.Markdown("## Here are some examples you can use:")
|
119 |
-
gr.Examples(examples, [query, top_k])
|
|
|
|
|
120 |
|
121 |
# when user input query and top_k
|
122 |
submit_btn.click(search_text, [query, top_k], [gallery])
|
123 |
|
124 |
-
image_relevance_state = gr.State({}
|
125 |
-
selected_index = gr.Number(value=0, visible=False)
|
126 |
-
|
127 |
-
callback.setup([image_relevance_state], "flagged")
|
128 |
|
129 |
# when user select an image in the gallery
|
130 |
-
gallery.select(select_image, None, selected_index)
|
131 |
# when user select the relevance of the image
|
132 |
-
relevance.select(fn=select_image_relevance,
|
133 |
-
|
134 |
-
|
135 |
|
136 |
# when user click save button
|
137 |
-
# we will flag the current
|
138 |
-
save_btn.click(lambda *args: callback.flag(args), [
|
139 |
-
gallery_embs = []
|
140 |
|
141 |
gr.Markdown(
|
142 |
"""
|
|
|
1 |
import gradio as gr
|
2 |
+
from gradio.flagging import FlaggingCallback
|
3 |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
|
4 |
+
from sentence_transformers import util
|
|
|
5 |
import pickle
|
6 |
from PIL import Image
|
7 |
import os
|
8 |
+
import logging
|
9 |
+
import csv
|
10 |
+
import datetime
|
11 |
+
from gradio_client import utils as client_utils
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import List, Any
|
14 |
+
from gradio.components import IOComponent
|
15 |
+
|
16 |
+
class SaveRelevanceCallback(FlaggingCallback):
|
17 |
+
""" Callback to save the image relevance state to a csv file
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
pass
|
22 |
+
|
23 |
+
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
|
24 |
+
"""
|
25 |
+
This method gets called once at the beginning of the Interface.launch() method.
|
26 |
|
27 |
+
Args:
|
28 |
+
components ([IOComponent]): Set of components that will provide flagged data.
|
29 |
+
flagging_dir (string): typically containing the path to the directory where the flagging file should be storied
|
30 |
+
(provided as an argument to Interface.__init__()).
|
31 |
+
"""
|
32 |
+
self.components = components
|
33 |
+
self.flagging_dir = flagging_dir
|
34 |
+
os.makedirs(flagging_dir, exist_ok=True)
|
35 |
+
logging.info(f"[SaveRelevance]: Flagging directory set to {flagging_dir}")
|
36 |
+
|
37 |
+
def flag(self,
|
38 |
+
flag_data: List[Any],
|
39 |
+
flag_option: str | None = None,
|
40 |
+
flag_index: int | None = None,
|
41 |
+
username: str | None = None,
|
42 |
+
) -> int:
|
43 |
+
"""
|
44 |
+
This gets called every time the <flag> button is pressed.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
interface: The Interface object that is being used to launch the flagging interface.
|
48 |
+
flag_data: The data to be flagged.
|
49 |
+
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
|
50 |
+
flag_index (optional): The index of the sample that is being flagged.
|
51 |
+
username (optional): The username of the user that is flagging the data, if logged in.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
(int): The total number of samples that have been flagged.
|
55 |
+
"""
|
56 |
+
logging.info("[SaveRelevance]: Flagging data...")
|
57 |
+
flagging_dir = self.flagging_dir
|
58 |
+
log_filepath = Path(flagging_dir) / "log.csv"
|
59 |
+
is_new = not Path(log_filepath).exists()
|
60 |
+
headers = ["query", "image directory", "relevance", "username", "timestamp"]
|
61 |
+
|
62 |
+
csv_data = []
|
63 |
+
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
64 |
+
save_dir = Path(
|
65 |
+
flagging_dir
|
66 |
+
) / client_utils.strip_invalid_filename_characters(
|
67 |
+
getattr(component, "label", None) or f"component {idx}"
|
68 |
+
)
|
69 |
+
if gr.utils.is_update(sample):
|
70 |
+
csv_data.append(str(sample))
|
71 |
+
else:
|
72 |
+
new_data = component.deserialize(sample, save_dir=save_dir) if sample is not None else ""
|
73 |
+
if new_data and idx == 1:
|
74 |
+
# TO-DO: change this to a more robust way of getting the image name/identifier
|
75 |
+
# This doesn't work - the directory contains all the images in gallery
|
76 |
+
new_data = new_data.split('/')[-1]
|
77 |
+
csv_data.append(new_data)
|
78 |
+
csv_data.append(str(datetime.datetime.now()))
|
79 |
+
|
80 |
+
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
|
81 |
+
writer = csv.writer(csvfile)
|
82 |
+
if is_new:
|
83 |
+
writer.writerow(gr.utils.sanitize_list_for_csv(headers))
|
84 |
+
writer.writerow(gr.utils.sanitize_list_for_csv(csv_data))
|
85 |
+
|
86 |
+
with open(log_filepath, "r", encoding="utf-8") as csvfile:
|
87 |
+
line_count = len([None for _ in csv.reader(csvfile)]) - 1
|
88 |
+
|
89 |
+
logging.info(f"[SaveRelevance]: Saved a total of {line_count} samples to {log_filepath}")
|
90 |
+
return line_count
|
91 |
|
92 |
## Define model
|
93 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
94 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
95 |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
96 |
|
97 |
+
examples = [[("Dog in the beach"), 2, 'ghost'],
|
98 |
+
[("Paris during night."), 1, 'ghost'],
|
99 |
+
[("A cute kangaroo"), 5, 'ghost'],
|
100 |
+
[("Dois cachorros"), 2, 'ghost'],
|
101 |
+
[("un homme marchant sur le parc"), 3, 'ghost'],
|
102 |
+
[("et høyt fjell"), 2, 'ghost']]
|
103 |
+
|
104 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
|
105 |
|
106 |
#Open the precomputed embeddings
|
107 |
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
|
108 |
with open(emb_filename, 'rb') as fIn:
|
109 |
img_names, img_emb = pickle.load(fIn)
|
110 |
#print(f'img_emb: {print(img_emb)}')
|
111 |
+
#print(f'img_names: {print(img_names)}')
|
112 |
|
113 |
# helper functions
|
114 |
def search_text(query, top_k=1):
|
|
|
122 |
[list]: list of images that are related to the query.
|
123 |
[list]: list of image embs that are related to the query.
|
124 |
"""
|
125 |
+
logging.info(f"[SearchText]: Searching for {query} with top_k={top_k}...")
|
126 |
+
|
127 |
# First, we encode the query.
|
128 |
inputs = tokenizer([query], padding=True, return_tensors="pt")
|
129 |
query_emb = model.get_text_features(**inputs)
|
|
|
139 |
object = Image.open(os.path.join(
|
140 |
"photos/", img_names[hit['corpus_id']]))
|
141 |
image.append(object)
|
142 |
+
# selected_image_embs.append(img_emb[hit['corpus_id']])
|
143 |
#print(f'array length is: {len(image)}')
|
144 |
+
logging.info(f"[SearchText]: Found {len(image)} images.")
|
145 |
return image
|
146 |
|
147 |
+
# def select_image(evt: gr.SelectData):
|
148 |
+
# """ Returns the index of the selected image
|
149 |
|
150 |
+
# Argrs:
|
151 |
+
# evt (SelectData): the event we are listening to
|
152 |
|
153 |
+
# Returns:
|
154 |
+
# int: index of the selected image
|
155 |
+
# """
|
156 |
+
# logging.info(f"[SelectImage]: Selected image {evt.index}.")
|
157 |
+
# return evt.index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
+
callback = SaveRelevanceCallback()
|
161 |
with gr.Blocks() as demo:
|
162 |
# create display
|
163 |
gr.Markdown(
|
164 |
"""
|
165 |
# Text to Image using CLIP Model 📸
|
166 |
|
167 |
+
My version of the Gradio Demo fo CLIP model with the option to select relevance level of each image. \n
|
168 |
+
This demo is based on assessment for the 🤗 Huggingface course 2.
|
169 |
+
|
170 |
+
|
171 |
+
- To use it, simply write which image you are looking for. See the examples section below for more details.
|
172 |
+
- After you submit your query, you will see a gallery of images that are related to your query.
|
173 |
+
- You can select the relevance of each image by using the dropdown menu.
|
174 |
|
175 |
+
---
|
176 |
|
177 |
+
To-do:
|
178 |
+
- [ ] Add a way to save multiple image-relevance pairs at once.
|
179 |
+
- [ ] Improve image identification in the csv file.
|
180 |
"""
|
181 |
)
|
182 |
with gr.Row():
|
|
|
184 |
query = gr.Textbox(lines=4,
|
185 |
label="Write what you are looking for in an image...",
|
186 |
placeholder="Text Here...")
|
187 |
+
top_k = gr.Slider(0, 5, step=1, label="Top K relevant images to show")
|
188 |
+
username = gr.Textbox(lines=1, label="Input your unique username 👻 ", placeholder="Text username here...")
|
189 |
with gr.Column():
|
190 |
gallery = gr.Gallery(
|
191 |
label="Generated images", show_label=False, elem_id="gallery"
|
192 |
).style(grid=[3], height="auto")
|
193 |
+
relevance = gr.Dropdown([str(i) for i in range(6)], multiselect=False,
|
194 |
label="How relevent is this image to your input text?")
|
195 |
with gr.Row():
|
196 |
with gr.Column():
|
197 |
submit_btn = gr.Button("Submit")
|
198 |
with gr.Column():
|
199 |
+
save_btn = gr.Button("Save after you select the relevance of each image")
|
200 |
gr.Markdown("## Here are some examples you can use:")
|
201 |
+
gr.Examples(examples, [query, top_k, username])
|
202 |
+
|
203 |
+
callback.setup([query, gallery, relevance, username], "flagged")
|
204 |
|
205 |
# when user input query and top_k
|
206 |
submit_btn.click(search_text, [query, top_k], [gallery])
|
207 |
|
208 |
+
# image_relevance_state = gr.State(value={})
|
209 |
+
# selected_index = gr.Number(value=0, visible=False, precision=0)
|
|
|
|
|
210 |
|
211 |
# when user select an image in the gallery
|
212 |
+
# gallery.select(select_image, None, selected_index)
|
213 |
# when user select the relevance of the image
|
214 |
+
# relevance.select(fn=select_image_relevance,
|
215 |
+
# inputs=[gallery, selected_index, image_relevance_state],
|
216 |
+
# outputs=image_relevance_state)
|
217 |
|
218 |
# when user click save button
|
219 |
+
# we will flag the current query, selected image, relevance, and username
|
220 |
+
save_btn.click(lambda *args: callback.flag(args), [query, gallery, relevance, username], preprocess=False)
|
221 |
+
# gallery_embs = []
|
222 |
|
223 |
gr.Markdown(
|
224 |
"""
|