Spaces:
Build error
Build error
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw | |
from torchvision import transforms | |
from config import SAPIENS_LITE_MODELS_PATH | |
def load_model(task, version): | |
try: | |
model_path = SAPIENS_LITE_MODELS_PATH[task][version] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
model = torch.jit.load(model_path) | |
model.eval() | |
model.to(device) | |
return model, device | |
except KeyError as e: | |
print(f"Error: Tarea o versión inválida. {e}") | |
return None, None | |
def preprocess_image(image, input_shape): | |
img = image.resize((input_shape[2], input_shape[1])) | |
img = np.array(img).transpose(2, 0, 1) | |
img = torch.from_numpy(img).float() | |
img = img[[2, 1, 0], ...] # RGB to BGR | |
mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1) | |
std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1) | |
img = (img - mean) / std | |
return img.unsqueeze(0) | |
def udp_decode(heatmap, img_size, heatmap_size): | |
# This is a simplified version. You might need to implement the full UDP decode logic | |
h, w = heatmap_size | |
keypoints = np.zeros((heatmap.shape[0], 2)) | |
keypoint_scores = np.zeros(heatmap.shape[0]) | |
for i in range(heatmap.shape[0]): | |
hm = heatmap[i] | |
idx = np.unravel_index(np.argmax(hm), hm.shape) | |
keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h] | |
keypoint_scores[i] = hm[idx] | |
return keypoints, keypoint_scores | |
def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3): | |
draw = ImageDraw.Draw(image) | |
for (x, y), score in zip(keypoints, keypoint_scores): | |
if score > threshold: | |
draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red') | |
return image | |
def process_image_or_video(input_data, task='pose', version='sapiens_1b'): | |
model, device = load_model(task, version) | |
if model is None or device is None: | |
return None | |
input_shape = (3, 1024, 768) | |
def process_frame(frame): | |
if isinstance(frame, np.ndarray): | |
frame = Image.fromarray(frame) | |
if frame.mode == 'RGBA': | |
frame = frame.convert('RGB') | |
img = preprocess_image(frame, input_shape) | |
with torch.no_grad(): | |
heatmap = model(img.to(device)) | |
keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(), | |
input_shape[1:], | |
(input_shape[1] // 4, input_shape[2] // 4)) | |
scale_x = frame.width / input_shape[2] | |
scale_y = frame.height / input_shape[1] | |
keypoints[:, 0] *= scale_x | |
keypoints[:, 1] *= scale_y | |
pose_image = visualize_keypoints(frame, keypoints, keypoint_scores) | |
return pose_image | |
if isinstance(input_data, np.ndarray): # Video frame | |
return process_frame(input_data) | |
elif isinstance(input_data, Image.Image): # Imagen | |
return process_frame(input_data) | |
else: | |
print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.") | |
return None |