File size: 5,668 Bytes
6cdcc54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a2bc3
6cdcc54
 
 
 
 
 
 
 
 
 
 
 
c9849bb
 
6cdcc54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a011821
 
6cdcc54
 
 
 
 
 
 
 
 
 
a011821
6cdcc54
a011821
6cdcc54
 
 
 
a011821
6cdcc54
 
 
 
a011821
 
 
 
 
 
 
6cdcc54
 
 
 
 
 
 
a011821
6cdcc54
 
 
 
 
 
 
 
 
 
a011821
6cdcc54
 
 
 
a011821
6cdcc54
a011821
 
 
 
 
6cdcc54
 
 
a011821
6cdcc54
a011821
6cdcc54
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import io
import base64
from datasets import load_dataset

max_token_budget = 512

min_pixels = 1 * 28 * 28
max_pixels = max_token_budget * 28 * 28
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)

ds = load_dataset("gigant/tib-bench")["train"]


def segments(example):
  # create a text with the <image> tokens from the timestamps of the extracted keyframes and transcript
  text = ""
  segment_i = 0
  for i, timestamp in enumerate(example['keyframes']['timestamp']):
    text += f"<image>" #f"<image {i}>"
    start, end = timestamp[0], timestamp[1]
    while segment_i < len(example["transcript_segments"]["seek"]) and end > example["transcript_segments"]["seek"][segment_i] * 0.01:
      text += example["transcript_segments"]["text"][segment_i]
      segment_i += 1
  if segment_i < len(example["transcript_segments"]):
    text += "".join(example["transcript_segments"]["text"][segment_i:])
  return text

def create_interleaved_html(text, slides, scale=0.4, max_width=600):
    """
    Creates an HTML string with interleaved images and text segments.
    The images are converted to base64 and embedded directly in the HTML.
    """
    html = []
    segments = text.split("<image>")

    for j, segment in enumerate(segments):  # Skip the first empty string bc of leading <image>
        # Add the image
        if j > 0:
          img = slides[j - 1]
          img_width = int(img.width * scale)
          img_height = int(img.height * scale)
          if img_width > max_width:
              ratio = max_width / img_width
              img_width = max_width
              img_height = int(img_height * ratio)

          # Convert image to base64
          buffer = io.BytesIO()
          img.resize((img_width, img_height)).save(buffer, format="PNG")
          img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")

          html.append(f'<img src="data:image/png;base64,{img_str}" style="max-width: {max_width}px; display: block; margin: 20px auto;">')
          # Add the text segment after the image
        html.append(f'<div style="white-space: pre-wrap;">{segment}</div>')

    return "".join(html)

def doc_to_messages(text, slides):
  content = []
  segments = text.split("<image>")
  for j, segment in enumerate(segments):
    if j > 0:
      content.append({"type": "image", "image": slides[j - 1]})
    content.append({"type": "text", "text": segment})
  messages = [
    {
        "role": "user",
        "content": content,
    }
  ]
  # Preparation for inference
  text = processor.apply_chat_template(
      messages, tokenize=False, add_generation_prompt=True
  )
  image_inputs, video_inputs = process_vision_info(messages)
  inputs = processor(
      text=[text],
      images=image_inputs,
      videos=video_inputs,
      padding=True,
      return_tensors="pt",
  )
  return inputs


# Global variables to keep track of current document
current_doc_index = 0
annotations = []

choices = [f"{i} | {ds['title'][i]}" for i in range(len(ds))]

def load_document(index):
    """Load a specific document from the dataset"""
    if 0 <= index < len(ds):
        doc = ds[index]
        segments_doc = segments(doc)
        return (
            doc["title"],
            doc["abstract"],
            create_interleaved_html(segments_doc, doc["slides"], scale=0.7),
            doc_to_messages(segments_doc, doc["slides"]).input_ids.shape[1],
            choices[index],
        )
    return ("", "", "", "", "")

def get_next_document():
    """Get the next document in the dataset"""
    global current_doc_index
    return choices[(current_doc_index + 1) % len(ds)]

def get_prev_document():
    """Get the previous document in the dataset"""
    global current_doc_index
    return choices[(current_doc_index - 1) % len(ds)]

def get_selected_document(arg):
    """Get the selected document from the dataset"""
    global current_doc_index
    index = int(arg.split(" | ")[0])
    current_doc_index = index
    return load_document(current_doc_index)


theme = gr.themes.Ocean()

with gr.Blocks(theme=theme) as demo:
    gr.Markdown("# Slide Presentation Visualization Tool")
    pres_selection_dd = gr.Dropdown(label="Presentation", value=choices[0], choices=choices)
    with gr.Row():
        with gr.Column():
            body = gr.HTML(max_height=400)

        with gr.Column():
            title = gr.Textbox(label="Title", interactive=False, max_lines=1)
            abstract = gr.Textbox(label="Abstract", interactive=False, max_lines=8)
            token_count = gr.Textbox(label=f"Token Count (Qwen2-VL with under {max_token_budget} tokens per image)", interactive=False, max_lines=1)

    # Load first document
    title_val, abstract_val, body_val, token_count_val, choices_val = load_document(current_doc_index)
    title.value = title_val
    abstract.value = abstract_val
    body.value = body_val
    token_count.value = str(token_count_val)
    pres_selection_dd.value = choices_val

    pres_selection_dd.change(
        fn=get_selected_document,
        inputs=pres_selection_dd,
        outputs=[title, abstract, body, token_count, pres_selection_dd],
    )

    with gr.Row():
      prev_button = gr.Button("Previous Document")
      prev_button.click(fn=get_prev_document, inputs=[], outputs=[pres_selection_dd])
      next_button = gr.Button("Next Document")
      next_button.click(fn=get_next_document, inputs=[], outputs=[pres_selection_dd])

demo.launch()