Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
import random | |
from typing import Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from tokenizers import Tokenizer | |
from torch.distributions import Dirichlet | |
from fourm.data.modality_transforms import get_transform_key | |
from fourm.utils import to_2tuple | |
from fourm.utils.tokenizer import get_sentinel_to_id_mapping | |
def sample_cosine(min_val: float = 0, max_val: float =1) -> float: | |
"""Sample a value from a cosine distribution between min_val and max_val | |
Args: | |
min_val: Minimum value | |
max_val: Maximum value | |
Returns: | |
Sampled value | |
""" | |
return min_val + 0.5 * (max_val - min_val) * (1 + math.cos(math.pi * random.uniform(0, 1))) | |
def sample_uniform(min_val: float = 0, max_val: float =1) -> float: | |
"""Sample a value from a uniform distribution between min_val and max_val | |
Args: | |
min_val: Minimum value | |
max_val: Maximum value | |
Returns: | |
Sampled value | |
""" | |
return random.uniform(min_val, max_val) | |
def simple_span_masking(sequence: List[int], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]: | |
"""Span masking for a sequence | |
Args: | |
sequence: Sequence to mask | |
sentinel_to_id: Mapping from sentinel to id | |
keep_prob: Probability of keeping a token | |
Returns: | |
Masked input sequence and masked target sequence | |
""" | |
sequence_length = len(sequence) | |
# 0 for keep, 1 for mask | |
masks = torch.where(torch.rand(sequence_length) <= keep_prob, 0, 1).bool().tolist() | |
input_sequence = [] | |
target_sequence = [] | |
prev_mask = False | |
sentinel_count = 0 | |
for token, mask in zip(sequence, masks): | |
if mask: | |
if not prev_mask: | |
sentinel_count += 1 | |
input_sequence.append(sentinel_to_id[sentinel_count]) | |
target_sequence.append(sentinel_to_id[sentinel_count]) | |
prev_mask = True | |
target_sequence.append(token) | |
else: | |
prev_mask = False | |
input_sequence.append(token) | |
target_sequence.append(sentinel_to_id[sentinel_count + 1]) | |
return input_sequence, target_sequence | |
def chunk_span_masking(sequence_chunks: List[List[int]], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]: | |
"""Span masking where masking is performed at the chunk level. | |
Args: | |
sequence_chunks: Sequence chunks to mask | |
sentinel_to_id: Mapping from sentinel to id | |
keep_prob: Probability of keeping a token | |
Returns: | |
Masked input sequence and masked target sequence | |
""" | |
chunk_length = len(sequence_chunks) | |
# 0 for keep, 1 for mask | |
masks = torch.where(torch.rand(chunk_length) <= keep_prob, 0, 1).bool().tolist() | |
input_sequence = [] | |
target_sequence = [] | |
prev_mask = False | |
sentinel_count = 0 | |
for chunk, mask in zip(sequence_chunks, masks): | |
if mask: | |
if not prev_mask: | |
sentinel_count += 1 | |
input_sequence.append(sentinel_to_id[sentinel_count]) | |
target_sequence.append(sentinel_to_id[sentinel_count]) | |
prev_mask = True | |
target_sequence.extend(chunk) | |
else: | |
prev_mask = False | |
input_sequence.extend(chunk) | |
target_sequence.append(sentinel_to_id[sentinel_count + 1]) | |
return input_sequence, target_sequence | |
class UnifiedMasking(object): | |
def __init__(self, | |
modality_info: Dict, | |
text_tokenizer: Optional[Tokenizer], | |
input_tokens_range: Union[int, Tuple[int, int]], | |
target_tokens_range: Optional[Union[int, Tuple[int, int]]], | |
max_tries: int = 100, | |
sampling_weights: Optional[List[float]] = None,): | |
"""Performs masking on a dict of modalities (both image based and sequence based modalities) | |
Args: | |
modality_info: Dict with the modalities and their corresponding information | |
text_tokenizer: Tokenizer to use for text modalities | |
input_tokens_range: Range of number of tokens to mask in the input | |
target_tokens_range: Range of number of tokens to mask in the target | |
max_tries: Maximum number of tries to find a valid token budgets | |
sampling_weights: Sampling weights for the mixture of Dirichlet distributions | |
""" | |
self.input_tokens_range = to_2tuple(input_tokens_range) | |
self.target_tokens_range = to_2tuple(target_tokens_range) if target_tokens_range is not None else None | |
self.modality_info = modality_info | |
self.num_modalities = len(modality_info) | |
self.max_tries = max_tries | |
self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()]) | |
self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()]) | |
self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()]) | |
# Dirichlet sampling (supports a mixture of multiple Dirichlet distributions) | |
eps = 1e-9 | |
input_alphas = torch.tensor([mod["input_alphas"] for mod in modality_info.values()]) | |
input_alphas = rearrange(input_alphas, "nmod nmix -> nmix nmod") | |
self.input_dirichlets = [Dirichlet(torch.clamp(input_alpha, min=eps)) for input_alpha in input_alphas] | |
target_alphas = torch.tensor([mod["target_alphas"] for mod in modality_info.values()]) | |
target_alphas = rearrange(target_alphas, "nmod nmix -> nmix nmod") | |
self.target_dirichlets = [Dirichlet(torch.clamp(target_alpha, min=eps)) for target_alpha in target_alphas] | |
assert(len(self.input_dirichlets) == len(self.target_dirichlets)) | |
self.num_dirichlets = len(self.input_dirichlets) | |
if sampling_weights is not None: | |
assert len(sampling_weights) == self.num_dirichlets | |
self.sampling_weights = torch.tensor(sampling_weights) | |
else: | |
self.sampling_weights = None | |
self.text_tokenizer = text_tokenizer | |
self.keep_prob_decay_factor = 0.9 | |
self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer) | |
self.sentinel_ids = set(self.sentinel_to_id.values()) | |
self.pad_id = text_tokenizer.token_to_id("[PAD]") | |
self.eos_id = text_tokenizer.token_to_id("[EOS]") | |
def input_token_budget(self, num_input_tokens, dir_idx=0): | |
"""Sample a token budget for the input | |
Args: | |
num_input_tokens: Number of tokens in the input | |
Returns: | |
Token budget for the input | |
""" | |
# Get the number of tokens for each modality | |
for i in range(self.max_tries): | |
input_token_budget = (self.input_dirichlets[dir_idx].sample() * num_input_tokens).floor().int() | |
diff = num_input_tokens - input_token_budget.sum() | |
# Adds the remaining tokens by sampling from the Dirichlet and taking the argmax | |
# This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0) | |
input_token_budget += torch.bincount(self.input_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(input_token_budget)) | |
# If token budget is over max tokens for a given modality, set it to max | |
input_token_budget = torch.clamp(input_token_budget, max=self.max_tokens) | |
if (input_token_budget >= self.min_tokens).all(): | |
return input_token_budget.tolist() | |
print(f"More than max tries for input!") | |
return input_token_budget.tolist() | |
def target_token_budget(self, input_token_budget, num_target_tokens, dir_idx=0): | |
"""Sample a token budget for the target | |
Args: | |
input_token_budget: Token budget for the input | |
num_target_tokens: Number of tokens in the target | |
Returns: | |
Token budget for the target | |
""" | |
# We don't reduce the number of tokens for sequence based tasks | |
max_tokens_remaining = torch.where(self.mod_is_img, self.max_tokens - torch.tensor(input_token_budget), self.max_tokens) | |
max_tokens_remaining = torch.max(self.min_tokens, max_tokens_remaining) | |
for i in range(self.max_tries): | |
target_token_budget = (self.target_dirichlets[dir_idx].sample() * num_target_tokens).floor().int() | |
diff = num_target_tokens - target_token_budget.sum() | |
# Adds the remaining tokens by sampling from the Dirichlet and taking the argmax | |
# This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0) | |
target_token_budget += torch.bincount(self.target_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(target_token_budget)) | |
# If token budget is over max tokens for a given modality, set it to max | |
target_token_budget = torch.clamp(target_token_budget, max=max_tokens_remaining) | |
if (target_token_budget >= self.min_tokens).all(): | |
return target_token_budget.tolist() | |
print(f"More than max tries for target!") | |
return target_token_budget.tolist() | |
def image_mask(self, tensor: torch.Tensor, num_tokens: int, input_budget: int, target_budget: int): | |
"""Applies input and target masking to an image tensor | |
Args: | |
tensor: Image tensor | |
num_tokens: Number of tokens in the tensor | |
input_budget: Token budget for the input | |
target_budget: Token budget for the target | |
Returns: | |
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask | |
""" | |
noise = torch.rand(num_tokens) | |
ids_shuffle = torch.argsort(noise, dim=0) | |
input_mask = torch.ones(num_tokens, dtype=torch.bool) | |
input_mask[:input_budget] = 0 | |
input_mask = torch.gather(input_mask, dim=0, index=ids_shuffle) | |
if target_budget is None: | |
target_mask = ~input_mask | |
else: | |
target_mask = torch.ones(num_tokens, dtype=torch.bool) | |
target_mask[input_budget:input_budget + target_budget] = 0 | |
target_mask = torch.gather(target_mask, dim=0, index=ids_shuffle) | |
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int) | |
first_mask_token = torch.argmin(target_mask + torch.arange(target_mask.shape[0], device=target_mask.device) * 1e-6) | |
decoder_attention_mask[first_mask_token] = (~target_mask).sum() # Equiv. to target budget | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask} | |
def sequence_token_mask(self, sequence_ids: str, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str, vocab_offset: int): | |
"""Applies input and target masking to a sequence of tokens (e.g. DINOv2 global tokens) | |
The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence. | |
If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough. | |
Args: | |
sequence_ids: Sequence ids | |
max_tokens: Maximum number of tokens in the sequence | |
input_budget: Token budget for the input | |
target_budget: Token budget for the target | |
keep_scheme: Scheme for sampling the keep probability | |
vocab_offset: Offset to avoid overlap with sentinel tokens | |
Returns: | |
Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask | |
""" | |
seq_ids = sequence_ids | |
seq_ids = seq_ids + vocab_offset # Avoid overlap with sentinel tokens (needs to be substracted after decoding) | |
# If input budget is 0, treat it as if the whole sequence is completely masked | |
if input_budget == 0: | |
keep_prob = 0. | |
input_seq_ids = [] | |
_, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob) | |
else: | |
if keep_scheme == 'random': | |
keep_prob = sample_uniform(0, 1) | |
elif keep_scheme == 'all': | |
keep_prob = 1.0 | |
elif keep_scheme == 'binary': | |
keep_prob = random.choice([0., 1.]) | |
else: | |
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}") | |
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob) | |
# Keep lowering the keep_prob while we are over-budget | |
while len(input_seq_ids) > input_budget: | |
keep_prob = keep_prob * self.keep_prob_decay_factor | |
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob) | |
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target | |
max_length = (max_tokens + 1) * 2 | |
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id | |
input_mask = torch.ones(max_length, dtype=torch.bool) | |
target_mask = torch.ones(max_length, dtype=torch.bool) | |
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int) | |
# Set input and input mask | |
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int) | |
input_mask[:len(input_seq_ids)] = 0 | |
if target_budget is None or len(target_seq_ids) <= target_budget: | |
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int) | |
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0 | |
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1 | |
else: | |
# Randomly choose sentinel token. | |
sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids] | |
# If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence | |
chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1)) | |
# If length starting at this token g.t. budget, truncate until budget is reached | |
if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget: | |
target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget] | |
# Otherwise, select earliest sentinel token such that we don't go over budget | |
# Note: We could also use the randomly chosen sentinel token, but that would waste budget | |
else: | |
for idx in sentinel_indices: | |
if len(target_seq_ids) - idx <= target_budget: | |
target_seq_ids = target_seq_ids[idx:] | |
break | |
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int) | |
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0 | |
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1 | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask} | |
def sequence_mask(self, sequence: Union[str, List[str]], max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str): | |
"""Applies input and target masking to a sequence | |
The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence. | |
If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough. | |
Args: | |
sequence: Sequence, can be either a str or list of strings | |
max_tokens: Maximum number of tokens in the sequence | |
input_budget: Token budget for the input | |
target_budget: Token budget for the target | |
keep_scheme: Scheme for sampling the keep probability | |
Returns: | |
Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask | |
""" | |
if isinstance(sequence, str): | |
# Tokenize the sequence and get the ids | |
seq_ids: List[int] = self.text_tokenizer.encode(sequence).ids | |
# Add EOS to all sequences | |
seq_ids.append(self.eos_id) | |
# Truncate sequence | |
seq_ids = seq_ids[:max_tokens] | |
# Use default span masking | |
span_masking_fn = simple_span_masking | |
elif isinstance(sequence, list): | |
# Tokenize the sequence chunks and get the ids | |
encoded_seq_chunks = self.text_tokenizer.encode_batch(sequence) | |
seq_ids: List[List[int]] = [seq.ids for seq in encoded_seq_chunks] | |
# Add EOS as an extra chunk | |
seq_ids.append([self.eos_id]) | |
# Truncate sequence to keep all chunks below max token length | |
cumulative_token_count = np.cumsum(np.array([len(chunk) for chunk in seq_ids])) | |
seq_ids = [chunk for (chunk, token_count) in zip(seq_ids, cumulative_token_count) if token_count <= max_tokens] | |
# Span mask over chunks | |
span_masking_fn = chunk_span_masking | |
else: | |
raise ValueError(f"Invalid sequence: {sequence}") | |
# If input budget is 0, treat it as if the whole sequence is completely masked | |
if input_budget == 0: | |
keep_prob = 0. | |
input_seq_ids = [] | |
_, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob) | |
else: | |
if keep_scheme == 'random': | |
keep_prob = sample_uniform(0, 1) | |
elif keep_scheme == 'all': | |
keep_prob = 1.0 | |
elif keep_scheme == 'binary': | |
keep_prob = random.choice([0., 1.]) | |
else: | |
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}") | |
input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob) | |
# Keep lowering the keep_prob while we are over-budget | |
while len(input_seq_ids) > input_budget: | |
keep_prob = keep_prob * self.keep_prob_decay_factor | |
input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob) | |
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target | |
max_length = (max_tokens + 1) * 2 | |
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id | |
input_mask = torch.ones(max_length, dtype=torch.bool) | |
target_mask = torch.ones(max_length, dtype=torch.bool) | |
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int) | |
# Set input and input mask | |
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int) | |
input_mask[:len(input_seq_ids)] = 0 | |
if target_budget is None or len(target_seq_ids) <= target_budget: | |
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int) | |
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0 | |
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1 | |
else: | |
# Randomly choose sentinel token. | |
sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids] | |
# If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence | |
chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1)) | |
# If length starting at this token g.t. budget, truncate until budget is reached | |
if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget: | |
target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget] | |
# Otherwise, select earliest sentinel token such that we don't go over budget | |
# Note: We could also use the randomly chosen sentinel token, but that would waste budget | |
else: | |
for idx in sentinel_indices: | |
if len(target_seq_ids) - idx <= target_budget: | |
target_seq_ids = target_seq_ids[idx:] | |
break | |
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int) | |
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0 | |
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1 | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask} | |
def sequence_emb_mask_span(self, emb_tensor: torch.Tensor, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str): | |
"""Applies input masking to an sequence embedding tensor, target masking is not supported with sequence embeddings | |
Args: | |
emb_tensor: Sequence embedding tensor | |
max_tokens: Maximum number of tokens in the sequence | |
input_budget: Token budget for the input | |
target_budget: Token budget for the target (unused for now) | |
keep_scheme: Scheme for sampling the keep probability | |
Returns: | |
Dictionary containing the masked sequence embedding tensor, the input mask, the target mask, and the decoder attention mask | |
""" | |
# Only supported as input modality now | |
# Make fake seq ids for sequence embeddings to reuse simple_span_masking function | |
fake_seq_ids = [] | |
emb_dict = {} | |
id_num = len(self.sentinel_ids) | |
emb_ind = 0 | |
while(len(fake_seq_ids) < len(emb_tensor)): | |
if id_num not in self.sentinel_ids: # replace with T5 sentinel_id | |
fake_seq_ids.append(id_num) | |
emb_dict[id_num] = emb_tensor[emb_ind, :] | |
emb_ind += 1 | |
id_num += 1 | |
# Truncate sequence | |
fake_seq_ids = fake_seq_ids[:max_tokens] | |
# If input budget is 0, treat it as if the whole sequence is completely masked | |
if input_budget == 0: | |
keep_prob = 0. | |
fake_input_seq_ids = [] | |
_, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob) | |
else: | |
if keep_scheme == 'random': | |
keep_prob = sample_uniform(0, 1) | |
elif keep_scheme == 'all': | |
keep_prob = 1.0 | |
elif keep_scheme == 'binary': | |
keep_prob = random.choice([0., 1.]) | |
else: | |
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}") | |
fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob) | |
# Keep lowering the keep_prob while we are over-budget | |
while len(fake_input_seq_ids) > input_budget: | |
keep_prob = keep_prob * self.keep_prob_decay_factor | |
fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob) | |
# Span masking can add up to max_tokens tokens for input | |
max_length = max_tokens | |
tensor = torch.zeros((max_length, emb_tensor.shape[1]), dtype=torch.float32) | |
input_mask = torch.ones(max_length, dtype=torch.bool) | |
target_mask = torch.ones(max_length, dtype=torch.bool) | |
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int) | |
# Put tensor values back based on the fake seq ids | |
for i_, fake_id in enumerate(fake_input_seq_ids): | |
if fake_id in self.sentinel_ids: | |
tensor[i_, :] = torch.zeros_like(emb_tensor[0,:]) # TODO replace to learned embeddings later | |
else: | |
tensor[i_, :] = emb_dict[fake_id] | |
# Set input and input mask | |
input_mask[:len(fake_input_seq_ids)] = 0 | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask} | |
def __call__(self, mod_dict): | |
"""Applies input and target masking to a dictionary of modalities | |
Args: | |
mod_dict: Dictionary of modalities | |
Returns: | |
Dictionary containing the masked modalities | |
""" | |
if self.sampling_weights is not None: | |
# Sample masking scheme according to a list of weights | |
dir_idx = torch.multinomial(self.sampling_weights, 1).item() | |
else: | |
# Randomly sample masking scheme | |
dir_idx = random.randint(0, self.num_dirichlets - 1) | |
num_input_tokens = random.randint(*self.input_tokens_range) | |
num_target_tokens = random.randint(*self.target_tokens_range) if self.target_tokens_range is not None else None | |
input_token_budget = self.input_token_budget(num_input_tokens, dir_idx) | |
if num_target_tokens is not None: | |
target_token_budget = self.target_token_budget(input_token_budget, num_target_tokens, dir_idx) | |
else: | |
target_token_budget = [None] * self.num_modalities | |
masked_mod_dict = {} | |
for (mod_name, mod_info), input_budget, target_budget in zip(self.modality_info.items(), input_token_budget, target_token_budget): | |
mod_type = mod_info['type'] | |
mod_name_load = mod_name if mod_name in mod_dict else get_transform_key(mod_name) | |
if mod_type == 'img': | |
masked_mod_dict[mod_name] = self.image_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget) | |
elif mod_type == 'seq': | |
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx] | |
masked_mod_dict[mod_name] = self.sequence_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme) | |
elif mod_type == 'seq_token': | |
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx] | |
vocab_offset = mod_info.get('vocab_offset', 0) # Check if any space is allocated to sentinel tokens and other special tokens | |
masked_mod_dict[mod_name] = self.sequence_token_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme, vocab_offset=vocab_offset) | |
elif mod_type == "seq_emb": | |
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx] | |
masked_mod_dict[mod_name] = self.sequence_emb_mask_span(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme) | |
else: | |
raise ValueError(f"Invalid modality type: {mod_type}") | |
return masked_mod_dict | |
class TransferMasking(object): | |
def __init__(self, | |
modality_info: Dict, | |
text_tokenizer: Optional[Tokenizer], | |
input_modalities: List[str], | |
target_modalities: List[str]): | |
"""Performs masking for transfer on a dict of modalities (both image based and sequence based modalities), | |
by specifying which modalities are inputs and which are targets. | |
Args: | |
modality_info: Dict with the modalities and their corresponding information | |
text_tokenizer: Tokenizer to use for text modalities | |
input_modalities: List of modalities to use as input | |
target_modalities: List of modalities to use as target | |
""" | |
self.modality_info = modality_info | |
self.num_modalities = len(modality_info) | |
self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()]) | |
self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()]) | |
self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()]) | |
self.input_modalities = set(input_modalities) | |
self.target_modalities = set(target_modalities) | |
# Tokenizer for text modalities | |
self.text_tokenizer = text_tokenizer | |
if self.text_tokenizer is not None: | |
self.keep_prob_decay_factor = 0.9 | |
self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer) | |
self.sentinel_ids = set(self.sentinel_to_id.values()) | |
self.pad_id = text_tokenizer.token_to_id("[PAD]") | |
self.eos_id = text_tokenizer.token_to_id("[EOS]") | |
def input_image(self, tensor: torch.Tensor, num_tokens: int): | |
"""Applies masking for an image given as input | |
Args: | |
tensor: Image tensor | |
num_tokens: Number of tokens in the tensor | |
Returns: | |
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask | |
""" | |
# Input mask | |
input_mask = torch.zeros(num_tokens, dtype=torch.bool) | |
# Target mask | |
target_mask = torch.ones(num_tokens, dtype=torch.bool) | |
# Decoder attention mask | |
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int) | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask} | |
def target_image(self, tensor: torch.Tensor, num_tokens: int): | |
"""Applies masking for an image given as target | |
Args: | |
tensor: Image tensor | |
num_tokens: Number of tokens in the tensor | |
Returns: | |
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask | |
""" | |
# Input mask | |
input_mask = torch.ones(num_tokens, dtype=torch.bool) | |
# Target mask | |
target_mask = torch.zeros(num_tokens, dtype=torch.bool) | |
# Decoder attention mask | |
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int) | |
decoder_attention_mask[0] = num_tokens | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask} | |
def input_sequence(self, sequence_str: str, max_tokens: int): | |
"""Applies masking for a sequence given as input | |
Args: | |
sequence_str: Sequence string | |
max_tokens: Maximum number of tokens in the sequence | |
Returns: | |
Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask | |
""" | |
# Tokenize the text and get the ids | |
seq_ids = self.text_tokenizer.encode(sequence_str).ids | |
# Add EOS to all sequences | |
seq_ids.append(self.eos_id) | |
# Truncate sequence | |
seq_ids = seq_ids[:max_tokens] | |
keep_prob = 1. | |
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob) | |
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target | |
max_length = (max_tokens + 1) * 2 | |
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id | |
input_mask = torch.ones(max_length, dtype=torch.bool) | |
target_mask = torch.ones(max_length, dtype=torch.bool) | |
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int) | |
# Set input and input mask | |
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int) | |
input_mask[:len(input_seq_ids)] = 0 | |
tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int) | |
target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0 | |
decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1 | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask} | |
def target_sequence(self, sequence_str: str, max_tokens: int): | |
"""Applies masking for a sequence given as target | |
Args: | |
sequence_str: Sequence string | |
max_tokens: Maximum number of tokens in the sequence | |
Returns: | |
Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask | |
""" | |
# Tokenize the text and get the ids | |
seq_ids = self.text_tokenizer.encode(sequence_str).ids | |
# Add EOS to all sequences | |
seq_ids.append(self.eos_id) | |
# Truncate sequence | |
seq_ids = seq_ids[:max_tokens] | |
keep_prob = 0. | |
input_seq_ids = [] | |
_, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob) | |
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target | |
max_length = (max_tokens + 1) * 2 | |
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id | |
input_mask = torch.ones(max_length, dtype=torch.bool) | |
target_mask = torch.ones(max_length, dtype=torch.bool) | |
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int) | |
# Set input and input mask | |
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int) | |
input_mask[:len(input_seq_ids)] = 0 | |
tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int) | |
target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0 | |
decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1 | |
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, | |
"decoder_attention_mask": decoder_attention_mask} | |
def __call__(self, mod_dict): | |
"""Applies input and target masking to a dictionary of modalities | |
Args: | |
mod_dict: Dictionary of modalities | |
Returns: | |
Dictionary containing the masked modalities | |
""" | |
masked_mod_dict = {} | |
for mod_name, mod_info in self.modality_info.items(): | |
mod_type = mod_info['type'] | |
if mod_type == 'img' and mod_name in self.input_modalities: | |
masked_mod_dict[mod_name] = self.input_image(mod_dict[mod_name], mod_info['max_tokens']) | |
elif mod_type == 'img' and mod_name in self.target_modalities: | |
masked_mod_dict[mod_name] = self.target_image(mod_dict[mod_name], mod_info['max_tokens']) | |
elif mod_type == 'seq' and mod_name in self.input_modalities: | |
masked_mod_dict[mod_name] = self.input_sequence(mod_dict[mod_name], mod_info['max_tokens']) | |
elif mod_type == 'seq' and mod_name in self.target_modalities: | |
masked_mod_dict[mod_name] = self.target_sequence(mod_dict[mod_name], mod_info['max_tokens']) | |
else: | |
raise ValueError(f"Invalid modality type: {mod_type} or modality name not in input or target modalities: {mod_name}") | |
if 'mask_valid' in mod_dict: | |
masked_mod_dict['mask_valid'] = mod_dict['mask_valid'] | |
return masked_mod_dict |