|
import torch |
|
from torch import nn |
|
|
|
from uniperceiver.config import configurable |
|
from .build import PREDICTOR_REGISTRY |
|
import math |
|
import pickle |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
from uniperceiver.utils import comm |
|
import torch.distributed as dist |
|
from uniperceiver.modeling.layers import FP16LayerNorm |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
|
|
__all__ = ["EmbedClsAsRetrievalPredictor"] |
|
def gelu_accurate(x): |
|
if not hasattr(gelu_accurate, "_a"): |
|
gelu_accurate._a = math.sqrt(2 / math.pi) |
|
return ( |
|
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) |
|
) |
|
|
|
|
|
def gelu(x: torch.Tensor) -> torch.Tensor: |
|
return torch.nn.functional.gelu(x.float()).type_as(x) |
|
|
|
@PREDICTOR_REGISTRY.register() |
|
class EmbedClsAsRetrievalPredictor(nn.Module): |
|
@configurable |
|
def __init__( |
|
self, |
|
temperature, |
|
use_norm, |
|
temp_learn, |
|
mb_list, |
|
queue_len, |
|
feat_dim, |
|
task2tempname, |
|
fc_prompt_feature_index, |
|
output_proj, |
|
cfg, |
|
|
|
): |
|
super(EmbedClsAsRetrievalPredictor, self).__init__() |
|
self.cfg = cfg |
|
self.use_norm = use_norm |
|
self.temp_learn = temp_learn |
|
if temp_learn: |
|
self.logit_scale_img_cls = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_video_cls = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_text_mlm = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_text_caption = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_caption = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_mlm = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_retrieve = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_tqa_mlm = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_tqa_caption = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_tqa_retrieve = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
self.logit_scale_downstream = nn.Parameter(torch.ones([]) * np.log(1 / temperature)) |
|
|
|
else: |
|
self.logit_scale_img_cls = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_video_cls = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_text_mlm = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_text_caption = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_caption = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_mlm = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_retrieve = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_tqa_mlm = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_tqa_caption = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_tqa_retrieve = torch.ones([]).cuda() * np.log(1 / temperature) |
|
self.logit_scale_downstream = torch.ones([]).cuda() * np.log(1 / temperature) |
|
|
|
|
|
self.task2tempname = task2tempname |
|
|
|
|
|
self.memory_save = [] |
|
self.queue_len = queue_len |
|
self.feat_dim = feat_dim |
|
self.fc_prompt_feature_index = fc_prompt_feature_index |
|
for task_name in mb_list: |
|
self.memory_save.append(task_name) |
|
self.register_buffer('queue_h1_{}'.format(task_name), torch.randn(queue_len, feat_dim )) |
|
self.register_buffer('queue_h2_{}'.format(task_name), torch.randn(queue_len, feat_dim )) |
|
setattr(self, 'queue_h1_{}'.format(task_name), nn.functional.normalize(getattr(self, 'queue_h1_{}'.format(task_name)), dim=1)) |
|
setattr(self, 'queue_h2_{}'.format(task_name), nn.functional.normalize(getattr(self, 'queue_h2_{}'.format(task_name)), dim=1)) |
|
|
|
self.register_buffer("queue_ptr1_{}".format(task_name), torch.zeros(1, dtype=torch.long)) |
|
self.register_buffer("queue_ptr2_{}".format(task_name), torch.zeros(1, dtype=torch.long)) |
|
|
|
pass |
|
|
|
self.output_proj = output_proj |
|
if self.output_proj: |
|
|
|
|
|
|
|
|
|
if self.cfg.SOLVER.FORCE_LN_FP16: |
|
self.ln_post = FP16LayerNorm(feat_dim) |
|
else: |
|
self.ln_post = nn.LayerNorm(feat_dim) |
|
self.proj = nn.Linear(feat_dim, feat_dim) |
|
|
|
if cfg.MODEL.FEATURE_GATHER_FORCE: |
|
assert cfg.DATALOADER.STRATEGY == 'turn' |
|
self.gather_feature = True |
|
else: |
|
self.gather_feature = len(cfg.TASKS) == 1 and getattr(cfg.MODEL, "FEATURE_GATHER", False) |
|
|
|
@torch.no_grad() |
|
def _dequeue_and_enqueue(self, q1, q2, task_name): |
|
"""Update queue.""" |
|
|
|
|
|
|
|
batch_size1 = q1.shape[0] |
|
batch_size2 = q2.shape[0] |
|
|
|
ptr1 = int(getattr(self, "queue_ptr1_{}".format(task_name))) |
|
ptr2 = int(getattr(self, "queue_ptr2_{}".format(task_name))) |
|
|
|
assert self.queue_len % batch_size1 == 0 |
|
|
|
|
|
|
|
getattr(self, 'queue_h1_{}'.format(task_name))[ptr1:ptr1+batch_size1, :] = q1 |
|
getattr(self, 'queue_h2_{}'.format(task_name))[ptr2:ptr2+batch_size2, :] = q2 |
|
|
|
ptr1 = (ptr1 + batch_size1) % self.queue_len |
|
ptr2 = (ptr2 + batch_size2) % self.queue_len |
|
|
|
|
|
getattr(self, "queue_ptr1_{}".format(task_name))[0] = ptr1 |
|
getattr(self, "queue_ptr2_{}".format(task_name))[0] = ptr2 |
|
|
|
pass |
|
|
|
def replace_weight(self, weight): |
|
pass |
|
|
|
def replace_module_hidden(self,dense, layer_norm): |
|
pass |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
|
|
mb_list = [] |
|
task2tempname = {} |
|
if len(cfg.TASKS) > 0: |
|
for task_config in cfg.TASKS: |
|
if 'MODEL' in task_config and task_config['MODEL'].get('MEMORY_BANK', False): |
|
mb_list.append(task_config['NAME']) |
|
task2tempname[task_config['NAME']] = task_config['MODEL']['TEMP_NAME'] |
|
|
|
ret = { "temperature": cfg.MODEL.PRED_TEMPERATURE, |
|
"use_norm": cfg.MODEL.PRED_USE_NORM, |
|
'temp_learn': getattr(cfg.MODEL, "LEARN_TEMP", False), |
|
'mb_list': mb_list, |
|
'queue_len': cfg.MODEL.QUEUE_LEN, |
|
'feat_dim': cfg.MODEL.ENCODER_DIM, |
|
'task2tempname': task2tempname, |
|
"fc_prompt_feature_index": cfg.MODEL.FC_PROMPT_INDEX, |
|
"output_proj": cfg.MODEL.OUTPUT_PROJ, |
|
"cfg": cfg, |
|
} |
|
print(f'********* using temperature {cfg.MODEL.PRED_TEMPERATURE} **********') |
|
|
|
return ret |
|
|
|
@classmethod |
|
def add_config(cls, cfg): |
|
pass |
|
|
|
def test_forward(self, logits): |
|
return { "output": logits } |
|
|
|
def postproj(self, hidden_states): |
|
x = self.ln_post(hidden_states) |
|
if self.proj is not None: |
|
x = self.proj(x) |
|
return x |
|
|
|
|
|
|
|
def forward(self, |
|
input_sample_list, |
|
target_sample_list, |
|
shared_target_sets, |
|
target_set_list, |
|
target_idx_list, |
|
task_info, |
|
**kwargs): |
|
|
|
if len(target_sample_list) > 0: |
|
q2_hidden_states = target_sample_list[0]['data'] |
|
else: |
|
if len(target_set_list) > 1: |
|
raise NotImplementedError('only one target supported now') |
|
target_set_name = target_set_list[0] |
|
q2_hidden_states = shared_target_sets[target_set_name][0]['data'] |
|
|
|
|
|
q1_hidden_states = input_sample_list[0]['data'] |
|
|
|
q2_hidden_states = q2_hidden_states[:, 0] |
|
|
|
|
|
task_type = task_info.get('task_type') |
|
|
|
if task_type in ['image_classification', 'video_classification']: |
|
q1_hidden_states = q1_hidden_states[:, 0] |
|
elif task_type in ['image_retrieval', 'video_retrieval']: |
|
q1_hidden_states = q1_hidden_states[:, 0] |
|
elif task_type == 'text_mlm': |
|
mask_tokens = target_idx_list[0].ne(-1) |
|
q1_hidden_states = q1_hidden_states[:, -mask_tokens.size(1):][mask_tokens] |
|
target_idx_list[0] = target_idx_list[0][mask_tokens] |
|
elif task_type in ['image_caption', 'video_caption']: |
|
if self.training: |
|
sample_info = input_sample_list[0]['sample_info'] |
|
if isinstance(sample_info, list): |
|
sample_info = sample_info[0] |
|
text_length = sample_info['data_length'][-1] // 2 |
|
q1_hidden_states = q1_hidden_states[:, -text_length:, :] |
|
mask_tokens = target_idx_list[0].ne(-1) |
|
q1_hidden_states = q1_hidden_states[mask_tokens] |
|
target_idx_list[0] = target_idx_list[0][mask_tokens] |
|
else: |
|
q1_hidden_states = q1_hidden_states[:, -1] |
|
elif task_type in ['text_classification', 'vqa']: |
|
sample_info = input_sample_list[0]['sample_info'] |
|
if isinstance(sample_info, list): |
|
sample_info = sample_info[0] |
|
|
|
sample_infos = sample_info if isinstance(sample_info, list) else sample_info['sample_info_per_sample'][-1] |
|
if 'spe_index' in sample_infos[0]: |
|
text_length = sample_info['data_length'][-1] |
|
q1_hidden_states = q1_hidden_states[:, -text_length:, :] |
|
|
|
spe_index = torch.tensor([si['spe_index'] for si in sample_infos], device=q1_hidden_states.device).view(-1, 1, 1).expand(-1, -1, q1_hidden_states.size(2)) |
|
q1_hidden_states = torch.gather(q1_hidden_states, 1, spe_index)[:, 0] |
|
else: |
|
q1_hidden_states = q1_hidden_states[:, 0] |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
if self.output_proj: |
|
|
|
q1_hidden_states = self.postproj(q1_hidden_states) |
|
q2_hidden_states = self.postproj(q2_hidden_states) |
|
|
|
feat = q1_hidden_states |
|
|
|
|
|
if len(target_sample_list) == 0: |
|
|
|
logits = self._forward(q1_hidden_states, q2_hidden_states, task_name=task_info.get("task_name", None)) |
|
|
|
|
|
ret = { "logits": [logits], "feats": [feat], "loss_names": [''] } |
|
if len(target_idx_list) > 0: |
|
ret.update({"targets": [target_idx_list[0]]}) |
|
|
|
if not self.training: |
|
ret_test = self.test_forward(logits) |
|
ret.update(ret_test) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if not self.training: |
|
return { |
|
"input_feats": q1_hidden_states / q1_hidden_states.norm(dim=-1, keepdim=True), |
|
"tgt_feats": q2_hidden_states / q2_hidden_states.norm(dim=-1, keepdim=True), |
|
} |
|
|
|
if self.gather_feature: |
|
local_q1 = q1_hidden_states |
|
local_q2 = q2_hidden_states |
|
packed_feature = torch.cat([local_q1, local_q2], dim=1).float() |
|
|
|
gathered_features = [ torch.zeros_like(packed_feature) for _ in range(comm.get_world_size())] |
|
|
|
dist.all_gather(gathered_features, packed_feature) |
|
|
|
all_features = torch.cat([packed_feature] + |
|
gathered_features[:comm.get_rank()] + |
|
gathered_features[comm.get_rank() + 1:]).to(local_q1) |
|
|
|
q1_hidden_states, q2_hidden_states = torch.split(all_features, [local_q1.size(1), local_q2.size(1)], dim=1) |
|
|
|
|
|
if task_info.get("task_name", None) in self.memory_save: |
|
|
|
logits, logits_per_cls = self._forward_with_mb( |
|
q1_hidden_states, |
|
q2_hidden_states, |
|
task_name=task_info.get("task_name", None)) |
|
else: |
|
logits, logits_per_cls = self._forward(q1_hidden_states, |
|
q2_hidden_states, |
|
task_name=task_info.get( |
|
"task_name", None), |
|
mutual_retrieval=True) |
|
|
|
target = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device) |
|
target_per_cls = target |
|
|
|
ret = { |
|
"logits": [logits, logits_per_cls], |
|
"targets": [target, target_per_cls], |
|
"loss_names": ['i2t', 't2i'], |
|
} |
|
|
|
return ret |
|
|
|
|
|
def _forward(self, g, cls_name, task_name, mutual_retrieval=False,): |
|
temperature = self.temperature_task(task_name) |
|
if self.cfg.SOLVER.FORCE_TEMP_FP16: |
|
temperature = temperature.half() |
|
if self.temp_learn and temperature > 100.0: |
|
temperature = 100.0 |
|
|
|
|
|
|
|
if self.use_norm: |
|
if not self.cfg.SOLVER.FORCE_NORM_FP16: |
|
g = g / g.norm(dim=-1, keepdim=True) |
|
cls_name = cls_name / cls_name.norm(dim=-1, keepdim=True) |
|
else: |
|
with autocast(enabled=False): |
|
g = g / g.norm(dim=-1, keepdim=True) |
|
cls_name = cls_name / cls_name.norm(dim=-1, keepdim=True) |
|
|
|
|
|
logits = (g @ cls_name.t()) * temperature |
|
|
|
if mutual_retrieval: |
|
logits_per_cls = logits.transpose(0, 1) |
|
return logits, logits_per_cls |
|
return logits |
|
|
|
def _forward_with_mb(self, g, cls_name, task_name): |
|
temperature = self.temperature_task(task_name) |
|
if self.temp_learn and temperature > 100.0: |
|
temperature = 100.0 |
|
|
|
|
|
|
|
if self.use_norm: |
|
g = g / g.norm(dim=-1, keepdim=True) |
|
cls_name = cls_name / cls_name.norm(dim=-1, keepdim=True) |
|
|
|
logits_per_image = (g @ cls_name.t()) * temperature |
|
|
|
logits_per_cls = logits_per_image.transpose(0, 1) |
|
|
|
logits_per_image_neg = (g @ getattr(self, 'queue_h1_{}'.format(task_name)).clone().detach().t()) * temperature |
|
|
|
logits_per_cls_neg = (cls_name @ getattr(self, 'queue_h2_{}'.format(task_name)).clone().detach().t()) * temperature |
|
|
|
self._dequeue_and_enqueue(cls_name, g, task_name) |
|
|
|
return torch.cat([logits_per_image, logits_per_image_neg], dim=-1) , torch.cat([logits_per_cls, logits_per_cls_neg], dim=-1) |
|
|
|
@property |
|
def temperature_dict(self): |
|
return { |
|
'temperature/img_cls': 1/self.logit_scale_img_cls.exp(), |
|
'temperature/video_cls': 1/self.logit_scale_video_cls.exp(), |
|
'temperature/text_mlm': 1/self.logit_scale_text_mlm.exp(), |
|
'temperature/text_caption': 1/self.logit_scale_text_caption.exp(), |
|
'temperature/caption': 1/self.logit_scale_caption.exp(), |
|
'temperature/mlm': 1/self.logit_scale_mlm.exp(), |
|
'temperature/retrieve': 1/self.logit_scale_retrieve.exp(), |
|
'temperature/tqa_mlm': 1/self.logit_scale_tqa_mlm.exp(), |
|
'temperature/tqa_caption': 1/self.logit_scale_tqa_caption.exp(), |
|
'temperature/tqa_retrieve': 1/self.logit_scale_tqa_retrieve.exp(), |
|
'temperature/downstream': 1/self.logit_scale_downstream.exp(), |
|
} |
|
|
|
def temperature_task(self, taskname): |
|
return getattr(self, self.task2tempname[taskname]).exp() |
|
|