prithivMLmods commited on
Commit
67f9a49
·
verified ·
1 Parent(s): 5935cac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -43
app.py CHANGED
@@ -6,15 +6,16 @@ import matplotlib.pyplot as plt
6
  import random
7
  import spaces
8
  import time
 
9
  from PIL import Image
10
  from threading import Thread
11
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
12
  from transformers.image_utils import load_image
13
 
14
  #####################################
15
- # 1. Load Gemma3 Model & Processor
16
  #####################################
17
- MODEL_ID = "google/gemma-3-12b-it" # Example placeholder
18
 
19
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
20
  model = Gemma3ForConditionalGeneration.from_pretrained(
@@ -52,29 +53,6 @@ def downsample_video(video_path, num_frames=10):
52
  vidcap.release()
53
  return frames
54
 
55
- #####################################
56
- # 2.5: Parse Categories from Model Output
57
- #####################################
58
- def parse_inferred_categories(generated_text):
59
- """
60
- A naive parser that looks for lines starting with 'Category:'
61
- and collects the text after that as the category name.
62
- Example lines in model output:
63
- Category: Nutrition
64
- Category: Outdoor Scenes
65
- Returns a list of category strings.
66
- """
67
- categories = []
68
- for line in generated_text.split("\n"):
69
- line = line.strip()
70
- # Check if the line starts with 'Category:' (case-insensitive)
71
- if line.lower().startswith("category:"):
72
- # Extract everything after 'Category:'
73
- cat = line.split(":", 1)[1].strip()
74
- if cat:
75
- categories.append(cat)
76
- return categories
77
-
78
  #####################################
79
  # 3. The Inference Function
80
  #####################################
@@ -82,8 +60,8 @@ def parse_inferred_categories(generated_text):
82
  def video_inference(video_file, duration):
83
  """
84
  - Takes a recorded video file and a chosen duration (string).
85
- - Downsamples the video, passes frames to the Gemma3 model for inference.
86
- - Returns model-generated text + a bar chart with categories derived from that text.
87
  """
88
  if video_file is None:
89
  return "No video provided.", None
@@ -100,6 +78,7 @@ def video_inference(video_file, duration):
100
  "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
101
  }
102
  ]
 
103
  # Add frames (with timestamp) to the messages
104
  for (image, ts) in frames:
105
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
@@ -108,7 +87,7 @@ def video_inference(video_file, duration):
108
  # Prepare final prompt
109
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
 
111
- # Collect images for model
112
  frame_images = [img for (img, _) in frames]
113
 
114
  inputs = processor(
@@ -130,23 +109,37 @@ def video_inference(video_file, duration):
130
  generated_text += new_text
131
  time.sleep(0.01)
132
 
133
- # 3.4: Parse categories from model output
134
- categories = parse_inferred_categories(generated_text)
135
- # If no categories were found, use fallback
136
- if not categories:
137
- categories = ["Category A", "Category B", "Category C"]
 
 
 
 
 
 
138
 
139
- # Create dummy values for each category
140
- values = [random.randint(1, 10) for _ in categories]
 
141
 
142
- # 3.5: Create bar chart
 
 
 
143
  fig, ax = plt.subplots()
144
- ax.bar(categories, values, color=["#4B0082", "#9370DB", "#4B0082"]*(len(categories)//3+1))
145
- ax.set_title("Inferred Categories from Model Output")
146
- ax.set_ylabel("Value")
147
- ax.set_xlabel("Categories")
148
- plt.xticks(rotation=30, ha="right")
 
 
 
149
 
 
150
  return generated_text, fig
151
 
152
  #####################################
@@ -155,7 +148,7 @@ def video_inference(video_file, duration):
155
  def build_app():
156
  with gr.Blocks() as demo:
157
  gr.Markdown("""
158
- # **Gemma3 (or Qwen2.5-VL) Live Video Analysis**
159
  Record a video (from webcam or file), then click **Stop**.
160
  Next, click **Analyze** to run the model and see textual + chart outputs.
161
  """)
@@ -168,8 +161,9 @@ def build_app():
168
  label="Suggested Recording Duration (seconds)",
169
  info="Select how long you plan to record before pressing Stop."
170
  )
 
171
  video = gr.Video(
172
- label="Webcam Recording (press Record, then Stop)",
173
  format="mp4"
174
  )
175
  analyze_btn = gr.Button("Analyze", variant="primary")
 
6
  import random
7
  import spaces
8
  import time
9
+ import re
10
  from PIL import Image
11
  from threading import Thread
12
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
13
  from transformers.image_utils import load_image
14
 
15
  #####################################
16
+ # 1. Load Model & Processor
17
  #####################################
18
+ MODEL_ID = "google/gemma-3-12b-it" # Example model ID (adjust to your needs)
19
 
20
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
21
  model = Gemma3ForConditionalGeneration.from_pretrained(
 
53
  vidcap.release()
54
  return frames
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  #####################################
57
  # 3. The Inference Function
58
  #####################################
 
60
  def video_inference(video_file, duration):
61
  """
62
  - Takes a recorded video file and a chosen duration (string).
63
+ - Downsamples the video, passes frames to the model for inference.
64
+ - Returns model-generated text + a bar chart based on the text.
65
  """
66
  if video_file is None:
67
  return "No video provided.", None
 
78
  "content": [{"type": "text", "text": "Please describe what's happening in this video."}]
79
  }
80
  ]
81
+
82
  # Add frames (with timestamp) to the messages
83
  for (image, ts) in frames:
84
  messages[0]["content"].append({"type": "text", "text": f"Frame at {ts} seconds:"})
 
87
  # Prepare final prompt
88
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
 
90
+ # Gather images for the model
91
  frame_images = [img for (img, _) in frames]
92
 
93
  inputs = processor(
 
109
  generated_text += new_text
110
  time.sleep(0.01)
111
 
112
+ # 3.4: Build a bar chart based on top keywords from the generated text
113
+ # (Naive approach: frequency of top 5 words)
114
+ words = re.findall(r'\w+', generated_text.lower())
115
+ freq = {}
116
+ for w in words:
117
+ freq[w] = freq.get(w, 0) + 1
118
+
119
+ # Sort words by frequency (descending)
120
+ sorted_items = sorted(freq.items(), key=lambda x: x[1], reverse=True)
121
+ # Pick top 5 words (if fewer than 5, pick all)
122
+ top5 = sorted_items[:5]
123
 
124
+ if not top5:
125
+ # If there's no text or no valid words, return no chart
126
+ return generated_text, None
127
 
128
+ categories = [item[0] for item in top5]
129
+ values = [item[1] for item in top5]
130
+
131
+ # Create the figure
132
  fig, ax = plt.subplots()
133
+ colors = ["#4B0082", "#9370DB", "#8A2BE2", "#DA70D6", "#BA55D3"] # Purple-ish palette
134
+ # Make sure we have enough colors for the number of bars
135
+ color_list = colors[: len(categories)]
136
+
137
+ ax.bar(categories, values, color=color_list)
138
+ ax.set_title("Top Keywords in Generated Description")
139
+ ax.set_ylabel("Frequency")
140
+ ax.set_xlabel("Keyword")
141
 
142
+ # Return the final text and the figure
143
  return generated_text, fig
144
 
145
  #####################################
 
148
  def build_app():
149
  with gr.Blocks() as demo:
150
  gr.Markdown("""
151
+ # **Gemma-3 (Example) Live Video Analysis**
152
  Record a video (from webcam or file), then click **Stop**.
153
  Next, click **Analyze** to run the model and see textual + chart outputs.
154
  """)
 
161
  label="Suggested Recording Duration (seconds)",
162
  info="Select how long you plan to record before pressing Stop."
163
  )
164
+ # For older Gradio versions, avoid `source="webcam"`.
165
  video = gr.Video(
166
+ label="Webcam Recording (press the Record button, then Stop)",
167
  format="mp4"
168
  )
169
  analyze_btn = gr.Button("Analyze", variant="primary")