gaur3009 commited on
Commit
580136d
·
verified ·
1 Parent(s): 1edd3bd

Create preprocessing.py

Browse files
Files changed (1) hide show
  1. preprocessing.py +10 -0
preprocessing.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def pad_to_22_channels(input_tensor):
5
+ if input_tensor.shape[1] == 3: # RGB input
6
+ return torch.cat([input_tensor]*7 + [input_tensor[:,0:1]], dim=1)
7
+ return input_tensor
8
+
9
+ # Modify where you prepare inputs (usually near gmm() call)
10
+ inputA = pad_to_22_channels(your_rgb_input) # Before passing to GMM