#!/usr/bin/env python3 # -*- coding: utf-8 -*- """Multi-Head Attention layer definition.""" import math import numpy import torch from torch import nn from typing import Optional, Tuple import torch.nn.functional as F from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask import funasr_detach.models.lora.layers as lora class CosineDistanceAttention(nn.Module): """Compute Cosine Distance between spk decoder output and speaker profile Args: profile_path: speaker profile file path (.npy file) """ def __init__(self): super().__init__() self.softmax = nn.Softmax(dim=-1) def forward(self, spk_decoder_out, profile, profile_lens=None): """ Args: spk_decoder_out(torch.Tensor):(B, L, D) spk_profiles(torch.Tensor):(B, N, D) """ x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D) if profile_lens is not None: mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device) min_value = float( numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min ) weights_not_softmax = F.cosine_similarity( x, profile.unsqueeze(1), dim=-1 ).masked_fill(mask, min_value) weights = self.softmax(weights_not_softmax).masked_fill( mask, 0.0 ) # (B, L, N) else: x = x[:, -1:, :, :] weights_not_softmax = F.cosine_similarity( x, profile.unsqueeze(1).to(x.device), dim=-1 ) weights = self.softmax(weights_not_softmax) # (B, 1, N) spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D) return spk_embedding, weights