Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from collections import defaultdict | |
from typing import Union, List, Optional | |
import numpy as np | |
import torch | |
from einops import rearrange, repeat | |
from torch import nn | |
import torch.nn.functional as F | |
from fourm.utils import get_sentinel_to_id_mapping, merge_span_masking | |
from fourm.utils.generation import cosine_schedule, linear_schedule, onex_temp_schedule, linear_temp_schedule, continue_schedule | |
from tqdm import tqdm | |
import copy | |
def empty_img_modality(mod_dict, key): | |
# Input mask | |
mod_dict[key]['input_mask'][:] = True | |
# Target Mask | |
mod_dict[key]['target_mask'][:] = False | |
return mod_dict | |
def empty_seq_modality(mod_dict, key, s1_id=5): | |
# To create an empty sequence, we suppose an input budget of 1, and the rest assigned to targets | |
# Input tensor | |
# Input is [S_1], target is [S_1] ...... [S_2] | |
# (so [S_1] [S_1] ..... [S_2] when combined) | |
mod_dict[key]['tensor'][:] = 0 | |
mod_dict[key]['tensor'][:,[0,1]] = s1_id # s1_id is id of the first sentinel token ([S_1]) | |
mod_dict[key]['tensor'][:,-1] = s1_id + 1 | |
# Input mask | |
# Set first token to input (i.e. 0), rest to target (i.e. 1) | |
mod_dict[key]['input_mask'][:] = True | |
mod_dict[key]['input_mask'][:,0] = False | |
# Target Mask | |
mod_dict[key]['target_mask'] = ~mod_dict[key]['input_mask'] | |
# Decoder attn mask | |
# WARNING: Not needed / used in GenerationSampler, where causal mask is enforced | |
# First token is input, not part of target | |
mod_dict[key]['decoder_attention_mask'][:] = 1 | |
mod_dict[key]['decoder_attention_mask'][:, 0] = 0 | |
return mod_dict | |
def empty_seq_emb_modality(mod_dict, key): | |
# Tensor | |
mod_dict[key]['tensor'] = torch.zeros_like(mod_dict[key]['tensor']) | |
# Input mask | |
mod_dict[key]['input_mask'] = torch.ones_like(mod_dict[key]['input_mask']) | |
# It is crucial to specify the input mask as such, CFG won't work otherwise! | |
mod_dict[key]['input_mask'][:, 0] = False | |
# Target Mask | |
mod_dict[key]['target_mask'] = torch.ones_like(mod_dict[key]['target_mask']) | |
# Decoder attn mask | |
mod_dict[key]['decoder_attention_mask'][:] = False | |
return mod_dict | |
def init_empty_target_modality(mod_dict, modality_info, domain, batch_size, num_tokens, device): | |
""" | |
Initializes an empty target modality dictionary for a given domain. | |
Used to initialize target modality dictionaries for generation. | |
""" | |
if modality_info[domain]['type'] == 'img': | |
# Initialize mod dict | |
mod_dict[domain] = { | |
'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int64, device=device), | |
'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device), | |
'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device), | |
} | |
# Set it to the correct values | |
mod_dict = empty_img_modality(mod_dict, domain) | |
elif modality_info[domain]['type'] in ['seq', 'seq_token', 'seq_emb']: | |
# Initialize mod dict | |
num_tokens = max(num_tokens, 2) | |
mod_dict[domain] = { | |
'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int32, device=device), | |
'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device), | |
'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device), | |
'decoder_attention_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device), | |
} | |
# Set it to the correct values | |
if modality_info[domain]['type'] in ['seq', 'seq_token']: | |
mod_dict = empty_seq_modality(mod_dict, domain) | |
elif modality_info[domain]['type'] == 'seq_emb': | |
mod_dict = empty_seq_emb_modality(mod_dict, domain) | |
else: | |
raise ValueError() | |
return mod_dict | |
def init_full_input_modality(mod_dict, modality_info, domain, device, eos_id=3): | |
if domain.startswith('rgb'): | |
batch_size, _, H, W = mod_dict[domain]['tensor'].shape | |
patch_size = modality_info[domain]['patch_size'] | |
num_tokens = (H // patch_size) * (W // patch_size) | |
shape = (batch_size, num_tokens) | |
else: | |
shape = mod_dict[domain]['tensor'].shape | |
if 'input_mask' not in mod_dict[domain]: | |
mod_dict[domain]['input_mask'] = torch.zeros(shape, dtype=torch.bool, device=device) | |
if 'target_mask' not in mod_dict[domain]: | |
mod_dict[domain]['target_mask'] = torch.ones(shape, dtype=torch.bool, device=device) | |
if 'decoder_attention_mask' not in mod_dict[domain]: | |
mod_dict[domain]['decoder_attention_mask'] = torch.zeros(shape, dtype=torch.bool, device=device) | |
if modality_info[domain]['type'] == 'img': | |
mod_dict[domain]['input_mask'][:] = False | |
mod_dict[domain]['target_mask'][:] = True | |
elif modality_info[domain]['type'] in ['seq', 'seq_token']: | |
if eos_id in mod_dict[domain]['tensor']: | |
eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item() | |
else: | |
mod_dict[domain]['tensor'][:,0] = eos_id | |
eos_idx = 0 | |
mod_dict[domain]['input_mask'][:,:eos_idx+1] = False | |
mod_dict[domain]['input_mask'][:,eos_idx+1:] = True | |
mod_dict[domain]['target_mask'][:] = True | |
elif modality_info[domain]['type'] in ['seq_emb']: | |
# T5 caption has the valid mask saved alongside the embeddings | |
mod_dict[domain]['input_mask'] = ~mod_dict[domain]['mask_valid'] | |
mod_dict[domain]['target_mask'] = torch.ones_like(mod_dict[domain]['mask_valid']) | |
mod_dict[domain]['decoder_attention_mask'] = torch.zeros_like(mod_dict[domain]['mask_valid']) | |
return mod_dict | |
def custom_text(sample, input_text, eos_token, key, device, text_tokenizer, target_max_len=50, start_token="[S_1]"): | |
input_ids = text_tokenizer.encode(input_text).ids | |
input_ids = torch.tensor(input_ids).unsqueeze(0) | |
target_text = [start_token] | |
target_text.extend(["[PAD]"] * (target_max_len - 2)) | |
target_text.append(eos_token) | |
target_text = " ".join(target_text) | |
target_ids = text_tokenizer.encode(target_text).ids | |
target_ids = torch.tensor(target_ids).unsqueeze(0) | |
all_ids = torch.cat([input_ids, target_ids], dim=1) | |
input_mask = torch.cat([ | |
torch.zeros_like(input_ids, dtype=torch.bool), | |
torch.ones_like(target_ids, dtype=torch.bool), | |
], dim=1) | |
target_mask = torch.cat([ | |
torch.ones_like(input_ids, dtype=torch.bool), | |
torch.zeros_like(target_ids, dtype=torch.bool), | |
], dim=1) | |
sample[key] = {} | |
sample[key]['tensor'] = all_ids.to(device) | |
sample[key]['input_mask'] = input_mask.to(device) | |
sample[key]['target_mask'] = target_mask.to(device) | |
sample[key]['decoder_attention_mask'] = torch.zeros(all_ids.shape, dtype=torch.bool, device=device) | |
return sample | |
def expand_to_batch(mod_dict, batch_size): | |
for mod, d in mod_dict.items(): | |
for k, v in d.items(): | |
if k in ['tensor', 'input_mask', 'target_mask', 'decoder_attention_mask', 'mask_valid']: | |
B = v.shape[0] | |
if B == 1: | |
mod_dict[mod][k] = repeat(v, "1 ... -> b ...", b=batch_size) | |
elif B != batch_size: | |
raise ValueError(f"Invalid batch size: {B} instead of {batch_size}") | |
return mod_dict | |
def build_chained_generation_schedules( | |
cond_domains: List[str], | |
target_domains: List[str], | |
tokens_per_target: List[int], | |
autoregression_schemes: List[str], | |
decoding_steps: List[int], | |
token_decoding_schedules: List[str], | |
temps: List[float], | |
temp_schedules: List[float], | |
cfg_scales: List[float], | |
cfg_schedules: List[str], | |
cfg_grow_conditioning: bool = False, | |
modality_info: Optional[dict] = None, | |
): | |
""" | |
Builds a list of chained generation schedules, where each schedule is a tuple of the form: | |
(target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains) | |
Args: | |
cond_domains: List of conditioning domains | |
target_domains: List of target domains | |
tokens_per_target: List of number of tokens to decode for each target domain | |
autoregression_schemes: List of autoregression schemes for each target domain. maskgit, roar, or autoregressive | |
decoding_steps: List of number of maskgit steps for each target domain (if applicable) | |
token_decoding_schedules: List of maskgit token schedules for each target domain (if applicable). cosine or linear | |
temps: List of starting temperatures for each target domain | |
temp_schedules: List of temperature schedules for each target domain. linear, constant, or onex:{min_t}:{power} | |
cfg_scales: List of classifier-free guidance scales for each target domain | |
cfg_schedules: List of classifier-free guidance schedules for each target domain. constant or cosine | |
cfg_grow_conditioning: After every completed modality, add them to classifier-free guidance conditioning | |
modality_info: Dictionary with metadata for each modality, optionally used to verify that the schedule is compatible with the modality | |
""" | |
# List of {target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains} dicts | |
chained_schedules = [] | |
cond_domains = cond_domains.copy() | |
for target_idx in range(len(target_domains)): | |
scheme = autoregression_schemes[target_idx] | |
target_domain = target_domains[target_idx] | |
ntoks = tokens_per_target[target_idx] | |
maskgit_token_schedule_name = token_decoding_schedules[target_idx] | |
temp = temps[target_idx] | |
temp_schedule_name = temp_schedules[target_idx] | |
cfg_scale = cfg_scales[target_idx] | |
cfg_schedule_name = cfg_schedules[target_idx] | |
# Auto-regressive (caption, detection, ...) | |
if scheme == 'autoregressive': | |
chained_schedules.append({ | |
'target_domain': target_domain, | |
'scheme': scheme, | |
'num_tokens': None, | |
'temperature': temp, | |
'cfg_scale': cfg_scale, | |
'cfg_cond_domains': cond_domains.copy() | |
}) | |
continue | |
# Use modality info for (optional) assert if provided | |
if modality_info is not None: | |
assert modality_info[target_domain]['type'] not in ['seq', 'seq_token'], f'Illegal autoregressive scheme {scheme} for target domain {target_domain}' | |
# Token schedule | |
if scheme == 'maskgit': | |
# MaskGIT token schedule setup | |
num_steps = decoding_steps[target_idx] | |
if maskgit_token_schedule_name == 'cosine': | |
token_schedule = cosine_schedule(num_steps, (ntoks)) | |
elif maskgit_token_schedule_name == 'linear': | |
token_schedule = linear_schedule(num_steps, (ntoks)) | |
else: | |
raise ValueError(f'Illegal MaskGIT token schedule {maskgit_token_schedule_name}') | |
elif scheme == 'roar': | |
# ROAR token schedule setup (one-by-one, but random order) | |
num_steps = decoding_steps[target_idx] | |
token_schedule = linear_schedule(num_steps, ntoks) | |
else: | |
raise ValueError(f'Illegal decoding scheme {scheme}') | |
# Temperature schedule | |
if temp_schedule_name == 'linear': | |
temp_schedule = linear_temp_schedule(temp, token_schedule) | |
elif temp_schedule_name == 'constant': | |
temp_schedule = temp * np.ones(num_steps) | |
elif 'onex' in temp_schedule_name: | |
# onex temperature schedule has to be formatted like onex:{min_t}:{power} | |
min_t, power = [float(f) for f in temp_schedule_name.split(':')[1:]] | |
temp_schedule = onex_temp_schedule(max_t=temp, min_t=min_t, token_schedule=token_schedule, power=power) | |
else: | |
raise ValueError(f'Illegal temperature schedule {temp_schedule_name}') | |
# Classifier-free guidance scale schedule | |
if cfg_schedule_name == 'constant': | |
if isinstance(cfg_scale, float): | |
cfg_schedule = cfg_scale * np.ones(num_steps) | |
elif isinstance(cfg_scale, list): | |
cfg_schedule = np.array(cfg_scale) * np.ones(num_steps).reshape(-1, 1) | |
elif cfg_schedule_name == 'cosine': | |
raise NotImplementedError() | |
else: | |
raise ValueError(f'Illegal guidance schedule {cfg_schedule_name}') | |
# Concatenate schedule for this modality with previous ones | |
schedule = [ | |
{ | |
'target_domain': target_domain, | |
'scheme': scheme, | |
'num_tokens': tok, | |
'temperature': temp, | |
'cfg_scale': cfg, | |
'cfg_cond_domains': cond_domains.copy() | |
} | |
for tok, temp, cfg in zip(token_schedule, temp_schedule, cfg_schedule) | |
] | |
chained_schedules.extend(schedule) | |
# Optionally add this new modality to the ones affected by classifier-free guidance | |
if cfg_grow_conditioning: | |
cond_domains.append(target_domain) | |
return chained_schedules | |
class GenerationSampler(nn.Module): | |
"""Sampler that wraps a trained 4M model for generation use cases. | |
Implements standard autoregressive, MaskGIT, and ROAR generation schemes with chaining and weighted guidance.""" | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
def top_k_top_p_filtering(self, logits, top_k=0.0, top_p=0.0): | |
# Compatible with batching | |
# From https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
if top_k > 0.0: | |
if isinstance(top_k, int): | |
k = min(top_k, logits.shape[-1]) | |
elif isinstance(top_k, float): | |
k = min(int(top_k * logits.shape[-1]), logits.shape[-1]) | |
else: | |
raise ValueError(f"Invalid value for top_k: {top_k}") | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, k)[0][..., -1, None] | |
logits[indices_to_remove] = float("-inf") | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, dim=1, descending=True) | |
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cum_probs > top_p | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
restore_indices = torch.argsort(sorted_indices, dim=-1) | |
indices_to_remove = torch.gather(sorted_indices_to_remove, dim=-1, index=restore_indices) | |
logits[indices_to_remove] = float("-inf") | |
return logits | |
def sample_tokens(self, logits, temperature=1.0, top_k=0.0, top_p=0.0): | |
if np.isclose(temperature, 0, atol=1e-10): | |
samples = torch.argmax(logits, dim=-1) | |
# Since argmax is used, all sampled_probs will be 1 as we're selecting the max probability | |
sampled_probs = torch.ones_like(samples, dtype=torch.float32) | |
else: | |
filtered_logits = self.top_k_top_p_filtering(logits, top_k, top_p) | |
probs = F.softmax(filtered_logits / temperature, dim=-1) | |
samples = torch.multinomial(probs, 1)[:, 0] | |
sampled_probs = probs[torch.arange(len(samples)), samples] | |
return samples, sampled_probs | |
def sample_tokens_batched(self, logits, temperature=1.0, top_k=0.0, top_p=0.0): | |
if logits.ndim > 2: | |
B, N = logits.shape[0], logits.shape[1] | |
logits = rearrange(logits, 'b n v -> (b n) v') | |
samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p) | |
samples = rearrange(samples, '(b n) -> b n', b=B, n=N) | |
sampled_probs = rearrange(sampled_probs, '(b n) -> b n', b=B, n=N) | |
return samples, sampled_probs | |
else: | |
return self.sample_tokens(logits, temperature, top_k, top_p) | |
def select_tokens(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False): | |
samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p) | |
top_indices = torch.topk(sampled_probs, num_select)[1] | |
top_samples = samples[top_indices] | |
if return_all_samples: | |
return top_samples, top_indices, samples | |
else: | |
return top_samples, top_indices | |
def select_tokens_batched(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False): | |
if logits.ndim > 2: | |
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k, top_p) # both of shape (B, N) | |
top_indices = torch.topk(sampled_probs, num_select, dim=-1)[1] | |
# Need to switch to gather instead of indexing here | |
top_samples = torch.gather(samples, dim=-1, index=top_indices) | |
if return_all_samples: | |
return top_samples, top_indices, samples | |
else: | |
return top_samples, top_indices | |
else: | |
return self.sample_tokens(logits, num_select, temperature, top_k, top_p, return_all_samples) | |
def forward_mask_encoder_generation(self, encoder_mod_dict): | |
"""Modification of forward_mask_encoder adapted for generation, with support for batching | |
""" | |
# Form input | |
B = list(encoder_mod_dict.values())[0]['tensor'].shape[0] | |
encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.model.cat_encoder_tensors(encoder_mod_dict) | |
# Take max num encoder of tokens (although assuming it's the same everywhere would be better) | |
num_encoder_tokens = (~encoder_mask_all.reshape(B, -1)).sum(dim=1).max() | |
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way | |
mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6 | |
ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1) | |
# ids_restore = torch.argsort(ids_shuffle, dim=1) | |
ids_keep = ids_shuffle[:, :num_encoder_tokens] | |
encoder_tokens = torch.gather(encoder_tokens_all, dim=1, | |
index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2])) | |
encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) | |
encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep) | |
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) | |
if self.model.num_register_tokens > 0: | |
prompt_tokens = repeat(self.prompt_tokens, '() n d -> b n d', b=B) | |
# We add prompt tokens at the beginning of the sequence | |
encoder_tokens = torch.cat([prompt_tokens, encoder_tokens], dim=1) | |
encoder_emb = torch.cat([torch.zeros_like(prompt_tokens), encoder_emb], dim=1) | |
encoder_mask = torch.cat([torch.zeros((B, prompt_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1) | |
mod_mask = torch.cat([torch.full((B, prompt_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1) | |
encoder_tokens[encoder_mask] = 0. | |
encoder_emb[encoder_mask] = 0. | |
mod_mask[encoder_mask] = -1 | |
# Mask could be of shape 'b n1 n2' but not needed for masked_fill | |
# This means this mask can then be re-used for decoder cross-attention | |
encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2') | |
return encoder_tokens, encoder_emb, encoder_mask, mod_mask | |
def forward_mask_decoder_maskgit(self, mod_dict, target_mod, seed=None): | |
"""Modification of forward_mask_decoder for MaskGIT generation, with support for batching | |
""" | |
if seed is not None: | |
torch.manual_seed(seed) | |
d = mod_dict[target_mod] | |
decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token | |
emb_all = d['emb'] | |
decoder_mask_all = d['target_mask'] | |
B = decoder_tokens_all.shape[0] # Get batch size | |
mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16) | |
mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0) | |
mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching | |
num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching / Assumes num_decoder_tokens is the same across the batch | |
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way | |
mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 | |
ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1) | |
# ids_restore = torch.argsort(ids_shuffle, dim=1) | |
ids_keep = ids_shuffle[:, :num_decoder_tokens] | |
decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2])) | |
decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) | |
decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) | |
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) | |
mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep) | |
decoder_tokens[decoder_mask] = 0. | |
decoder_emb[decoder_mask] = 0. | |
mod_mask[decoder_mask] = -1 | |
return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos | |
def forward_mask_decoder_roar(self, mod_dict, target_mod, num_select, seed=None): | |
"""Modification of forward_mask_decoder for ROAR generation, with support for batching | |
""" | |
if seed is not None: | |
torch.manual_seed(seed) | |
d = mod_dict[target_mod] | |
decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token | |
emb_all = d['emb'] | |
decoder_mask_all = d['target_mask'] | |
B = decoder_tokens_all.shape[0] # Get batch size | |
mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16) | |
mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0) | |
mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching | |
# Only keep the first num_select tokens | |
num_decoder_tokens = min(num_select, (~decoder_mask_all[0]).sum()) # Adapted for batching / Assumes num_decoder_tokens is the same across the batch | |
# Add a small random number to the mask so they get sorted in a random way, but keeping the masked tokens first | |
mask_rand = torch.rand(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 | |
ids_shuffle = torch.argsort(decoder_mask_all + mask_rand, dim=1) | |
# ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# Only keep the first num_select_tokens | |
ids_keep = ids_shuffle[:, :num_decoder_tokens] | |
decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2])) | |
decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) | |
decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) | |
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) | |
mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep) | |
decoder_tokens[decoder_mask] = 0. | |
decoder_emb[decoder_mask] = 0. | |
mod_mask[decoder_mask] = -1 | |
return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos | |
def forward_mask_decoder_autoregressive(self, mod_dict, target_mod, seed=None): | |
# Adapted for batching | |
if seed is not None: | |
torch.manual_seed(seed) | |
# This is the concatenation part | |
d = mod_dict[target_mod] | |
decoder_ids_all = d['ids'] | |
emb_all = d['emb'] | |
decoder_mask_all = d['target_mask'] | |
B = decoder_ids_all.shape[0] # Get batch size | |
mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16) | |
mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0) | |
mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) | |
num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching, but assumes num_decoder_tokens is the same across the batch | |
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way | |
mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6 | |
ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1) | |
# ids_restore = torch.argsort(ids_shuffle, dim=1) | |
ids_keep = ids_shuffle[:, :num_decoder_tokens] | |
# Same as in forward_mask_decoder | |
decoder_ids = torch.gather(decoder_ids_all, dim=1, index=ids_keep) | |
decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2])) | |
decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep) | |
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep) | |
mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep) | |
decoder_ids[decoder_mask] = 0 | |
decoder_emb[decoder_mask] = 0. | |
mod_mask[decoder_mask] = -1 | |
return decoder_ids, decoder_emb, decoder_mask, mod_mask, mod_pos | |
def merge_sequences(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"): | |
device = mod_dict[target_mod]['tensor'].device | |
# Get input ids | |
input_ids = mod_dict[target_mod]['tensor'].squeeze().detach().cpu() | |
input_ids = input_ids[mod_dict[target_mod]['input_mask'].squeeze().detach().cpu() == 0] | |
input_ids = input_ids.tolist() | |
if len(input_ids) == 0: | |
input_ids = [text_tokenizer.get_vocab()[default_sentinel]] | |
# Get predicted ids | |
pred_ids = pred_ids.squeeze().detach().cpu().tolist() | |
if isinstance(pred_ids, int): | |
pred_ids = [pred_ids] | |
# Get sentinel ids using the tokenizer | |
sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values()) | |
# Perform merging | |
merged_ids = merge_span_masking(input_ids, pred_ids, sentinel_ids) | |
merged_ids = torch.tensor(merged_ids).unsqueeze(0) | |
# Create new dict | |
new_input_mask = torch.zeros_like(merged_ids, dtype=torch.bool) | |
new_target_mask = torch.ones_like(merged_ids, dtype=torch.bool) | |
new_dict = {'tensor': merged_ids.to(device), | |
'input_mask': new_input_mask.to(device), | |
'target_mask': new_target_mask.to(device)} | |
new_dict['decoder_attention_mask'] = torch.zeros_like(new_target_mask, dtype=torch.bool) | |
mod_dict[target_mod] = new_dict | |
return mod_dict | |
def merge_sequences_batched(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"): | |
# Unbatches and calls merge sequence per batch, then regroups it into a batch | |
pad_id = text_tokenizer.token_to_id("[PAD]") | |
B = mod_dict[target_mod]['tensor'].shape[0] | |
device = mod_dict[target_mod]['tensor'].device | |
tensors = torch.split(mod_dict[target_mod]['tensor'], 1) | |
input_masks = torch.split(mod_dict[target_mod]['input_mask'], 1) | |
pred_ids = torch.split(pred_ids, 1) | |
input_dicts = [] | |
for t, im in zip(tensors, input_masks): | |
d = {target_mod: {'tensor': t, 'input_mask': im}} | |
input_dicts.append(d) | |
merged_tensors = [] | |
merged_input_masks = [] | |
merged_target_masks = [] | |
merged_seq_lens = [] | |
for input_d, pi in zip(input_dicts, pred_ids): | |
# Output of merge_sequences is mod_dict with modified target mod | |
merged_d = self.merge_sequences(input_d, pi, target_mod, text_tokenizer, default_sentinel)[target_mod] | |
merged_tensors.append(merged_d['tensor']) | |
merged_input_masks.append(merged_d['input_mask']) | |
merged_target_masks.append(merged_d['input_mask']) | |
merged_seq_lens.append(merged_d['tensor'].shape[1]) | |
max_seq_len = max(merged_seq_lens) | |
for i in range(len(merged_tensors)): | |
# Right pad all tensors | |
p1d = (0, max_seq_len - merged_seq_lens[i]) | |
merged_tensors[i] = F.pad(merged_tensors[i], p1d, "constant",pad_id) | |
merged_input_masks[i] = F.pad(merged_input_masks[i], p1d, "constant", True) | |
merged_target_masks[i] = F.pad(merged_target_masks[i], p1d, "constant", True) | |
new_dict = {'tensor': torch.cat(merged_tensors, dim=0).to(device), | |
'input_mask': torch.cat(merged_input_masks, dim=0).to(device), | |
'target_mask': torch.cat(merged_target_masks, dim=0).to(device)} | |
new_dict['decoder_attention_mask'] = torch.zeros_like(new_dict['target_mask'], dtype=torch.bool) | |
mod_dict[target_mod] = new_dict | |
return mod_dict | |
def forward_enc_dec_maskgit_batched(self, mod_dict, target_mod, seed=None): | |
# Encoder | |
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) | |
for mod, d in mod_dict.items() | |
if mod in self.model.encoder_embeddings} | |
encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) | |
x = encoder_tokens + encoder_emb | |
x = self.model.forward_encoder(x, encoder_mask) | |
# Decoder | |
context = self.model.decoder_proj_context(x) + encoder_emb | |
decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} | |
decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_maskgit(decoder_mod_dict, target_mod, seed=seed) | |
y = decoder_tokens + decoder_emb | |
y = self.model.forward_decoder(y, context, encoder_mask, None) | |
B, N, D = y.shape | |
logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod] | |
logits = logits.reshape(B, N, -1) | |
return logits, mod_pos | |
def maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None): | |
logits, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed) | |
# MaskGIT sampling | |
top_samples, top_indices = self.select_tokens_batched(logits, num_select, | |
temperature=temperature, top_k=top_k, top_p=top_p) | |
# Update mod dict | |
# We rely on gather / scatter for batched operations | |
top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select) | |
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples) | |
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) | |
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) | |
return mod_dict | |
def guided_maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, | |
conditioning=[], guidance_scale=1.0, seed=None, write_all_predictions=False): | |
### 1 - First pass, with conditioning | |
logits_cond, _ = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed) | |
### 2 - Second pass, without conditioning | |
mod_dict_uncond = copy.deepcopy(mod_dict) | |
for mod in conditioning: | |
if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']: | |
mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod) | |
elif self.model.modality_info[mod]['type'] in ['seq_emb']: | |
mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod) | |
else: | |
mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod) | |
logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict_uncond, target_mod, seed=seed) | |
### 3 - Classifier-free guidance | |
logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale | |
### 4 - MaskGIT sampling | |
top_samples, top_indices, all_samples = self.select_tokens_batched( | |
logits, num_select, | |
temperature=temperature, top_k=top_k, top_p=top_p, | |
return_all_samples=True | |
) | |
### 5 - Update mod dict | |
# We rely on gather / scatter for batched operations | |
top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select) | |
if write_all_predictions: | |
mod_dict[target_mod]['tensor'][:, mod_pos] = all_samples | |
else: | |
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples) | |
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) | |
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) | |
return mod_dict | |
def multi_guided_maskgit_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod, num_select, | |
temperature, top_k, top_p, seed=None, write_all_predictions=False): | |
### 1 - Conditional forward passes (one for each guided condition) | |
logits_cond_all = [] | |
for cond_dict in cond_dicts: | |
logits_cond_i, _ = self.forward_enc_dec_maskgit_batched(cond_dict, target_mod, seed=seed) | |
logits_cond_all.append(logits_cond_i) | |
### 2 - Unconditional forward pass | |
logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(uncond_dict, target_mod, seed=seed) | |
### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)} | |
# See https://arxiv.org/abs/2206.01714 | |
logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0) | |
### 4 - MaskGIT sampling | |
top_samples, top_indices, all_samples = self.select_tokens_batched( | |
logits, num_select, | |
temperature=temperature, top_k=top_k, top_p=top_p, | |
return_all_samples=True | |
) | |
### 5 - Update mod dict with newly generated tokens | |
# We rely on gather / scatter for batched operations | |
top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select) | |
if write_all_predictions: | |
uncond_dict[target_mod]['tensor'][:, mod_pos] = all_samples | |
else: | |
uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, top_pos, top_samples) | |
uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) | |
uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) | |
# Update conditioning dicts | |
for i in range(len(cond_dicts)): | |
cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, top_pos, top_samples) | |
cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool)) | |
cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool)) | |
return uncond_dict, cond_dicts | |
def forward_enc_dec_roar_batched(self, mod_dict, target_mod, num_select, seed=None): | |
# Encoder | |
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) | |
for mod, d in mod_dict.items() | |
if mod in self.model.encoder_embeddings} | |
encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) | |
x = encoder_tokens + encoder_emb | |
x = self.model.forward_encoder(x, encoder_mask) | |
# Decoder | |
context = self.model.decoder_proj_context(x) + encoder_emb | |
decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} | |
decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_roar(decoder_mod_dict, target_mod, num_select, seed=seed) | |
y = decoder_tokens + decoder_emb | |
y = self.model.forward_decoder(y, context, encoder_mask, None) | |
B, N, D = y.shape | |
logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod] | |
logits = logits.reshape(B, N, -1) | |
return logits, mod_pos | |
def roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None): | |
"""ROAR = Random Order Autoregression""" | |
logits, mod_pos = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed) | |
# Simple sampling | |
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p) | |
# Update mod dict | |
# We rely on scatter for batched operations | |
select_pos = mod_pos | |
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples) | |
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) | |
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) | |
return mod_dict | |
def guided_roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, | |
conditioning=[], guidance_scale=1.0, seed=None): | |
"""ROAR = Random Order Autoregression""" | |
### 1 - First pass, with conditioning | |
logits_cond, _ = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed) | |
### 2 - Second pass, without conditioning | |
mod_dict_uncond = copy.deepcopy(mod_dict) | |
for mod in conditioning: | |
if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']: | |
mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod) | |
elif self.model.modality_info[mod]['type'] in ['seq_emb']: | |
mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod) | |
else: | |
mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod) | |
logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(mod_dict_uncond, target_mod, num_select, seed=seed) | |
### 3 - Classifier-free guidance | |
logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale | |
### 4 - Simple sampling | |
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p) | |
### 5 - Update mod dict | |
# We rely on gather / scatter for batched operations | |
select_pos = mod_pos | |
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples) | |
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) | |
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) | |
return mod_dict | |
def multi_guided_roar_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod, | |
num_select, temperature, top_k, top_p, seed=None): | |
### 1 - Conditional forward passes (one for each guided condition) | |
logits_cond_all = [] | |
for cond_dict in cond_dicts: | |
logits_cond_i, _ = self.forward_enc_dec_roar_batched(cond_dict, target_mod, num_select, seed=seed) | |
logits_cond_all.append(logits_cond_i) | |
### 2 - Unconditional forward pass | |
logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(uncond_dict, target_mod, num_select, seed=seed) | |
### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)} | |
# See https://arxiv.org/abs/2206.01714 | |
logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0) | |
### 4 - Simple sampling | |
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p) | |
### 5 - Update mod dict | |
# We rely on gather / scatter for batched operations | |
select_pos = mod_pos | |
uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, select_pos, samples) | |
uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) | |
uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) | |
# Update conditioning dicts | |
for i in range(len(cond_dicts)): | |
cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, select_pos, samples) | |
cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool)) | |
cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool)) | |
return uncond_dict, cond_dicts | |
def autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float, | |
use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None, seed=None): | |
# Encoder | |
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) | |
for mod, d in mod_dict.items() | |
if mod in self.model.encoder_embeddings} | |
encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) | |
x = encoder_tokens + encoder_emb | |
x = self.model.forward_encoder(x, encoder_mask) # B, N, D | |
# Get batch size | |
B = x.shape[0] | |
# Decoder | |
context = self.model.decoder_proj_context(x) + encoder_emb | |
decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} | |
decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict, target_mod, seed=seed) | |
device = decoder_ids.device | |
seq_len = self.model.modality_info[target_mod]['max_tokens'] | |
if use_eos and eos_token is None: | |
# The eos_token is the final sentinel token provided | |
eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all | |
if use_eos: | |
eos_token = eos_token.to(device) | |
# If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token) | |
out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device) | |
# Set decoder_tokens to None, we do not use them for decoding | |
decoder_ids = None | |
# If all samples of the batch have eos, return early | |
if use_eos and (out == eos_token).any(dim=-1).all(): | |
return out | |
y_emb = decoder_emb[:, :seq_len] | |
seq_len = y_emb.shape[1] | |
# Auto-regressive decoding and sampling | |
for i in range(seq_len): | |
cur_len = out.shape[1] | |
# Convert ids into word embeddings and add corresponding posembs + modemb | |
y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len] | |
# Build causal mask | |
causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1) | |
causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B) | |
y = self.model.forward_decoder(y, context, encoder_mask, causal_mask) | |
logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask[:, :cur_len])[target_mod] | |
logits = rearrange(logits, "(b n) d -> b n d", b=B, n=cur_len) | |
last_logits = logits[:, -1] | |
# Sample token for the newly generated logit | |
if np.isclose(temperature, 0, atol=1e-10): | |
sample = torch.argmax(last_logits, dim=-1, keepdim=True) | |
else: | |
filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p) | |
probs = F.softmax(filtered_logits / temperature, dim=-1) | |
sample = torch.multinomial(probs, 1) | |
out = torch.cat((out, sample), dim=-1) | |
if use_eos and (out == eos_token).any(dim=-1).all(): | |
break | |
mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer) | |
return mod_dict | |
def guided_autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float, | |
use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None, | |
conditioning=[], guidance_scale=1.0, seed=None): | |
### 1 - Encoder forward pass, with conditioning | |
# Encoder | |
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) | |
for mod, d in mod_dict.items() | |
if mod in self.model.encoder_embeddings} | |
encoder_tokens, encoder_emb, encoder_mask_cond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) | |
x = encoder_tokens + encoder_emb | |
x = self.model.forward_encoder(x, encoder_mask_cond) # B, N, D | |
# Get batch size | |
B = x.shape[0] | |
# Decoder | |
context_cond = self.model.decoder_proj_context(x) + encoder_emb | |
decoder_mod_dict_cond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} | |
decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_cond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_cond, target_mod, seed=seed) | |
device = decoder_ids.device | |
seq_len = self.model.modality_info[target_mod]['max_tokens'] | |
### 2 - Encoder forward pass, without conditioning | |
mod_dict_uncond = copy.deepcopy(mod_dict) | |
for mod in conditioning: | |
if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']: | |
mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod) | |
elif self.model.modality_info[mod]['type'] in ['seq_emb']: | |
mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod) | |
else: | |
mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod) | |
# Encoder | |
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d) | |
for mod, d in mod_dict_uncond.items() | |
if mod in self.model.encoder_embeddings} | |
encoder_tokens, encoder_emb, encoder_mask_uncond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict) | |
x = encoder_tokens + encoder_emb | |
x = self.model.forward_encoder(x, encoder_mask_uncond) # B, N, D | |
# Decoder | |
context_uncond = self.model.decoder_proj_context(x) + encoder_emb | |
decoder_mod_dict_uncond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])} | |
decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_uncond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_uncond, target_mod, seed=seed) | |
if use_eos and eos_token is None: | |
# The eos_token is the final sentinel token provided | |
eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all | |
if use_eos: | |
eos_token = eos_token.to(device) | |
# If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token) | |
out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device) | |
# Set decoder_tokens to None, we do not use them for decoding | |
decoder_ids = None | |
# If all samples of the batch have eos, return early | |
if use_eos and (out == eos_token).any(dim=-1).all(): | |
return out | |
y_emb = decoder_emb[:, :seq_len] | |
seq_len = y_emb.shape[1] | |
### 3 - Auto-regressive decoding and sampling | |
for i in range(seq_len): | |
cur_len = out.shape[1] | |
# Convert ids into word embeddings and add corresponding posembs + modemb | |
y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len] | |
# Build causal mask | |
causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1) | |
causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B) | |
### 3a - Decoder forward pass, with conditioning | |
y_cond = self.model.forward_decoder(y, context_cond, encoder_mask_cond, causal_mask) | |
logits_cond = self.model.forward_logits(y_cond, decoder_mod_dict_cond, decoder_mod_mask_cond[:, :cur_len])[target_mod] | |
logits_cond = rearrange(logits_cond, "(b n) d -> b n d", b=B, n=cur_len) | |
last_logits_cond = logits_cond[:, -1] | |
### 3b - Decoder forward pass, without conditioning | |
y_uncond = self.model.forward_decoder(y, context_uncond, encoder_mask_uncond, causal_mask) | |
logits_uncond = self.model.forward_logits(y_uncond, decoder_mod_dict_uncond, decoder_mod_mask_uncond[:, :cur_len])[target_mod] | |
logits_uncond = rearrange(logits_uncond, "(b n) d -> b n d", b=B, n=cur_len) | |
last_logits_uncond = logits_uncond[:, -1] | |
### 3c - Classifier-free guidance | |
last_logits = last_logits_uncond + (last_logits_cond - last_logits_uncond) * guidance_scale | |
# Sample token for the newly generated logit | |
if np.isclose(temperature, 0, atol=1e-10): | |
sample = torch.argmax(last_logits, dim=-1, keepdim=True) | |
else: | |
filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p) | |
probs = F.softmax(filtered_logits / temperature, dim=-1) | |
sample = torch.multinomial(probs, 1) | |
out = torch.cat((out, sample), dim=-1) | |
if use_eos and (out == eos_token).any(dim=-1).all(): | |
break | |
mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer) | |
return mod_dict | |
def generate(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None): | |
""" Generates a sequence of tokens from the input modalities. | |
:param mod_dict: Dictionary of modalities. | |
:param schedule: Schedule of modalities to use. | |
List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}. | |
:param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering). | |
:param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering). | |
:param text_tokenizer: Text tokenizer. | |
:param verbose: Whether to print progress. | |
:param seed: Random seed. | |
:return: Generated mod dict. | |
""" | |
# Input embedding -> tokenizes the modalities - Many are placeholder for now | |
mod_dict = copy.deepcopy(mod_dict) | |
for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose): | |
target_mod = schedule_step_info['target_domain'] | |
temp = schedule_step_info['temperature'] | |
cfg_scale = schedule_step_info.get('cfg_scale', 1.0) | |
cfg_conditioning = schedule_step_info.get('cfg_cond_domains', []) | |
seed_i = seed + step if seed is not None else None | |
if self.model.modality_info[target_mod]['type'] == 'img': | |
scheme = schedule_step_info['scheme'] | |
num_select = schedule_step_info['num_tokens'] | |
if scheme.lower() == 'maskgit': | |
if cfg_scale == 1.0 or len(cfg_conditioning) == 0: | |
mod_dict = self.maskgit_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, | |
top_k=top_k, top_p=top_p, seed=seed_i | |
) | |
else: | |
mod_dict = self.guided_maskgit_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, | |
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i | |
) | |
elif scheme.lower() == 'roar': | |
if cfg_scale == 1.0 or len(cfg_conditioning) == 0: | |
mod_dict = self.roar_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, | |
top_k=top_k, top_p=top_p, seed=seed_i | |
) | |
else: | |
mod_dict = self.guided_roar_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, | |
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i | |
) | |
else: | |
raise ValueError("Invalid sampling scheme") | |
elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']: | |
if cfg_scale == 1.0 or len(cfg_conditioning) == 0: | |
mod_dict = self.autoregressive_step_batched( | |
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, | |
text_tokenizer=text_tokenizer, seed=seed_i | |
) | |
else: | |
mod_dict = self.guided_autoregressive_step_batched( | |
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, | |
text_tokenizer=text_tokenizer, conditioning=cfg_conditioning, | |
guidance_scale=cfg_scale, seed=seed_i | |
) | |
else: | |
raise ValueError("Invalid schedule") | |
return mod_dict | |
def generate_iter(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None): | |
""" Iterator that generates a sequence of tokens from the input modalities step by step. | |
:param mod_dict: Dictionary of modalities. | |
:param schedule: Schedule of modalities to use. | |
List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}. | |
:param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering). | |
:param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering). | |
:param text_tokenizer: Text tokenizer. | |
:param verbose: Whether to print progress. | |
:param seed: Random seed. | |
:return: Iterator of generated mod dict. | |
""" | |
# Input embedding -> tokenizes the modalities - Many are placeholder for now | |
mod_dict = copy.deepcopy(mod_dict) | |
for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose): | |
target_mod = schedule_step_info['target_domain'] | |
temp = schedule_step_info['temperature'] | |
cfg_scale = schedule_step_info.get('cfg_scale', 1.0) | |
cfg_conditioning = schedule_step_info.get('cfg_cond_domains', []) | |
seed_i = seed + step if seed is not None else None | |
if self.model.modality_info[target_mod]['type'] == 'img': | |
scheme = schedule_step_info['scheme'] | |
num_select = schedule_step_info['num_tokens'] | |
if scheme.lower() == 'maskgit': | |
if cfg_scale == 1.0 or len(cfg_conditioning) == 0: | |
mod_dict = self.maskgit_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, | |
top_k=top_k, top_p=top_p, seed=seed_i | |
) | |
else: | |
mod_dict = self.guided_maskgit_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, | |
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i, | |
write_all_predictions=True | |
) | |
elif scheme.lower() == 'roar': | |
if cfg_scale == 1.0 or len(cfg_conditioning) == 0: | |
mod_dict = self.roar_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, | |
top_k=top_k, top_p=top_p, seed=seed_i | |
) | |
else: | |
mod_dict = self.guided_roar_step_batched( | |
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p, | |
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i | |
) | |
else: | |
raise ValueError("Invalid sampling scheme") | |
elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']: | |
if cfg_scale == 1.0 or len(cfg_conditioning) == 0: | |
mod_dict = self.autoregressive_step_batched( | |
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, | |
text_tokenizer=text_tokenizer, seed=seed_i | |
) | |
else: | |
mod_dict = self.guided_autoregressive_step_batched( | |
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p, | |
text_tokenizer=text_tokenizer, conditioning=cfg_conditioning, | |
guidance_scale=cfg_scale, seed=seed_i | |
) | |
else: | |
raise ValueError("Invalid schedule") | |
yield mod_dict | |
def generate_multi_guided(self, uncond_dict, cond_dicts, schedule, top_k=0.0, top_p=0.0, | |
text_tokenizer=None, verbose=False, seed=None): | |
# Generation function for multiple weighted conditions | |
# To detect when a modality has finished generating, we keep track of the current target modality | |
cur_target_mod = schedule[0]['target_domain'] | |
uncond_dict = copy.deepcopy(uncond_dict) | |
cond_dicts = copy.deepcopy(cond_dicts) | |
# Add the to-be-generated modality to the conditional dicts | |
for i in range(len(cond_dicts)): | |
cond_dicts[i][cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod]) | |
for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose): | |
target_mod = schedule_step_info['target_domain'] | |
temp = schedule_step_info['temperature'] | |
num_select = schedule_step_info['num_tokens'] | |
cond_weights = schedule_step_info['cfg_scale'] | |
# Once a modality is fully generated, add it as a new condition | |
if cur_target_mod != target_mod: | |
for i in range(len(cond_dicts)): | |
# Remove the previously generated modality from the conditionings | |
del cond_dicts[i][cur_target_mod] | |
# Add the next modality to be generated to the conditionings | |
cond_dicts[i][target_mod] = copy.deepcopy(uncond_dict[target_mod]) | |
# Remove the fully generated modality from the unconditional dict inputs | |
uncond_dict[cur_target_mod]['input_mask'][:] = True | |
# Add the previously generated modality as an additional condition | |
new_cond = {} | |
new_cond[cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod]) | |
new_cond[cur_target_mod]['input_mask'][:] = False | |
new_cond[cur_target_mod]['target_mask'][:] = True | |
new_cond[target_mod] = copy.deepcopy(uncond_dict[target_mod]) | |
cond_dicts.append(new_cond) | |
cur_target_mod = target_mod | |
if self.model.modality_info[target_mod]['type'] == 'img': | |
scheme = schedule_step_info['scheme'] | |
if scheme.lower() == 'maskgit': | |
uncond_dict, cond_dicts = self.multi_guided_maskgit_step_batched( | |
uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed | |
) | |
elif scheme.lower() == 'roar': | |
uncond_dict, cond_dicts = self.multi_guided_roar_step_batched( | |
uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed | |
) | |
else: | |
raise ValueError("Invalid sampling scheme") | |
else: | |
raise NotImplementedError("Only image modalities are supported for now") | |
return uncond_dict | |
def generate_sam_dense(self, mod_dict, schedule, text_tokenizer, batch_size=16, | |
key='sam_instance', top_k=0.0, top_p=0.0, seed=None, verbose=False): | |
# Generation function for dense SAM instance prediction | |
device = mod_dict[list(mod_dict.keys())[0]]['tensor'].device | |
mod_dict = copy.deepcopy(mod_dict) | |
# Repeat the input batch to match the batch size | |
expanded_batch = expand_to_batch(copy.deepcopy(mod_dict), batch_size=batch_size) | |
# Filter the schedule to only include the key domain | |
schedule = [s for s in schedule if s['target_domain'] == key] | |
out_dict = self.generate( | |
expanded_batch, schedule, text_tokenizer=text_tokenizer, | |
verbose=verbose, seed=seed, | |
top_p=top_p, top_k=top_k, | |
) | |
# Merge the batch generated sequences into one sequence | |
sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values()) | |
merged_seq = [] | |
for i in range(batch_size): | |
input_seq = out_dict[key]['tensor'][i] | |
input_seq = input_seq[out_dict[key]['input_mask'][i] == 0] | |
input_seq = input_seq.tolist() | |
target_seq = out_dict[key]['tensor'][i] | |
target_seq = target_seq[out_dict[key]['target_mask'][i] == 0] | |
target_seq = target_seq.tolist() | |
merged_seq.extend(merge_span_masking(input_seq, target_seq, sentinel_ids=sentinel_ids)) | |
merged_seq = torch.tensor(merged_seq, device=device).unsqueeze(0) | |
mod_dict[key] = { | |
'tensor': merged_seq, | |
'input_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device), | |
'target_mask': torch.ones(merged_seq.shape, dtype=torch.bool, device=device), | |
'decoder_attention_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device), | |
} | |
return mod_dict |