Update processor __call__ for single image/category
Browse files- processor.py +7 -3
processor.py
CHANGED
@@ -76,12 +76,15 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
76 |
"""
|
77 |
use_cats = categories is not None
|
78 |
|
79 |
-
# Single Image + Single category
|
80 |
if isinstance(images, Image.Image):
|
81 |
-
|
|
|
82 |
if use_cats:
|
83 |
-
|
|
|
84 |
|
|
|
85 |
data = {}
|
86 |
data["pixel_values"] = torch.stack([self.process_img(img) for img in images])
|
87 |
|
@@ -91,3 +94,4 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
91 |
)
|
92 |
|
93 |
return BatchFeature(data=data)
|
|
|
|
76 |
"""
|
77 |
use_cats = categories is not None
|
78 |
|
79 |
+
# Single Image (+ Single category)
|
80 |
if isinstance(images, Image.Image):
|
81 |
+
data = {}
|
82 |
+
data["pixel_values"] = self.process_img(images)
|
83 |
if use_cats:
|
84 |
+
data["category_indices"] = self.process_cat(categories)
|
85 |
+
return BatchFeature(data=data)
|
86 |
|
87 |
+
# Multiple Images (+ Multiple Categories)
|
88 |
data = {}
|
89 |
data["pixel_values"] = torch.stack([self.process_img(img) for img in images])
|
90 |
|
|
|
94 |
)
|
95 |
|
96 |
return BatchFeature(data=data)
|
97 |
+
|