File size: 681 Bytes
80ebcb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from typing import List, Union

import torch

from .. import functional as FF
from .base import ProcessorMixin


class CaptionTextDropoutProcessor(ProcessorMixin):
    def __init__(self, dropout_p: float = 0.0) -> None:
        self.dropout_p = dropout_p

    def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]:
        return FF.dropout_caption(caption, self.dropout_p)


class CaptionEmbeddingDropoutProcessor(ProcessorMixin):
    def __init__(self, dropout_p: float = 0.0) -> None:
        self.dropout_p = dropout_p

    def forward(self, embedding: torch.Tensor) -> torch.Tensor:
        return FF.dropout_embeddings_to_zero(embedding, self.dropout_p)