File size: 5,594 Bytes
7216efd
 
 
8f2370f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---

license: mit
---


# DepthPro: Human Segmentation

- This work is a part of the [DepthPro: Beyond Depth Estimation](https://github.com/geetu040/depthpro-beyond-depth) repository, which further explores this model's capabilities on:
    - Image Segmentation - Human Segmentation
    - Image Super Resolution - 384px to 1536px (4x Upscaling)
    - Image Super Resolution - 256px to 1024px (4x Upscaling)

# Usage

Install the required libraries:
```bash

pip install -q numpy pillow torch torchvision

pip install -q git+https://github.com/geetu040/transformers.git@depth-pro-projects#egg=transformers

```

Import the required libraries:
```py

import requests

from PIL import Image

import torch

from huggingface_hub import hf_hub_download

import matplotlib.pyplot as plt



# custom installation from this PR: https://github.com/huggingface/transformers/pull/34583

# !pip install git+https://github.com/geetu040/transformers.git@depth-pro-projects#egg=transformers

from transformers import DepthProConfig, DepthProImageProcessorFast, DepthProForDepthEstimation

```

Load DepthProForDepthEstimation model
```py

# load DepthPro model, used as backbone

config = DepthProConfig(

    patch_size=32,

    patch_embeddings_size=4,

    num_hidden_layers=12,

    intermediate_hook_ids=[11, 8, 7, 5],

    intermediate_feature_dims=[256, 256, 256, 256],

    scaled_images_ratios=[0.5, 1.0],

    scaled_images_overlap_ratios=[0.5, 0.25],

    scaled_images_feature_dims=[1024, 512],

    use_fov_model=False,

)

depthpro_for_depth_estimation = DepthProForDepthEstimation(config)

```

Create DepthProForSuperResolution model
```py

# create DepthPro for super resolution

class DepthProForSuperResolution(torch.nn.Module):

    def __init__(self, depthpro_for_depth_estimation):

        super().__init__()



        self.depthpro_for_depth_estimation = depthpro_for_depth_estimation

        hidden_size = self.depthpro_for_depth_estimation.config.fusion_hidden_size



        self.image_head = torch.nn.Sequential(

            torch.nn.ConvTranspose2d(

                in_channels=config.num_channels,

                out_channels=hidden_size,

                kernel_size=4, stride=2, padding=1

            ),

            torch.nn.ReLU(),

        )



        self.head = torch.nn.Sequential(

            torch.nn.Conv2d(

                in_channels=hidden_size,

                out_channels=hidden_size,

                kernel_size=3, stride=1, padding=1

            ),

            torch.nn.ReLU(),

            torch.nn.ConvTranspose2d(

                in_channels=hidden_size,

                out_channels=hidden_size,

                kernel_size=4, stride=2, padding=1

            ),

            torch.nn.ReLU(),

            torch.nn.Conv2d(

                in_channels=hidden_size,

                out_channels=self.depthpro_for_depth_estimation.config.num_channels,

                kernel_size=3, stride=1, padding=1

            ),

        )



    def forward(self, pixel_values):

        # x is the low resolution image

        x = pixel_values

        encoder_features = self.depthpro_for_depth_estimation.depth_pro(x).features

        fused_hidden_state = self.depthpro_for_depth_estimation.fusion_stage(encoder_features)[-1]

        x = self.image_head(x)

        x = torch.nn.functional.interpolate(x, size=fused_hidden_state.shape[2:])

        x = x + fused_hidden_state

        x = self.head(x)

        return x

```

Load the model and image processor:
```py

# initialize the model

model = DepthProForSuperResolution(depthpro_for_depth_estimation)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)



# load weights

weights_path = hf_hub_download(repo_id="geetu040/DepthPro_SR_4x_256p", filename="model_weights.pth")

model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))



# load image processor

image_processor = DepthProImageProcessorFast(

    do_resize=False,

    do_rescale=True,

    do_normalize=True

)

```

Inference:
```py

# inference



url = "https://huggingface.co/spaces/geetu040/DepthPro_SR_4x_256p/resolve/main/assets/examples/man_with_arms_open.jpeg"



image = Image.open(requests.get(url, stream=True).raw)

image.thumbnail((256, 256)) # resizes the image object to fit within a 256x256 pixel box



# prepare image for the model

inputs = image_processor(images=image, return_tensors="pt")

inputs = {k: v.to(device) for k, v in inputs.items()}



with torch.no_grad():

    outputs = model(**inputs)



# convert tensors to PIL.Image

output = outputs[0]                        # extract the first and only batch

output = output.cpu()                      # unload from cuda if used

output = torch.permute(output, (1, 2, 0))  # (C, H, W) -> (H, W, C)

output = output * 0.5 + 0.5                # undo normalization

output = output * 255.                     # undo scaling

output = output.clip(0, 255.)              # fix out of range

output = output.numpy()                    # convert to numpy

output = output.astype('uint8')            # convert to PIL.Image compatible format

output = Image.fromarray(output)           # create PIL.Image object



# visualize the prediction

fig, axes = plt.subplots(1, 2, figsize=(20, 20))

axes[0].imshow(image)

axes[0].set_title(f'Low-Resolution (LR) {image.size}')

axes[0].axis('off')

axes[1].imshow(output)

axes[1].set_title(f'Super-Resolution (SR) {output.size}')

axes[1].axis('off')

plt.subplots_adjust(wspace=0, hspace=0)

plt.show()

```