prithivMLmods commited on
Commit
652a69b
·
verified ·
1 Parent(s): 67f9a49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -125
app.py CHANGED
@@ -2,20 +2,16 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import cv2
5
- import matplotlib.pyplot as plt
6
- import random
7
- import spaces
8
  import time
9
  import re
10
  from PIL import Image
11
  from threading import Thread
12
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
13
- from transformers.image_utils import load_image
14
 
15
  #####################################
16
  # 1. Load Model & Processor
17
  #####################################
18
- MODEL_ID = "google/gemma-3-12b-it" # Example model ID (adjust to your needs)
19
 
20
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
21
  model = Gemma3ForConditionalGeneration.from_pretrained(
@@ -26,159 +22,110 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
26
  model.eval()
27
 
28
  #####################################
29
- # 2. Helper Function: Downsample Video
30
  #####################################
31
- def downsample_video(video_path, num_frames=10):
32
  """
33
- Downsamples the video file to `num_frames` evenly spaced frames.
34
- Each frame is converted to a PIL Image along with its timestamp.
35
  """
36
- vidcap = cv2.VideoCapture(video_path)
37
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
38
- fps = vidcap.get(cv2.CAP_PROP_FPS)
39
- frames = []
40
- if total_frames <= 0 or fps <= 0:
41
- vidcap.release()
42
- return frames
43
-
44
- frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
45
- for i in frame_indices:
46
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
47
- success, image = vidcap.read()
48
- if success:
49
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
- pil_image = Image.fromarray(image)
51
- timestamp = round(i / fps, 2)
52
- frames.append((pil_image, timestamp))
53
- vidcap.release()
54
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  #####################################
57
- # 3. The Inference Function
58
  #####################################
59
- @spaces.GPU
60
- def video_inference(video_file, duration):
61
  """
62
- - Takes a recorded video file and a chosen duration (string).
63
- - Downsamples the video, passes frames to the model for inference.
64
- - Returns model-generated text + a bar chart based on the text.
65
  """
66
- if video_file is None:
67
- return "No video provided.", None
68
-
69
- # 3.1: Downsample the recorded video
70
- frames = downsample_video(video_file)
71
  if not frames:
72
- return "Could not read frames from video.", None
73
-
74
- # 3.2: Construct prompt
75
- messages = [
76
- {
77
- "role": "user",
78
- "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
79
- }
80
- ]
81
-
82
- # Add frames (with timestamp) to the messages
83
  for (image, ts) in frames:
84
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
85
  messages[0]["content"].append({"type": "image", "image": image})
86
-
87
- # Prepare final prompt
88
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
-
90
- # Gather images for the model
91
  frame_images = [img for (img, _) in frames]
92
-
93
  inputs = processor(
94
  text=[prompt],
95
  images=frame_images,
96
  return_tensors="pt",
97
  padding=True
98
  ).to("cuda")
99
-
100
- # 3.3: Generate text output (streaming)
101
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
102
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
103
-
104
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
105
  thread.start()
106
-
107
  generated_text = ""
108
  for new_text in streamer:
109
  generated_text += new_text
110
  time.sleep(0.01)
111
-
112
- # 3.4: Build a bar chart based on top keywords from the generated text
113
- # (Naive approach: frequency of top 5 words)
114
- words = re.findall(r'\w+', generated_text.lower())
115
- freq = {}
116
- for w in words:
117
- freq[w] = freq.get(w, 0) + 1
118
-
119
- # Sort words by frequency (descending)
120
- sorted_items = sorted(freq.items(), key=lambda x: x[1], reverse=True)
121
- # Pick top 5 words (if fewer than 5, pick all)
122
- top5 = sorted_items[:5]
123
-
124
- if not top5:
125
- # If there's no text or no valid words, return no chart
126
- return generated_text, None
127
-
128
- categories = [item[0] for item in top5]
129
- values = [item[1] for item in top5]
130
-
131
- # Create the figure
132
- fig, ax = plt.subplots()
133
- colors = ["#4B0082", "#9370DB", "#8A2BE2", "#DA70D6", "#BA55D3"] # Purple-ish palette
134
- # Make sure we have enough colors for the number of bars
135
- color_list = colors[: len(categories)]
136
-
137
- ax.bar(categories, values, color=color_list)
138
- ax.set_title("Top Keywords in Generated Description")
139
- ax.set_ylabel("Frequency")
140
- ax.set_xlabel("Keyword")
141
-
142
- # Return the final text and the figure
143
- return generated_text, fig
144
 
145
  #####################################
146
- # 4. Build a Professional Gradio UI
147
  #####################################
148
- def build_app():
149
  with gr.Blocks() as demo:
150
- gr.Markdown("""
151
- # **Gemma-3 (Example) Live Video Analysis**
152
- Record a video (from webcam or file), then click **Stop**.
153
- Next, click **Analyze** to run the model and see textual + chart outputs.
154
- """)
155
-
156
- with gr.Row():
157
- with gr.Column():
158
- duration = gr.Radio(
159
- choices=["5", "10", "20", "30"],
160
- value="5",
161
- label="Suggested Recording Duration (seconds)",
162
- info="Select how long you plan to record before pressing Stop."
163
- )
164
- # For older Gradio versions, avoid `source="webcam"`.
165
- video = gr.Video(
166
- label="Webcam Recording (press the Record button, then Stop)",
167
- format="mp4"
168
- )
169
- analyze_btn = gr.Button("Analyze", variant="primary")
170
- with gr.Column():
171
- output_text = gr.Textbox(label="Model Output")
172
- output_plot = gr.Plot(label="Analytics Chart")
173
-
174
- analyze_btn.click(
175
- fn=video_inference,
176
- inputs=[video, duration],
177
- outputs=[output_text, output_plot]
178
- )
179
-
180
  return demo
181
 
182
  if __name__ == "__main__":
183
- app = build_app()
184
  app.launch(debug=True)
 
2
  import torch
3
  import numpy as np
4
  import cv2
 
 
 
5
  import time
6
  import re
7
  from PIL import Image
8
  from threading import Thread
9
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
 
10
 
11
  #####################################
12
  # 1. Load Model & Processor
13
  #####################################
14
+ MODEL_ID = "google/gemma-3-12b-it" # Adjust to your needs
15
 
16
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
17
  model = Gemma3ForConditionalGeneration.from_pretrained(
 
22
  model.eval()
23
 
24
  #####################################
25
+ # 2. Helper Function: Capture Live Frames
26
  #####################################
27
+ def capture_live_frames(duration=5, num_frames=10):
28
  """
29
+ Captures live frames from the default webcam for a specified duration.
30
+ Returns a list of (PIL image, timestamp) tuples.
31
  """
32
+ cap = cv2.VideoCapture(0) # Use default webcam
33
+ if not cap.isOpened():
34
+ return []
35
+
36
+ # Try to get FPS, default to 30 if not available.
37
+ fps = cap.get(cv2.CAP_PROP_FPS)
38
+ if fps <= 0:
39
+ fps = 30
40
+ total_frames_to_capture = int(duration * fps)
41
+ frame_indices = np.linspace(0, total_frames_to_capture - 1, num_frames, dtype=int)
42
+
43
+ captured_frames = []
44
+ frame_count = 0
45
+ start_time = time.time()
46
+
47
+ while frame_count < total_frames_to_capture:
48
+ ret, frame = cap.read()
49
+ if not ret:
50
+ break
51
+ if frame_count in frame_indices:
52
+ # Convert BGR (OpenCV) to RGB (PIL)
53
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
54
+ pil_image = Image.fromarray(frame_rgb)
55
+ timestamp = round(frame_count / fps, 2)
56
+ captured_frames.append((pil_image, timestamp))
57
+ frame_count += 1
58
+ # Break if the elapsed time exceeds the duration.
59
+ if time.time() - start_time > duration:
60
+ break
61
+ cap.release()
62
+ return captured_frames
63
 
64
  #####################################
65
+ # 3. Live Inference Function
66
  #####################################
67
+ def live_inference(duration=5):
 
68
  """
69
+ Captures live frames from the webcam, builds a prompt, and returns the generated text.
 
 
70
  """
71
+ frames = capture_live_frames(duration=duration, num_frames=10)
 
 
 
 
72
  if not frames:
73
+ return "Could not capture live frames from the webcam."
74
+
75
+ # Build prompt using the captured frames.
76
+ messages = [{
77
+ "role": "user",
78
+ "content": [{"type": "text", "text": "Please describe what's happening in this live video."}]
79
+ }]
 
 
 
 
80
  for (image, ts) in frames:
81
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
82
  messages[0]["content"].append({"type": "image", "image": image})
83
+
 
84
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
85
  frame_images = [img for (img, _) in frames]
86
+
87
  inputs = processor(
88
  text=[prompt],
89
  images=frame_images,
90
  return_tensors="pt",
91
  padding=True
92
  ).to("cuda")
93
+
94
+ # Generate text using streaming.
95
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
96
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
97
+
98
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
99
  thread.start()
100
+
101
  generated_text = ""
102
  for new_text in streamer:
103
  generated_text += new_text
104
  time.sleep(0.01)
105
+
106
+ return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  #####################################
109
+ # 4. Build Gradio Live App
110
  #####################################
111
+ def build_live_app():
112
  with gr.Blocks() as demo:
113
+ gr.Markdown("# **Live Video Analysis**\n\nPress **Start** to capture a few seconds of live video from your webcam and analyze the content.")
114
+ with gr.Column():
115
+ duration_input = gr.Number(label="Capture Duration (seconds)", value=5, precision=0)
116
+ start_btn = gr.Button("Start")
117
+ output_text = gr.Textbox(label="Model Output")
118
+ restart_btn = gr.Button("Start Again", visible=False)
119
+
120
+ # This function triggers the live inference and also makes the restart button visible.
121
+ def start_inference(duration):
122
+ text = live_inference(duration)
123
+ return text, gr.update(visible=True)
124
+
125
+ start_btn.click(fn=start_inference, inputs=duration_input, outputs=[output_text, restart_btn])
126
+ restart_btn.click(fn=start_inference, inputs=duration_input, outputs=[output_text, restart_btn])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  return demo
128
 
129
  if __name__ == "__main__":
130
+ app = build_live_app()
131
  app.launch(debug=True)