bla commited on
Commit
005ff51
·
verified ·
1 Parent(s): 939d2b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -35
app.py CHANGED
@@ -148,23 +148,18 @@ class YOLOWorldDetector:
148
 
149
  print(f"Loading {self.model_name} on {self.device}...")
150
  try:
151
- # Use the correct repository ID
152
- self.model = AutoModel.from_pretrained(
153
- f"IDEA-Research/{self.model_name}",
154
- trust_remote_code=True
155
- )
156
- self.model.to(self.device)
157
- self.processor = AutoProcessor.from_pretrained(
158
- f"IDEA-Research/{self.model_name}"
159
- )
160
- print("Model loaded successfully!")
161
  except Exception as e:
162
- print(f"Error loading model: {e}")
163
- print("Falling back to YOLOv8 for detection...")
164
- # Fallback to YOLOv8 if YOLOWorld fails to load
165
- self.model = None
166
- self.processor = None
167
- self.fallback_model = YOLO("yolov8n.pt")
168
 
169
  # Segmentation models
170
  self.seg_models = {}
@@ -176,15 +171,18 @@ class YOLOWorldDetector:
176
 
177
  print(f"Loading {self.model_name} on {self.device}...")
178
  try:
179
- # Use Ultralytics YOLOWorld model
180
  from ultralytics import YOLOWorld
181
  self.model = YOLOWorld(self.model_name)
182
- print("Model loaded successfully!")
 
183
  except Exception as e:
184
  print(f"Error loading YOLOWorld model: {e}")
185
  print("Falling back to standard YOLOv8 for detection...")
186
- # Fallback to YOLOv8 if YOLOWorld fails to load
187
  self.model = YOLO("yolov8n.pt")
 
 
188
  return f"Using {self.model_name} model"
189
 
190
  def load_seg_model(self, model_name):
@@ -198,28 +196,33 @@ class YOLOWorldDetector:
198
  if image is None:
199
  return None, "No image provided"
200
 
201
- try:
202
- # Check if we're using YOLOWorld or standard YOLO
203
- from ultralytics import YOLOWorld
204
- is_yoloworld = isinstance(self.model, YOLOWorld)
205
- except:
206
- is_yoloworld = False
207
-
208
  # Process the image
209
  if isinstance(image, str):
210
  img_for_json = cv2.imread(image)
211
  elif isinstance(image, np.ndarray):
212
  img_for_json = image.copy()
 
 
 
213
 
214
- # Run inference
215
- if is_yoloworld:
216
- # YOLOWorld supports text prompts
217
- results = self.model.predict(
218
- source=image,
219
- classes=text_prompt.split(','),
220
- conf=confidence_threshold,
221
- verbose=False
222
- )
 
 
 
 
 
 
 
 
 
223
  else:
224
  # Standard YOLO doesn't use text prompts
225
  results = self.model.predict(
 
148
 
149
  print(f"Loading {self.model_name} on {self.device}...")
150
  try:
151
+ # Try to load using Ultralytics YOLOWorld
152
+ from ultralytics import YOLOWorld
153
+ self.model = YOLOWorld(self.model_name)
154
+ self.model_type = "yoloworld"
155
+ print("YOLOWorld model loaded successfully!")
 
 
 
 
 
156
  except Exception as e:
157
+ print(f"Error loading YOLOWorld model: {e}")
158
+ print("Falling back to standard YOLOv8 for detection...")
159
+ # Fallback to YOLOv8
160
+ self.model = YOLO("yolov8n.pt")
161
+ self.model_type = "yolov8"
162
+ print("YOLOv8 fallback model loaded successfully!")
163
 
164
  # Segmentation models
165
  self.seg_models = {}
 
171
 
172
  print(f"Loading {self.model_name} on {self.device}...")
173
  try:
174
+ # Try to load using Ultralytics YOLOWorld
175
  from ultralytics import YOLOWorld
176
  self.model = YOLOWorld(self.model_name)
177
+ self.model_type = "yoloworld"
178
+ print("YOLOWorld model loaded successfully!")
179
  except Exception as e:
180
  print(f"Error loading YOLOWorld model: {e}")
181
  print("Falling back to standard YOLOv8 for detection...")
182
+ # Fallback to YOLOv8
183
  self.model = YOLO("yolov8n.pt")
184
+ self.model_type = "yolov8"
185
+ print("YOLOv8 fallback model loaded successfully!")
186
  return f"Using {self.model_name} model"
187
 
188
  def load_seg_model(self, model_name):
 
196
  if image is None:
197
  return None, "No image provided"
198
 
 
 
 
 
 
 
 
199
  # Process the image
200
  if isinstance(image, str):
201
  img_for_json = cv2.imread(image)
202
  elif isinstance(image, np.ndarray):
203
  img_for_json = image.copy()
204
+ else:
205
+ # Convert PIL Image to numpy array if needed
206
+ img_for_json = np.array(image)
207
 
208
+ # Run inference based on model type
209
+ if self.model_type == "yoloworld":
210
+ try:
211
+ # YOLOWorld supports text prompts
212
+ results = self.model.predict(
213
+ source=image,
214
+ classes=text_prompt.split(','),
215
+ conf=confidence_threshold,
216
+ verbose=False
217
+ )
218
+ except Exception as e:
219
+ print(f"Error during YOLOWorld inference: {e}")
220
+ # If YOLOWorld inference fails, try to use it as standard YOLO
221
+ results = self.model.predict(
222
+ source=image,
223
+ conf=confidence_threshold,
224
+ verbose=False
225
+ )
226
  else:
227
  # Standard YOLO doesn't use text prompts
228
  results = self.model.predict(