onuralpszr commited on
Commit
49f86c8
Β·
verified Β·
1 Parent(s): 47768b2

fix: πŸ› single input for detection and handle class names

Browse files

Signed-off-by: Onuralp SEZER <[email protected]>

Files changed (1) hide show
  1. app.py +36 -15
app.py CHANGED
@@ -54,6 +54,12 @@ model_id = "google/paligemma2-3b-pt-448"
54
  model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE)
55
  processor = PaliGemmaProcessor.from_pretrained(model_id)
56
 
 
 
 
 
 
 
57
  @spaces.GPU
58
  def paligemma_detection(input_image, input_text, max_new_tokens):
59
  model_inputs = processor(text=input_text,
@@ -70,13 +76,17 @@ def paligemma_detection(input_image, input_text, max_new_tokens):
70
 
71
 
72
 
73
- def annotate_image(result, resolution_wh, class_names, cv_image):
 
 
 
 
74
 
75
  detections = sv.Detections.from_lmm(
76
- sv.LMM.PALIGEMMA,
77
- result,
78
- resolution_wh=resolution_wh,
79
- classes=class_names.split(',')
80
  )
81
 
82
  annotated_image = BOX_ANNOTATOR.annotate(
@@ -98,17 +108,17 @@ def annotate_image(result, resolution_wh, class_names, cv_image):
98
  return annotated_image
99
 
100
 
101
- def process_image(input_image, input_text, class_names, max_new_tokens):
102
  cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
103
  result = paligemma_detection(input_image, input_text, max_new_tokens)
104
  annotated_image = annotate_image(result,
105
  (input_image.width, input_image.height),
106
- class_names, cv_image)
107
  return annotated_image, result
108
 
109
 
110
  @spaces.GPU
111
- def process_video(input_video, input_text, class_names, max_new_tokens, progress=gr.Progress(track_tqdm=True)):
112
  if not input_video:
113
  gr.Info("Please upload a video.")
114
  return None
@@ -117,6 +127,11 @@ def process_video(input_video, input_text, class_names, max_new_tokens, progress
117
  gr.Info("Please enter a text prompt.")
118
  return None
119
 
 
 
 
 
 
120
  name = generate_unique_name()
121
  frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
122
  create_directory(frame_directory_path)
@@ -146,7 +161,7 @@ def process_video(input_video, input_text, class_names, max_new_tokens, progress
146
  sv.LMM.PALIGEMMA,
147
  result,
148
  resolution_wh=(video_info.width, video_info.height),
149
- classes=class_names.split(',')
150
  )
151
 
152
  annotated_frame = BOX_ANNOTATOR.annotate(
@@ -177,15 +192,18 @@ with gr.Blocks() as app:
177
  with gr.Row():
178
  with gr.Column():
179
  input_image = gr.Image(type="pil", label="Input Image")
180
- input_text = gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog")
181
- class_names = gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names")
 
 
 
182
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.")
183
  with gr.Column():
184
  annotated_image = gr.Image(type="pil", label="Annotated Image")
185
  detection_result = gr.Textbox(label="Detection Result")
186
  gr.Button("Submit").click(
187
  fn=process_image,
188
- inputs=[input_image, input_text, class_names, max_new_tokens],
189
  outputs=[annotated_image, detection_result]
190
  )
191
 
@@ -193,8 +211,11 @@ with gr.Blocks() as app:
193
  with gr.Row():
194
  with gr.Column():
195
  input_video = gr.Video(label="Input Video")
196
- input_text = gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog")
197
- class_names = gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names")
 
 
 
198
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.")
199
  with gr.Column():
200
  output_video = gr.Video(label="Annotated Video")
@@ -202,7 +223,7 @@ with gr.Blocks() as app:
202
 
203
  gr.Button("Process Video").click(
204
  fn=process_video,
205
- inputs=[input_video, input_text, class_names, max_new_tokens],
206
  outputs=[output_video, detection_result]
207
  )
208
 
 
54
  model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE)
55
  processor = PaliGemmaProcessor.from_pretrained(model_id)
56
 
57
+ def parse_class_names(prompt):
58
+ if not prompt.lower().startswith('detect '):
59
+ return []
60
+ classes_text = prompt[7:].strip()
61
+ return [cls.strip() for cls in classes_text.split(';') if cls.strip()]
62
+
63
  @spaces.GPU
64
  def paligemma_detection(input_image, input_text, max_new_tokens):
65
  model_inputs = processor(text=input_text,
 
76
 
77
 
78
 
79
+ def annotate_image(result, resolution_wh, prompt, cv_image):
80
+ class_names = parse_class_names(prompt)
81
+ if not class_names:
82
+ gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format")
83
+ return cv_image
84
 
85
  detections = sv.Detections.from_lmm(
86
+ sv.LMM.PALIGEMMA,
87
+ result,
88
+ resolution_wh=resolution_wh,
89
+ classes=class_names
90
  )
91
 
92
  annotated_image = BOX_ANNOTATOR.annotate(
 
108
  return annotated_image
109
 
110
 
111
+ def process_image(input_image, input_text, max_new_tokens):
112
  cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
113
  result = paligemma_detection(input_image, input_text, max_new_tokens)
114
  annotated_image = annotate_image(result,
115
  (input_image.width, input_image.height),
116
+ input_text, cv_image)
117
  return annotated_image, result
118
 
119
 
120
  @spaces.GPU
121
+ def process_video(input_video, input_text, max_new_tokens, progress=gr.Progress(track_tqdm=True)):
122
  if not input_video:
123
  gr.Info("Please upload a video.")
124
  return None
 
127
  gr.Info("Please enter a text prompt.")
128
  return None
129
 
130
+ class_names = parse_class_names(input_text)
131
+ if not class_names:
132
+ gr.Warning("Invalid prompt format. Please use 'detect class1;class2;class3' format")
133
+ return None, None
134
+
135
  name = generate_unique_name()
136
  frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
137
  create_directory(frame_directory_path)
 
161
  sv.LMM.PALIGEMMA,
162
  result,
163
  resolution_wh=(video_info.width, video_info.height),
164
+ classes=class_names
165
  )
166
 
167
  annotated_frame = BOX_ANNOTATOR.annotate(
 
192
  with gr.Row():
193
  with gr.Column():
194
  input_image = gr.Image(type="pil", label="Input Image")
195
+ input_text = gr.Textbox(
196
+ lines=2,
197
+ placeholder="Enter prompt in format like this: detect person;dog;building",
198
+ label="Enter detection prompt"
199
+ )
200
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=10, label="Max New Tokens", info="Set to larger for longer generation.")
201
  with gr.Column():
202
  annotated_image = gr.Image(type="pil", label="Annotated Image")
203
  detection_result = gr.Textbox(label="Detection Result")
204
  gr.Button("Submit").click(
205
  fn=process_image,
206
+ inputs=[input_image, input_text, max_new_tokens],
207
  outputs=[annotated_image, detection_result]
208
  )
209
 
 
211
  with gr.Row():
212
  with gr.Column():
213
  input_video = gr.Video(label="Input Video")
214
+ input_text = gr.Textbox(
215
+ lines=2,
216
+ placeholder="Enter prompt in format like this: detect person;dog;building",
217
+ label="Enter detection prompt"
218
+ )
219
  max_new_tokens = gr.Slider(minimum=20, maximum=200, value=100, step=1, label="Max New Tokens", info="Set to larger for longer generation.")
220
  with gr.Column():
221
  output_video = gr.Video(label="Annotated Video")
 
223
 
224
  gr.Button("Process Video").click(
225
  fn=process_video,
226
+ inputs=[input_video, input_text, max_new_tokens],
227
  outputs=[output_video, detection_result]
228
  )
229