Step-Audio-TTS-3B / stepaudio.py
mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
raw
history blame
4.62 kB
import os
import torch
import torchaudio
from transformers import AutoTokenizer, AutoModelForCausalLM
from tokenizer import StepAudioTokenizer
from tts import StepAudioTTS
from utils import load_audio, speech_adjust, volumn_adjust
class StepAudio:
def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str):
# load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib`
# if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib
try:
if torch.__version__ >= "2.5":
torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'))
elif torch.__version__ >= "2.3":
torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so'))
elif torch.__version__ >= "2.2":
torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so'))
print("Load optimus_ths successfully and flash attn would be enabled")
except Exception as err:
print(f"Fail to load optimus_ths and flash attn is disabled: {err}")
self.llm_tokenizer = AutoTokenizer.from_pretrained(
llm_path, trust_remote_code=True
)
self.encoder = StepAudioTokenizer(tokenizer_path)
self.decoder = StepAudioTTS(tts_path, self.encoder)
self.llm = AutoModelForCausalLM.from_pretrained(
llm_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
def __call__(
self,
messages: list,
speaker_id: str,
speed_ratio: float = 1.0,
volumn_ratio: float = 1.0,
):
text_with_audio = self.apply_chat_template(messages)
token_ids = self.llm_tokenizer.encode(text_with_audio, return_tensors="pt")
outputs = self.llm.generate(
token_ids, max_new_tokens=2048, temperature=0.7, top_p=0.9, do_sample=True
)
output_token_ids = outputs[:, token_ids.shape[-1] : -1].tolist()[0]
output_text = self.llm_tokenizer.decode(output_token_ids)
output_audio, sr = self.decoder(output_text, speaker_id)
if speed_ratio != 1.0:
output_audio = speech_adjust(output_audio, sr, speed_ratio)
if volumn_ratio != 1.0:
output_audio = volumn_adjust(output_audio, volumn_ratio)
return output_text, output_audio, sr
def encode_audio(self, audio_path):
audio_wav, sr = load_audio(audio_path)
audio_tokens = self.encoder(audio_wav, sr)
return audio_tokens
def apply_chat_template(self, messages: list):
text_with_audio = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "user":
role = "human"
if isinstance(content, str):
text_with_audio += f"<|BOT|>{role}\n{content}<|EOT|>"
elif isinstance(content, dict):
if content["type"] == "text":
text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>"
elif content["type"] == "audio":
audio_tokens = self.encode_audio(content["audio"])
text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>"
elif content is None:
text_with_audio += f"<|BOT|>{role}\n"
else:
raise ValueError(f"Unsupported content type: {type(content)}")
if not text_with_audio.endswith("<|BOT|>assistant\n"):
text_with_audio += "<|BOT|>assistant\n"
return text_with_audio
if __name__ == "__main__":
model = StepAudio(
encoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-encoder",
decoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-decoder",
llm_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-v18",
)
text, audio, sr = model(
[{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}],
"闫雨婷",
)
torchaudio.save("output/output_e2e_tqta.wav", audio, sr)
text, audio, sr = model(
[
{
"role": "user",
"content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"},
}
],
"闫雨婷",
)
torchaudio.save("output/output_e2e_aqta.wav", audio, sr)