# Copyright (C) 2023 Deforum LLC # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, version 3 of the License. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # Contact the authors: https://deforum.github.io/ import torch import cv2 import os import numpy as np import torchvision.transforms as transforms from .general_utils import download_file_with_checksum from leres.lib.multi_depth_model_woauxi import RelDepthModel from leres.lib.net_tools import load_ckpt class LeReSDepth: def __init__(self, width=448, height=448, models_path=None, checkpoint_name='res101.pth', backbone='resnext101'): self.width = width self.height = height self.models_path = models_path self.checkpoint_name = checkpoint_name self.backbone = backbone download_file_with_checksum(url='https://cloudstor.aarnet.edu.au/plus/s/lTIJF4vrvHCAI31/download', expected_checksum='7fdc870ae6568cb28d56700d0be8fc45541e09cea7c4f84f01ab47de434cfb7463cacae699ad19fe40ee921849f9760dedf5e0dec04a62db94e169cf203f55b1', dest_folder=models_path, dest_filename=self.checkpoint_name) self.depth_model = RelDepthModel(backbone=self.backbone) self.depth_model.eval() self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" self.depth_model.to(self.DEVICE) load_ckpt(os.path.join(self.models_path, self.checkpoint_name), self.depth_model, None, None) @staticmethod def scale_torch(img): if len(img.shape) == 2: img = img[np.newaxis, :, :] if img.shape[2] == 3: transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225))]) img = transform(img) else: img = img.astype(np.float32) img = torch.from_numpy(img) return img def predict(self, image): resized_image = cv2.resize(image, (self.width, self.height)) img_torch = self.scale_torch(resized_image)[None, :, :, :] pred_depth = self.depth_model.inference(img_torch).cpu().numpy().squeeze() pred_depth_ori = cv2.resize(pred_depth, (image.shape[1], image.shape[0])) return torch.from_numpy(pred_depth_ori).unsqueeze(0).to(self.DEVICE) def save_raw_depth(self, depth, filepath): depth_normalized = (depth / depth.max() * 60000).astype(np.uint16) cv2.imwrite(filepath, depth_normalized) def to(self, device): self.DEVICE = device self.depth_model = self.depth_model.to(device) def delete(self): del self.depth_model