Catherine ZHOU
make buttoms into two columns
89c8ebd
raw
history blame
10.9 kB
import gradio as gr
from gradio.flagging import FlaggingCallback, SimpleCSVLogger
from gradio.components import IOComponent
from gradio_client import utils as client_utils
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from sentence_transformers import util
import pickle
from PIL import Image
import os
import logging
import csv
import datetime
import time
from pathlib import Path
from typing import List, Any
class SaveRelevanceCallback(FlaggingCallback):
""" Callback to save the image relevance state to a csv file
"""
def __init__(self):
pass
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
"""
This method gets called once at the beginning of the Interface.launch() method.
Args:
components ([IOComponent]): Set of components that will provide flagged data.
flagging_dir (string): typically containing the path to the directory where the flagging file should be storied
(provided as an argument to Interface.__init__()).
"""
self.components = components
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
logging.info(f"[SaveRelevance]: Flagging directory set to {flagging_dir}")
def flag(self,
flag_data: List[Any],
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> int:
"""
This gets called every time the <flag> button is pressed.
Args:
interface: The Interface object that is being used to launch the flagging interface.
flag_data: The data to be flagged.
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
flag_index (optional): The index of the sample that is being flagged.
username (optional): The username of the user that is flagging the data, if logged in.
Returns:
(int): The total number of samples that have been flagged.
"""
logging.info("[SaveRelevance]: Flagging data...")
flagging_dir = self.flagging_dir
log_filepath = Path(flagging_dir) / "relevance_log.csv"
is_new = not Path(log_filepath).exists()
headers = ["query", "selected image", "relevance", "username", "timestamp"]
csv_data = []
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
save_dir = Path(
flagging_dir
) / client_utils.strip_invalid_filename_characters(
getattr(component, "label", None) or f"component {idx}"
)
if gr.utils.is_update(sample):
csv_data.append(str(sample))
else:
new_data = component.deserialize(sample, save_dir=save_dir) if sample is not None else ""
if new_data and idx == 1:
# TO-DO: change this to a more robust way of getting the image name/identifier
# This doesn't work - the directory contains all the images in gallery
new_data = new_data.split('/')[-1]
csv_data.append(new_data)
csv_data.append(str(datetime.datetime.now()))
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(gr.utils.sanitize_list_for_csv(headers))
writer.writerow(gr.utils.sanitize_list_for_csv(csv_data))
with open(log_filepath, "r", encoding="utf-8") as csvfile:
line_count = len([None for _ in csv.reader(csvfile)]) - 1
logging.info(f"[SaveRelevance]: Saved a total of {line_count} samples to {log_filepath}")
return line_count
## Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
examples = [[("Dog in the beach"), 2, 'ghost'],
[("Paris during night."), 1, 'ghost'],
[("A cute kangaroo"), 5, 'ghost'],
[("Dois cachorros"), 2, 'ghost'],
[("un homme marchant sur le parc"), 3, 'ghost'],
[("et høyt fjell"), 2, 'ghost']]
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
#Open the precomputed embeddings
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
with open(emb_filename, 'rb') as fIn:
img_names, img_emb = pickle.load(fIn)
#print(f'img_emb: {print(img_emb)}')
#print(f'img_names: {print(img_names)}')
# helper functions
def search_text(query, top_k=1):
"""" Search an image based on the text query.
Args:
query ([string]): query you want search for
top_k (int, optional): Amount of images o return]. Defaults to 1.
Returns:
[list]: list of images with captions that are related to the query.
[list]: list of images that are related to the query.
[list]: list of captions with the images that are related to the query.
[time]: start time of marking relevance of the images.
"""
logging.info(f"[SearchText]: Searching for {query} with top_k={top_k}...")
# First, we encode the query.
inputs = tokenizer([query], padding=True, return_tensors="pt")
query_emb = model.get_text_features(**inputs)
# Then, we use the util.semantic_search function, which computes the cosine-similarity
# between the query embedding and all image embeddings.
# It then returns the top_k highest ranked images, which we output
hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0]
image_caption = []
images = []
captions = []
for hit in hits:
#print(img_names[hit['corpus_id']])
object = Image.open(os.path.join(
"photos/", img_names[hit['corpus_id']]))
caption = ""
image_caption.append((object, caption))
images.append(object)
captions.append(caption)
curr_time = time.time()
logging.info(f"[SearchText]: Found {len(image_caption)} images at "
f"{time.ctime(curr_time)}.")
return image_caption, images, captions, curr_time
def display(images, texts, event_data: gr.SelectData):
""" Display the selected image and its caption.
Args:
images ([list]): list of images
texts ([list]): list of captions
event_data (gr.SelectData): data from the select event
Returns:
[object]: image
[string]: caption
"""
return images[event_data.index], texts[event_data.index]
callback = SaveRelevanceCallback()
time_record = SimpleCSVLogger()
with gr.Blocks(title="Text to Image using CLIP Model 📸") as demo:
# create display
gr.Markdown(
"""
# Text to Image using CLIP Model 📸
My version of the Gradio Demo fo CLIP model with the option to select relevance level of each image. \n
This demo is based on assessment for the 🤗 Huggingface course 2.
- To use it, simply write which image you are looking for. See the examples section below for more details.
- After you submit your query, you will see a gallery of images that are related to your query.
- You can select the relevance of each image by using the dropdown menu.
- Click save buttom to save the image and its relevance to [a csv file](./blob/main/image_relevance/relevance_log.csv).
- After you are done with all the images, click the `I'm Done!` buttom. We will save the time you spent to mark all images.
---
To-do:
- Add a way to save multiple image-relevance pairs at once.
- Improve image identification in the csv file. ✅
- Record time spent to mark all images. ✅
"""
)
with gr.Row():
with gr.Column():
query = gr.Textbox(lines=4,
label="Query",
placeholder="Text Here...")
top_k = gr.Slider(0, 5, step=1, label="Top K")
username = gr.Textbox(lines=1, label="Your Name",
placeholder="Text username here...")
submit_btn = gr.Button("Submit")
with gr.Column():
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[3], height="auto")
t = gr.Textbox(label="Image Caption")
relevance = gr.Dropdown(
["0: Not relevant",
"1: Related but not relevant",
"2: Somehow relevant",
"3: Highly relevant"
], multiselect=False,
label="How relevent is this image?"
)
with gr.Row():
save_btn = gr.Button(
"Save after you select the relevance of each image")
save_all_btn = gr.Button("I'm finished!")
i = gr.Image(interactive=False, label="Selected Image", visible=False)
gr.Markdown("## Here are some examples you can use:")
gr.Examples(examples, [query, top_k, username])
# states for passing images and texts to other blocks
images = gr.State()
texts = gr.State()
start_time = gr.Number(visible=False)
time_spent = gr.Number(visible=False)
# when user input query and top_k
submit_btn.click(search_text, [query, top_k], [gallery, images, texts, start_time])
# selected = gr.State()
gallery.select(display, [images, texts], [i, t])
# when user click save button
# we will flag the current query, selected image, relevance, and username
callback.setup([query, i, relevance, username], "image_relevance")
time_record.setup([query, username, start_time, time_spent], "time")
save_btn.click(lambda *args: callback.flag(args),
[query, i, relevance, username], preprocess=False)
def log_time(query, username, start_time):
logging.info(f"[SaveAll]: Saving time for {query} by {username} from {time.ctime(start_time)}.")
time_record.flag([query, username,
str(datetime.datetime.fromtimestamp(time.time())),
round(time.time() - start_time, 3)])
save_all_btn.click(log_time, [query, username, start_time], preprocess=False)
gr.Markdown(
"""
You find more information about this demo on my ✨ github repository [marcelcastrobr](https://github.com/marcelcastrobr/huggingface_course2)
"""
)
if __name__ == "__main__":
demo.launch(debug=True)