gigant's picture
Update app.py
c9849bb verified
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import io
import base64
from datasets import load_dataset
max_token_budget = 512
min_pixels = 1 * 28 * 28
max_pixels = max_token_budget * 28 * 28
processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
ds = load_dataset("gigant/tib-bench")["train"]
def segments(example):
# create a text with the <image> tokens from the timestamps of the extracted keyframes and transcript
text = ""
segment_i = 0
for i, timestamp in enumerate(example['keyframes']['timestamp']):
text += f"<image>" #f"<image {i}>"
start, end = timestamp[0], timestamp[1]
while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01:
text += example["transcript_segments"]["text"][segment_i]
segment_i += 1
if segment_i < len(example["transcript_segments"]):
text += "".join(example["transcript_segments"]["text"][segment_i:])
return text
def create_interleaved_html(text, slides, scale=0.4, max_width=600):
"""
Creates an HTML string with interleaved images and text segments.
The images are converted to base64 and embedded directly in the HTML.
"""
html = []
segments = text.split("<image>")
for j, segment in enumerate(segments): # Skip the first empty string bc of leading <image>
# Add the image
if j > 0:
img = slides[j - 1]
img_width = int(img.width * scale)
img_height = int(img.height * scale)
if img_width > max_width:
ratio = max_width / img_width
img_width = max_width
img_height = int(img_height * ratio)
# Convert image to base64
buffer = io.BytesIO()
img.resize((img_width, img_height)).save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
html.append(f'<img src="data:image/png;base64,{img_str}" style="max-width: {max_width}px; display: block; margin: 20px auto;">')
# Add the text segment after the image
html.append(f'<div style="white-space: pre-wrap;">{segment}</div>')
return "".join(html)
def doc_to_messages(text, slides):
content = []
segments = text.split("<image>")
for j, segment in enumerate(segments):
if j > 0:
content.append({"type": "image", "image": slides[j - 1]})
content.append({"type": "text", "text": segment})
messages = [
{
"role": "user",
"content": content,
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
return inputs
# Global variables to keep track of current document
current_doc_index = 0
annotations = []
choices = [f"{i} | {ds['title'][i]}" for i in range(len(ds))]
def load_document(index):
"""Load a specific document from the dataset"""
if 0 <= index < len(ds):
doc = ds[index]
segments_doc = segments(doc)
return (
doc["title"],
doc["abstract"],
create_interleaved_html(segments_doc, doc["slides"], scale=0.7),
doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1],
choices[index],
)
return ("", "", "", "", "")
def get_next_document():
"""Get the next document in the dataset"""
global current_doc_index
return choices[(current_doc_index + 1) % len(ds)]
def get_prev_document():
"""Get the previous document in the dataset"""
global current_doc_index
return choices[(current_doc_index - 1) % len(ds)]
def get_selected_document(arg):
"""Get the selected document from the dataset"""
global current_doc_index
index = int(arg.split(" | ")[0])
current_doc_index = index
return load_document(current_doc_index)
theme = gr.themes.Ocean()
with gr.Blocks(theme=theme) as demo:
gr.Markdown("# Slide Presentation Visualization Tool")
pres_selection_dd = gr.Dropdown(label="Presentation", value=choices[0], choices=choices)
with gr.Row():
with gr.Column():
body = gr.HTML(max_height=400)
with gr.Column():
title = gr.Textbox(label="Title", interactive=False, max_lines=1)
abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8)
token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1)
# Load first document
title_val, abstract_val, body_val, token_count_val, choices_val = load_document(current_doc_index)
title.value = title_val
abstract.value = abstract_val
body.value = body_val
token_count.value = str(token_count_val)
pres_selection_dd.value = choices_val
pres_selection_dd.change(
fn=get_selected_document,
inputs=pres_selection_dd,
outputs=[title, abstract, body, token_count, pres_selection_dd],
)
with gr.Row():
prev_button = gr.Button("Previous Document")
prev_button.click(fn=get_prev_document, inputs=[], outputs=[pres_selection_dd])
next_button = gr.Button("Next Document")
next_button.click(fn=get_next_document, inputs=[], outputs=[pres_selection_dd])
demo.launch()