prithivMLmods commited on
Commit
f5ecaf8
·
verified ·
1 Parent(s): 7493bfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -164
app.py CHANGED
@@ -1,179 +1,129 @@
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
- import cv2
5
- from PIL import Image
6
  from threading import Thread
7
- from transformers import (
8
- AutoModelForCausalLM,
9
- AutoTokenizer,
10
- TextIteratorStreamer,
11
- Qwen2VLForConditionalGeneration,
12
- AutoProcessor,
13
- )
14
- import spaces
15
  import time
 
 
 
 
16
 
17
- # Load Model & Processor
18
- MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
19
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
20
- model = Qwen2VLForConditionalGeneration.from_pretrained(
21
- MODEL_ID,
22
- trust_remote_code=True,
23
- torch_dtype=torch.bfloat16
24
- ).to("cuda")
25
- model.eval()
26
-
27
- # Helper Function: Downsample Video
28
- def downsample_video(video_path, max_duration=10, num_frames=10):
29
- """
30
- Downsamples the video to `num_frames` evenly spaced frames within the first `max_duration` seconds.
31
- Returns a list of (PIL Image, timestamp) tuples.
32
- """
33
- vidcap = cv2.VideoCapture(video_path)
34
- fps = vidcap.get(cv2.CAP_PROP_FPS)
35
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
36
- if fps <= 0 or total_frames <= 0:
37
- vidcap.release()
38
- return []
39
-
40
- # Limit to first `max_duration` seconds
41
- max_frames = min(int(fps * max_duration), total_frames)
42
- frame_indices = np.linspace(0, max_frames - 1, num_frames, dtype=int)
43
 
44
- frames = []
45
- for i in frame_indices:
46
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
47
- success, image = vidcap.read()
48
- if success:
49
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
- pil_image = Image.fromarray(image)
51
- timestamp = round(i / fps, 2)
52
- frames.append((pil_image, timestamp))
53
- vidcap.release()
54
- return frames
55
 
56
- # Inference Function
57
  @spaces.GPU
58
- def video_inference(video_file):
59
- """
60
- Processes the video file and generates a text description based on the first 10 seconds.
61
- Returns the generated text.
62
- """
63
- if video_file is None:
64
- return "No video provided."
65
-
66
- frames = downsample_video(video_file, max_duration=10, num_frames=10)
67
- if not frames:
68
- return "Could not read frames from video."
69
-
70
- # Construct prompt
71
- messages = [
72
- {
73
- "role": "user",
74
- "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
75
- }
76
- ]
77
- for (image, ts) in frames:
78
- messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
79
- messages[0]["content"].append({"type": "image", "image": image})
80
-
81
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
- frame_images = [img for (img, _) in frames]
83
-
84
- inputs = processor(
85
- text=[prompt],
86
- images=frame_images,
87
- return_tensors="pt",
88
- padding=True
89
- ).to("cuda")
90
-
91
- # Generate text with streaming
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
93
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
 
94
 
95
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
96
  thread.start()
97
 
98
- generated_text = ""
 
 
99
  for new_text in streamer:
100
- generated_text += new_text
 
101
  time.sleep(0.01)
102
-
103
- return generated_text
104
-
105
- # Button Toggle Function
106
- def toggle_button(has_result):
107
- """
108
- Returns visibility states for start_again_btn and start_btn based on has_result.
109
- """
110
- if has_result:
111
- return gr.update(visible=True), gr.update(visible=False)
112
- else:
113
- return gr.update(visible=False), gr.update(visible=True)
114
-
115
- # Build the Gradio App
116
- def build_app():
117
- with gr.Blocks() as demo:
118
- gr.Markdown("""
119
- # **Gemma-3 Live Video Analysis**
120
- Press **Start** to record a short video clip (up to 10 seconds). Stop recording to see the analysis.
121
- After the result, press **Start Again** to analyze another clip.
122
- """)
123
-
124
- # State to track if a result has been generated
125
- has_result = gr.State(value=False)
126
-
127
- with gr.Row():
128
- with gr.Column():
129
- video = gr.Video(
130
- sources=["webcam"],
131
- label="Webcam Recording",
132
- format="mp4"
133
  )
134
- # Two buttons: one for Start, one for Start Again
135
- start_btn = gr.Button("Start", visible=True)
136
- start_again_btn = gr.Button("Start Again", visible=False)
137
- with gr.Column():
138
- output_text = gr.Textbox(label="Model Output")
139
-
140
- # When video is recorded and stopped, process it
141
- def process_video(video_file, has_result_state):
142
- if video_file is None:
143
- return "Please record a video.", has_result_state
144
- result = video_inference(video_file)
145
- return result, True
146
-
147
- video.change(
148
- fn=process_video,
149
- inputs=[video, has_result],
150
- outputs=[output_text, has_result]
151
- )
152
-
153
- # Update button visibility based on has_result
154
- has_result.change(
155
- fn=toggle_button,
156
- inputs=has_result,
157
- outputs=[start_again_btn, start_btn]
158
- )
159
-
160
- # Clicking either button resets the video and output
161
- def reset_state():
162
- return None, "", False
163
-
164
- start_btn.click(
165
- fn=reset_state,
166
- inputs=None,
167
- outputs=[video, output_text, has_result]
168
- )
169
- start_again_btn.click(
170
- fn=reset_state,
171
- inputs=None,
172
- outputs=[video, output_text, has_result]
173
- )
174
-
175
- return demo
176
 
