File size: 2,323 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""This file contains the definition of utility functions for masking."""

import math
from typing import Text, Tuple
import torch


def get_mask_tokens(
    tokens: torch.Tensor,
    mask_token: int,
    mode: Text = "arccos",
    min_masking_ratio: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Get the masked tokens.
    Args:
        tokens -> torch.Tensor: The input tokens.
        mask_token -> int: The special `mask` token.
        mode -> Text: The masking function to use (default: "arccos").
    Returns:
        masked_tokens -> torch.Tensor: The masked input tokens. Each masked token is set to mask_token.
        mask -> torch.Tensor: A boolean tensor mask indicating which tokens are masked.
    """
    r = torch.rand(tokens.size(0)) * (1 - min_masking_ratio)
    if mode == "linear":
        val_to_mask = 1 - r
    elif mode == "square":
        val_to_mask = 1 - (r**2)
    elif mode == "cosine":
        val_to_mask = torch.cos(r * math.pi * 0.5)
    elif mode == "arccos":
        val_to_mask = torch.acos(r) / (math.pi * 0.5)
    else:
        raise ValueError(
            "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos'."
        )

    masked_tokens = tokens.detach().clone()
    mask = torch.rand(tokens.size()) < val_to_mask.view(-1, 1, 1)

    masked_tokens[mask] = torch.full_like(masked_tokens[mask], mask_token)
    return masked_tokens, mask


def get_masking_ratio(progress: float, mode: Text = "arccos") -> torch.Tensor:
    """Get masking ratio.
    Args:
        progress -> float: The percentage of iterations already done.
        mode -> Text: The masking function to use (default: "arccos").

    Returns:
        val_to_mask -> torch.Tensor: The masking ratio.
    """
    r = torch.tensor(progress)
    if mode == "root":
        val_to_mask = 1 - (r**0.5)
    elif mode == "square":
        val_to_mask = 1 - (r**2)
    elif mode == "cosine":
        val_to_mask = torch.cos(r * math.pi * 0.5)
    elif mode == "arccos":
        val_to_mask = torch.acos(r) / (math.pi * 0.5)
    elif mode == "linear":
        val_to_mask = 1 - r
    else:
        raise ValueError(
            "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos', 'root'."
        )

    val_to_mask = torch.clamp(val_to_mask, 1e-6, 1.0)
    return val_to_mask