"""This file contains the definition of utility functions for masking.""" import math from typing import Text, Tuple import torch def get_mask_tokens( tokens: torch.Tensor, mask_token: int, mode: Text = "arccos", min_masking_ratio: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Get the masked tokens. Args: tokens -> torch.Tensor: The input tokens. mask_token -> int: The special `mask` token. mode -> Text: The masking function to use (default: "arccos"). Returns: masked_tokens -> torch.Tensor: The masked input tokens. Each masked token is set to mask_token. mask -> torch.Tensor: A boolean tensor mask indicating which tokens are masked. """ r = torch.rand(tokens.size(0)) * (1 - min_masking_ratio) if mode == "linear": val_to_mask = 1 - r elif mode == "square": val_to_mask = 1 - (r**2) elif mode == "cosine": val_to_mask = torch.cos(r * math.pi * 0.5) elif mode == "arccos": val_to_mask = torch.acos(r) / (math.pi * 0.5) else: raise ValueError( "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos'." ) masked_tokens = tokens.detach().clone() mask = torch.rand(tokens.size()) < val_to_mask.view(-1, 1, 1) masked_tokens[mask] = torch.full_like(masked_tokens[mask], mask_token) return masked_tokens, mask def get_masking_ratio(progress: float, mode: Text = "arccos") -> torch.Tensor: """Get masking ratio. Args: progress -> float: The percentage of iterations already done. mode -> Text: The masking function to use (default: "arccos"). Returns: val_to_mask -> torch.Tensor: The masking ratio. """ r = torch.tensor(progress) if mode == "root": val_to_mask = 1 - (r**0.5) elif mode == "square": val_to_mask = 1 - (r**2) elif mode == "cosine": val_to_mask = torch.cos(r * math.pi * 0.5) elif mode == "arccos": val_to_mask = torch.acos(r) / (math.pi * 0.5) elif mode == "linear": val_to_mask = 1 - r else: raise ValueError( "Invalid mode. Choose between 'linear','square', 'cosine', 'arccos', 'root'." ) val_to_mask = torch.clamp(val_to_mask, 1e-6, 1.0) return val_to_mask