|
import numpy |
|
|
|
|
|
class ChannelSelector(object): |
|
"""Select 1ch from multi-channel signal """ |
|
|
|
def __init__(self, train_channel="random", eval_channel=0, axis=1): |
|
self.train_channel = train_channel |
|
self.eval_channel = eval_channel |
|
self.axis = axis |
|
|
|
def __repr__(self): |
|
return ( |
|
"{name}(train_channel={train_channel}, " |
|
"eval_channel={eval_channel}, axis={axis})".format( |
|
name=self.__class__.__name__, |
|
train_channel=self.train_channel, |
|
eval_channel=self.eval_channel, |
|
axis=self.axis, |
|
) |
|
) |
|
|
|
def __call__(self, x, train=True): |
|
|
|
|
|
if x.ndim <= self.axis: |
|
|
|
|
|
ind = tuple( |
|
slice(None) if i < x.ndim else None for i in range(self.axis + 1) |
|
) |
|
x = x[ind] |
|
|
|
if train: |
|
channel = self.train_channel |
|
else: |
|
channel = self.eval_channel |
|
|
|
if channel == "random": |
|
ch = numpy.random.randint(0, x.shape[self.axis]) |
|
else: |
|
ch = channel |
|
|
|
ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim)) |
|
return x[ind] |
|
|