Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import os | |
from detectron2.structures import ImageList | |
class ViTMatte(nn.Module): | |
def __init__(self, | |
*, | |
backbone, | |
criterion, | |
pixel_mean, | |
pixel_std, | |
input_format, | |
size_divisibility, | |
decoder, | |
): | |
super(ViTMatte, self).__init__() | |
self.backbone = backbone | |
self.criterion = criterion | |
self.input_format = input_format | |
self.size_divisibility = size_divisibility | |
self.decoder = decoder | |
self.register_buffer( | |
"pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False | |
) | |
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) | |
assert ( | |
self.pixel_mean.shape == self.pixel_std.shape | |
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" | |
def device(self): | |
return self.pixel_mean.device | |
def forward(self, batched_inputs): | |
images, targets, H, W = self.preprocess_inputs(batched_inputs) | |
features = self.backbone(images) | |
outputs = self.decoder(features, images) | |
if self.training: | |
assert targets is not None | |
trimap = images[:, 3:4] | |
sample_map = torch.zeros_like(trimap) | |
sample_map[trimap==0.5] = 1 | |
losses = self.criterion(sample_map ,outputs, targets) | |
return losses | |
else: | |
outputs['phas'] = outputs['phas'][:,:,:H,:W] | |
return outputs | |
def preprocess_inputs(self, batched_inputs): | |
""" | |
Normalize, pad and batch the input images. | |
""" | |
images = batched_inputs["image"].to(self.device) | |
trimap = batched_inputs['trimap'].to(self.device) | |
images = (images - self.pixel_mean) / self.pixel_std | |
if 'fg' in batched_inputs.keys(): | |
trimap[trimap < 85] = 0 | |
trimap[trimap >= 170] = 1 | |
trimap[trimap >= 85] = 0.5 | |
images = torch.cat((images, trimap), dim=1) | |
B, C, H, W = images.shape | |
if images.shape[-1]%32!=0 or images.shape[-2]%32!=0: | |
new_H = (32-images.shape[-2]%32) + H | |
new_W = (32-images.shape[-1]%32) + W | |
new_images = torch.zeros((images.shape[0], images.shape[1], new_H, new_W)).to(self.device) | |
new_images[:,:,:H,:W] = images[:,:,:,:] | |
del images | |
images = new_images | |
if "alpha" in batched_inputs: | |
phas = batched_inputs["alpha"].to(self.device) | |
else: | |
phas = None | |
return images, dict(phas=phas), H, W |