Spaces:
Sleeping
Sleeping
from typing import Any | |
import einops | |
import torch.nn.functional as F | |
from torch import nn | |
from transformers.models.bert.configuration_bert import BertConfig | |
from src.model.blip import create_vit, init_tokenizer, load_checkpoint | |
from src.model.med import BertModel | |
from src.tools.utils import print_dist | |
class BLIPCirTextOnly(nn.Module): | |
def __init__( | |
self, | |
loss: Any, | |
med_config="configs/med_config.json", | |
image_size=384, | |
vit="large", | |
vit_grad_ckpt=True, | |
vit_ckpt_layer=12, | |
embed_dim=256, | |
train_vit=False, | |
): | |
""" | |
Args: | |
med_config (str): path for the mixture of encoder-decoder model's configuration file | |
image_size (int): input image size | |
vit (str): model size of vision transformer | |
""" | |
super().__init__() | |
self.loss = loss | |
self.visual_encoder, vision_width = create_vit( | |
vit, image_size, vit_grad_ckpt, vit_ckpt_layer | |
) | |
self.tokenizer = init_tokenizer() | |
med_config = BertConfig.from_json_file(med_config) | |
med_config.encoder_width = vision_width | |
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) | |
text_width = self.text_encoder.config.hidden_size | |
self.text_proj = nn.Linear(text_width, embed_dim) | |
assert train_vit == False, "train_vit must be False when using BLIPCirTextOnly" | |
self.train_vit = train_vit | |
if not self.train_vit: | |
# Do not train visual encoder | |
for p in self.visual_encoder.parameters(): | |
p.requires_grad = False | |
for name, param in self.text_encoder.named_parameters(): | |
if "crossattention" in name: | |
param.requires_grad = False | |
if "position_ids" in name: | |
param.requires_grad = False | |
self.temp = 0.07 | |
def forward(self, batch, fabric): | |
ref_img, tar_feat, caption, _ = batch | |
device = ref_img.device | |
# Encode the target image | |
tar_feat = tar_feat.to(device) | |
tar_img_feat = F.normalize(tar_feat, dim=-1) | |
text = self.tokenizer( | |
caption, | |
padding="max_length", | |
truncation=True, | |
max_length=35, | |
return_tensors="pt", | |
).to(device) | |
# Shift encoder | |
query_embs = self.text_encoder( | |
text.input_ids, | |
attention_mask=text.attention_mask, | |
return_dict=True, | |
mode="text", | |
) | |
query_feat = query_embs.last_hidden_state[:, 0, :] | |
query_feat = F.normalize(self.text_proj(query_feat), dim=-1) | |
if fabric.world_size > 1: | |
# d: devices, b: batch size, e: embedding dim | |
query_feat = fabric.all_gather(query_feat, sync_grads=True) | |
query_feat = einops.rearrange(query_feat, "d b e -> (d b) e") | |
tar_img_feat = fabric.all_gather(tar_img_feat, sync_grads=True) | |
tar_img_feat = einops.rearrange(tar_img_feat, "d b e -> (d b) e") | |
return self.loss(query_feat, tar_img_feat, self.temp) | |
def blip_cir_text(model, ckpt_path, **kwargs): | |
if ckpt_path: | |
model, msg = load_checkpoint(model, ckpt_path) | |
print_dist("missing keys:") | |
print_dist(msg.missing_keys) | |
return model | |