import torch import numpy as np from PIL import Image, ImageDraw from torchvision import transforms from config import SAPIENS_LITE_MODELS_PATH def load_model(task, version): try: model_path = SAPIENS_LITE_MODELS_PATH[task][version] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True model = torch.jit.load(model_path) model.eval() model.to(device) return model, device except KeyError as e: print(f"Error: Tarea o versión inválida. {e}") return None, None def preprocess_image(image, input_shape): img = image.resize((input_shape[2], input_shape[1])) img = np.array(img).transpose(2, 0, 1) img = torch.from_numpy(img).float() img = img[[2, 1, 0], ...] # RGB to BGR mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1) std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1) img = (img - mean) / std return img.unsqueeze(0) def udp_decode(heatmap, img_size, heatmap_size): # This is a simplified version. You might need to implement the full UDP decode logic h, w = heatmap_size keypoints = np.zeros((heatmap.shape[0], 2)) keypoint_scores = np.zeros(heatmap.shape[0]) for i in range(heatmap.shape[0]): hm = heatmap[i] idx = np.unravel_index(np.argmax(hm), hm.shape) keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h] keypoint_scores[i] = hm[idx] return keypoints, keypoint_scores def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3): draw = ImageDraw.Draw(image) for (x, y), score in zip(keypoints, keypoint_scores): if score > threshold: draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red') return image def process_image_or_video(input_data, task='pose', version='sapiens_1b'): model, device = load_model(task, version) if model is None or device is None: return None input_shape = (3, 1024, 768) def process_frame(frame): if isinstance(frame, np.ndarray): frame = Image.fromarray(frame) if frame.mode == 'RGBA': frame = frame.convert('RGB') img = preprocess_image(frame, input_shape) with torch.no_grad(): heatmap = model(img.to(device)) keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(), input_shape[1:], (input_shape[1] // 4, input_shape[2] // 4)) scale_x = frame.width / input_shape[2] scale_y = frame.height / input_shape[1] keypoints[:, 0] *= scale_x keypoints[:, 1] *= scale_y pose_image = visualize_keypoints(frame, keypoints, keypoint_scores) return pose_image if isinstance(input_data, np.ndarray): # Video frame return process_frame(input_data) elif isinstance(input_data, Image.Image): # Imagen return process_frame(input_data) else: print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.") return None