OSUM / wenet /llm_asr /llmasr_model.py
tomxxie
适配zeroGPU
568e264
import logging
import os
import torchaudio
import torch
from peft import LoraConfig, TaskType, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoTokenizer
from wenet.transformer.encoder import TransformerEncoder
from wenet.llm_asr.utils4llmasr import *
from gxl_ai_utils.utils import utils_file
from wenet.llm_asr.downsampler import get_downsampler, LyzConv1dSubsampling
from wenet.utils.mask import make_pad_mask
# import torch_npu
# from torch_npu.contrib import transfer_to_npu
# from msprobe.pytorch import seed_all,PrecisionDebugger
class LLMASR_Model(nn.Module):
def __init__(self,
encoder,
encoder_output_dim,
llm_path,
lora=True, lora_alpha=32, lora_rank=8, lora_dropout=0.1,
prompt_pattern="{}:<Speech><SpeechHere></Speech>",
# "USER: <Speech><SpeechHere></Speech> {}\nASSISTANT:"
is_inference=False,
downsample_rate=1,
llm_embed_dim=4096,
task_num=2,
adapter_type='lyz',
speech_token_num=0,
train_speech_out=False):
""""""
super().__init__()
self.downsample_rate = downsample_rate
self.encoder = encoder
self.ln_speech = nn.LayerNorm(encoder_output_dim)
# 连接层, 51.6M
if adapter_type == 'gxl':
self.speech_transformer = TransformerEncoder(
input_size=encoder_output_dim,
output_size=encoder_output_dim,
attention_heads=4,
linear_units=2560,
num_blocks=4,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="linear",
pos_enc_layer_type="abs_pos",
normalize_before=True
)
else:
self.speech_transformer = None
# LLM,
self.low_resource = False
if not self.low_resource:
self.llama_model = AutoModelForCausalLM.from_pretrained(
llm_path,
# torch_dtype=torch.float32 if is_inference else torch.float16,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
output_hidden_states=True,
)
else:
self.llama_model = AutoModelForCausalLM.from_pretrained(
llm_path,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
trust_remote_code=True,
output_hidden_states=True,
)
self.max_length = 300
self.min_length = 1
self.num_beams = 4
self.do_sample = True
self.top_p = 0.0
self.top_k = 0
self.repetition_penalty = 1.05
self.length_penalty = 1.0
self.temperature = 1.0
self.IGNORE_ID = -100
# lora
self.lora = lora
if lora:
utils_file.logging_limit_print("耿雪龙: 使用lora了")
#target_modules = ['w_pack', 'o_proj', 'gate_proj', 'down_proj']
target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj']
if is_inference:
self.peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
)
else:
self.peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules,
)
self.llama_model = get_peft_model(self.llama_model, self.peft_config)
# tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
llm_path, use_fast=False, trust_remote_code=True)
"""
设置分词器的pad_token和padding的方向。
"""
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.tokenizer.padding_side = "right"
if hasattr(self.llama_model.config, 'hidden_size'):
utils_file.logging_limit_print(
f"self.llama_model.config.hidden_size: {self.llama_model.config.hidden_size}")
if adapter_type == 'lyz':
self.down_sample_2 = LyzConv1dSubsampling(encoder_output_dim, self.llama_model.config.hidden_size)
elif adapter_type == 'gxl':
self.down_sample_2 = get_downsampler(downsample_rate, encoder_output_dim)
self.speech_llama_proj = nn.Linear(
encoder_output_dim, self.llama_model.config.hidden_size)
# self.task_embeddings = torch.nn.Embedding(task_num, self.llama_model.config.hidden_size)
else:
raise NotImplementedError("self.llama_model.config.hidden_size not exist")
self.embed_tokens = self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens
self.lm_head = self.llama_model.model.lm_head if self.lora else self.llama_model.lm_head
self.speech_token_num = speech_token_num
# init speech token module
if speech_token_num > 0:
utils_file.logging_info(f'耿雪龙: 进行语音token生成任务, speech_token_num: {speech_token_num}')
self.speech_token_emded = torch.nn.Embedding(speech_token_num + 2, self.llama_model.config.hidden_size)
self.speaker_head = torch.nn.Linear(self.llama_model.config.hidden_size, speech_token_num)
else:
# 不做任何处理
self.speaker_head = nn.Identity()
self.speech_token_emded = nn.Identity()
self.train_speech_out = train_speech_out
utils_file.logging_info(f'耿雪龙: 是否进行语音输出训练:{self.train_speech_out}')
self.loss_fct = CrossEntropyLoss(reduction='mean')
# self.debugger = PrecisionDebugger(config_path='./do_align_test/config_gpu.json', model=self.encoder)
def get_label_embedding(self, labels, labels_lengths):
""""""
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
labels_embeds = self.embed_tokens(labels)
labels_target = labels.masked_fill(labels_pad_mask, self.IGNORE_ID) # B, L
labels_mask = ~labels_pad_mask
return labels_embeds, labels_target, labels_mask
def get_speech_token_label_embedding(self, speech_token_labels, speech_tokens_length):
""""""
speech_tokens_pad_mask = make_pad_mask(speech_tokens_length) # B, L
speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0)
speech_token_labels_embeds = self.speech_token_emded(speech_token_labels)
utils_file.logging_limit_print(f'进行speech_token_labels修改,修改前 speech_token_labels',
speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0])
speech_token_labels = speech_token_labels + 152064
utils_file.logging_limit_print(f'进行speech_token_labels修改,修改后 speech_token_labels',
speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0])
speech_token_labels_target = speech_token_labels.masked_fill(speech_tokens_pad_mask, self.IGNORE_ID) # B, L
speech_token_labels_mask = ~speech_tokens_pad_mask
return speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask
def forward(self,
batch,
device,
):
""""""
rank = int(os.environ.get('RANK', 0))
# wavs = batch['feats'].to(device)
# wavs_len = batch['feats_lengths'].to(device)
# if rank == 0:
# utils_file.logging_limit_print(
# f'wavs shape: {wavs.shape},第一帧的前20个数字:\n{wavs[0][0][:20]}')
output_type = batch['output_type']
assert output_type in ['text', 'speech2text_token', 'text2token'], f"output_type:{output_type} not support"
# utils_file.logging_limit_print('进入 llmasr forward() ,首先来看一下输入')
# utils_file.logging_limit_print('wavs.shape:', wavs.shape)
# utils_file.logging_limit_print('wavs_len.shape:', wavs_len.shape)
# utils_file.logging_limit_print('wavs_len:', wavs_len)
# utils_file.logging_limit_print('labels.shape:', labels.shape)
# utils_file.logging_limit_print('labels_lengths.shape:', labels_lengths.shape)
# utils_file.logging_limit_print('output_type:', output_type)
# utils_file.logging_limit_print('观看结束')
# speech inputs
if output_type == 'text' or output_type == 'speech2text_token':
wavs = batch['feats'].to(device)
wavs_len = batch['feats_lengths'].to(device)
speech_embeds, speech_masks = self.get_embedding_from_wav(wavs, wavs_len)
speech_target = torch.full(speech_masks.shape, self.IGNORE_ID).to(
speech_embeds.device)
utils_file.logging_limit_print('进入 llmasr forward() ,首先来看一下输入')
utils_file.logging_limit_print('wavs.shape:', wavs.shape)
utils_file.logging_limit_print('wavs_len.shape:', wavs_len.shape)
utils_file.logging_limit_print('wavs_len:', wavs_len)
utils_file.logging_limit_print('output_type:', output_type)
utils_file.logging_limit_print('speech_embeds:', speech_embeds.shape)
utils_file.logging_limit_print('观看结束') # haha
else:
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
# text 2 token ,拿到文本序列
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_embeds = self.embed_tokens(labels) # B, L, D
speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
speech_embeds.device)
speech_masks = ~labels_pad_mask
# add bos and eos
speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num,
1 + self.speech_token_num,
speech_embeds, speech_masks, speech_target)
# prompt
if 'prompt' in batch:
prompt = batch['prompt'].to(device)
prompt_lengths = batch['prompt_lengths'].to(device)
prompt_pad_mask = make_pad_mask(prompt_lengths) # B, L
prompt = prompt.masked_fill(prompt_pad_mask, self.tokenizer.eos_token_id)
prompt_embeds = self.embed_tokens(prompt) # B, L, D
prompt_target = torch.full(prompt.shape, self.IGNORE_ID).to(
speech_embeds.device) # B, L
prompt_mask = ~prompt_pad_mask
else:
raise ValueError('prompt is not in batch')
if output_type == 'speech2text_token':
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
speech_token_labels = batch['speech_tokens'].to(device)
speech_tokens_length = batch['speech_tokens_length'].to(device)
utils_file.logging_limit_print('进入 llmasr forward() ,首先来一下目标')
utils_file.logging_limit_print('labels.shape:', labels.shape)
utils_file.logging_limit_print('labels_lengths.shape:', labels_lengths.shape)
utils_file.logging_limit_print('labels_lengths:', labels_lengths)
utils_file.logging_limit_print('speech_token_labels.shape:', speech_token_labels.shape)
utils_file.logging_limit_print('speech_tokens_length.shape:', speech_tokens_length.shape)
utils_file.logging_limit_print('speech_tokens_length:', speech_tokens_length)
utils_file.logging_limit_print('观看结束')
labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding(
speech_token_labels, speech_tokens_length)
# concat
inputs_embeds = torch.cat([prompt_embeds, speech_embeds,
labels_embeds, speech_token_labels_embeds], dim=1)
attention_mask = torch.cat([prompt_mask, speech_masks,
labels_mask, speech_token_labels_mask], dim=1)
target = torch.cat([prompt_target, speech_target,
labels_target, speech_token_labels_target], dim=1)
elif output_type == "text2token":
speech_token_labels = batch['speech_tokens'].to(device)
speech_tokens_length = batch['speech_tokens_length'].to(device)
speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding(
speech_token_labels, speech_tokens_length)
inputs_embeds = torch.cat([prompt_embeds, speech_embeds,
speech_token_labels_embeds], dim=1)
attention_mask = torch.cat([prompt_mask, speech_masks,
speech_token_labels_mask], dim=1)
target = torch.cat([prompt_target, speech_target,
speech_token_labels_target], dim=1)
elif output_type == "text":
labels = batch['target'].to(device)
labels_lengths = batch['target_lengths'].to(device)
labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
# concat
inputs_embeds = torch.cat([prompt_embeds, speech_embeds,
labels_embeds], dim=1)
attention_mask = torch.cat([prompt_mask, speech_masks,
labels_mask], dim=1)
target = torch.cat([prompt_target, speech_target,
labels_target], dim=1)
else:
raise NotImplementedError(f'output_type {output_type} not support')
utils_file.logging_limit_print(f'耿雪龙 output_type: {output_type}')
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
# labels=target,
attention_mask=attention_mask,
position_ids=position_ids.to(inputs_embeds.device)
)
hidden_states = outputs['hidden_states'][-1]
logits = self.lm_head(hidden_states)
logits2 = self.speaker_head(hidden_states) # speech_head
combined_logits = torch.cat([logits, logits2], dim=-1)
shift_logits = combined_logits[..., :-1, :].contiguous()
shift_target = target[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, combined_logits.shape[-1]) # 注意这里维度的调整,根据logits2的维度相应改变
shift_target = shift_target.view(-1)
shift_target = shift_target.to(shift_logits.device)
loss = self.loss_fct(shift_logits, shift_target)
loss.requires_grad_(True)
return {"loss": loss}
def generate(
self,
wavs,
wavs_len,
prompt,
):
speech_embeds, speech_masks = self.get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
speech_embeds, speech_masks, None)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
# embeds = embeds.to(torch.float16)
embeds = embeds.to(torch.bfloat16)
atts = atts.to(torch.bfloat16)
outputs = self.llama_model.generate(
inputs_embeds=embeds,
max_new_tokens=self.max_length,
num_beams=self.num_beams,
do_sample=self.do_sample,
min_length=self.min_length,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
length_penalty=self.length_penalty,
temperature=self.temperature,
attention_mask=atts,
eos_token_id=151643,
pad_token_id=-100,
)
# 获取生成的token IDs
# token_ids = outputs[0].tolist() # 假设batch_size=1,取第一个输出
# 将token IDs转换为字符串
# tokens = [self.tokenizer.decode([token_id], skip_special_tokens=True) for token_id in token_ids]
# 打印token列表和字符串列表
# print("Token IDs:", token_ids)
# print("Tokens:", tokens)
# 使用tokenizer将token IDs批量转换为字符串
# output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
# print("Output Text:", output_text)
output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
# 处理token,为英文单词前加上空格
# processed_tokens = []
# for token in tokens:
# # 检查是否为英文单词(简单判断:是否全部由字母组成)
# if token.isalpha() and token[0].isascii():
# processed_tokens.append(" " + token) # 英文单词前加空格
# else:
# processed_tokens.append(token) # 其他token保持不变
# output_text = "".join(processed_tokens)
return output_text
def generate4seech_token(
self,
wavs,
wavs_len,
prompt,
):
speech_embeds, speech_masks = self.get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
speech_embeds, speech_masks, None)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
atts = atts.half()
outputs = self.llama_model.generate(
inputs_embeds=embeds,
max_new_tokens=self.max_length,
num_beams=self.num_beams,
do_sample=self.do_sample,
min_length=self.min_length,
top_p=self.top_p,
top_k=self.top_k,
repetition_penalty=self.repetition_penalty,
length_penalty=self.length_penalty,
temperature=self.temperature,
attention_mask=atts,
eos_token_id=151643,
pad_token_id=-100,
)
output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
return output_text
def get_embedding_from_wav(self, wavs, wavs_len):
"""
return:
wav_embedding: (b, l, v)
wav_mask: (b, l), wav为有效值的位置为true
"""
# utils_file.logging_limit_print('get_embedding_from_wav(): wavs.shape:', wavs.shape)
# utils_file.logging_limit_print('get_embedding_from_wav(): wavs_len.shape:', wavs_len.shape)
rank = int(os.environ.get('RANK', 0))
# self.debugger.start()
encoder_out, encoder_mask = self.encoder(wavs, wavs_len)
# self.debugger.stop()
# self.debugger.step()
if rank == 0:
utils_file.logging_limit_print(
f'encoder out shape: {encoder_out.shape},encoder的第一帧的前20个数字:\n{encoder_out[0][0][:20]}')
# utils_file.logging_limit_print(
# 'get_embedding_from_wav(): speech_embeds.shape,by self.encoder(wavs, wavs_len):',
# encoder_out.shape)
speech_embeds, encoder_mask = self.down_sample_2(encoder_out, encoder_mask)
if rank == 0:
utils_file.logging_limit_print(
f'out of down_sample_2 shape: {speech_embeds.shape},encoder的第一帧的前20个数字:\n{speech_embeds[0][0][:20]}')
# utils_file.logging_limit_print(
# 'get_embedding_from_wav(): speech_embeds.shape,by self.down_sample_2(speech_embeds):', speech_embeds.shape)
# # max_utt_len = speech_embeds.size(1)
# filled_wavs_len = torch.ones(speech_embeds.size(0)) * max_utt_len
# filled_wavs_len = filled_wavs_len.to(speech_embeds.device)
if self.speech_transformer is not None:
filled_wavs_len = encoder_mask.squeeze(1).sum(-1)
speech_embeds, encoder_mask = self.speech_transformer(speech_embeds, filled_wavs_len)
if rank == 0:
utils_file.logging_limit_print(
f'out of link shape: {speech_embeds.shape},encoder的第一帧的前20个数字:\n {speech_embeds[0][0][:20]}')
# utils_file.logging_limit_print(
# 'get_embedding_from_wav(): speech_embeds.shape,by self.speech_transformer(speech_embeds, speech_lens):',
# speech_embeds.shape)
speech_embeds = self.speech_llama_proj(speech_embeds)
if rank == 0:
utils_file.logging_limit_print(
f'out of speech_llama_proj shape: {speech_embeds.shape},encoder的第一帧的前20个数字:\n {speech_embeds[0][0][:20]}')
# utils_file.logging_limit_print(
# 'get_embedding_from_wav(): speech_embeds.shape,by self.speech_llama_proj(speech_embeds):',
# speech_embeds.shape)
return speech_embeds, encoder_mask.squeeze(1)
def get_embedding_from_text(self, text):
text_id = self.tokenizer(
text,
return_tensors="pt",
add_special_tokens=False
).to(
self.embed_tokens.weight.device).input_ids
text_embeds = self.embed_tokens(text_id)
return text_embeds
def get_embeds_from_wav_path(self, wav_path):
wav_i2_path = wav_path
utils_file.logging_limit_print('get_embeds_from_wav_path(): wav_i2_path:', wav_i2_path)
waveform_i2, _ = torchaudio.load(wav_i2_path)
utils_file.logging_limit_print('get_embeds_from_wav_path(): waveform_i2.shape:', waveform_i2.shape)
if len(waveform_i2.shape) != 1:
waveform_i2 = waveform_i2[0]
waveform_i2 = waveform_i2.to(self.embed_tokens.weight.device)
wavs_len_i2 = torch.tensor([len(waveform_i2)], device=self.embed_tokens.weight.device, dtype=torch.int32)
wavs_i2 = waveform_i2.unsqueeze(0)
sample_i2_embeds = self.get_embedding_from_wav(wavs_i2, wavs_len_i2)
utils_file.logging_limit_print('get_embeds_from_wav_path(): sample_i2_embeds.shape:', sample_i2_embeds.shape)
return sample_i2_embeds
def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None):
B = len(inputs_embeds)
bos_eos_target = torch.full([B, 1], self.IGNORE_ID).to(inputs_embeds.device) # B,1
bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1
if bos is not None:
bos_embed = self.speech_token_emded(torch.full([B, 1],
bos).to(inputs_embeds.device)) # B, 1, D
inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D
attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T)
if target is not None:
target = torch.cat((bos_eos_target, target), 1) # B, (1+T), D
if eos is not None:
eos_embed = self.speech_token_emded(torch.full([B, 1],
eos).to(inputs_embeds.device)) # B, 1, D
inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D
attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1)
if target is not None:
target = torch.cat((target, bos_eos_target), 1) # B, (1+T+1), D
return inputs_embeds, attention_mask, target
def infer_for_speech2text_token( # speech2text-token
self,
wavs,
wavs_len,
prompt,
text=None,
):
if text is not None:
prompt = torch.cat((prompt, text), dim=1)
speech_embeds, speech_masks = self.get_embedding_from_wav(wavs, wavs_len)
speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, None,
speech_embeds, speech_masks, None)
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
atts = atts.half()
device = wavs.device
max_len = 300
hyps = torch.ones([1, 1], dtype=torch.int64,
device=device).fill_(1 + self.speech_token_num) # (B*N, 1)
llm_out = self.llama_model(
inputs_embeds=embeds,
past_key_values=None,
output_hidden_states=True
)
cache = llm_out.past_key_values
utils_file.logging_limit_print('得到首个cache,开始进行for循环推理')
token_emb = self.speech_token_emded(hyps[:, -1:])
for i in range(max_len):
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1] # 最后一层的
token_logits_1 = self.lm_head(hidden_states)
# utils_file.logging_limit_print(f'token_logits_1.shape:{token_logits_1.shape}')
token_logits_2 = self.speaker_head(hidden_states)
# utils_file.logging_limit_print(f'token_logits_2.shape:{token_logits_2.shape}')
big_logits = torch.cat([token_logits_1, token_logits_2], dim=-1)
# utils_file.logging_limit_print(f'big_logits.shape:{big_logits.shape}')
logp = torch.nn.functional.log_softmax(big_logits[:, -1, :], dim=-1) # 取了最后一个
# utils_file.logging_limit_print(f'logp.shape:{logp.shape}')
max_index = torch.argmax(logp, dim=-1, keepdim=True)
# utils_file.logging_limit_print(f'max_index.shape:{max_index.shape}')
utils_file.logging_limit_print(f'max_index:{max_index}')
hyps = torch.cat((hyps, max_index),
dim=1) # (B*N, i+1)
if max_index < 152064:
token_emb = self.embed_tokens(hyps[:, -1:])
else:
if max_index == 152064 + 4096:
utils_file.logging_limit_print(f'耿雪龙 遇到token结束符号,结束')
break
token_emb = self.speech_token_emded(hyps[:, -1:])
best_hyps = hyps[0, :]
text_res = []
token_res = []
for i in best_hyps[1:]:
if i < 152064:
text_res.append(i)
else:
token_res.append(str((i - 152064).item()))
str_i = self.tokenizer.decode(text_res, skip_special_tokens=True, add_special_tokens=False)
return [str_i + " | " + " ".join(token_res)]
# output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
def infer_for_text2token( # text2token
self,
wavs,
wavs_len,
prompt,
text=None,
):
if text is not None:
prompt = torch.cat((prompt, text), dim=1)
# speech_embeds, speech_masks = self.get_embedding_from_wav(wavs, wavs_len)
# speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, None,
# speech_embeds, speech_masks, None)
labels_lengths = torch.tensor([len(text)-1], dtype=torch.int64)
labels = text[:,:-1]
labels_pad_mask = make_pad_mask(labels_lengths) # B, L
labels = labels.masked_fill(labels_pad_mask, 0)
speech_embeds = self.embed_tokens(labels) # B, L, D
speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
speech_embeds.device)
speech_masks = ~labels_pad_mask
prompt = self.tokenizer([prompt], return_tensors="pt"
)['input_ids'].to(speech_embeds.device)
prompt_embeds = self.embed_tokens(prompt)
embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
if self.embed_tokens.weight.dtype == torch.float16:
utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
embeds = embeds.to(torch.float16)
atts = atts.half()
device = wavs.device
max_len = 300
hyps = torch.ones([1, 1], dtype=torch.int64,
device=device).fill_() # (B*N, 1)
llm_out = self.llama_model(
inputs_embeds=embeds,
past_key_values=None,
output_hidden_states=True
)
cache = llm_out.past_key_values
utils_file.logging_limit_print('得到首个cache,开始进行for循环推理')
token_emb = self.embed_tokens(hyps[:, -1:])
for i in range(max_len):
llm_out = self.llama_model(
inputs_embeds=token_emb,
past_key_values=cache,
output_hidden_states=True
)
cache = llm_out.past_key_values
hidden_states = llm_out.hidden_states[-1] # 最后一层的
token_logits_1 = self.lm_head(hidden_states)
# utils_file.logging_limit_print(f'token_logits_1.shape:{token_logits_1.shape}')
token_logits_2 = self.speaker_head(hidden_states)
# utils_file.logging_limit_print(f'token_logits_2.shape:{token_logits_2.shape}')
big_logits = torch.cat([token_logits_1, token_logits_2], dim=-1)
# utils_file.logging_limit_print(f'big_logits.shape:{big_logits.shape}')
logp = torch.nn.functional.log_softmax(big_logits[:, -1, :], dim=-1) # 取了最后一个
# utils_file.logging_limit_print(f'logp.shape:{logp.shape}')
max_index = torch.argmax(logp, dim=-1, keepdim=True)
# utils_file.logging_limit_print(f'max_index.shape:{max_index.shape}')
utils_file.logging_limit_print(f'max_index:{max_index}')
hyps = torch.cat((hyps, max_index),
dim=1) # (B*N, i+1)
if max_index < 152064:
token_emb = self.embed_tokens(hyps[:, -1:])
else:
if max_index == 152064 + 4096 :
utils_file.logging_limit_print(f'耿雪龙 遇到token结束符号,结束')
break
token_emb = self.speech_token_emded(hyps[:, -1:])
best_hyps = hyps[0, :]
text_res = []
token_res = []
for i in best_hyps[1:]:
if i < 152064:
text_res.append(i)
else:
token_res.append(str((i - 152064).item()))
str_i = self.tokenizer.decode(text_res, skip_special_tokens=True, add_special_tokens=False)
return [str_i + " | " + " ".join(token_res)]
# output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)