Spaces:
Running
Running
File size: 3,957 Bytes
6e9c433 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from collections import defaultdict
import os
from typing import Any, Dict, List, Optional, Tuple
import cv2
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np
from .configs.base_config import base_cfg
from .rgbd_model import RGBDModel
class ModelPL(pl.LightningModule):
def __init__(self, cfg: base_cfg):
super().__init__()
self.cfg = cfg
self.model = RGBDModel(cfg)
def forward(self, images: Tensor, depths: Tensor):
return self.model.forward(images, depths)
def __inference_v1(
self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
):
res_lst: List[List[np.ndarray]] = []
for output, image_size in zip(outputs["sod"], image_sizes):
output: Tensor = F.interpolate(
output.unsqueeze(0),
size=(image_size[1], image_size[0]),
mode="bilinear",
align_corners=False,
)
res: np.ndarray = output.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
if self.cfg.is_fp16:
res = np.float32(res)
res_lst.append([(res * 255).astype(np.uint8)])
return res_lst
def __inference_v2(
self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
):
res_lst: List[List[np.ndarray]] = []
for output, image_size in zip(outputs["sod"], image_sizes):
output: Tensor = F.interpolate(
output.unsqueeze(0),
size=(image_size[1], image_size[0]),
mode="bilinear",
align_corners=False,
)
res: np.ndarray = torch.argmax(output, dim=1).cpu().numpy().squeeze()
res_lst.append([res])
return res_lst
def __inference_v3v5(
self, outputs: Dict[str, Tensor], image_sizes: List[Tuple[int, int]]
):
res_lst: List[List[np.ndarray]] = []
for bi, image_size in enumerate(image_sizes):
res_lst_per_sample: List[np.ndarray] = []
for i in range(self.cfg.num_classes):
pred = outputs[f"sod{i}"][bi]
pred: Tensor = F.interpolate(
pred.unsqueeze(0),
size=(image_size[1], image_size[0]),
mode="bilinear",
align_corners=False,
)
res: np.ndarray = pred.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
if self.cfg.is_fp16:
res = np.float32(res)
res_lst_per_sample.append((res * 255).astype(np.uint8))
res_lst.append(res_lst_per_sample)
return res_lst
@torch.no_grad()
def inference(
self,
image_sizes: List[Tuple[int, int]],
images: Tensor,
depths: Optional[Tensor],
max_gts: Optional[List[int]],
) -> List[List[np.ndarray]]:
self.model.eval()
assert len(image_sizes) == len(
images
), "The number of image_sizes must equal to the number of images"
gpu_images: Tensor = images.to(self.device)
gpu_depths: Tensor = depths.to(self.device)
if self.cfg.ground_truth_version == 6:
with torch.cuda.amp.autocast(enabled=self.cfg.is_fp16):
outputs: Dict[str, Tensor] = dict()
for i in range(self.cfg.num_classes):
outputs[f"sod{i}"] = self.model.inference(
gpu_images, gpu_depths, [i] * gpu_images.shape[0], max_gts
)["sod"]
return self.__inference_v3v5(outputs, image_sizes)
else:
raise Exception(
f"Unsupported ground_truth_version {self.cfg.ground_truth_version}"
)
|