Spaces:
Running
Running
import torch | |
import torch.distributed.nn | |
from torch import distributed as dist, nn as nn | |
from torch.nn import functional as F | |
import numpy as np | |
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score | |
try: | |
import horovod.torch as hvd | |
except ImportError: | |
hvd = None | |
def gather_features( | |
audio_features, | |
text_features, | |
audio_features_mlp=None, | |
text_features_mlp=None, | |
local_loss=False, | |
gather_with_grad=False, | |
rank=0, | |
world_size=1, | |
use_horovod=False, | |
mlp_loss=False, | |
): | |
if use_horovod: | |
assert hvd is not None, "Please install horovod" | |
if gather_with_grad: | |
all_audio_features = hvd.allgather(audio_features) | |
all_text_features = hvd.allgather(text_features) | |
if mlp_loss: | |
all_audio_features_mlp = hvd.allgather(audio_features_mlp) | |
all_text_features_mlp = hvd.allgather(text_features_mlp) | |
else: | |
with torch.no_grad(): | |
all_audio_features = hvd.allgather(audio_features) | |
all_text_features = hvd.allgather(text_features) | |
if mlp_loss: | |
all_audio_features_mlp = hvd.allgather(audio_features_mlp) | |
all_text_features_mlp = hvd.allgather(text_features_mlp) | |
if not local_loss: | |
# ensure grads for local rank when all_* features don't have a gradient | |
gathered_audio_features = list( | |
all_audio_features.chunk(world_size, dim=0) | |
) | |
gathered_text_features = list( | |
all_text_features.chunk(world_size, dim=0) | |
) | |
gathered_audio_features[rank] = audio_features | |
gathered_text_features[rank] = text_features | |
all_audio_features = torch.cat(gathered_audio_features, dim=0) | |
all_text_features = torch.cat(gathered_text_features, dim=0) | |
if mlp_loss: | |
gathered_audio_features_mlp = list( | |
all_audio_features_mlp.chunk(world_size, dim=0) | |
) | |
gathered_text_features_mlp = list( | |
all_text_features_mlp.chunk(world_size, dim=0) | |
) | |
gathered_audio_features_mlp[rank] = audio_features_mlp | |
gathered_text_features_mlp[rank] = text_features_mlp | |
all_audio_features_mlp = torch.cat( | |
gathered_audio_features_mlp, dim=0 | |
) | |
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) | |
else: | |
# We gather tensors from all gpus | |
if gather_with_grad: | |
all_audio_features = torch.cat( | |
torch.distributed.nn.all_gather(audio_features), dim=0 | |
) | |
all_text_features = torch.cat( | |
torch.distributed.nn.all_gather(text_features), dim=0 | |
) | |
if mlp_loss: | |
all_audio_features_mlp = torch.cat( | |
torch.distributed.nn.all_gather(audio_features_mlp), dim=0 | |
) | |
all_text_features_mlp = torch.cat( | |
torch.distributed.nn.all_gather(text_features_mlp), dim=0 | |
) | |
else: | |
gathered_audio_features = [ | |
torch.zeros_like(audio_features) for _ in range(world_size) | |
] | |
gathered_text_features = [ | |
torch.zeros_like(text_features) for _ in range(world_size) | |
] | |
dist.all_gather(gathered_audio_features, audio_features) | |
dist.all_gather(gathered_text_features, text_features) | |
if mlp_loss: | |
gathered_audio_features_mlp = [ | |
torch.zeros_like(audio_features_mlp) for _ in range(world_size) | |
] | |
gathered_text_features_mlp = [ | |
torch.zeros_like(text_features_mlp) for _ in range(world_size) | |
] | |
dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) | |
dist.all_gather(gathered_text_features_mlp, text_features_mlp) | |
if not local_loss: | |
# ensure grads for local rank when all_* features don't have a gradient | |
gathered_audio_features[rank] = audio_features | |
gathered_text_features[rank] = text_features | |
if mlp_loss: | |
gathered_audio_features_mlp[rank] = audio_features_mlp | |
gathered_text_features_mlp[rank] = text_features_mlp | |
all_audio_features = torch.cat(gathered_audio_features, dim=0) | |
all_text_features = torch.cat(gathered_text_features, dim=0) | |
if mlp_loss: | |
all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) | |
all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) | |
if mlp_loss: | |
return ( | |
all_audio_features, | |
all_text_features, | |
all_audio_features_mlp, | |
all_text_features_mlp, | |
) | |
else: | |
return all_audio_features, all_text_features | |
class ClipLoss(nn.Module): | |
def __init__( | |
self, | |
local_loss=False, | |
gather_with_grad=False, | |
cache_labels=False, | |
rank=0, | |
world_size=1, | |
use_horovod=False, | |
mlp_loss=False, | |
weight_loss_kappa=0, | |
): | |
super().__init__() | |
self.local_loss = local_loss | |
self.gather_with_grad = gather_with_grad | |
self.cache_labels = cache_labels | |
self.rank = rank | |
self.world_size = world_size | |
self.use_horovod = use_horovod | |
self.mlp_loss = mlp_loss | |
self.weighted_loss = bool(weight_loss_kappa != 0) | |
self.weight_loss_kappa = weight_loss_kappa | |
# cache state | |
self.prev_num_logits = 0 | |
self.labels = {} | |
def forward( | |
self, | |
audio_features, | |
text_features, | |
logit_scale_a, | |
logit_scale_t=None, | |
audio_features_mlp=None, | |
text_features_mlp=None, | |
): | |
device = audio_features.device | |
if self.mlp_loss: | |
if self.world_size > 1: | |
( | |
all_audio_features, | |
all_text_features, | |
all_audio_features_mlp, | |
all_text_features_mlp, | |
) = gather_features( | |
audio_features=audio_features, | |
text_features=text_features, | |
audio_features_mlp=audio_features_mlp, | |
text_features_mlp=text_features_mlp, | |
local_loss=self.local_loss, | |
gather_with_grad=self.gather_with_grad, | |
rank=self.rank, | |
world_size=self.world_size, | |
use_horovod=self.use_horovod, | |
mlp_loss=self.mlp_loss, | |
) | |
if self.local_loss: | |
a_logits_per_audio = ( | |
logit_scale_a * audio_features @ all_text_features_mlp.T | |
) | |
a_logits_per_text = ( | |
logit_scale_a * text_features_mlp @ all_audio_features.T | |
) | |
t_logits_per_audio = ( | |
logit_scale_t * audio_features_mlp @ all_text_features.T | |
) | |
t_logits_per_text = ( | |
logit_scale_t * text_features @ all_audio_features_mlp.T | |
) | |
else: | |
a_logits_per_audio = ( | |
logit_scale_a * all_audio_features @ all_text_features_mlp.T | |
) | |
a_logits_per_text = a_logits_per_audio.T | |
t_logits_per_audio = ( | |
logit_scale_t * all_audio_features_mlp @ all_text_features.T | |
) | |
t_logits_per_text = t_logits_per_audio.T | |
else: | |
a_logits_per_audio = ( | |
logit_scale_a * audio_features @ text_features_mlp.T | |
) | |
a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T | |
t_logits_per_audio = ( | |
logit_scale_t * audio_features_mlp @ text_features.T | |
) | |
t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T | |
# calculated ground-truth and cache if enabled | |
num_logits = a_logits_per_audio.shape[0] | |
if self.prev_num_logits != num_logits or device not in self.labels: | |
labels = torch.arange(num_logits, device=device, dtype=torch.long) | |
if self.world_size > 1 and self.local_loss: | |
labels = labels + num_logits * self.rank | |
if self.cache_labels: | |
self.labels[device] = labels | |
self.prev_num_logits = num_logits | |
else: | |
labels = self.labels[device] | |
if not self.weighted_loss: | |
total_loss = ( | |
F.cross_entropy(a_logits_per_audio, labels) | |
+ F.cross_entropy(a_logits_per_text, labels) | |
+ F.cross_entropy(t_logits_per_audio, labels) | |
+ F.cross_entropy(t_logits_per_text, labels) | |
) / 4 | |
else: | |
audio_weight = (audio_features @ audio_features.T).detach() | |
audio_weight = ( | |
torch.exp( | |
torch.sum(audio_weight, axis=1) | |
/ (self.weight_loss_kappa * len(audio_weight)) | |
) | |
).detach() | |
text_weight = (text_features @ text_features.T).detach() | |
text_weight = ( | |
torch.exp( | |
torch.sum(text_weight, axis=1) | |
/ (self.weight_loss_kappa * len(text_features)) | |
) | |
).detach() | |
total_loss = ( | |
F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) | |
+ F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) | |
+ F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) | |
+ F.cross_entropy(t_logits_per_text, labels, weight=text_weight) | |
) / 4 | |
else: | |
if self.world_size > 1: | |
all_audio_features, all_text_features = gather_features( | |
audio_features=audio_features, | |
text_features=text_features, | |
local_loss=self.local_loss, | |
gather_with_grad=self.gather_with_grad, | |
rank=self.rank, | |
world_size=self.world_size, | |
use_horovod=self.use_horovod, | |
mlp_loss=self.mlp_loss, | |
) | |
if self.local_loss: | |
logits_per_audio = ( | |
logit_scale_a * audio_features @ all_text_features.T | |
) | |
logits_per_text = ( | |
logit_scale_a * text_features @ all_audio_features.T | |
) | |
else: | |
logits_per_audio = ( | |
logit_scale_a * all_audio_features @ all_text_features.T | |
) | |
logits_per_text = logits_per_audio.T | |
else: | |
logits_per_audio = logit_scale_a * audio_features @ text_features.T | |
logits_per_text = logit_scale_a * text_features @ audio_features.T | |
# calculated ground-truth and cache if enabled | |
num_logits = logits_per_audio.shape[0] | |
if self.prev_num_logits != num_logits or device not in self.labels: | |
labels = torch.arange(num_logits, device=device, dtype=torch.long) | |
if self.world_size > 1 and self.local_loss: | |
labels = labels + num_logits * self.rank | |
if self.cache_labels: | |
self.labels[device] = labels | |
self.prev_num_logits = num_logits | |
else: | |
labels = self.labels[device] | |
if not self.weighted_loss: | |
total_loss = ( | |
F.cross_entropy(logits_per_audio, labels) | |
+ F.cross_entropy(logits_per_text, labels) | |
) / 2 | |
else: | |
audio_weight = (all_audio_features @ all_audio_features.T).detach() | |
audio_weight = ( | |
torch.exp( | |
torch.sum(audio_weight, axis=1) | |
/ (self.weight_loss_kappa * len(all_audio_features)) | |
) | |
).detach() | |
text_weight = (all_text_features @ all_text_features.T).detach() | |
text_weight = ( | |
torch.exp( | |
torch.sum(text_weight, axis=1) | |
/ (self.weight_loss_kappa * len(all_text_features)) | |
) | |
).detach() | |
total_loss = ( | |
F.cross_entropy(logits_per_audio, labels, weight=text_weight) | |
+ F.cross_entropy(logits_per_text, labels, weight=audio_weight) | |
) / 2 | |
return total_loss | |
def lp_gather_features(pred, target, world_size=1, use_horovod=False): | |
if use_horovod: | |
assert hvd is not None, "Please install horovod" | |
with torch.no_grad(): | |
all_preds = hvd.allgather(pred) | |
all_targets = hvd.allgath(target) | |
else: | |
gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] | |
gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] | |
dist.all_gather(gathered_preds, pred) | |
dist.all_gather(gathered_targets, target) | |
all_preds = torch.cat(gathered_preds, dim=0) | |
all_targets = torch.cat(gathered_targets, dim=0) | |
return all_preds, all_targets | |
def get_map(pred, target): | |
pred = torch.sigmoid(pred).numpy() | |
target = target.numpy() | |
return np.mean(average_precision_score(target, pred, average=None)) | |
def get_acc(pred, target): | |
pred = torch.argmax(pred, 1).numpy() | |
target = torch.argmax(target, 1).numpy() | |
return accuracy_score(target, pred) | |
def get_mauc(pred, target): | |
pred = torch.sigmoid(pred).numpy() | |
target = target.numpy() | |
return np.mean(roc_auc_score(target, pred, average=None)) | |
class LPMetrics(object): | |
def __init__(self, metric_names=["map", "acc", "mauc"]): | |
self.metrics = [] | |
for name in metric_names: | |
self.metrics.append(self.get_metric(name)) | |
self.metric_names = metric_names | |
def get_metric(self, name): | |
if name == "map": | |
return get_map | |
elif name == "acc": | |
return get_acc | |
elif name == "mauc": | |
return get_mauc | |
else: | |
raise ValueError(f"the metric should be at least one of [map, acc, mauc]") | |
def evaluate_mertics(self, pred, target): | |
metric_dict = {} | |
for i in range(len(self.metric_names)): | |
metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) | |
return metric_dict | |
def calc_celoss(pred, target): | |
target = torch.argmax(target, 1).long() | |
return nn.CrossEntropyLoss()(pred, target) | |
class LPLoss(nn.Module): | |
def __init__(self, loss_name): | |
super().__init__() | |
if loss_name == "bce": | |
self.loss_func = nn.BCEWithLogitsLoss() | |
elif loss_name == "ce": | |
self.loss_func = calc_celoss | |
elif loss_name == "mse": | |
self.loss_func = nn.MSELoss() | |
else: | |
raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") | |
def forward(self, pred, target): | |
loss = self.loss_func(pred, target) | |
return loss | |