ddoc's picture
Upload 170 files
bfe5c2a
import gc
import cv2
import numpy as np
import torch
from PIL import Image
from einops import rearrange, repeat
from modules import devices
from modules.shared import cmd_opts
from .depth_adabins import AdaBinsModel
from .depth_leres import LeReSDepth
from .depth_midas import MidasDepth
from .depth_zoe import ZoeDepth
from .general_utils import debug_print
class DepthModel:
_instance = None
def __new__(cls, *args, **kwargs):
keep_in_vram = kwargs.get('keep_in_vram', False)
depth_algorithm = kwargs.get('depth_algorithm', 'Midas-3-Hybrid')
Width, Height = kwargs.get('Width', 512), kwargs.get('Height', 512)
midas_weight = kwargs.get('midas_weight', 0.2)
model_switched = cls._instance and cls._instance.depth_algorithm != depth_algorithm
resolution_changed = cls._instance and (cls._instance.Width != Width or cls._instance.Height != Height)
zoe_algorithm = 'zoe' in depth_algorithm.lower()
model_deleted = cls._instance and cls._instance.should_delete
should_reload = (cls._instance is None or model_deleted or model_switched or (zoe_algorithm and resolution_changed))
if should_reload:
cls._instance = super().__new__(cls)
cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
elif cls._instance.should_delete and keep_in_vram:
cls._instance._initialize(models_path=args[0], device=args[1], half_precision=not cmd_opts.no_half, keep_in_vram=keep_in_vram, depth_algorithm=depth_algorithm, Width=Width, Height=Height, midas_weight=midas_weight)
cls._instance.should_delete = not keep_in_vram
return cls._instance
def _initialize(self, models_path, device, half_precision=not cmd_opts.no_half, keep_in_vram=False, depth_algorithm='Midas-3-Hybrid', Width=512, Height=512, midas_weight=1.0):
self.models_path = models_path
self.device = device
self.half_precision = half_precision
self.keep_in_vram = keep_in_vram
self.depth_algorithm = depth_algorithm
self.Width, self.Height = Width, Height
self.midas_weight = midas_weight
self.depth_min, self.depth_max = 1000, -1000
self.adabins_helper = None
self._initialize_model()
def _initialize_model(self):
depth_algo = self.depth_algorithm.lower()
if depth_algo.startswith('zoe'):
self.zoe_depth = ZoeDepth(self.Width, self.Height)
if depth_algo == 'zoe+adabins (old)':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
elif depth_algo == 'leres':
self.leres_depth = LeReSDepth(width=448, height=448, models_path=self.models_path, checkpoint_name='res101.pth', backbone='resnext101')
elif depth_algo == 'adabins':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
elif depth_algo.startswith('midas'):
self.midas_depth = MidasDepth(self.models_path, self.device, half_precision=self.half_precision, midas_model_type=self.depth_algorithm)
if depth_algo == 'midas+adabins (old)':
self.adabins_model = AdaBinsModel(self.models_path, keep_in_vram=self.keep_in_vram)
self.adabins_helper = self.adabins_model.adabins_helper
else:
raise Exception(f"Unknown depth_algorithm: {self.depth_algorithm}")
def predict(self, prev_img_cv2, midas_weight, half_precision) -> torch.Tensor:
img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR))
if self.depth_algorithm.lower().startswith('zoe'):
depth_tensor = self.zoe_depth.predict(img_pil).to(self.device)
if self.depth_algorithm.lower() == 'zoe+adabins (old)' and midas_weight < 1.0:
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
elif self.depth_algorithm.lower() == 'leres':
depth_tensor = self.leres_depth.predict(prev_img_cv2.astype(np.float32) / 255.0)
elif self.depth_algorithm.lower() == 'adabins':
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
depth_tensor = torch.tensor(adabins_depth)
if use_adabins is False:
raise Exception("Error getting depth from AdaBins") # TODO: fallback to something else maybe?
elif self.depth_algorithm.lower().startswith('midas'):
depth_tensor = self.midas_depth.predict(prev_img_cv2, half_precision)
if self.depth_algorithm.lower() == 'midas+adabins (old)' and midas_weight < 1.0:
use_adabins, adabins_depth = AdaBinsModel._instance.predict(img_pil, prev_img_cv2)
if use_adabins: # if there was no error in getting the adabins depth, align midas with adabins
depth_tensor = self.blend_and_align_with_adabins(depth_tensor, adabins_depth, midas_weight)
else: # Unknown!
raise Exception(f"Unknown depth_algorithm passed to depth.predict function: {self.depth_algorithm}")
return depth_tensor
def blend_and_align_with_adabins(self, depth_tensor, adabins_depth, midas_weight):
depth_tensor = torch.subtract(50.0, depth_tensor) / 19.0 # align midas depth with adabins depth. Original alignment code from Disco Diffusion
blended_depth_map = (depth_tensor.cpu().numpy() * midas_weight + adabins_depth * (1.0 - midas_weight))
depth_tensor = torch.from_numpy(np.expand_dims(blended_depth_map, axis=0)).squeeze().to(self.device)
debug_print(f"Blended Midas Depth with AdaBins Depth")
return depth_tensor
def to(self, device):
self.device = device
if self.depth_algorithm.lower().startswith('zoe'):
self.zoe_depth.zoe.to(device)
elif self.depth_algorithm.lower() == 'leres':
self.leres_depth.to(device)
elif self.depth_algorithm.lower().startswith('midas'):
self.midas_depth.to(device)
if hasattr(self, 'adabins_model'):
self.adabins_model.to(device)
gc.collect()
torch.cuda.empty_cache()
def to_image(self, depth: torch.Tensor):
depth = depth.cpu().numpy()
depth = np.expand_dims(depth, axis=0) if len(depth.shape) == 2 else depth
self.depth_min, self.depth_max = min(self.depth_min, depth.min()), max(self.depth_max, depth.max())
denom = max(1e-8, self.depth_max - self.depth_min)
temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c')
return Image.fromarray(repeat(temp, 'h w 1 -> h w c', c=3).astype(np.uint8))
def save(self, filename: str, depth: torch.Tensor):
self.to_image(depth).save(filename)
def delete_model(self):
for attr in ['zoe_depth', 'leres_depth']:
if hasattr(self, attr):
getattr(self, attr).delete()
delattr(self, attr)
if hasattr(self, 'midas_depth'):
del self.midas_depth
if hasattr(self, 'adabins_model'):
self.adabins_model.delete_model()
gc.collect()
torch.cuda.empty_cache()
devices.torch_gc()