File size: 3,823 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np

a = np.arange(12).reshape(2, 3, 2)  # (batch, channel, dim)
print(a)
array([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]])

swap_mat = create_swap_channel_mat(input_shape, swap_channel=(1, 2))

# will swap channel 1 and 2 of batch 0 with channel 1 and 2 of batch 1
b = a @ swap_mat
print(b)
# expected output
array([[[0, 1], [8, 9], [10, 11]], [[6, 7], [2, 3], [4, 5]]])

import torch


def swap_channels_between_batches(a_tensor, swap_channels):
    # Copy the tensor to avoid modifying the original tensor
    result_tensor = a_tensor.clone()

    # Unpack the channels to be swapped
    ch1, ch2 = swap_channels

    # Swap the specified channels between batches
    result_tensor[0, ch1, :], result_tensor[1, ch1, :] = a_tensor[1, ch1, :].clone(), a_tensor[0, ch1, :].clone()
    result_tensor[0, ch2, :], result_tensor[1, ch2, :] = a_tensor[1, ch2, :].clone(), a_tensor[0, ch2, :].clone()

    return result_tensor


# Define a sample tensor 'a_tensor'
a_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)

# Define channels to swap
swap_channels = (1, 2)  # Channels to swap between batches

# Swap the channels between batches
swapped_tensor = swap_channels_between_batches(a_tensor, swap_channels)

# Print the original tensor and the tensor after swapping channels between batches
print("Original Tensor 'a_tensor':")
print(a_tensor)
print("\nTensor after swapping channels between batches:")
print(swapped_tensor)

#-------------------------------------------------

import torch
from einops import rearrange


def shift(arr, num, fill_value=np.nan):
    result = np.empty_like(arr)
    if num > 0:
        result[:num] = fill_value
        result[num:] = arr[:-num]
    elif num < 0:
        result[num:] = fill_value
        result[:num] = arr[-num:]
    else:
        result[:] = arr
    return result


def create_batch_swap_matrix(batch_size, channels, swap_channels):
    swap_mat = np.eye(batch_size * channels)

    for c in swap_channels:
        idx1 = c  # 첫 번째 배치의 κ΅ν™˜ν•  채널 인덱슀
        idx2 = c + channels  # 두 번째 배치의 κ΅ν™˜ν•  채널 인덱슀

        swap_mat[idx1, idx1], swap_mat[idx2, idx2] = 0, 0  # λŒ€κ°μ„  값을 0으둜 μ„€μ •
        swap_mat[idx1, idx2], swap_mat[idx2, idx1] = 1, 1  # ν•΄λ‹Ή 채널을 κ΅ν™˜
    return swap_mat


def create_batch_swap_matrix(batch_size, channels, swap_channels):
    swap_mat = np.eye(batch_size * channels)

    # λͺ¨λ“  채널에 λŒ€ν•΄ κ΅ν™˜ μˆ˜ν–‰
    for c in swap_channels:
        idx1 = np.arange(c, batch_size * channels, channels)  # ν˜„μž¬ μ±„λ„μ˜ λͺ¨λ“  배치 인덱슀
        idx2 = (idx1 + channels) % (batch_size * channels)  # μˆœν™˜μ„ μœ„ν•΄ modulo μ‚¬μš©

        swap_mat[idx1, idx1] = 0
        swap_mat[idx2, idx2] = 0
        swap_mat[idx1, idx2] = 1
        swap_mat[idx2, idx1] = 1

    return swap_mat


def swap_channels_between_batches(input_tensor, swap_matrix):
    reshaped_tensor = rearrange(input_tensor, 'b c d -> (b c) d')
    swapped_tensor = swap_matrix @ reshaped_tensor
    return rearrange(swapped_tensor, '(b c) d -> b c d', b=input_tensor.shape[0])


# 예제 νŒŒλΌλ―Έν„°
batch_size = 2
channels = 3
# swap_info  = {
#     : [1, 2] # batch_index: [channel_indices]
# }
swap_channels = [1, 2]  # κ΅ν™˜ν•  채널

# 예제 ν…μ„œ 생성
input_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32)

# swap matrix 생성
swap_matrix = create_batch_swap_matrix(batch_size, channels, swap_channels)
swap_matrix = torch.Tensor(swap_matrix)

# 채널 κ΅ν™˜ μˆ˜ν–‰
swapped_tensor = swap_channels_between_batches(input_tensor, swap_matrix)

# κ²°κ³Ό 좜λ ₯
print("Original Tensor:")
print(input_tensor)
print("\nSwapped Tensor:")
print(swapped_tensor)