# Copyright 2024 The YourMT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Please see the details in the LICENSE file. """lm_head.py""" import torch from torch import nn from typing import Optional, Dict class LMHead(nn.Module): """Language Model Head with tied weights.""" def __init__(self, decoder_config: Dict, init_factor: float = 1.0, tie_word_embeddings: bool = True): super().__init__() self.d_model = decoder_config["d_model"] self.init_factor = init_factor self.tie_word_embeddings = tie_word_embeddings self.lm_head = nn.Linear(decoder_config["d_model"], decoder_config["vocab_size"], bias=False) self._init_weights() def _init_weights(self): if self.tie_word_embeddings is False: self.lm_head.weight.data.normal_(mean=0.0, std=self.init_factor * 1.0) def forward(self, decoder_hs: torch.FloatTensor) -> torch.FloatTensor: if self.tie_word_embeddings is True: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 decoder_hs = decoder_hs * (self.d_model**-0.5) lm_logits = self.lm_head(decoder_hs) return lm_logits