Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Alibaba Inc | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch.nn as nn | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from inspiremusic.utils.mask import make_pad_mask | |
from inspiremusic.utils.hinter import hint_once | |
class QwenEncoder(nn.Module): | |
def __init__( | |
self, | |
input_size: int, | |
pretrain_path: str = "Qwen/Qwen2.0-0.5B", | |
trainable: bool = False, | |
do_fusion_emb: bool = False, | |
fusion_drop_rate: float = 0.0, | |
): | |
super(QwenEncoder, self).__init__() | |
self.input_size = input_size | |
self.trainable = trainable | |
self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="cpu") | |
self._output_size = self.model.config.hidden_size | |
self.do_fusion_emb = do_fusion_emb | |
self.hidden_norm = torch.nn.LayerNorm(self._output_size) | |
self.fusion_dropout = nn.Dropout(fusion_drop_rate) | |
if do_fusion_emb: | |
self.fusion_layer = torch.nn.Linear(self._output_size * 2, self._output_size) | |
self.emb_norm = torch.nn.LayerNorm(self._output_size) | |
self.fusion_norm = torch.nn.LayerNorm(self._output_size) | |
from inspiremusic.transformer.activation import Swish | |
self.fusion_act = Swish(self) | |
if not self.trainable: | |
self.model.eval() | |
def output_size(self) -> int: | |
return self._output_size | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
ilens: torch.Tensor, | |
): | |
device = input_ids.device | |
input_ids = torch.clamp(input_ids, min=0, max=None) | |
input_masks = (~make_pad_mask(ilens)).to(device).long() | |
if not self.trainable: | |
with torch.no_grad(): | |
model_outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=input_masks, | |
output_hidden_states=True | |
) | |
else: | |
model_outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=input_masks, | |
output_hidden_states=True | |
) | |
outs = model_outputs.hidden_states[-1] | |
outs = self.hidden_norm(outs) | |
if self.do_fusion_emb: | |
hint_once("fuse embedding and LM outputs", "fuse_emb") | |
outs = self.fusion_dropout(self.fusion_act(outs)) | |
emb = model_outputs.hidden_states[0] | |
emb = self.fusion_dropout(self.fusion_act(self.emb_norm(emb))) | |
outs = self.fusion_layer( | |
torch.cat([outs, emb], dim=-1) | |
) | |
outs = self.fusion_act(self.fusion_norm(outs)) | |
return outs, ilens | |
class QwenEmbeddingEncoder(nn.Module): | |
def __init__( | |
self, | |
input_size: int, | |
pretrain_path: str = "Qwen/Qwen2.0-0.5B", | |
): | |
super(QwenEmbeddingEncoder, self).__init__() | |
self.input_size = input_size | |
from transformers import Qwen2ForCausalLM | |
# self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2") | |
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, | |
device_map="cpu") | |
self._output_size = self.model.config.hidden_size | |
def output_size(self) -> int: | |
return self._output_size | |
def forward( | |
self, | |
input_embeds: torch.Tensor, | |
ilens: torch.Tensor, | |
): | |
input_masks = (~make_pad_mask(ilens)).to(input_embeds.device).long() | |
outs = self.model( | |
inputs_embeds=input_embeds, | |
attention_mask=input_masks, | |
output_hidden_states=True, | |
return_dict=True, | |
) | |
return outs.hidden_states[-1], input_masks | |
def forward_one_step(self, xs, masks, cache=None): | |
outs = self.model( | |
inputs_embeds=xs, | |
attention_mask=masks, | |
output_hidden_states=True, | |
return_dict=True, | |
use_cache=True, | |
past_key_values=cache, | |
) | |
xs = outs.hidden_states[-1] | |
new_cache = outs.past_key_values | |
return xs, masks, new_cache | |
class QwenInputOnlyEncoder(nn.Module): | |
def __init__( | |
self, | |
input_size: int, | |
pretrain_path: str = "Qwen/Qwen2.0-0.5B", | |
): | |
super(QwenInputOnlyEncoder, self).__init__() | |
self.input_size = input_size | |
from transformers import Qwen2ForCausalLM | |
# model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2") | |
model = Qwen2ForCausalLM.from_pretrained(pretrain_path, | |
device_map="cpu") | |
self.embed = model.model.embed_tokens | |
for p in self.embed.parameters(): | |
p.requires_grad = False | |
# set text embedding to non-trainable | |
# self.post_embed = model.model.rotary_emb | |
self._output_size = model.config.hidden_size | |
def output_size(self) -> int: | |
return self._output_size | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
ilens: torch.Tensor, | |
): | |
input_masks = (~make_pad_mask(ilens)).to(input_ids.device).long() | |
outs = self.embed(input_ids) | |
return outs, input_masks | |