Spaces:
Runtime error
Runtime error
''' | |
* Copyright (c) 2023, salesforce.com, inc. | |
* All rights reserved. | |
* SPDX-License-Identifier: BSD-3-Clause | |
* For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
* By Le Xue | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from lavis.models.ulip_models.utils import utils | |
class ULIPWithImageLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.labels = None | |
self.last_local_batch_size = None | |
def forward(self, outputs): | |
pc_embed = outputs['pc_embed'] | |
text_embed = outputs['text_embed'] | |
image_embed = outputs['image_embed'] | |
logit_scale = outputs['logit_scale'] | |
local_batch_size = pc_embed.size(0) | |
if local_batch_size != self.last_local_batch_size: | |
self.labels = local_batch_size * utils.get_rank() + torch.arange( | |
local_batch_size, device=pc_embed.device | |
) | |
self.last_local_batch_size = local_batch_size | |
# normalized features | |
pc_embed = F.normalize(pc_embed, dim=-1, p=2) | |
text_embed = F.normalize(text_embed, dim=-1, p=2) | |
image_embed = F.normalize(image_embed, dim=-1, p=2) | |
# gather features from all GPUs | |
pc_embed_all, text_embed_all, image_embed_all = \ | |
utils.all_gather_batch([pc_embed, text_embed, image_embed]) | |
# cosine similarity as logits | |
logits_per_pc_text = logit_scale * pc_embed @ text_embed_all.t() | |
logits_per_text_pc = logit_scale * text_embed @ pc_embed_all.t() | |
logits_per_pc_image = logit_scale * pc_embed @ image_embed_all.t() | |
logits_per_image_pc = logit_scale * image_embed @ pc_embed_all.t() | |
loss = (F.cross_entropy(logits_per_pc_text, self.labels) + \ | |
F.cross_entropy(logits_per_text_pc, self.labels)) / 2 + \ | |
(F.cross_entropy(logits_per_pc_image, self.labels) + F.cross_entropy(logits_per_image_pc, self.labels)) / 2 | |
# compute accuracy | |
with torch.no_grad(): | |
pred = torch.argmax(logits_per_pc_text, dim=-1) | |
correct = pred.eq(self.labels).sum() | |
pc_text_acc = 100 * correct / local_batch_size | |
pred = torch.argmax(logits_per_pc_image, dim=-1) | |
correct = pred.eq(self.labels).sum() | |
pc_image_acc = 100 * correct / local_batch_size | |
return {'loss': loss, 'ulip_loss': loss, 'ulip_pc_image_acc': pc_image_acc, 'ulip_pc_text_acc': pc_text_acc} | |