Update modelling_magiv2.py
Browse files- modelling_magiv2.py +2 -0
modelling_magiv2.py
CHANGED
@@ -103,6 +103,8 @@ class Magiv2Model(PreTrainedModel):
|
|
103 |
|
104 |
|
105 |
def assign_names_to_characters(self, images, character_bboxes, character_bank, character_clusters, eta=0.75):
|
|
|
|
|
106 |
chapter_wide_char_embeddings = self.predict_crop_embeddings(images, character_bboxes)
|
107 |
chapter_wide_char_embeddings = torch.cat(chapter_wide_char_embeddings, dim=0)
|
108 |
chapter_wide_char_embeddings = torch.nn.functional.normalize(chapter_wide_char_embeddings, p=2, dim=1).cpu().numpy()
|
|
|
103 |
|
104 |
|
105 |
def assign_names_to_characters(self, images, character_bboxes, character_bank, character_clusters, eta=0.75):
|
106 |
+
if len(character_bank["images"]) == 0:
|
107 |
+
return ["Other" for bboxes_for_image in character_bboxes for bbox in bboxes_for_image]
|
108 |
chapter_wide_char_embeddings = self.predict_crop_embeddings(images, character_bboxes)
|
109 |
chapter_wide_char_embeddings = torch.cat(chapter_wide_char_embeddings, dim=0)
|
110 |
chapter_wide_char_embeddings = torch.nn.functional.normalize(chapter_wide_char_embeddings, p=2, dim=1).cpu().numpy()
|