ysmao's picture
add layout controlnet
4342954
raw
history blame
1.33 kB
import torch
import numpy as np
from PIL import Image
from transformers import DPTFeatureExtractor
from transformers import DPTForDepthEstimation
class DepthDetector:
def __init__(self, model_path=None):
if model_path is not None:
self.model_path = model_path
else:
self.model_path = "Intel/dpt-hybrid-midas"
self.model = DPTForDepthEstimation.from_pretrained(self.model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.feature_extractor = DPTFeatureExtractor.from_pretrained(self.model_path)
@torch.no_grad()
def __call__(self, image):
self.model.to(self.device)
H, W, C = image.shape
inputs = self.feature_extractor(images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
outputs = self.model(**inputs)
predicted_depth = outputs.predicted_depth
outputs = predicted_depth.squeeze().cpu().numpy()
if len(outputs.shape) == 2:
output = outputs[np.newaxis, :, :]
else:
output = outputs
formatted = (output * 255 / np.max(output)).astype("uint8")
depth_image = Image.fromarray(formatted[0, ...]).resize((W, H))
self.model.to("cpu")
return depth_image