Spaces:
Running
Running
File size: 681 Bytes
80ebcb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from typing import List, Union
import torch
from .. import functional as FF
from .base import ProcessorMixin
class CaptionTextDropoutProcessor(ProcessorMixin):
def __init__(self, dropout_p: float = 0.0) -> None:
self.dropout_p = dropout_p
def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]:
return FF.dropout_caption(caption, self.dropout_p)
class CaptionEmbeddingDropoutProcessor(ProcessorMixin):
def __init__(self, dropout_p: float = 0.0) -> None:
self.dropout_p = dropout_p
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
return FF.dropout_embeddings_to_zero(embedding, self.dropout_p)
|