prithivMLmods commited on
Commit
8081540
·
verified ·
1 Parent(s): f022e05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -98
app.py CHANGED
@@ -2,18 +2,14 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import cv2
5
- import spaces
6
- import time
7
- import re
8
  from PIL import Image
9
  from threading import Thread
10
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
 
 
11
 
12
- #####################################
13
- # 1. Load Model & Processor
14
- #####################################
15
- MODEL_ID = "google/gemma-3-12b-it" # Adjust model ID as needed
16
-
17
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
18
  model = Gemma3ForConditionalGeneration.from_pretrained(
19
  MODEL_ID,
@@ -22,125 +18,155 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
22
  ).to("cuda")
23
  model.eval()
24
 
25
- #####################################
26
- # 2. Helper Function: Get a Working Camera
27
- #####################################
28
- def get_working_camera():
29
- """
30
- Tries camera indices 0, 1, and 2 until a working camera is found.
31
- Returns the VideoCapture object or None if no camera can be opened.
32
  """
33
- for i in range(3):
34
- cap = cv2.VideoCapture(i)
35
- if cap.isOpened():
36
- return cap
37
- return None
38
-
39
- #####################################
40
- # 3. Helper Function: Capture Live Frames
41
- #####################################
42
- def capture_live_frames(duration=5, num_frames=10):
43
- """
44
- Captures live frames from a working webcam for a specified duration.
45
  Returns a list of (PIL Image, timestamp) tuples.
46
  """
47
- cap = get_working_camera()
48
- if cap is None:
49
- return [] # No working camera found
50
-
51
- # Try to get FPS; default to 30 if not available.
52
- fps = cap.get(cv2.CAP_PROP_FPS)
53
- if fps <= 0:
54
- fps = 30
55
- total_frames_to_capture = int(duration * fps)
56
- frame_indices = np.linspace(0, total_frames_to_capture - 1, num_frames, dtype=int)
57
-
58
- captured_frames = []
59
- frame_count = 0
60
- start_time = time.time()
61
-
62
- while frame_count < total_frames_to_capture:
63
- ret, frame = cap.read()
64
- if not ret:
65
- break
66
- if frame_count in frame_indices:
67
- # Convert from BGR to RGB for PIL
68
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
69
- pil_image = Image.fromarray(frame_rgb)
70
- timestamp = round(frame_count / fps, 2)
71
- captured_frames.append((pil_image, timestamp))
72
- frame_count += 1
73
- if time.time() - start_time > duration:
74
- break
75
- cap.release()
76
- return captured_frames
77
-
78
- #####################################
79
- # 4. Live Inference Function
80
- #####################################
81
  @spaces.GPU
82
- def live_inference(duration=5):
83
  """
84
- Captures live frames from the webcam, builds a prompt, and returns the generated text.
 
85
  """
86
- frames = capture_live_frames(duration=duration, num_frames=10)
 
 
 
87
  if not frames:
88
- return "Could not capture live frames from the webcam."
89
-
90
- # Build prompt using captured frames and timestamps.
91
- messages = [{
92
- "role": "user",
93
- "content": [{"type": "text", "text": "Please describe what's happening in this live video."}]
94
- }]
 
 
95
  for (image, ts) in frames:
96
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
97
  messages[0]["content"].append({"type": "image", "image": image})
98
-
99
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
100
  frame_images = [img for (img, _) in frames]
101
-
102
  inputs = processor(
103
  text=[prompt],
104
  images=frame_images,
105
  return_tensors="pt",
106
  padding=True
107
  ).to("cuda")
108
-
109
- # Generate text output using a streaming approach.
110
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
111
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
112
-
113
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
114
  thread.start()
115
-
116
  generated_text = ""
117
  for new_text in streamer:
118
  generated_text += new_text
119
  time.sleep(0.01)
120
-
121
  return generated_text
122
 
123
- #####################################
124
- # 5. Build Gradio Live App
125
- #####################################
126
- def build_live_app():
 
 
 
 
 
 
 
127
  with gr.Blocks() as demo:
128
- gr.Markdown("# **Live Video Analysis**\n\nPress **Start** to capture a few seconds of live video from your webcam and analyze the content.")
129
- with gr.Column():
130
- duration_input = gr.Number(label="Capture Duration (seconds)", value=5, precision=0)
131
- start_btn = gr.Button("Start")
132
- output_text = gr.Textbox(label="Model Output")
133
- restart_btn = gr.Button("Start Again", visible=False)
134
-
135
- # Function to trigger live inference and reveal the restart button
136
- def start_inference(duration):
137
- text = live_inference(duration)
138
- return text, gr.update(visible=True)
139
-
140
- start_btn.click(fn=start_inference, inputs=duration_input, outputs=[output_text, restart_btn])
141
- restart_btn.click(fn=start_inference, inputs=duration_input, outputs=[output_text, restart_btn])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  return demo
143
 
144
  if __name__ == "__main__":
145
- app = build_live_app()
146
- app.launch(debug=True, share=True)
 
2
  import torch
3
  import numpy as np
4
  import cv2
 
 
 
5
  from PIL import Image
6
  from threading import Thread
7
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
8
+ import spaces
9
+ import time
10
 
