bla commited on
Commit
7715eea
·
verified ·
1 Parent(s): a7d9ced

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -78
app.py CHANGED
@@ -127,10 +127,10 @@ custom_css = """
127
 
128
  # Available model sizes
129
  DETECTION_MODELS = {
130
- "tiny": "yoloworld-t",
131
- "small": "yoloworld-s",
132
- "base": "yoloworld-b",
133
- "large": "yoloworld-l",
134
  }
135
 
136
  SEGMENTATION_MODELS = {
@@ -147,11 +147,24 @@ class YOLOWorldDetector:
147
  self.model_name = DETECTION_MODELS[model_size]
148
 
149
  print(f"Loading {self.model_name} on {self.device}...")
150
- self.model = AutoModel.from_pretrained(f"deepdatacloud/{self.model_name}",
151
- trust_remote_code=True)
152
- self.model.to(self.device)
153
- self.processor = AutoProcessor.from_pretrained(f"deepdatacloud/{self.model_name}")
154
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Segmentation models
157
  self.seg_models = {}
@@ -162,13 +175,18 @@ class YOLOWorldDetector:
162
  self.model_name = DETECTION_MODELS[model_size]
163
 
164
  print(f"Loading {self.model_name} on {self.device}...")
165
- self.model = AutoModel.from_pretrained(f"deepdatacloud/{self.model_name}",
166
- trust_remote_code=True)
167
- self.model.to(self.device)
168
- self.processor = AutoProcessor.from_pretrained(f"deepdatacloud/{self.model_name}")
169
- print("Model loaded successfully!")
 
 
 
 
 
170
  return f"Using {self.model_name} model"
171
-
172
  def load_seg_model(self, model_name):
173
  if model_name not in self.seg_models:
174
  print(f"Loading segmentation model {model_name}...")
@@ -180,75 +198,48 @@ class YOLOWorldDetector:
180
  if image is None:
181
  return None, "No image provided"
182
 
 
 
 
 
 
 
 
 
183
  if isinstance(image, str):
184
- image = Image.open(image).convert("RGB")
185
  elif isinstance(image, np.ndarray):
186
- image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
187
-
188
- # Process inputs
189
- inputs = self.processor(text=text_prompt, images=image, return_tensors="pt")
190
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
191
 
192
  # Run inference
193
- with torch.no_grad():
194
- outputs = self.model(**inputs)
195
-
196
- # Process results
197
- target_sizes = torch.tensor([image.size[::-1]], device=self.device)
198
- results = self.processor.post_process_object_detection(
199
- outputs=outputs,
200
- target_sizes=target_sizes,
201
- threshold=confidence_threshold
202
- )[0]
203
-
204
- # Convert image to numpy for drawing
205
- image_np = np.array(image)
206
-
207
- # Draw bounding boxes
208
- for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
209
- box = box.cpu().numpy().astype(int)
210
- score = score.cpu().item()
211
- label = label.cpu().item()
212
-
213
- # Get class name from model's config
214
- class_name = f"{text_prompt.split(',')[label] if label < len(text_prompt.split(',')) else 'Object'}: {score:.2f}"
215
-
216
- # Draw rectangle
217
- cv2.rectangle(
218
- image_np,
219
- (box[0], box[1]),
220
- (box[2], box[3]),
221
- (0, 255, 0),
222
- 2
223
- )
224
-
225
- # Draw label background
226
- text_size = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
227
- cv2.rectangle(
228
- image_np,
229
- (box[0], box[1] - text_size[1] - 5),
230
- (box[0] + text_size[0], box[1]),
231
- (0, 255, 0),
232
- -1
233
  )
234
-
235
- # Draw text
236
- cv2.putText(
237
- image_np,
238
- class_name,
239
- (box[0], box[1] - 5),
240
- cv2.FONT_HERSHEY_SIMPLEX,
241
- 0.5,
242
- (0, 0, 0),
243
- 2
244
  )
245
 
 
 
 
246
  # Convert results to JSON format (percentages)
247
  json_results = []
248
- img_height, img_width = image_np.shape[:2]
249
 
250
- for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
251
- box = box.cpu().numpy()
 
 
 
252
  x1, y1, x2, y2 = box
253
 
254
  json_results.append({
@@ -258,12 +249,12 @@ class YOLOWorldDetector:
258
  "width": ((x2 - x1) / img_width) * 100,
259
  "height": ((y2 - y1) / img_height) * 100
260
  },
261
- "score": float(score.cpu().item()),
262
- "label": int(label.cpu().item()),
263
- "label_text": text_prompt.split(',')[label] if label < len(text_prompt.split(',')) else 'Object'
264
  })
265
 
266
- return cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR), json_results
267
 
268
  def segment(self, image, model_name, confidence_threshold=0.3):
269
  if image is None:
 
127
 
128
  # Available model sizes
129
  DETECTION_MODELS = {
130
+ "small": "yolov8s-worldv2.pt",
131
+ "medium": "yolov8m-worldv2.pt",
132
+ "large": "yolov8l-worldv2.pt",
133
+ "xlarge": "yolov8x-worldv2.pt",
134
  }
135
 
136
  SEGMENTATION_MODELS = {
 
147
  self.model_name = DETECTION_MODELS[model_size]
148
 
149
  print(f"Loading {self.model_name} on {self.device}...")
150
+ try:
151
+ # Use the correct repository ID
152
+ self.model = AutoModel.from_pretrained(
153
+ f"IDEA-Research/{self.model_name}",
154
+ trust_remote_code=True
155
+ )
156
+ self.model.to(self.device)
157
+ self.processor = AutoProcessor.from_pretrained(
158
+ f"IDEA-Research/{self.model_name}"
159
+ )
160
+ print("Model loaded successfully!")
161
+ except Exception as e:
162
+ print(f"Error loading model: {e}")
163
+ print("Falling back to YOLOv8 for detection...")
164
+ # Fallback to YOLOv8 if YOLOWorld fails to load
165
+ self.model = None
166
+ self.processor = None
167
+ self.fallback_model = YOLO("yolov8n.pt")
168
 
169
  # Segmentation models
170
  self.seg_models = {}
 
175
  self.model_name = DETECTION_MODELS[model_size]
176
 
177
  print(f"Loading {self.model_name} on {self.device}...")
178
+ try:
179
+ # Use Ultralytics YOLOWorld model
180
+ from ultralytics import YOLOWorld
181
+ self.model = YOLOWorld(self.model_name)
182
+ print("Model loaded successfully!")
183
+ except Exception as e:
184
+ print(f"Error loading YOLOWorld model: {e}")
185
+ print("Falling back to standard YOLOv8 for detection...")
186
+ # Fallback to YOLOv8 if YOLOWorld fails to load
187
+ self.model = YOLO("yolov8n.pt")
188
  return f"Using {self.model_name} model"
189
+
190
  def load_seg_model(self, model_name):
191
  if model_name not in self.seg_models:
192
  print(f"Loading segmentation model {model_name}...")
 
198
  if image is None:
199
  return None, "No image provided"
200
 
201
+ try:
202
+ # Check if we're using YOLOWorld or standard YOLO
203
+ from ultralytics import YOLOWorld
204
+ is_yoloworld = isinstance(self.model, YOLOWorld)
205
+ except:
206
+ is_yoloworld = False
207
+
208
+ # Process the image
209
  if isinstance(image, str):
210
+ img_for_json = cv2.imread(image)
211
  elif isinstance(image, np.ndarray):
212
+ img_for_json = image.copy()
 
 
 
 
213
 
214
  # Run inference
215
+ if is_yoloworld:
216
+ # YOLOWorld supports text prompts
217
+ results = self.model.predict(
218
+ source=image,
219
+ classes=text_prompt.split(','),
220
+ conf=confidence_threshold,
221
+ verbose=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
+ else:
224
+ # Standard YOLO doesn't use text prompts
225
+ results = self.model.predict(
226
+ source=image,
227
+ conf=confidence_threshold,
228
+ verbose=False
 
 
 
 
229
  )
230
 
231
+ # Get the plotted result
232
+ res_plotted = results[0].plot()
233
+
234
  # Convert results to JSON format (percentages)
235
  json_results = []
236
+ img_height, img_width = img_for_json.shape[:2]
237
 
238
+ for i, (box, cls, conf) in enumerate(zip(
239
+ results[0].boxes.xyxy.cpu().numpy(),
240
+ results[0].boxes.cls.cpu().numpy(),
241
+ results[0].boxes.conf.cpu().numpy()
242
+ )):
243
  x1, y1, x2, y2 = box
244
 
245
  json_results.append({
 
249
  "width": ((x2 - x1) / img_width) * 100,
250
  "height": ((y2 - y1) / img_height) * 100
251
  },
252
+ "score": float(conf),
253
+ "label": int(cls),
254
+ "label_text": results[0].names[int(cls)]
255
  })
256
 
257
+ return res_plotted, json_results
258
 
259
  def segment(self, image, model_name, confidence_threshold=0.3):
260
  if image is None: