|
import torch |
|
|
|
|
|
def fasterrcnn_reshape_transform(x): |
|
target_size = x['pool'].size()[-2:] |
|
activations = [] |
|
for key, value in x.items(): |
|
activations.append( |
|
torch.nn.functional.interpolate( |
|
torch.abs(value), |
|
target_size, |
|
mode='bilinear')) |
|
activations = torch.cat(activations, axis=1) |
|
return activations |
|
|
|
|
|
def swinT_reshape_transform(tensor, height=7, width=7): |
|
result = tensor.reshape(tensor.size(0), |
|
height, width, tensor.size(2)) |
|
|
|
|
|
|
|
result = result.transpose(2, 3).transpose(1, 2) |
|
return result |
|
|
|
|
|
def vit_reshape_transform(tensor, height=14, width=14): |
|
result = tensor[:, 1:, :].reshape(tensor.size(0), |
|
height, width, tensor.size(2)) |
|
|
|
|
|
|
|
result = result.transpose(2, 3).transpose(1, 2) |
|
return result |
|
|