Spaces:
Build error
Build error
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
from config import LABELS_TO_IDS | |
from utils.vis_utils import visualize_mask_with_overlay | |
def load_model(task, version): | |
from config import SAPIENS_LITE_MODELS_PATH | |
import os | |
try: | |
model_path = SAPIENS_LITE_MODELS_PATH[task][version] | |
if not os.path.exists(model_path): | |
print(f"Advertencia: El archivo del modelo no existe en {model_path}") | |
return None, None | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
model = torch.jit.load(model_path) | |
model.eval() | |
model.to(device) | |
return model, device | |
except KeyError as e: | |
print(f"Error: Tarea o versi贸n inv谩lida. {e}") | |
return None, None | |
def process_image_or_video(input_data, task='seg', version='sapiens_0.3b'): | |
# Configurar el modelo | |
model, device = load_model(task, version) | |
if model is None or device is None: | |
return None | |
# Configurar la transformaci贸n de entrada | |
transform_fn = transforms.Compose([ | |
transforms.Resize((1024, 768)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Funci贸n para procesar un solo frame | |
def process_frame(frame): | |
if isinstance(frame, np.ndarray): | |
frame = Image.fromarray(frame) | |
if frame.mode == 'RGBA': | |
frame = frame.convert('RGB') | |
input_tensor = transform_fn(frame).unsqueeze(0).to(device) | |
with torch.inference_mode(): | |
output = model(input_tensor) | |
output = torch.nn.functional.interpolate(output, size=(frame.height, frame.width), mode="bilinear", align_corners=False) | |
_, preds = torch.max(output, 1) | |
mask = preds.squeeze(0).cpu().numpy() | |
mask_image = Image.fromarray(mask.astype("uint8")) | |
blended_image = visualize_mask_with_overlay(frame, mask_image, LABELS_TO_IDS, alpha=0.5) | |
return blended_image | |
# Procesar imagen o video | |
if isinstance(input_data, np.ndarray): # Video frame | |
return process_frame(input_data) | |
elif isinstance(input_data, Image.Image): # Imagen | |
return process_frame(input_data) | |
else: | |
print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.") | |
return None |