Malaji71 commited on
Commit
659c8b0
verified
1 Parent(s): d99e44d

Update optimizer.py

Browse files
Files changed (1) hide show
  1. optimizer.py +15 -19
optimizer.py CHANGED
@@ -190,27 +190,23 @@ class UltraSupremeOptimizer:
190
  def run_clip_inference(self, image: Image.Image) -> Tuple[str, str, str]:
191
  """Solo la inferencia CLIP usa GPU"""
192
  try:
193
- # Preparar modelos para GPU
194
- self._prepare_models_for_gpu()
195
-
196
- # Usar autocast para manejar precisi贸n mixta
197
- with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16):
198
- # Convertir imagen a tensor y asegurar que est茅 en half precision
199
- from torchvision import transforms
200
- preprocess = transforms.Compose([
201
- transforms.Resize((224, 224)),
202
- transforms.ToTensor(),
203
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
204
- std=[0.26862954, 0.26130258, 0.27577711]),
205
- ])
206
 
207
- # Procesar imagen manualmente para controlar la precisi贸n
208
- image_tensor = preprocess(image).unsqueeze(0).half().to("cuda")
209
 
210
- # Ejecutar inferencias con manejo especial
211
- full_prompt = self._safe_interrogate(image, 'interrogate')
212
- clip_fast = self._safe_interrogate(image, 'interrogate_fast')
213
- clip_classic = self._safe_interrogate(image, 'interrogate_classic')
 
 
 
 
 
 
214
 
215
  return full_prompt, clip_fast, clip_classic
216
 
 
190
  def run_clip_inference(self, image: Image.Image) -> Tuple[str, str, str]:
191
  """Solo la inferencia CLIP usa GPU"""
192
  try:
193
+ # NO usar half precision - mantener float32 para compatibilidad
194
+ if hasattr(self.interrogator, 'caption_model'):
195
+ self.interrogator.caption_model = self.interrogator.caption_model.to("cuda")
 
 
 
 
 
 
 
 
 
 
196
 
197
+ if hasattr(self.interrogator, 'clip_model'):
198
+ self.interrogator.clip_model = self.interrogator.clip_model.to("cuda")
199
 
200
+ if hasattr(self.interrogator, 'blip_model'):
201
+ self.interrogator.blip_model = self.interrogator.blip_model.to("cuda")
202
+
203
+ self.interrogator.config.device = "cuda"
204
+ logger.info("Models moved to GPU with float32 (full precision)")
205
+
206
+ # Ejecutar inferencias sin autocast para evitar problemas de half precision
207
+ full_prompt = self.interrogator.interrogate(image)
208
+ clip_fast = self.interrogator.interrogate_fast(image)
209
+ clip_classic = self.interrogator.interrogate_classic(image)
210
 
211
  return full_prompt, clip_fast, clip_classic
212