Spaces:
Running
Running
File size: 5,668 Bytes
6cdcc54 09a2bc3 6cdcc54 c9849bb 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 a011821 6cdcc54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import io
import base64
from datasets import load_dataset
max_token_budget = 512
min_pixels = 1 * 28 * 28
max_pixels = max_token_budget * 28 * 28
processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
ds = load_dataset("gigant/tib-bench")["train"]
def segments(example):
# create a text with the <image> tokens from the timestamps of the extracted keyframes and transcript
text = ""
segment_i = 0
for i, timestamp in enumerate(example['keyframes']['timestamp']):
text += f"<image>" #f"<image {i}>"
start, end = timestamp[0], timestamp[1]
while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01:
text += example["transcript_segments"]["text"][segment_i]
segment_i += 1
if segment_i < len(example["transcript_segments"]):
text += "".join(example["transcript_segments"]["text"][segment_i:])
return text
def create_interleaved_html(text, slides, scale=0.4, max_width=600):
"""
Creates an HTML string with interleaved images and text segments.
The images are converted to base64 and embedded directly in the HTML.
"""
html = []
segments = text.split("<image>")
for j, segment in enumerate(segments): # Skip the first empty string bc of leading <image>
# Add the image
if j > 0:
img = slides[j - 1]
img_width = int(img.width * scale)
img_height = int(img.height * scale)
if img_width > max_width:
ratio = max_width / img_width
img_width = max_width
img_height = int(img_height * ratio)
# Convert image to base64
buffer = io.BytesIO()
img.resize((img_width, img_height)).save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
html.append(f'<img src="data:image/png;base64,{img_str}" style="max-width: {max_width}px; display: block; margin: 20px auto;">')
# Add the text segment after the image
html.append(f'<div style="white-space: pre-wrap;">{segment}</div>')
return "".join(html)
def doc_to_messages(text, slides):
content = []
segments = text.split("<image>")
for j, segment in enumerate(segments):
if j > 0:
content.append({"type": "image", "image": slides[j - 1]})
content.append({"type": "text", "text": segment})
messages = [
{
"role": "user",
"content": content,
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
return inputs
# Global variables to keep track of current document
current_doc_index = 0
annotations = []
choices = [f"{i} | {ds['title'][i]}" for i in range(len(ds))]
def load_document(index):
"""Load a specific document from the dataset"""
if 0 <= index < len(ds):
doc = ds[index]
segments_doc = segments(doc)
return (
doc["title"],
doc["abstract"],
create_interleaved_html(segments_doc, doc["slides"], scale=0.7),
doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1],
choices[index],
)
return ("", "", "", "", "")
def get_next_document():
"""Get the next document in the dataset"""
global current_doc_index
return choices[(current_doc_index + 1) % len(ds)]
def get_prev_document():
"""Get the previous document in the dataset"""
global current_doc_index
return choices[(current_doc_index - 1) % len(ds)]
def get_selected_document(arg):
"""Get the selected document from the dataset"""
global current_doc_index
index = int(arg.split(" | ")[0])
current_doc_index = index
return load_document(current_doc_index)
theme = gr.themes.Ocean()
with gr.Blocks(theme=theme) as demo:
gr.Markdown("# Slide Presentation Visualization Tool")
pres_selection_dd = gr.Dropdown(label="Presentation", value=choices[0], choices=choices)
with gr.Row():
with gr.Column():
body = gr.HTML(max_height=400)
with gr.Column():
title = gr.Textbox(label="Title", interactive=False, max_lines=1)
abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8)
token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1)
# Load first document
title_val, abstract_val, body_val, token_count_val, choices_val = load_document(current_doc_index)
title.value = title_val
abstract.value = abstract_val
body.value = body_val
token_count.value = str(token_count_val)
pres_selection_dd.value = choices_val
pres_selection_dd.change(
fn=get_selected_document,
inputs=pres_selection_dd,
outputs=[title, abstract, body, token_count, pres_selection_dd],
)
with gr.Row():
prev_button = gr.Button("Previous Document")
prev_button.click(fn=get_prev_document, inputs=[], outputs=[pres_selection_dd])
next_button = gr.Button("Next Document")
next_button.click(fn=get_next_document, inputs=[], outputs=[pres_selection_dd])
demo.launch() |