# coding: utf-8 """ face detectoin and alignment using XPose """ import os import pickle import torch import numpy as np from PIL import Image from torchvision.ops import nms from collections import OrderedDict def clean_state_dict(state_dict): new_state_dict = OrderedDict() for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] # remove `module.` new_state_dict[k] = v return new_state_dict from src.models.XPose import transforms as T from src.models.XPose.models import build_model from src.models.XPose.predefined_keypoints import * from src.models.XPose.util import box_ops from src.models.XPose.util.config import Config class XPoseRunner(object): def __init__(self, model_config_path, model_checkpoint_path, embeddings_cache_path=None, cpu_only=False, **kwargs): self.device_id = kwargs.get("device_id", 0) self.flag_use_half_precision = kwargs.get("flag_use_half_precision", True) self.device = f"cuda:{self.device_id}" if not cpu_only else "cpu" self.model = self.load_animal_model(model_config_path, model_checkpoint_path, self.device) # Load cached embeddings if available try: with open(f'{embeddings_cache_path}_9.pkl', 'rb') as f: self.ins_text_embeddings_9, self.kpt_text_embeddings_9 = pickle.load(f) with open(f'{embeddings_cache_path}_68.pkl', 'rb') as f: self.ins_text_embeddings_68, self.kpt_text_embeddings_68 = pickle.load(f) print("Loaded cached embeddings from file.") except Exception: raise ValueError("Could not load clip embeddings from file, please check your file path.") def load_animal_model(self, model_config_path, model_checkpoint_path, device): args = Config.fromfile(model_config_path) args.device = device model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location=lambda storage, loc: storage) load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) model.eval() return model def load_image(self, input_image): image_pil = input_image.convert("RGB") transform = T.Compose([ T.RandomResize([800], max_size=1333), # NOTE: fixed size to 800 T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) image, _ = transform(image_pil, None) return image_pil, image def get_unipose_output(self, image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold): instance_list = instance_text_prompt.split(',') if len(keypoint_text_prompt) == 9: # torch.Size([1, 512]) torch.Size([9, 512]) ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_9, self.kpt_text_embeddings_9 elif len(keypoint_text_prompt) == 68: # torch.Size([1, 512]) torch.Size([68, 512]) ins_text_embeddings, kpt_text_embeddings = self.ins_text_embeddings_68, self.kpt_text_embeddings_68 else: raise ValueError("Invalid number of keypoint embeddings.") target = { "instance_text_prompt": instance_list, "keypoint_text_prompt": keypoint_text_prompt, "object_embeddings_text": ins_text_embeddings.float(), "kpts_embeddings_text": torch.cat( (kpt_text_embeddings.float(), torch.zeros(100 - kpt_text_embeddings.shape[0], 512, device=self.device)), dim=0), "kpt_vis_text": torch.cat((torch.ones(kpt_text_embeddings.shape[0], device=self.device), torch.zeros(100 - kpt_text_embeddings.shape[0], device=self.device)), dim=0) } self.model = self.model.to(self.device) image = image.to(self.device) with torch.no_grad(): with torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=self.flag_use_half_precision): outputs = self.model(image[None], [target]) logits = outputs["pred_logits"].sigmoid()[0] boxes = outputs["pred_boxes"][0] keypoints = outputs["pred_keypoints"][0][:, :2 * len(keypoint_text_prompt)] logits_filt = logits.cpu().clone() boxes_filt = boxes.cpu().clone() keypoints_filt = keypoints.cpu().clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] boxes_filt = boxes_filt[filt_mask] keypoints_filt = keypoints_filt[filt_mask] keep_indices = nms(box_ops.box_cxcywh_to_xyxy(boxes_filt), logits_filt.max(dim=1)[0], iou_threshold=IoU_threshold) filtered_boxes = boxes_filt[keep_indices] filtered_keypoints = keypoints_filt[keep_indices] return filtered_boxes, filtered_keypoints def run(self, input_image, instance_text_prompt, keypoint_text_example, box_threshold, IoU_threshold): if keypoint_text_example in globals(): keypoint_dict = globals()[keypoint_text_example] elif instance_text_prompt in globals(): keypoint_dict = globals()[instance_text_prompt] else: keypoint_dict = globals()["animal"] keypoint_text_prompt = keypoint_dict.get("keypoints") keypoint_skeleton = keypoint_dict.get("skeleton") image_pil, image = self.load_image(input_image) boxes_filt, keypoints_filt = self.get_unipose_output(image, instance_text_prompt, keypoint_text_prompt, box_threshold, IoU_threshold) size = image_pil.size H, W = size[1], size[0] keypoints_filt = keypoints_filt[0].squeeze(0) kp = np.array(keypoints_filt.cpu()) num_kpts = len(keypoint_text_prompt) Z = kp[:num_kpts * 2] * np.array([W, H] * num_kpts) Z = Z.reshape(num_kpts * 2) x = Z[0::2] y = Z[1::2] return np.stack((x, y), axis=1) def warmup(self): img_rgb = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) self.run(img_rgb, 'face', 'face', box_threshold=0.0, IoU_threshold=0.0)