chong.zhang
update
e5d1aa9
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from inspiremusic.utils.mask import make_pad_mask
from inspiremusic.utils.hinter import hint_once
class QwenEncoder(nn.Module):
def __init__(
self,
input_size: int,
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
trainable: bool = False,
do_fusion_emb: bool = False,
fusion_drop_rate: float = 0.0,
):
super(QwenEncoder, self).__init__()
self.input_size = input_size
self.trainable = trainable
self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="cpu")
self._output_size = self.model.config.hidden_size
self.do_fusion_emb = do_fusion_emb
self.hidden_norm = torch.nn.LayerNorm(self._output_size)
self.fusion_dropout = nn.Dropout(fusion_drop_rate)
if do_fusion_emb:
self.fusion_layer = torch.nn.Linear(self._output_size * 2, self._output_size)
self.emb_norm = torch.nn.LayerNorm(self._output_size)
self.fusion_norm = torch.nn.LayerNorm(self._output_size)
from inspiremusic.transformer.activation import Swish
self.fusion_act = Swish(self)
if not self.trainable:
self.model.eval()
def output_size(self) -> int:
return self._output_size
def forward(
self,
input_ids: torch.Tensor,
ilens: torch.Tensor,
):
device = input_ids.device
input_ids = torch.clamp(input_ids, min=0, max=None)
input_masks = (~make_pad_mask(ilens)).to(device).long()
if not self.trainable:
with torch.no_grad():
model_outputs = self.model(
input_ids=input_ids,
attention_mask=input_masks,
output_hidden_states=True
)
else:
model_outputs = self.model(
input_ids=input_ids,
attention_mask=input_masks,
output_hidden_states=True
)
outs = model_outputs.hidden_states[-1]
outs = self.hidden_norm(outs)
if self.do_fusion_emb:
hint_once("fuse embedding and LM outputs", "fuse_emb")
outs = self.fusion_dropout(self.fusion_act(outs))
emb = model_outputs.hidden_states[0]
emb = self.fusion_dropout(self.fusion_act(self.emb_norm(emb)))
outs = self.fusion_layer(
torch.cat([outs, emb], dim=-1)
)
outs = self.fusion_act(self.fusion_norm(outs))
return outs, ilens
class QwenEmbeddingEncoder(nn.Module):
def __init__(
self,
input_size: int,
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
):
super(QwenEmbeddingEncoder, self).__init__()
self.input_size = input_size
from transformers import Qwen2ForCausalLM
# self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2")
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path,
device_map="cpu")
self._output_size = self.model.config.hidden_size
def output_size(self) -> int:
return self._output_size
def forward(
self,
input_embeds: torch.Tensor,
ilens: torch.Tensor,
):
input_masks = (~make_pad_mask(ilens)).to(input_embeds.device).long()
outs = self.model(
inputs_embeds=input_embeds,
attention_mask=input_masks,
output_hidden_states=True,
return_dict=True,
)
return outs.hidden_states[-1], input_masks
def forward_one_step(self, xs, masks, cache=None):
outs = self.model(
inputs_embeds=xs,
attention_mask=masks,
output_hidden_states=True,
return_dict=True,
use_cache=True,
past_key_values=cache,
)
xs = outs.hidden_states[-1]
new_cache = outs.past_key_values
return xs, masks, new_cache
class QwenInputOnlyEncoder(nn.Module):
def __init__(
self,
input_size: int,
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
):
super(QwenInputOnlyEncoder, self).__init__()
self.input_size = input_size
from transformers import Qwen2ForCausalLM
# model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2")
model = Qwen2ForCausalLM.from_pretrained(pretrain_path,
device_map="cpu")
self.embed = model.model.embed_tokens
for p in self.embed.parameters():
p.requires_grad = False
# set text embedding to non-trainable
# self.post_embed = model.model.rotary_emb
self._output_size = model.config.hidden_size
def output_size(self) -> int:
return self._output_size
def forward(
self,
input_ids: torch.Tensor,
ilens: torch.Tensor,
):
input_masks = (~make_pad_mask(ilens)).to(input_ids.device).long()
outs = self.embed(input_ids)
return outs, input_masks