fix: add kwargs to model to ignore additional arguments
Browse files
model.py
CHANGED
@@ -31,6 +31,6 @@ class CSDModel(PreTrainedModel):
|
|
31 |
)
|
32 |
|
33 |
@torch.inference_mode()
|
34 |
-
def forward(self, pixel_values: torch.Tensor) -> CSDOutput:
|
35 |
image_embeds, style_embeds, content_embeds = self.model(pixel_values)
|
36 |
return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)
|
|
|
31 |
)
|
32 |
|
33 |
@torch.inference_mode()
|
34 |
+
def forward(self, pixel_values: torch.Tensor, **kwargs) -> CSDOutput:
|
35 |
image_embeds, style_embeds, content_embeds = self.model(pixel_values)
|
36 |
return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)
|