Upload processor
Browse files- processor.py +4 -2
processor.py
CHANGED
@@ -71,7 +71,7 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
71 |
BatchFeature
|
72 |
pixel_values : torch.Tensor
|
73 |
Processed image tensor (B C H W)
|
74 |
-
|
75 |
Categories indices (B)
|
76 |
"""
|
77 |
use_cats = categories is not None
|
@@ -86,6 +86,8 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
86 |
data["pixel_values"] = torch.stack([self.process_img(img) for img in images])
|
87 |
|
88 |
if use_cats:
|
89 |
-
data["
|
|
|
|
|
90 |
|
91 |
return BatchFeature(data=data)
|
|
|
71 |
BatchFeature
|
72 |
pixel_values : torch.Tensor
|
73 |
Processed image tensor (B C H W)
|
74 |
+
category_indices : torch.Tensor
|
75 |
Categories indices (B)
|
76 |
"""
|
77 |
use_cats = categories is not None
|
|
|
86 |
data["pixel_values"] = torch.stack([self.process_img(img) for img in images])
|
87 |
|
88 |
if use_cats:
|
89 |
+
data["category_indices"] = torch.stack(
|
90 |
+
[self.process_cat(c) for c in categories]
|
91 |
+
)
|
92 |
|
93 |
return BatchFeature(data=data)
|