Spaces:
Runtime error
Runtime error
File size: 1,230 Bytes
1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import List
import torch
from isegm.inference.clicker import Click
from .base import BaseTransform
class AddHorizontalFlip(BaseTransform):
def transform(self, image_nd, clicks_lists: List[List[Click]]):
assert len(image_nd.shape) == 4
image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0)
image_width = image_nd.shape[3]
clicks_lists_flipped = []
for clicks_list in clicks_lists:
clicks_list_flipped = [
click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1))
for click in clicks_list
]
clicks_lists_flipped.append(clicks_list_flipped)
clicks_lists = clicks_lists + clicks_lists_flipped
return image_nd, clicks_lists
def inv_transform(self, prob_map):
assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0
num_maps = prob_map.shape[0] // 2
prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:]
return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3]))
def get_state(self):
return None
def set_state(self, state):
pass
def reset(self):
pass
|