Layout-Control / annotator /dsine_local.py
ysmao's picture
update dependencies
ef0eb1c
raw
history blame
2.38 kB
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from .dsine.dsine import DSINE
from .dsine import utils as dsine_utils
class NormalDetector:
def __init__(self, model_path):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = DSINE()
self.model = dsine_utils.load_checkpoint(model_path, self.model)
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
self.fov = 60
@torch.no_grad()
def __call__(self, image):
self.model.to(self.device)
self.model.pixel_coords = self.model.pixel_coords.to(self.device)
img = np.array(image).astype(np.float32) / 255.0
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
_, _, orig_H, orig_W = img.shape
l, r, t, b = dsine_utils.pad_input(orig_H, orig_W)
img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
img = self.normalize(img)
intrinsics = dsine_utils.get_intrins_from_fov(
new_fov=self.fov, H=orig_H, W=orig_W, device=self.device
).unsqueeze(0)
intrinsics[:, 0, 2] += l
intrinsics[:, 1, 2] += t
pred_norm = self.model(img, intrins=intrinsics)[-1]
pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
pred_norm_np = (
pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
) # (H, W, 3)
pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)
normal_img = Image.fromarray(pred_norm_np).resize((orig_W, orig_H))
self.model.to("cpu")
self.model.pixel_coords = self.model.pixel_coords.to("cpu")
return normal_img
if __name__ == "__main__":
from diffusers.utils import load_image
image = load_image(
"https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg"
)
image = np.asarray(image)
normal_detector = NormalDetector(
model_path="/juicefs/training/models/open_source/dsine/dsine.pt",
efficientnet_path="/juicefs/training/models/open_source/dsine/tf_efficientnet_b5_ap-9e82fae8.pth",
)
normal_image = normal_detector(image)
normal_image.save("normal_image.jpg")