Spaces:
Build error
Build error
# import torch | |
# import numpy as np | |
# from PIL import Image | |
# from torchvision import transforms | |
# from config import LABELS_TO_IDS | |
# from utils.vis_utils import visualize_mask_with_overlay | |
# # Example usage | |
# TASK = 'pose' | |
# VERSION = 'sapiens_1b' | |
# model_path = get_model_path(TASK, VERSION) | |
# print(model_path) | |
# model = torch.jit.load(model_path) | |
# model.eval() | |
# model.to("cuda") | |
# def get_pose(image, pose_estimator, input_shape=(3, 1024, 768), device="cuda"): | |
# # Preprocess the image | |
# img = preprocess_image(image, input_shape) | |
# # Run the model | |
# with torch.no_grad(): | |
# heatmap = pose_estimator(img.to(device)) | |
# # Post-process the output | |
# keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(), | |
# input_shape[1:], | |
# (input_shape[1] // 4, input_shape[2] // 4)) | |
# # Scale keypoints to original image size | |
# scale_x = image.width / input_shape[2] | |
# scale_y = image.height / input_shape[1] | |
# keypoints[:, 0] *= scale_x | |
# keypoints[:, 1] *= scale_y | |
# # Visualize the keypoints on the original image | |
# pose_image = visualize_keypoints(image, keypoints, keypoint_scores) | |
# return pose_image | |
# def preprocess_image(image, input_shape): | |
# # Resize and normalize the image | |
# img = image.resize((input_shape[2], input_shape[1])) | |
# img = np.array(img).transpose(2, 0, 1) | |
# img = torch.from_numpy(img).float() | |
# img = img[[2, 1, 0], ...] # RGB to BGR | |
# mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1) | |
# std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1) | |
# img = (img - mean) / std | |
# return img.unsqueeze(0) | |
# def udp_decode(heatmap, img_size, heatmap_size): | |
# # This is a simplified version. You might need to implement the full UDP decode logic | |
# h, w = heatmap_size | |
# keypoints = np.zeros((heatmap.shape[0], 2)) | |
# keypoint_scores = np.zeros(heatmap.shape[0]) | |
# for i in range(heatmap.shape[0]): | |
# hm = heatmap[i] | |
# idx = np.unravel_index(np.argmax(hm), hm.shape) | |
# keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h] | |
# keypoint_scores[i] = hm[idx] | |
# return keypoints, keypoint_scores | |
# def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3): | |
# draw = ImageDraw.Draw(image) | |
# for (x, y), score in zip(keypoints, keypoint_scores): | |
# if score > threshold: | |
# draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red') | |
# return image | |
# from utils.vis_utils import resize_image | |
# pil_image = Image.open('/home/user/app/assets/image.webp') | |
# if pil_image.mode == 'RGBA': | |
# pil_image = pil_image.convert('RGB') | |
# output_pose = get_pose(resized_pil_image, model) | |
# output_pose | |
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw | |
from torchvision import transforms | |
from config import SAPIENS_LITE_MODELS_PATH | |
def load_model(task, version): | |
try: | |
model_path = SAPIENS_LITE_MODELS_PATH[task][version] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
model = torch.jit.load(model_path) | |
model.eval() | |
model.to(device) | |
return model, device | |
except KeyError as e: | |
print(f"Error: Tarea o versión inválida. {e}") | |
return None, None | |
def preprocess_image(image, input_shape): | |
img = image.resize((input_shape[2], input_shape[1])) | |
img = np.array(img).transpose(2, 0, 1) | |
img = torch.from_numpy(img).float() | |
img = img[[2, 1, 0], ...] # RGB to BGR | |
mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1) | |
std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1) | |
img = (img - mean) / std | |
return img.unsqueeze(0) | |
def udp_decode(heatmap, img_size, heatmap_size): | |
# This is a simplified version. You might need to implement the full UDP decode logic | |
h, w = heatmap_size | |
keypoints = np.zeros((heatmap.shape[0], 2)) | |
keypoint_scores = np.zeros(heatmap.shape[0]) | |
for i in range(heatmap.shape[0]): | |
hm = heatmap[i] | |
idx = np.unravel_index(np.argmax(hm), hm.shape) | |
keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h] | |
keypoint_scores[i] = hm[idx] | |
return keypoints, keypoint_scores | |
def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3): | |
draw = ImageDraw.Draw(image) | |
for (x, y), score in zip(keypoints, keypoint_scores): | |
if score > threshold: | |
draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red') | |
return image | |
def process_image_or_video(input_data, task='pose', version='sapiens_1b'): | |
model, device = load_model(task, version) | |
if model is None or device is None: | |
return None | |
input_shape = (3, 1024, 768) | |
def process_frame(frame): | |
if isinstance(frame, np.ndarray): | |
frame = Image.fromarray(frame) | |
if frame.mode == 'RGBA': | |
frame = frame.convert('RGB') | |
img = preprocess_image(frame, input_shape) | |
with torch.no_grad(): | |
heatmap = model(img.to(device)) | |
keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(), | |
input_shape[1:], | |
(input_shape[1] // 4, input_shape[2] // 4)) | |
scale_x = frame.width / input_shape[2] | |
scale_y = frame.height / input_shape[1] | |
keypoints[:, 0] *= scale_x | |
keypoints[:, 1] *= scale_y | |
pose_image = visualize_keypoints(frame, keypoints, keypoint_scores) | |
return pose_image | |
if isinstance(input_data, np.ndarray): # Video frame | |
return process_frame(input_data) | |
elif isinstance(input_data, Image.Image): # Imagen | |
return process_frame(input_data) | |
else: | |
print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.") | |
return None |