Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
import torch.nn.functional as F | |
from .dsine.dsine import DSINE | |
from .dsine import utils as dsine_utils | |
class NormalDetector: | |
def __init__(self, model_path): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = DSINE() | |
self.model = dsine_utils.load_checkpoint(model_path, self.model) | |
self.normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
) | |
self.fov = 60 | |
def __call__(self, image): | |
self.model.to(self.device) | |
self.model.pixel_coords = self.model.pixel_coords.to(self.device) | |
img = np.array(image).astype(np.float32) / 255.0 | |
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device) | |
_, _, orig_H, orig_W = img.shape | |
l, r, t, b = dsine_utils.pad_input(orig_H, orig_W) | |
img = F.pad(img, (l, r, t, b), mode="constant", value=0.0) | |
img = self.normalize(img) | |
intrinsics = dsine_utils.get_intrins_from_fov( | |
new_fov=self.fov, H=orig_H, W=orig_W, device=self.device | |
).unsqueeze(0) | |
intrinsics[:, 0, 2] += l | |
intrinsics[:, 1, 2] += t | |
pred_norm = self.model(img, intrins=intrinsics)[-1] | |
pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W] | |
pred_norm_np = ( | |
pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0) | |
) # (H, W, 3) | |
pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8) | |
normal_img = Image.fromarray(pred_norm_np).resize((orig_W, orig_H)) | |
self.model.to("cpu") | |
self.model.pixel_coords = self.model.pixel_coords.to("cpu") | |
return normal_img | |
if __name__ == "__main__": | |
from diffusers.utils import load_image | |
image = load_image( | |
"https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg" | |
) | |
image = np.asarray(image) | |
normal_detector = NormalDetector( | |
model_path="/juicefs/training/models/open_source/dsine/dsine.pt", | |
efficientnet_path="/juicefs/training/models/open_source/dsine/tf_efficientnet_b5_ap-9e82fae8.pth", | |
) | |
normal_image = normal_detector(image) | |
normal_image.save("normal_image.jpg") | |