11
+ # Load Model & Processor
12
+ MODEL_ID = "google/gemma-3-12b-it"
 
 
 
13
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
14
  model = Gemma3ForConditionalGeneration.from_pretrained(
15
  MODEL_ID,
 
18
  ).to("cuda")
19
  model.eval()
20
 
21
+ # Helper Function: Downsample Video
22
+ def downsample_video(video_path, max_duration=10, num_frames=10):
 
 
 
 
 
23
  """
24
+ Downsamples the video to `num_frames` evenly spaced frames within the first `max_duration` seconds.
 
 
 
 
 
 
 
 
 
 
 
25
  Returns a list of (PIL Image, timestamp) tuples.
26
  """
27
+ vidcap = cv2.VideoCapture(video_path)
28
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
29
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
30
+ if fps <= 0 or total_frames <= 0:
31
+ vidcap.release()
32
+ return []
33
+
34
+ # Limit to first `max_duration` seconds
35
+ max_frames = min(int(fps * max_duration), total_frames)
36
+ frame_indices = np.linspace(0, max_frames - 1, num_frames, dtype=int)
37
+
38
+ frames = []
39
+ for i in frame_indices:
40
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
41
+ success, image = vidcap.read()
42
+ if success:
43
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
44
+ pil_image = Image.fromarray(image)
45
+ timestamp = round(i / fps, 2)
46
+ frames.append((pil_image, timestamp))
47
+ vidcap.release()
48
+ return frames
49
+
50
+ # Inference Function
 
 
 
 
 
 
 
 
 
 
51
  @spaces.GPU
52
+ def video_inference(video_file):
53
  """
54
+ Processes the video file and generates a text description based on the first 10 seconds.
55
+ Returns the generated text.
56
  """
57
+ if video_file is None:
58
+ return "No video provided."
59
+
60
+ frames = downsample_video(video_file, max_duration=10, num_frames=10)
61
  if not frames:
62
+ return "Could not read frames from video."
63
+
64
+ # Construct prompt
65
+ messages = [
66
+ {
67
+ "role": "user",
68
+ "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
69
+ }
70
+ ]
71
  for (image, ts) in frames:
72
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
73
  messages[0]["content"].append({"type": "image", "image": image})
74
+
75
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
76
  frame_images = [img for (img, _) in frames]
77
+
78
  inputs = processor(
79
  text=[prompt],
80
  images=frame_images,
81
  return_tensors="pt",
82
  padding=True
83
  ).to("cuda")
84
+
85
+ # Generate text with streaming
86
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
87
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
88
+
89
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
90
  thread.start()
91
+
92
  generated_text = ""
93
  for new_text in streamer:
94
  generated_text += new_text
95
  time.sleep(0.01)
96
+
97
  return generated_text
98
 
99
+ # Button Toggle Function
100
+ def toggle_button(has_result):
101
+ """
102
+ Returns button label and visibility states based on whether a result has been generated.
103
+ """
104
+ if has_result:
105
+ return "Start Again", gr.Button(visible=True), gr.Button(visible=False)
106
+ return "Start", gr.Button(visible=False), gr.Button(visible=True)
107
+
108
+ # Build the Gradio App
109
+ def build_app():
110
  with gr.Blocks() as demo:
111
+ gr.Markdown("""
112
+ # **Gemma-3 Live Video Analysis**
113
+ Press **Start** to record a short video clip (up to 10 seconds). Stop recording to see the analysis.
114
+ After the result, press **Start Again** to analyze another clip.
115
+ """)
116
+
117
+ # State to track if a result has been generated
118
+ has_result = gr.State(value=False)
119
+
120
+ with gr.Row():
121
+ with gr.Column():
122
+ video = gr.Video(
123
+ source="webcam",
124
+ label="Webcam Recording",
125
+ format="mp4"
126
+ )
127
+ # Two buttons: one for Start, one for Start Again
128
+ start_again_btn = gr.Button("Start Again", visible=False)
129
+ start_btn = gr.Button("Start", visible=True)
130
+ with gr.Column():
131
+ output_text = gr.Textbox(label="Model Output")
132
+
133
+ # When video is recorded and stopped, process it
134
+ def process_video(video_file, has_result_state):
135
+ if video_file is None:
136
+ return "Please record a video.", has_result_state
137
+ result = video_inference(video_file)
138
+ return result, True
139
+
140
+ video.change(
141
+ fn=process_video,
142
+ inputs=[video, has_result],
143
+ outputs=[output_text, has_result]
144
+ )
145
+
146
+ # Update button visibility based on has_result
147
+ has_result.change(
148
+ fn=toggle_button,
149
+ inputs=has_result,
150
+ outputs=[start_again_btn, start_again_btn, start_btn]
151
+ )
152
+
153
+ # Clicking either button resets the video and output
154
+ def reset_state():
155
+ return None, "", False
156
+
157
+ start_btn.click(
158
+ fn=reset_state,
159
+ inputs=None,
160
+ outputs=[video, output_text, has_result]
161
+ )
162
+ start_again_btn.click(
163
+ fn=reset_state,
164
+ inputs=None,
165
+ outputs=[video, output_text, has_result]
166
+ )
167
+
168
  return demo
169
 
170
  if __name__ == "__main__":
171
+ app = build_app()
172
+ app.launch(debug=True)