import torch import numpy as np from PIL import Image from torchvision import transforms from config import LABELS_TO_IDS from utils.vis_utils import visualize_mask_with_overlay def load_model(task, version): from config import SAPIENS_LITE_MODELS_PATH import os try: model_path = SAPIENS_LITE_MODELS_PATH[task][version] if not os.path.exists(model_path): print(f"Advertencia: El archivo del modelo no existe en {model_path}") return None, None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True model = torch.jit.load(model_path) model.eval() model.to(device) return model, device except KeyError as e: print(f"Error: Tarea o versión inválida. {e}") return None, None def process_image_or_video(input_data, task='seg', version='sapiens_0.3b'): # Configurar el modelo model, device = load_model(task, version) if model is None or device is None: return None # Configurar la transformación de entrada transform_fn = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Función para procesar un solo frame def process_frame(frame): if isinstance(frame, np.ndarray): frame = Image.fromarray(frame) if frame.mode == 'RGBA': frame = frame.convert('RGB') input_tensor = transform_fn(frame).unsqueeze(0).to(device) with torch.inference_mode(): output = model(input_tensor) output = torch.nn.functional.interpolate(output, size=(frame.height, frame.width), mode="bilinear", align_corners=False) _, preds = torch.max(output, 1) mask = preds.squeeze(0).cpu().numpy() mask_image = Image.fromarray(mask.astype("uint8")) blended_image = visualize_mask_with_overlay(frame, mask_image, LABELS_TO_IDS, alpha=0.5) return blended_image # Procesar imagen o video if isinstance(input_data, np.ndarray): # Video frame return process_frame(input_data) elif isinstance(input_data, Image.Image): # Imagen return process_frame(input_data) else: print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.") return None