juanpablomesa
commited on
Commit
·
073334d
1
Parent(s):
0f7ecda
Added normalization using torch and shape check
Browse files- handler.py +9 -3
handler.py
CHANGED
@@ -118,9 +118,15 @@ class EndpointHandler:
|
|
118 |
self.logger.info("Squeezing tensor")
|
119 |
batch_emb = frame_embedding.squeeze(0)
|
120 |
|
121 |
-
#
|
122 |
-
self.logger.info("
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
self.logger.info("Converting into numpy array")
|
126 |
batch_emb = batch_emb.cpu().detach().numpy()
|
|
|
118 |
self.logger.info("Squeezing tensor")
|
119 |
batch_emb = frame_embedding.squeeze(0)
|
120 |
|
121 |
+
# Check the shape of the tensor
|
122 |
+
self.logger.info(f"Shape of the batch_emb tensor: {batch_emb.shape}")
|
123 |
+
|
124 |
+
# Normalize the embeddings if it's a 2D tensor
|
125 |
+
if batch_emb.dim() == 2:
|
126 |
+
self.logger.info("Normalizing embeddings")
|
127 |
+
batch_emb = torch.nn.functional.normalize(batch_emb, p=2, dim=1)
|
128 |
+
else:
|
129 |
+
self.logger.info("Skipping normalization due to tensor shape")
|
130 |
|
131 |
self.logger.info("Converting into numpy array")
|
132 |
batch_emb = batch_emb.cpu().detach().numpy()
|