robinhad commited on
Commit
7d9b785
·
verified ·
1 Parent(s): db68561

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Llama4ForConditionalGeneration, TextIteratorStreamer
3
+ from transformers.image_utils import load_image
4
+ from threading import Thread
5
+ import time
6
+ import torch
7
+ import spaces
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ def progress_bar_html(label: str) -> str:
13
+ """
14
+ Returns an HTML snippet for a thin progress bar with a label.
15
+ The progress bar is styled as a dark animated bar.
16
+ """
17
+ return f'''
18
+ <div style="display: flex; align-items: center;">
19
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
20
+ <div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
21
+ <div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
22
+ </div>
23
+ </div>
24
+ <style>
25
+ @keyframes loading {{
26
+ 0% {{ transform: translateX(-100%); }}
27
+ 100% {{ transform: translateX(100%); }}
28
+ }}
29
+ </style>
30
+ '''
31
+
32
+ def downsample_video(video_path):
33
+ """
34
+ Downsamples the video to 10 evenly spaced frames.
35
+ Each frame is converted to a PIL Image along with its timestamp.
36
+ """
37
+ vidcap = cv2.VideoCapture(video_path)
38
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
39
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
40
+ frames = []
41
+ if total_frames <= 0 or fps <= 0:
42
+ vidcap.release()
43
+ return frames
44
+ # Sample 10 evenly spaced frames.
45
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
46
+ for i in frame_indices:
47
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
48
+ success, image = vidcap.read()
49
+ if success:
50
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
+ pil_image = Image.fromarray(image)
52
+ timestamp = round(i / fps, 2)
53
+ frames.append((pil_image, timestamp))
54
+ vidcap.release()
55
+ return frames
56
+
57
+ MODEL_ID = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
58
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
+ """model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
+ MODEL_ID,
61
+ trust_remote_code=True,
62
+ torch_dtype=torch.bfloat16
63
+ ).to("cuda").eval()"""
64
+
65
+ model = Llama4ForConditionalGeneration.from_pretrained(
66
+ model_id,
67
+ attn_implementation="flex_attention",
68
+ device_map="auto",
69
+ torch_dtype=torch.bfloat16,
70
+ ).to("cuda").eval()
71
+
72
+ @spaces.GPU
73
+ def model_inference(input_dict, history):
74
+ text = input_dict["text"]
75
+ files = input_dict["files"]
76
+
77
+ if text.strip().lower().startswith("@video-infer"):
78
+ # Remove the tag from the query.
79
+ text = text[len("@video-infer"):].strip()
80
+ if not files:
81
+ gr.Error("Please upload a video file along with your @video-infer query.")
82
+ return
83
+ # Assume the first file is a video.
84
+ video_path = files[0]
85
+ frames = downsample_video(video_path)
86
+ if not frames:
87
+ gr.Error("Could not process video.")
88
+ return
89
+ # Build messages: start with the text prompt.
90
+ messages = [
91
+ {
92
+ "role": "user",
93
+ "content": [{"type": "text", "text": text}]
94
+ }
95
+ ]
96
+ # Append each frame with a timestamp label.
97
+ for image, timestamp in frames:
98
+ messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
99
+ messages[0]["content"].append({"type": "image", "image": image})
100
+ # Collect only the images from the frames.
101
+ video_images = [image for image, _ in frames]
102
+ # Prepare the prompt.
103
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
104
+ inputs = processor(
105
+ text=[prompt],
106
+ images=video_images,
107
+ return_tensors="pt",
108
+ padding=True,
109
+ ).to("cuda")
110
+ # Set up streaming generation.
111
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
112
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
113
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
114
+ thread.start()
115
+ buffer = ""
116
+ yield progress_bar_html("Processing video with Qwen2.5VL Model")
117
+ for new_text in streamer:
118
+ buffer += new_text
119
+ time.sleep(0.01)
120
+ yield buffer
121
+ return
122
+
123
+ if len(files) > 1:
124
+ images = [load_image(image) for image in files]
125
+ elif len(files) == 1:
126
+ images = [load_image(files[0])]
127
+ else:
128
+ images = []
129
+
130
+ if text == "" and not images:
131
+ gr.Error("Please input a query and optionally image(s).")
132
+ return
133
+ if text == "" and images:
134
+ gr.Error("Please input a text query along with the image(s).")
135
+ return
136
+
137
+ messages = [
138
+ {
139
+ "role": "user",
140
+ "content": [
141
+ *[{"type": "image", "image": image} for image in images],
142
+ {"type": "text", "text": text},
143
+ ],
144
+ }
145
+ ]
146
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
147
+ inputs = processor(
148
+ text=[prompt],
149
+ images=images if images else None,
150
+ return_tensors="pt",
151
+ padding=True,
152
+ ).to("cuda")
153
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
154
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
155
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
156
+ thread.start()
157
+ buffer = ""
158
+ yield progress_bar_html("Processing with Qwen2.5VL Model")
159
+ for new_text in streamer:
160
+ buffer += new_text
161
+ time.sleep(0.01)
162
+ yield buffer
163
+
164
+ examples = [
165
+ [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
166
+ [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
167
+ [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
168
+ [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
169
+ ]
170
+
171
+ demo = gr.ChatInterface(
172
+ fn=model_inference,
173
+ description="# **meta-llama/Llama-4-Scout-17B-16E-Instruct `@video-infer for video understanding`** (based on demo from here: https://huggingface.co/spaces/prithivMLmods/Qwen2.5-VL-7B-Instruct)",
174
+ examples=examples,
175
+ fill_height=True,
176
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
177
+ stop_btn="Stop Generation",
178
+ multimodal=True,
179
+ cache_examples=False,
180
+ )
181
+
182
+ demo.launch(debug=True)