177
- if __name__ == "__main__":
178
- app = build_app()
179
- app.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
 
 
 
3
  from threading import Thread
4
+ import re
 
 
 
 
 
 
 
5
  import time
6
+ import torch
7
+ import spaces
8
+ import subprocess
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
+ from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
14
+ model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct",
15
+ _attn_implementation="flash_attention_2",
16
+ torch_dtype=torch.bfloat16).to("cuda:0")
 
 
 
 
 
 
 
17
 
 
18
  @spaces.GPU
19
+ def model_inference(
20
+ input_dict, history, max_tokens
21
+ ):
22
+ text = input_dict["text"]
23
+ images = []
24
+ user_content = []
25
+ media_queue = []
26
+ if history == []:
27
+ text = input_dict["text"].strip()
28
+
29
+ for file in input_dict.get("files", []):
30
+ if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
31
+ media_queue.append({"type": "image", "path": file})
32
+ elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
33
+ media_queue.append({"type": "video", "path": file})
34
+
35
+ if "<image>" in text or "<video>" in text:
36
+ parts = re.split(r'(<image>|<video>)', text)
37
+ for part in parts:
38
+ if part == "<image>" and media_queue:
39
+ user_content.append(media_queue.pop(0))
40
+ elif part == "<video>" and media_queue:
41
+ user_content.append(media_queue.pop(0))
42
+ elif part.strip():
43
+ user_content.append({"type": "text", "text": part.strip()})
44
+ else:
45
+ user_content.append({"type": "text", "text": text})
46
+
47
+ for media in media_queue:
48
+ user_content.append(media)
49
+
50
+ resulting_messages = [{"role": "user", "content": user_content}]
51
+
52
+ elif len(history) > 0:
53
+ resulting_messages = []
54
+ user_content = []
55
+ media_queue = []
56
+ for hist in history:
57
+ if hist["role"] == "user" and isinstance(hist["content"], tuple):
58
+ file_name = hist["content"][0]
59
+ if file_name.endswith((".png", ".jpg", ".jpeg")):
60
+ media_queue.append({"type": "image", "path": file_name})
61
+ elif file_name.endswith(".mp4"):
62
+ media_queue.append({"type": "video", "path": file_name})
63
+
64
+ for hist in history:
65
+ if hist["role"] == "user" and isinstance(hist["content"], str):
66
+ text = hist["content"]
67
+ parts = re.split(r'(<image>|<video>)', text)
68
+
69
+ for part in parts:
70
+ if part == "<image>" and media_queue:
71
+ user_content.append(media_queue.pop(0))
72
+ elif part == "<video>" and media_queue:
73
+ user_content.append(media_queue.pop(0))
74
+ elif part.strip():
75
+ user_content.append({"type": "text", "text": part.strip()})
76
+
77
+ elif hist["role"] == "assistant":
78
+ resulting_messages.append({
79
+ "role": "user",
80
+ "content": user_content
81
+ })
82
+ resulting_messages.append({
83
+ "role": "assistant",
84
+ "content": [{"type": "text", "text": hist["content"]}]
85
+ })
86
+ user_content = []
87
+
88
+ if text == "" and not images:
89
+ gr.Error("Please input a query and optionally image(s).")
90
+
91
+ if text == "" and images:
92
+ gr.Error("Please input a text query along the images(s).")
93
+ print("resulting_messages", resulting_messages)
94
+ inputs = processor.apply_chat_template(
95
+ resulting_messages,
96
+ add_generation_prompt=True,
97
+ tokenize=True,
98
+ return_dict=True,
99
+ return_tensors="pt",
100
+ )
101
+
102
+ inputs = inputs.to(model.device)
103
+
104
+ # Generate
105
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
106
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
107
+ generated_text = ""
108
 
109
+ thread = Thread(target=model.generate, kwargs=generation_args)
110
  thread.start()
111
 
112
+ yield "..."
113
+ buffer = ""
114
+
115
  for new_text in streamer:
116
+ buffer += new_text
117
+ generated_text_without_prompt = buffer
118
  time.sleep(0.01)
119
+ yield buffer
120
+
121
+ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺",
122
+ description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
123
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
124
+ cache_examples=False,
125
+ additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
126
+ type="messages"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ demo.launch(debug=True)