joselobenitezg's picture
update seg inference
a92daf2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from config import LABELS_TO_IDS
from utils.vis_utils import visualize_mask_with_overlay
def load_model(task, version):
from config import SAPIENS_LITE_MODELS_PATH
import os
try:
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
if not os.path.exists(model_path):
print(f"Advertencia: El archivo del modelo no existe en {model_path}")
return None, None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model = torch.jit.load(model_path)
model.eval()
model.to(device)
return model, device
except KeyError as e:
print(f"Error: Tarea o versi贸n inv谩lida. {e}")
return None, None
def process_image_or_video(input_data, task='seg', version='sapiens_0.3b'):
# Configurar el modelo
model, device = load_model(task, version)
if model is None or device is None:
return None
# Configurar la transformaci贸n de entrada
transform_fn = transforms.Compose([
transforms.Resize((1024, 768)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Funci贸n para procesar un solo frame
def process_frame(frame):
if isinstance(frame, np.ndarray):
frame = Image.fromarray(frame)
if frame.mode == 'RGBA':
frame = frame.convert('RGB')
input_tensor = transform_fn(frame).unsqueeze(0).to(device)
with torch.inference_mode():
output = model(input_tensor)
output = torch.nn.functional.interpolate(output, size=(frame.height, frame.width), mode="bilinear", align_corners=False)
_, preds = torch.max(output, 1)
mask = preds.squeeze(0).cpu().numpy()
mask_image = Image.fromarray(mask.astype("uint8"))
blended_image = visualize_mask_with_overlay(frame, mask_image, LABELS_TO_IDS, alpha=0.5)
return blended_image
# Procesar imagen o video
if isinstance(input_data, np.ndarray): # Video frame
return process_frame(input_data)
elif isinstance(input_data, Image.Image): # Imagen
return process_frame(input_data)
else:
print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
return None