Jeney's picture
Upload folder using huggingface_hub
6a89c74
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os
from detectron2.structures import ImageList
class ViTMatte(nn.Module):
def __init__(self,
*,
backbone,
criterion,
pixel_mean,
pixel_std,
input_format,
size_divisibility,
decoder,
):
super(ViTMatte, self).__init__()
self.backbone = backbone
self.criterion = criterion
self.input_format = input_format
self.size_divisibility = size_divisibility
self.decoder = decoder
self.register_buffer(
"pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False
)
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
assert (
self.pixel_mean.shape == self.pixel_std.shape
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
images, targets, H, W = self.preprocess_inputs(batched_inputs)
features = self.backbone(images)
outputs = self.decoder(features, images)
if self.training:
assert targets is not None
trimap = images[:, 3:4]
sample_map = torch.zeros_like(trimap)
sample_map[trimap==0.5] = 1
losses = self.criterion(sample_map ,outputs, targets)
return losses
else:
outputs['phas'] = outputs['phas'][:,:,:H,:W]
return outputs
def preprocess_inputs(self, batched_inputs):
"""
Normalize, pad and batch the input images.
"""
images = batched_inputs["image"].to(self.device)
trimap = batched_inputs['trimap'].to(self.device)
images = (images - self.pixel_mean) / self.pixel_std
if 'fg' in batched_inputs.keys():
trimap[trimap < 85] = 0
trimap[trimap >= 170] = 1
trimap[trimap >= 85] = 0.5
images = torch.cat((images, trimap), dim=1)
B, C, H, W = images.shape
if images.shape[-1]%32!=0 or images.shape[-2]%32!=0:
new_H = (32-images.shape[-2]%32) + H
new_W = (32-images.shape[-1]%32) + W
new_images = torch.zeros((images.shape[0], images.shape[1], new_H, new_W)).to(self.device)
new_images[:,:,:H,:W] = images[:,:,:,:]
del images
images = new_images
if "alpha" in batched_inputs:
phas = batched_inputs["alpha"].to(self.device)
else:
phas = None
return images, dict(phas=phas), H, W