File size: 4,566 Bytes
613748c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from PIL import Image
import torch
from huggingface_hub import hf_hub_download

# custom installation from this PR: https://github.com/huggingface/transformers/pull/34583
# !pip install git+https://github.com/geetu040/transformers.git@depth-pro-projects#egg=transformers
from transformers import DepthProConfig, DepthProImageProcessorFast, DepthProForDepthEstimation

# load DepthPro model, used as backbone
config = DepthProConfig(
    patch_size=192,
    patch_embeddings_size=16,
    num_hidden_layers=12,
    intermediate_hook_ids=[11, 8, 7, 5],
    intermediate_feature_dims=[256, 256, 256, 256],
    scaled_images_ratios=[0.5, 1.0],
    scaled_images_overlap_ratios=[0.5, 0.25],
    scaled_images_feature_dims=[1024, 512],
    use_fov_model=False,
)
depthpro_for_depth_estimation = DepthProForDepthEstimation(config)

# create DepthPro for super resolution
class DepthProForSuperResolution(torch.nn.Module):
    def __init__(self, depthpro_for_depth_estimation):
        super().__init__()

        self.depthpro_for_depth_estimation = depthpro_for_depth_estimation
        hidden_size = self.depthpro_for_depth_estimation.config.fusion_hidden_size

        self.image_head = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(
                in_channels=config.num_channels,
                out_channels=hidden_size,
                kernel_size=4, stride=2, padding=1
            ),
            torch.nn.ReLU(),
        )

        self.head = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=hidden_size,
                out_channels=hidden_size,
                kernel_size=3, stride=1, padding=1
            ),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(
                in_channels=hidden_size,
                out_channels=hidden_size,
                kernel_size=4, stride=2, padding=1
            ),
            torch.nn.ReLU(),
            torch.nn.Conv2d(
                in_channels=hidden_size,
                out_channels=self.depthpro_for_depth_estimation.config.num_channels,
                kernel_size=3, stride=1, padding=1
            ),
        )

    def forward(self, pixel_values):
        # x is the low resolution image
        x = pixel_values
        encoder_features = self.depthpro_for_depth_estimation.depth_pro(x).features
        fused_hidden_state = self.depthpro_for_depth_estimation.fusion_stage(encoder_features)[-1]
        x = self.image_head(x)
        x = torch.nn.functional.interpolate(x, size=fused_hidden_state.shape[2:])
        x = x + fused_hidden_state
        x = self.head(x)
        return x

# initialize the model
model = DepthProForSuperResolution(depthpro_for_depth_estimation)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# load weights
weights_path = hf_hub_download(repo_id="geetu040/DepthPro_SR_4x_384p", filename="model_weights.pth")
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))

# load image processor
image_processor = DepthProImageProcessorFast(
    do_resize=True,
    size={"width": 384, "height": 384},
    do_rescale=True,
    do_normalize=True
)

# define crop function to ensure square image
def crop_image(image):
    """
    Crops the image from the center to make aspect ratio 1:1.
    """
    width, height = image.size
    min_dim = min(width, height)
    left = (width - min_dim) // 2
    top = (height - min_dim) // 2
    right = left + min_dim
    bottom = top + min_dim
    image = image.crop((left, top, right, bottom))
    return image


def predict(image):
	# inference

	image = crop_image(image)
	image = image.resize((384, 384), Image.Resampling.BICUBIC)

	# prepare image for the model
	inputs = image_processor(images=image, return_tensors="pt")
	inputs = {k: v.to(device) for k, v in inputs.items()}

	with torch.no_grad():
		outputs = model(**inputs)

	# convert tensors to PIL.Image
	output = outputs[0]                        # extract the first and only batch
	output = output.cpu()                      # unload from cuda if used
	output = torch.permute(output, (1, 2, 0))  # (C, H, W) -> (H, W, C)
	output = output * 0.5 + 0.5                # undo normalization
	output = output * 255.                     # undo scaling
	output = output.clip(0, 255.)              # fix out of range
	output = output.numpy()                    # convert to numpy
	output = output.astype('uint8')            # convert to PIL.Image compatible format
	output = Image.fromarray(output)           # create PIL.Image object

	return output