gaur3009 commited on
Commit
97b4c52
·
verified ·
1 Parent(s): 4da5a6a

Update preprocessing.py

Browse files
Files changed (1) hide show
  1. preprocessing.py +2 -2
preprocessing.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
- import numpy as np
3
 
4
  def pad_to_22_channels(input_tensor):
5
  if input_tensor.shape[1] == 3: # RGB input
6
- return torch.cat([input_tensor]*7 + [input_tensor[:,0:1]], dim=1)
 
7
  return input_tensor
 
1
  import torch
 
2
 
3
  def pad_to_22_channels(input_tensor):
4
  if input_tensor.shape[1] == 3: # RGB input
5
+ # Repeat channels to make 22 channels
6
+ return torch.cat([input_tensor] * 7 + [input_tensor[:, 0:1]], dim=1)
7
  return input_tensor