|
import pytest |
|
import torch |
|
|
|
from mmseg.core import OHEMPixelSampler |
|
from mmseg.models.decode_heads import FCNHead |
|
|
|
|
|
def _context_for_ohem(): |
|
return FCNHead(in_channels=32, channels=16, num_classes=19) |
|
|
|
|
|
def test_ohem_sampler(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
sampler = OHEMPixelSampler(context=_context_for_ohem()) |
|
seg_logit = torch.randn(1, 19, 45, 45) |
|
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89)) |
|
sampler.sample(seg_logit, seg_label) |
|
|
|
|
|
sampler = OHEMPixelSampler( |
|
context=_context_for_ohem(), thresh=0.7, min_kept=200) |
|
seg_logit = torch.randn(1, 19, 45, 45) |
|
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) |
|
seg_weight = sampler.sample(seg_logit, seg_label) |
|
assert seg_weight.shape[0] == seg_logit.shape[0] |
|
assert seg_weight.shape[1:] == seg_logit.shape[2:] |
|
assert seg_weight.sum() > 200 |
|
|
|
|
|
sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200) |
|
seg_logit = torch.randn(1, 19, 45, 45) |
|
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45)) |
|
seg_weight = sampler.sample(seg_logit, seg_label) |
|
assert seg_weight.shape[0] == seg_logit.shape[0] |
|
assert seg_weight.shape[1:] == seg_logit.shape[2:] |
|
assert seg_weight.sum() == 200 |
|
|