import torch import numpy as np from PIL import Image from torchvision import transforms import torch.nn.functional as F from .dsine.dsine import DSINE from .dsine import utils as dsine_utils class NormalDetector: def __init__(self, model_path): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = DSINE() self.model = dsine_utils.load_checkpoint(model_path, self.model) self.normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) self.fov = 60 @torch.no_grad() def __call__(self, image): self.model.to(self.device) self.model.pixel_coords = self.model.pixel_coords.to(self.device) img = np.array(image).astype(np.float32) / 255.0 img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device) _, _, orig_H, orig_W = img.shape l, r, t, b = dsine_utils.pad_input(orig_H, orig_W) img = F.pad(img, (l, r, t, b), mode="constant", value=0.0) img = self.normalize(img) intrinsics = dsine_utils.get_intrins_from_fov( new_fov=self.fov, H=orig_H, W=orig_W, device=self.device ).unsqueeze(0) intrinsics[:, 0, 2] += l intrinsics[:, 1, 2] += t pred_norm = self.model(img, intrins=intrinsics)[-1] pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W] pred_norm_np = ( pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0) ) # (H, W, 3) pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8) normal_img = Image.fromarray(pred_norm_np).resize((orig_W, orig_H)) self.model.to("cpu") self.model.pixel_coords = self.model.pixel_coords.to("cpu") return normal_img if __name__ == "__main__": from diffusers.utils import load_image image = load_image( "https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg" ) image = np.asarray(image) normal_detector = NormalDetector( model_path="/juicefs/training/models/open_source/dsine/dsine.pt", efficientnet_path="/juicefs/training/models/open_source/dsine/tf_efficientnet_b5_ap-9e82fae8.pth", ) normal_image = normal_detector(image) normal_image.save("normal_image.jpg")