prithivMLmods commited on
Commit
f022e05
·
verified ·
1 Parent(s): dec51b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -2,9 +2,9 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import cv2
 
5
  import time
6
  import re
7
- import spaces
8
  from PIL import Image
9
  from threading import Thread
10
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
@@ -12,7 +12,7 @@ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIter
12
  #####################################
13
  # 1. Load Model & Processor
14
  #####################################
15
- MODEL_ID = "google/gemma-3-12b-it" # Adjust to your needs
16
 
17
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
18
  model = Gemma3ForConditionalGeneration.from_pretrained(
@@ -23,18 +23,32 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
23
  model.eval()
24
 
25
  #####################################
26
- # 2. Helper Function: Capture Live Frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  #####################################
28
  def capture_live_frames(duration=5, num_frames=10):
29
  """
30
- Captures live frames from the default webcam for a specified duration.
31
- Returns a list of (PIL image, timestamp) tuples.
32
  """
33
- cap = cv2.VideoCapture(0) # Use default webcam
34
- if not cap.isOpened():
35
- return []
36
 
37
- # Try to get FPS, default to 30 if not available.
38
  fps = cap.get(cv2.CAP_PROP_FPS)
39
  if fps <= 0:
40
  fps = 30
@@ -50,20 +64,19 @@ def capture_live_frames(duration=5, num_frames=10):
50
  if not ret:
51
  break
52
  if frame_count in frame_indices:
53
- # Convert BGR (OpenCV) to RGB (PIL)
54
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55
  pil_image = Image.fromarray(frame_rgb)
56
  timestamp = round(frame_count / fps, 2)
57
  captured_frames.append((pil_image, timestamp))
58
  frame_count += 1
59
- # Break if the elapsed time exceeds the duration.
60
  if time.time() - start_time > duration:
61
  break
62
  cap.release()
63
  return captured_frames
64
 
65
  #####################################
66
- # 3. Live Inference Function
67
  #####################################
68
  @spaces.GPU
69
  def live_inference(duration=5):
@@ -74,7 +87,7 @@ def live_inference(duration=5):
74
  if not frames:
75
  return "Could not capture live frames from the webcam."
76
 
77
- # Build prompt using the captured frames.
78
  messages = [{
79
  "role": "user",
80
  "content": [{"type": "text", "text": "Please describe what's happening in this live video."}]
@@ -93,7 +106,7 @@ def live_inference(duration=5):
93
  padding=True
94
  ).to("cuda")
95
 
96
- # Generate text using streaming.
97
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
98
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
99
 
@@ -108,7 +121,7 @@ def live_inference(duration=5):
108
  return generated_text
109
 
110
  #####################################
111
- # 4. Build Gradio Live App
112
  #####################################
113
  def build_live_app():
114
  with gr.Blocks() as demo:
@@ -119,7 +132,7 @@ def build_live_app():
119
  output_text = gr.Textbox(label="Model Output")
120
  restart_btn = gr.Button("Start Again", visible=False)
121
 
122
- # This function triggers the live inference and also makes the restart button visible.
123
  def start_inference(duration):
124
  text = live_inference(duration)
125
  return text, gr.update(visible=True)
@@ -130,4 +143,4 @@ def build_live_app():
130
 
131
  if __name__ == "__main__":
132
  app = build_live_app()
133
- app.launch(debug=True)
 
2
  import torch
3
  import numpy as np
4
  import cv2
5
+ import spaces
6
  import time
7
  import re
 
8
  from PIL import Image
9
  from threading import Thread
10
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
 
12
  #####################################
13
  # 1. Load Model & Processor
14
  #####################################
15
+ MODEL_ID = "google/gemma-3-12b-it" # Adjust model ID as needed
16
 
17
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
18
  model = Gemma3ForConditionalGeneration.from_pretrained(
 
23
  model.eval()
24
 
25
  #####################################
26
+ # 2. Helper Function: Get a Working Camera
27
+ #####################################
28
+ def get_working_camera():
29
+ """
30
+ Tries camera indices 0, 1, and 2 until a working camera is found.
31
+ Returns the VideoCapture object or None if no camera can be opened.
32
+ """
33
+ for i in range(3):
34
+ cap = cv2.VideoCapture(i)
35
+ if cap.isOpened():
36
+ return cap
37
+ return None
38
+
39
+ #####################################
40
+ # 3. Helper Function: Capture Live Frames
41
  #####################################
42
  def capture_live_frames(duration=5, num_frames=10):
43
  """
44
+ Captures live frames from a working webcam for a specified duration.
45
+ Returns a list of (PIL Image, timestamp) tuples.
46
  """
47
+ cap = get_working_camera()
48
+ if cap is None:
49
+ return [] # No working camera found
50
 
51
+ # Try to get FPS; default to 30 if not available.
52
  fps = cap.get(cv2.CAP_PROP_FPS)
53
  if fps <= 0:
54
  fps = 30
 
64
  if not ret:
65
  break
66
  if frame_count in frame_indices:
67
+ # Convert from BGR to RGB for PIL
68
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
69
  pil_image = Image.fromarray(frame_rgb)
70
  timestamp = round(frame_count / fps, 2)
71
  captured_frames.append((pil_image, timestamp))
72
  frame_count += 1
 
73
  if time.time() - start_time > duration:
74
  break
75
  cap.release()
76
  return captured_frames
77
 
78
  #####################################
79
+ # 4. Live Inference Function
80
  #####################################
81
  @spaces.GPU
82
  def live_inference(duration=5):
 
87
  if not frames:
88
  return "Could not capture live frames from the webcam."
89
 
90
+ # Build prompt using captured frames and timestamps.
91
  messages = [{
92
  "role": "user",
93
  "content": [{"type": "text", "text": "Please describe what's happening in this live video."}]
 
106
  padding=True
107
  ).to("cuda")
108
 
109
+ # Generate text output using a streaming approach.
110
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
111
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
112
 
 
121
  return generated_text
122
 
123
  #####################################
124
+ # 5. Build Gradio Live App
125
  #####################################
126
  def build_live_app():
127
  with gr.Blocks() as demo:
 
132
  output_text = gr.Textbox(label="Model Output")
133
  restart_btn = gr.Button("Start Again", visible=False)
134
 
135
+ # Function to trigger live inference and reveal the restart button
136
  def start_inference(duration):
137
  text = live_inference(duration)
138
  return text, gr.update(visible=True)
 
143
 
144
  if __name__ == "__main__":
145
  app = build_live_app()
146
+ app.launch(debug=True, share=True)