joselobenitezg's picture
fix pose im size
607956f
raw
history blame
6.29 kB
# import torch
# import numpy as np
# from PIL import Image
# from torchvision import transforms
# from config import LABELS_TO_IDS
# from utils.vis_utils import visualize_mask_with_overlay
# # Example usage
# TASK = 'pose'
# VERSION = 'sapiens_1b'
# model_path = get_model_path(TASK, VERSION)
# print(model_path)
# model = torch.jit.load(model_path)
# model.eval()
# model.to("cuda")
# def get_pose(image, pose_estimator, input_shape=(3, 1024, 768), device="cuda"):
# # Preprocess the image
# img = preprocess_image(image, input_shape)
# # Run the model
# with torch.no_grad():
# heatmap = pose_estimator(img.to(device))
# # Post-process the output
# keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(),
# input_shape[1:],
# (input_shape[1] // 4, input_shape[2] // 4))
# # Scale keypoints to original image size
# scale_x = image.width / input_shape[2]
# scale_y = image.height / input_shape[1]
# keypoints[:, 0] *= scale_x
# keypoints[:, 1] *= scale_y
# # Visualize the keypoints on the original image
# pose_image = visualize_keypoints(image, keypoints, keypoint_scores)
# return pose_image
# def preprocess_image(image, input_shape):
# # Resize and normalize the image
# img = image.resize((input_shape[2], input_shape[1]))
# img = np.array(img).transpose(2, 0, 1)
# img = torch.from_numpy(img).float()
# img = img[[2, 1, 0], ...] # RGB to BGR
# mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
# std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)
# img = (img - mean) / std
# return img.unsqueeze(0)
# def udp_decode(heatmap, img_size, heatmap_size):
# # This is a simplified version. You might need to implement the full UDP decode logic
# h, w = heatmap_size
# keypoints = np.zeros((heatmap.shape[0], 2))
# keypoint_scores = np.zeros(heatmap.shape[0])
# for i in range(heatmap.shape[0]):
# hm = heatmap[i]
# idx = np.unravel_index(np.argmax(hm), hm.shape)
# keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h]
# keypoint_scores[i] = hm[idx]
# return keypoints, keypoint_scores
# def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3):
# draw = ImageDraw.Draw(image)
# for (x, y), score in zip(keypoints, keypoint_scores):
# if score > threshold:
# draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red')
# return image
# from utils.vis_utils import resize_image
# pil_image = Image.open('/home/user/app/assets/image.webp')
# if pil_image.mode == 'RGBA':
# pil_image = pil_image.convert('RGB')
# output_pose = get_pose(resized_pil_image, model)
# output_pose
import torch
import numpy as np
from PIL import Image, ImageDraw
from torchvision import transforms
from config import SAPIENS_LITE_MODELS_PATH
def load_model(task, version):
try:
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model = torch.jit.load(model_path)
model.eval()
model.to(device)
return model, device
except KeyError as e:
print(f"Error: Tarea o versión inválida. {e}")
return None, None
def preprocess_image(image, input_shape):
img = image.resize((input_shape[2], input_shape[1]))
img = np.array(img).transpose(2, 0, 1)
img = torch.from_numpy(img).float()
img = img[[2, 1, 0], ...] # RGB to BGR
mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1)
std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1)
img = (img - mean) / std
return img.unsqueeze(0)
def udp_decode(heatmap, img_size, heatmap_size):
# This is a simplified version. You might need to implement the full UDP decode logic
h, w = heatmap_size
keypoints = np.zeros((heatmap.shape[0], 2))
keypoint_scores = np.zeros(heatmap.shape[0])
for i in range(heatmap.shape[0]):
hm = heatmap[i]
idx = np.unravel_index(np.argmax(hm), hm.shape)
keypoints[i] = [idx[1] * img_size[1] / w, idx[0] * img_size[0] / h]
keypoint_scores[i] = hm[idx]
return keypoints, keypoint_scores
def visualize_keypoints(image, keypoints, keypoint_scores, threshold=0.3):
draw = ImageDraw.Draw(image)
for (x, y), score in zip(keypoints, keypoint_scores):
if score > threshold:
draw.ellipse([(x-2, y-2), (x+2, y+2)], fill='red', outline='red')
return image
def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
model, device = load_model(task, version)
if model is None or device is None:
return None
input_shape = (3, 1024, 768)
def process_frame(frame):
if isinstance(frame, np.ndarray):
frame = Image.fromarray(frame)
if frame.mode == 'RGBA':
frame = frame.convert('RGB')
img = preprocess_image(frame, input_shape)
with torch.no_grad():
heatmap = model(img.to(device))
keypoints, keypoint_scores = udp_decode(heatmap[0].cpu().float().numpy(),
input_shape[1:],
(input_shape[1] // 4, input_shape[2] // 4))
scale_x = frame.width / input_shape[2]
scale_y = frame.height / input_shape[1]
keypoints[:, 0] *= scale_x
keypoints[:, 1] *= scale_y
pose_image = visualize_keypoints(frame, keypoints, keypoint_scores)
return pose_image
if isinstance(input_data, np.ndarray): # Video frame
return process_frame(input_data)
elif isinstance(input_data, Image.Image): # Imagen
return process_frame(input_data)
else:
print("Tipo de entrada no soportado. Por favor, proporcione una imagen PIL o un frame de video numpy.")
return None