juanpablomesa commited on
Commit
073334d
·
1 Parent(s): 0f7ecda

Added normalization using torch and shape check

Browse files
Files changed (1) hide show
  1. handler.py +9 -3
handler.py CHANGED
@@ -118,9 +118,15 @@ class EndpointHandler:
118
  self.logger.info("Squeezing tensor")
119
  batch_emb = frame_embedding.squeeze(0)
120
 
121
- # Normalize the embeddings
122
- self.logger.info("Normalizing embeddings")
123
- batch_emb = torch.nn.functional.normalize(batch_emb, p=2, dim=1)
 
 
 
 
 
 
124
 
125
  self.logger.info("Converting into numpy array")
126
  batch_emb = batch_emb.cpu().detach().numpy()
 
118
  self.logger.info("Squeezing tensor")
119
  batch_emb = frame_embedding.squeeze(0)
120
 
121
+ # Check the shape of the tensor
122
+ self.logger.info(f"Shape of the batch_emb tensor: {batch_emb.shape}")
123
+
124
+ # Normalize the embeddings if it's a 2D tensor
125
+ if batch_emb.dim() == 2:
126
+ self.logger.info("Normalizing embeddings")
127
+ batch_emb = torch.nn.functional.normalize(batch_emb, p=2, dim=1)
128
+ else:
129
+ self.logger.info("Skipping normalization due to tensor shape")
130
 
131
  self.logger.info("Converting into numpy array")
132
  batch_emb = batch_emb.cpu().detach().numpy()