jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
681 Bytes
from typing import List, Union
import torch
from .. import functional as FF
from .base import ProcessorMixin
class CaptionTextDropoutProcessor(ProcessorMixin):
def __init__(self, dropout_p: float = 0.0) -> None:
self.dropout_p = dropout_p
def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]:
return FF.dropout_caption(caption, self.dropout_p)
class CaptionEmbeddingDropoutProcessor(ProcessorMixin):
def __init__(self, dropout_p: float = 0.0) -> None:
self.dropout_p = dropout_p
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
return FF.dropout_embeddings_to_zero(embedding, self.dropout_p)