dptdepth / image_processor_dptdepth.py
thinh-huynh-re's picture
Upload processor
6e93e71
from typing import Dict, Optional, Tuple
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from PIL.Image import Image
from torch import Tensor
from transformers.image_processing_utils import BaseImageProcessor
INPUT_IMAGE_SIZE = (352, 352)
transform = transforms.Compose(
[
transforms.Resize(
INPUT_IMAGE_SIZE,
interpolation=TF.InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5),
(0.5, 0.5, 0.5),
),
]
)
class DPTDepthImageProcessor(BaseImageProcessor):
model_input_names = ["dptdepth_preprocessor"]
def __init__(self, testsize: Optional[int] = 352, **kwargs) -> None:
super().__init__(**kwargs)
self.testsize = testsize
def preprocess(
self, inputs: Dict[str, Image], **kwargs # {'rgb': ... }
) -> Dict[str, Tensor]:
rgb: Tensor = transform(inputs["rgb"])
return dict(rgb=rgb.unsqueeze(0))
def postprocess(
self, logits: Tensor, size: Tuple[int, int], **kwargs
) -> np.ndarray:
logits: Tensor = F.upsample(
logits, size=size, mode="bilinear", align_corners=False
)
res: np.ndarray = logits.squeeze().data.cpu().numpy()
# res = (res - res.min()) / (res.max() - res.min() + 1e-8)
return res