aroraaman's picture
Add all of `fourm`
3424266
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from tokenizers import Tokenizer
from torch.distributions import Dirichlet
from fourm.data.modality_transforms import get_transform_key
from fourm.utils import to_2tuple
from fourm.utils.tokenizer import get_sentinel_to_id_mapping
def sample_cosine(min_val: float = 0, max_val: float =1) -> float:
"""Sample a value from a cosine distribution between min_val and max_val
Args:
min_val: Minimum value
max_val: Maximum value
Returns:
Sampled value
"""
return min_val + 0.5 * (max_val - min_val) * (1 + math.cos(math.pi * random.uniform(0, 1)))
def sample_uniform(min_val: float = 0, max_val: float =1) -> float:
"""Sample a value from a uniform distribution between min_val and max_val
Args:
min_val: Minimum value
max_val: Maximum value
Returns:
Sampled value
"""
return random.uniform(min_val, max_val)
def simple_span_masking(sequence: List[int], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
"""Span masking for a sequence
Args:
sequence: Sequence to mask
sentinel_to_id: Mapping from sentinel to id
keep_prob: Probability of keeping a token
Returns:
Masked input sequence and masked target sequence
"""
sequence_length = len(sequence)
# 0 for keep, 1 for mask
masks = torch.where(torch.rand(sequence_length) <= keep_prob, 0, 1).bool().tolist()
input_sequence = []
target_sequence = []
prev_mask = False
sentinel_count = 0
for token, mask in zip(sequence, masks):
if mask:
if not prev_mask:
sentinel_count += 1
input_sequence.append(sentinel_to_id[sentinel_count])
target_sequence.append(sentinel_to_id[sentinel_count])
prev_mask = True
target_sequence.append(token)
else:
prev_mask = False
input_sequence.append(token)
target_sequence.append(sentinel_to_id[sentinel_count + 1])
return input_sequence, target_sequence
def chunk_span_masking(sequence_chunks: List[List[int]], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
"""Span masking where masking is performed at the chunk level.
Args:
sequence_chunks: Sequence chunks to mask
sentinel_to_id: Mapping from sentinel to id
keep_prob: Probability of keeping a token
Returns:
Masked input sequence and masked target sequence
"""
chunk_length = len(sequence_chunks)
# 0 for keep, 1 for mask
masks = torch.where(torch.rand(chunk_length) <= keep_prob, 0, 1).bool().tolist()
input_sequence = []
target_sequence = []
prev_mask = False
sentinel_count = 0
for chunk, mask in zip(sequence_chunks, masks):
if mask:
if not prev_mask:
sentinel_count += 1
input_sequence.append(sentinel_to_id[sentinel_count])
target_sequence.append(sentinel_to_id[sentinel_count])
prev_mask = True
target_sequence.extend(chunk)
else:
prev_mask = False
input_sequence.extend(chunk)
target_sequence.append(sentinel_to_id[sentinel_count + 1])
return input_sequence, target_sequence
class UnifiedMasking(object):
def __init__(self,
modality_info: Dict,
text_tokenizer: Optional[Tokenizer],
input_tokens_range: Union[int, Tuple[int, int]],
target_tokens_range: Optional[Union[int, Tuple[int, int]]],
max_tries: int = 100,
sampling_weights: Optional[List[float]] = None,):
"""Performs masking on a dict of modalities (both image based and sequence based modalities)
Args:
modality_info: Dict with the modalities and their corresponding information
text_tokenizer: Tokenizer to use for text modalities
input_tokens_range: Range of number of tokens to mask in the input
target_tokens_range: Range of number of tokens to mask in the target
max_tries: Maximum number of tries to find a valid token budgets
sampling_weights: Sampling weights for the mixture of Dirichlet distributions
"""
self.input_tokens_range = to_2tuple(input_tokens_range)
self.target_tokens_range = to_2tuple(target_tokens_range) if target_tokens_range is not None else None
self.modality_info = modality_info
self.num_modalities = len(modality_info)
self.max_tries = max_tries
self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])
# Dirichlet sampling (supports a mixture of multiple Dirichlet distributions)
eps = 1e-9
input_alphas = torch.tensor([mod["input_alphas"] for mod in modality_info.values()])
input_alphas = rearrange(input_alphas, "nmod nmix -> nmix nmod")
self.input_dirichlets = [Dirichlet(torch.clamp(input_alpha, min=eps)) for input_alpha in input_alphas]
target_alphas = torch.tensor([mod["target_alphas"] for mod in modality_info.values()])
target_alphas = rearrange(target_alphas, "nmod nmix -> nmix nmod")
self.target_dirichlets = [Dirichlet(torch.clamp(target_alpha, min=eps)) for target_alpha in target_alphas]
assert(len(self.input_dirichlets) == len(self.target_dirichlets))
self.num_dirichlets = len(self.input_dirichlets)
if sampling_weights is not None:
assert len(sampling_weights) == self.num_dirichlets
self.sampling_weights = torch.tensor(sampling_weights)
else:
self.sampling_weights = None
self.text_tokenizer = text_tokenizer
self.keep_prob_decay_factor = 0.9
self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
self.sentinel_ids = set(self.sentinel_to_id.values())
self.pad_id = text_tokenizer.token_to_id("[PAD]")
self.eos_id = text_tokenizer.token_to_id("[EOS]")
def input_token_budget(self, num_input_tokens, dir_idx=0):
"""Sample a token budget for the input
Args:
num_input_tokens: Number of tokens in the input
Returns:
Token budget for the input
"""
# Get the number of tokens for each modality
for i in range(self.max_tries):
input_token_budget = (self.input_dirichlets[dir_idx].sample() * num_input_tokens).floor().int()
diff = num_input_tokens - input_token_budget.sum()
# Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
# This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
input_token_budget += torch.bincount(self.input_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(input_token_budget))
# If token budget is over max tokens for a given modality, set it to max
input_token_budget = torch.clamp(input_token_budget, max=self.max_tokens)
if (input_token_budget >= self.min_tokens).all():
return input_token_budget.tolist()
print(f"More than max tries for input!")
return input_token_budget.tolist()
def target_token_budget(self, input_token_budget, num_target_tokens, dir_idx=0):
"""Sample a token budget for the target
Args:
input_token_budget: Token budget for the input
num_target_tokens: Number of tokens in the target
Returns:
Token budget for the target
"""
# We don't reduce the number of tokens for sequence based tasks
max_tokens_remaining = torch.where(self.mod_is_img, self.max_tokens - torch.tensor(input_token_budget), self.max_tokens)
max_tokens_remaining = torch.max(self.min_tokens, max_tokens_remaining)
for i in range(self.max_tries):
target_token_budget = (self.target_dirichlets[dir_idx].sample() * num_target_tokens).floor().int()
diff = num_target_tokens - target_token_budget.sum()
# Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
# This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
target_token_budget += torch.bincount(self.target_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(target_token_budget))
# If token budget is over max tokens for a given modality, set it to max
target_token_budget = torch.clamp(target_token_budget, max=max_tokens_remaining)
if (target_token_budget >= self.min_tokens).all():
return target_token_budget.tolist()
print(f"More than max tries for target!")
return target_token_budget.tolist()
def image_mask(self, tensor: torch.Tensor, num_tokens: int, input_budget: int, target_budget: int):
"""Applies input and target masking to an image tensor
Args:
tensor: Image tensor
num_tokens: Number of tokens in the tensor
input_budget: Token budget for the input
target_budget: Token budget for the target
Returns:
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
"""
noise = torch.rand(num_tokens)
ids_shuffle = torch.argsort(noise, dim=0)
input_mask = torch.ones(num_tokens, dtype=torch.bool)
input_mask[:input_budget] = 0
input_mask = torch.gather(input_mask, dim=0, index=ids_shuffle)
if target_budget is None:
target_mask = ~input_mask
else:
target_mask = torch.ones(num_tokens, dtype=torch.bool)
target_mask[input_budget:input_budget + target_budget] = 0
target_mask = torch.gather(target_mask, dim=0, index=ids_shuffle)
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
first_mask_token = torch.argmin(target_mask + torch.arange(target_mask.shape[0], device=target_mask.device) * 1e-6)
decoder_attention_mask[first_mask_token] = (~target_mask).sum() # Equiv. to target budget
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
def sequence_token_mask(self, sequence_ids: str, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str, vocab_offset: int):
"""Applies input and target masking to a sequence of tokens (e.g. DINOv2 global tokens)
The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.
Args:
sequence_ids: Sequence ids
max_tokens: Maximum number of tokens in the sequence
input_budget: Token budget for the input
target_budget: Token budget for the target
keep_scheme: Scheme for sampling the keep probability
vocab_offset: Offset to avoid overlap with sentinel tokens
Returns:
Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
"""
seq_ids = sequence_ids
seq_ids = seq_ids + vocab_offset # Avoid overlap with sentinel tokens (needs to be substracted after decoding)
# If input budget is 0, treat it as if the whole sequence is completely masked
if input_budget == 0:
keep_prob = 0.
input_seq_ids = []
_, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
else:
if keep_scheme == 'random':
keep_prob = sample_uniform(0, 1)
elif keep_scheme == 'all':
keep_prob = 1.0
elif keep_scheme == 'binary':
keep_prob = random.choice([0., 1.])
else:
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
# Keep lowering the keep_prob while we are over-budget
while len(input_seq_ids) > input_budget:
keep_prob = keep_prob * self.keep_prob_decay_factor
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
max_length = (max_tokens + 1) * 2
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
input_mask = torch.ones(max_length, dtype=torch.bool)
target_mask = torch.ones(max_length, dtype=torch.bool)
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
# Set input and input mask
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
input_mask[:len(input_seq_ids)] = 0
if target_budget is None or len(target_seq_ids) <= target_budget:
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
else:
# Randomly choose sentinel token.
sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
# If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
# If length starting at this token g.t. budget, truncate until budget is reached
if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
# Otherwise, select earliest sentinel token such that we don't go over budget
# Note: We could also use the randomly chosen sentinel token, but that would waste budget
else:
for idx in sentinel_indices:
if len(target_seq_ids) - idx <= target_budget:
target_seq_ids = target_seq_ids[idx:]
break
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
def sequence_mask(self, sequence: Union[str, List[str]], max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
"""Applies input and target masking to a sequence
The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.
Args:
sequence: Sequence, can be either a str or list of strings
max_tokens: Maximum number of tokens in the sequence
input_budget: Token budget for the input
target_budget: Token budget for the target
keep_scheme: Scheme for sampling the keep probability
Returns:
Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
"""
if isinstance(sequence, str):
# Tokenize the sequence and get the ids
seq_ids: List[int] = self.text_tokenizer.encode(sequence).ids
# Add EOS to all sequences
seq_ids.append(self.eos_id)
# Truncate sequence
seq_ids = seq_ids[:max_tokens]
# Use default span masking
span_masking_fn = simple_span_masking
elif isinstance(sequence, list):
# Tokenize the sequence chunks and get the ids
encoded_seq_chunks = self.text_tokenizer.encode_batch(sequence)
seq_ids: List[List[int]] = [seq.ids for seq in encoded_seq_chunks]
# Add EOS as an extra chunk
seq_ids.append([self.eos_id])
# Truncate sequence to keep all chunks below max token length
cumulative_token_count = np.cumsum(np.array([len(chunk) for chunk in seq_ids]))
seq_ids = [chunk for (chunk, token_count) in zip(seq_ids, cumulative_token_count) if token_count <= max_tokens]
# Span mask over chunks
span_masking_fn = chunk_span_masking
else:
raise ValueError(f"Invalid sequence: {sequence}")
# If input budget is 0, treat it as if the whole sequence is completely masked
if input_budget == 0:
keep_prob = 0.
input_seq_ids = []
_, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
else:
if keep_scheme == 'random':
keep_prob = sample_uniform(0, 1)
elif keep_scheme == 'all':
keep_prob = 1.0
elif keep_scheme == 'binary':
keep_prob = random.choice([0., 1.])
else:
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
# Keep lowering the keep_prob while we are over-budget
while len(input_seq_ids) > input_budget:
keep_prob = keep_prob * self.keep_prob_decay_factor
input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
max_length = (max_tokens + 1) * 2
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
input_mask = torch.ones(max_length, dtype=torch.bool)
target_mask = torch.ones(max_length, dtype=torch.bool)
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
# Set input and input mask
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
input_mask[:len(input_seq_ids)] = 0
if target_budget is None or len(target_seq_ids) <= target_budget:
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
else:
# Randomly choose sentinel token.
sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
# If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
# If length starting at this token g.t. budget, truncate until budget is reached
if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
# Otherwise, select earliest sentinel token such that we don't go over budget
# Note: We could also use the randomly chosen sentinel token, but that would waste budget
else:
for idx in sentinel_indices:
if len(target_seq_ids) - idx <= target_budget:
target_seq_ids = target_seq_ids[idx:]
break
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
def sequence_emb_mask_span(self, emb_tensor: torch.Tensor, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
"""Applies input masking to an sequence embedding tensor, target masking is not supported with sequence embeddings
Args:
emb_tensor: Sequence embedding tensor
max_tokens: Maximum number of tokens in the sequence
input_budget: Token budget for the input
target_budget: Token budget for the target (unused for now)
keep_scheme: Scheme for sampling the keep probability
Returns:
Dictionary containing the masked sequence embedding tensor, the input mask, the target mask, and the decoder attention mask
"""
# Only supported as input modality now
# Make fake seq ids for sequence embeddings to reuse simple_span_masking function
fake_seq_ids = []
emb_dict = {}
id_num = len(self.sentinel_ids)
emb_ind = 0
while(len(fake_seq_ids) < len(emb_tensor)):
if id_num not in self.sentinel_ids: # replace with T5 sentinel_id
fake_seq_ids.append(id_num)
emb_dict[id_num] = emb_tensor[emb_ind, :]
emb_ind += 1
id_num += 1
# Truncate sequence
fake_seq_ids = fake_seq_ids[:max_tokens]
# If input budget is 0, treat it as if the whole sequence is completely masked
if input_budget == 0:
keep_prob = 0.
fake_input_seq_ids = []
_, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
else:
if keep_scheme == 'random':
keep_prob = sample_uniform(0, 1)
elif keep_scheme == 'all':
keep_prob = 1.0
elif keep_scheme == 'binary':
keep_prob = random.choice([0., 1.])
else:
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
# Keep lowering the keep_prob while we are over-budget
while len(fake_input_seq_ids) > input_budget:
keep_prob = keep_prob * self.keep_prob_decay_factor
fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
# Span masking can add up to max_tokens tokens for input
max_length = max_tokens
tensor = torch.zeros((max_length, emb_tensor.shape[1]), dtype=torch.float32)
input_mask = torch.ones(max_length, dtype=torch.bool)
target_mask = torch.ones(max_length, dtype=torch.bool)
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
# Put tensor values back based on the fake seq ids
for i_, fake_id in enumerate(fake_input_seq_ids):
if fake_id in self.sentinel_ids:
tensor[i_, :] = torch.zeros_like(emb_tensor[0,:]) # TODO replace to learned embeddings later
else:
tensor[i_, :] = emb_dict[fake_id]
# Set input and input mask
input_mask[:len(fake_input_seq_ids)] = 0
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
def __call__(self, mod_dict):
"""Applies input and target masking to a dictionary of modalities
Args:
mod_dict: Dictionary of modalities
Returns:
Dictionary containing the masked modalities
"""
if self.sampling_weights is not None:
# Sample masking scheme according to a list of weights
dir_idx = torch.multinomial(self.sampling_weights, 1).item()
else:
# Randomly sample masking scheme
dir_idx = random.randint(0, self.num_dirichlets - 1)
num_input_tokens = random.randint(*self.input_tokens_range)
num_target_tokens = random.randint(*self.target_tokens_range) if self.target_tokens_range is not None else None
input_token_budget = self.input_token_budget(num_input_tokens, dir_idx)
if num_target_tokens is not None:
target_token_budget = self.target_token_budget(input_token_budget, num_target_tokens, dir_idx)
else:
target_token_budget = [None] * self.num_modalities
masked_mod_dict = {}
for (mod_name, mod_info), input_budget, target_budget in zip(self.modality_info.items(), input_token_budget, target_token_budget):
mod_type = mod_info['type']
mod_name_load = mod_name if mod_name in mod_dict else get_transform_key(mod_name)
if mod_type == 'img':
masked_mod_dict[mod_name] = self.image_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget)
elif mod_type == 'seq':
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
masked_mod_dict[mod_name] = self.sequence_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
elif mod_type == 'seq_token':
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
vocab_offset = mod_info.get('vocab_offset', 0) # Check if any space is allocated to sentinel tokens and other special tokens
masked_mod_dict[mod_name] = self.sequence_token_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme, vocab_offset=vocab_offset)
elif mod_type == "seq_emb":
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
masked_mod_dict[mod_name] = self.sequence_emb_mask_span(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
else:
raise ValueError(f"Invalid modality type: {mod_type}")
return masked_mod_dict
class TransferMasking(object):
def __init__(self,
modality_info: Dict,
text_tokenizer: Optional[Tokenizer],
input_modalities: List[str],
target_modalities: List[str]):
"""Performs masking for transfer on a dict of modalities (both image based and sequence based modalities),
by specifying which modalities are inputs and which are targets.
Args:
modality_info: Dict with the modalities and their corresponding information
text_tokenizer: Tokenizer to use for text modalities
input_modalities: List of modalities to use as input
target_modalities: List of modalities to use as target
"""
self.modality_info = modality_info
self.num_modalities = len(modality_info)
self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])
self.input_modalities = set(input_modalities)
self.target_modalities = set(target_modalities)
# Tokenizer for text modalities
self.text_tokenizer = text_tokenizer
if self.text_tokenizer is not None:
self.keep_prob_decay_factor = 0.9
self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
self.sentinel_ids = set(self.sentinel_to_id.values())
self.pad_id = text_tokenizer.token_to_id("[PAD]")
self.eos_id = text_tokenizer.token_to_id("[EOS]")
def input_image(self, tensor: torch.Tensor, num_tokens: int):
"""Applies masking for an image given as input
Args:
tensor: Image tensor
num_tokens: Number of tokens in the tensor
Returns:
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
"""
# Input mask
input_mask = torch.zeros(num_tokens, dtype=torch.bool)
# Target mask
target_mask = torch.ones(num_tokens, dtype=torch.bool)
# Decoder attention mask
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
def target_image(self, tensor: torch.Tensor, num_tokens: int):
"""Applies masking for an image given as target
Args:
tensor: Image tensor
num_tokens: Number of tokens in the tensor
Returns:
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
"""
# Input mask
input_mask = torch.ones(num_tokens, dtype=torch.bool)
# Target mask
target_mask = torch.zeros(num_tokens, dtype=torch.bool)
# Decoder attention mask
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
decoder_attention_mask[0] = num_tokens
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
def input_sequence(self, sequence_str: str, max_tokens: int):
"""Applies masking for a sequence given as input
Args:
sequence_str: Sequence string
max_tokens: Maximum number of tokens in the sequence
Returns:
Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
"""
# Tokenize the text and get the ids
seq_ids = self.text_tokenizer.encode(sequence_str).ids
# Add EOS to all sequences
seq_ids.append(self.eos_id)
# Truncate sequence
seq_ids = seq_ids[:max_tokens]
keep_prob = 1.
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
max_length = (max_tokens + 1) * 2
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
input_mask = torch.ones(max_length, dtype=torch.bool)
target_mask = torch.ones(max_length, dtype=torch.bool)
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
# Set input and input mask
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
input_mask[:len(input_seq_ids)] = 0
tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
def target_sequence(self, sequence_str: str, max_tokens: int):
"""Applies masking for a sequence given as target
Args:
sequence_str: Sequence string
max_tokens: Maximum number of tokens in the sequence
Returns:
Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
"""
# Tokenize the text and get the ids
seq_ids = self.text_tokenizer.encode(sequence_str).ids
# Add EOS to all sequences
seq_ids.append(self.eos_id)
# Truncate sequence
seq_ids = seq_ids[:max_tokens]
keep_prob = 0.
input_seq_ids = []
_, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
max_length = (max_tokens + 1) * 2
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
input_mask = torch.ones(max_length, dtype=torch.bool)
target_mask = torch.ones(max_length, dtype=torch.bool)
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
# Set input and input mask
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
input_mask[:len(input_seq_ids)] = 0
tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask,
"decoder_attention_mask": decoder_attention_mask}
def __call__(self, mod_dict):
"""Applies input and target masking to a dictionary of modalities
Args:
mod_dict: Dictionary of modalities
Returns:
Dictionary containing the masked modalities
"""
masked_mod_dict = {}
for mod_name, mod_info in self.modality_info.items():
mod_type = mod_info['type']
if mod_type == 'img' and mod_name in self.input_modalities:
masked_mod_dict[mod_name] = self.input_image(mod_dict[mod_name], mod_info['max_tokens'])
elif mod_type == 'img' and mod_name in self.target_modalities:
masked_mod_dict[mod_name] = self.target_image(mod_dict[mod_name], mod_info['max_tokens'])
elif mod_type == 'seq' and mod_name in self.input_modalities:
masked_mod_dict[mod_name] = self.input_sequence(mod_dict[mod_name], mod_info['max_tokens'])
elif mod_type == 'seq' and mod_name in self.target_modalities:
masked_mod_dict[mod_name] = self.target_sequence(mod_dict[mod_name], mod_info['max_tokens'])
else:
raise ValueError(f"Invalid modality type: {mod_type} or modality name not in input or target modalities: {mod_name}")
if 'mask_valid' in mod_dict:
masked_mod_dict['mask_valid'] = mod_dict['mask_valid']
return masked_mod_dict