import numpy as np a = np.arange(12).reshape(2, 3, 2) # (batch, channel, dim) print(a) array([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]) swap_mat = create_swap_channel_mat(input_shape, swap_channel=(1, 2)) # will swap channel 1 and 2 of batch 0 with channel 1 and 2 of batch 1 b = a @ swap_mat print(b) # expected output array([[[0, 1], [8, 9], [10, 11]], [[6, 7], [2, 3], [4, 5]]]) import torch def swap_channels_between_batches(a_tensor, swap_channels): # Copy the tensor to avoid modifying the original tensor result_tensor = a_tensor.clone() # Unpack the channels to be swapped ch1, ch2 = swap_channels # Swap the specified channels between batches result_tensor[0, ch1, :], result_tensor[1, ch1, :] = a_tensor[1, ch1, :].clone(), a_tensor[0, ch1, :].clone() result_tensor[0, ch2, :], result_tensor[1, ch2, :] = a_tensor[1, ch2, :].clone(), a_tensor[0, ch2, :].clone() return result_tensor # Define a sample tensor 'a_tensor' a_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32) # Define channels to swap swap_channels = (1, 2) # Channels to swap between batches # Swap the channels between batches swapped_tensor = swap_channels_between_batches(a_tensor, swap_channels) # Print the original tensor and the tensor after swapping channels between batches print("Original Tensor 'a_tensor':") print(a_tensor) print("\nTensor after swapping channels between batches:") print(swapped_tensor) #------------------------------------------------- import torch from einops import rearrange def shift(arr, num, fill_value=np.nan): result = np.empty_like(arr) if num > 0: result[:num] = fill_value result[num:] = arr[:-num] elif num < 0: result[num:] = fill_value result[:num] = arr[-num:] else: result[:] = arr return result def create_batch_swap_matrix(batch_size, channels, swap_channels): swap_mat = np.eye(batch_size * channels) for c in swap_channels: idx1 = c # 첫 번째 배치의 교환할 채널 인덱스 idx2 = c + channels # 두 번째 배치의 교환할 채널 인덱스 swap_mat[idx1, idx1], swap_mat[idx2, idx2] = 0, 0 # 대각선 값을 0으로 설정 swap_mat[idx1, idx2], swap_mat[idx2, idx1] = 1, 1 # 해당 채널을 교환 return swap_mat def create_batch_swap_matrix(batch_size, channels, swap_channels): swap_mat = np.eye(batch_size * channels) # 모든 채널에 대해 교환 수행 for c in swap_channels: idx1 = np.arange(c, batch_size * channels, channels) # 현재 채널의 모든 배치 인덱스 idx2 = (idx1 + channels) % (batch_size * channels) # 순환을 위해 modulo 사용 swap_mat[idx1, idx1] = 0 swap_mat[idx2, idx2] = 0 swap_mat[idx1, idx2] = 1 swap_mat[idx2, idx1] = 1 return swap_mat def swap_channels_between_batches(input_tensor, swap_matrix): reshaped_tensor = rearrange(input_tensor, 'b c d -> (b c) d') swapped_tensor = swap_matrix @ reshaped_tensor return rearrange(swapped_tensor, '(b c) d -> b c d', b=input_tensor.shape[0]) # 예제 파라미터 batch_size = 2 channels = 3 # swap_info = { # : [1, 2] # batch_index: [channel_indices] # } swap_channels = [1, 2] # 교환할 채널 # 예제 텐서 생성 input_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32) # swap matrix 생성 swap_matrix = create_batch_swap_matrix(batch_size, channels, swap_channels) swap_matrix = torch.Tensor(swap_matrix) # 채널 교환 수행 swapped_tensor = swap_channels_between_batches(input_tensor, swap_matrix) # 결과 출력 print("Original Tensor:") print(input_tensor) print("\nSwapped Tensor:") print(swapped_tensor)