|
--- |
|
base_model: |
|
- HKUSTAudio/Llasa-3B |
|
--- |
|
|
|
# Sample Inference Script |
|
```py |
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
import torchaudio |
|
from exllamav2 import ( |
|
ExLlamaV2, |
|
ExLlamaV2Cache, |
|
ExLlamaV2Config, |
|
ExLlamaV2Tokenizer, |
|
Timer, |
|
) |
|
from exllamav2.generator import ( |
|
ExLlamaV2DynamicGenerator, |
|
ExLlamaV2DynamicJob, |
|
ExLlamaV2Sampler, |
|
) |
|
from jinja2 import Template |
|
from rich import print |
|
from torchaudio import functional as F |
|
from xcodec2.modeling_xcodec2 import XCodec2Model |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument("-m", "--model", required=True) |
|
parser.add_argument("-v", "--vocoder", required=True) |
|
parser.add_argument("-i", "--input", required=True) |
|
parser.add_argument("-a", "--audio", default="") |
|
parser.add_argument("-t", "--transcript", default="") |
|
parser.add_argument("-o", "--output", default="output.wav") |
|
parser.add_argument("-d", "--debug", action="store_true") |
|
parser.add_argument("--sample_rate", type=int, default=16000) |
|
parser.add_argument("--max_seq_len", type=int, default=2048) |
|
parser.add_argument("--temperature", type=float, default=0.8) |
|
parser.add_argument("--top_p", type=float, default=1.0) |
|
args = parser.parse_args() |
|
|
|
with Timer() as timer: |
|
config = ExLlamaV2Config(args.model) |
|
config.max_seq_len = args.max_seq_len |
|
|
|
model = ExLlamaV2(config, lazy_load=True) |
|
cache = ExLlamaV2Cache(model, lazy=True) |
|
model.load_autosplit(cache, progress=True) |
|
|
|
tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True) |
|
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer) |
|
|
|
print(f"Loaded model in {timer.interval:.2f} seconds.") |
|
|
|
with Timer() as timer: |
|
vocoder = XCodec2Model.from_pretrained(args.vocoder) |
|
vocoder = vocoder.cuda().eval() |
|
|
|
print(f"Loaded vocoder in {timer.interval:.2f} seconds.") |
|
|
|
if args.audio and args.transcript: |
|
with Timer() as timer: |
|
transcript = f"{args.transcript} " |
|
audio, sample_rate = torchaudio.load(args.audio) |
|
audio = audio.cuda() |
|
|
|
if audio.shape[0] > 1: |
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
|
if sample_rate != args.sample_rate: |
|
audio = F.resample(audio, sample_rate, args.sample_rate) |
|
|
|
print(f"Loaded audio in {timer.interval:.2f} seconds.") |
|
|
|
with Timer() as timer: |
|
audio = vocoder.encode_code(audio) |
|
audio = audio[0, 0, :] |
|
audio = [f"<|s_{a}|>" for a in audio] |
|
audio = "".join(audio) |
|
|
|
print(f"Encoded audio in {timer.interval:.2f} seconds.") |
|
else: |
|
transcript = "" |
|
audio = "" |
|
|
|
with Timer() as timer: |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": ( |
|
"Convert the text to speech:" |
|
"<|TEXT_UNDERSTANDING_START|>" |
|
f"{transcript}{args.input}" |
|
"<|TEXT_UNDERSTANDING_END|>" |
|
), |
|
}, |
|
{"role": "assistant", "content": f"<|SPEECH_GENERATION_START|>{audio}"}, |
|
] |
|
|
|
template = tokenizer.tokenizer_config_dict["chat_template"] |
|
template = Template(template) |
|
|
|
input = template.render(messages=messages, eos_token="") |
|
input_ids = tokenizer.encode(input, add_bos=True, encode_special_tokens=True) |
|
|
|
print(f"Encoded input in {timer.interval:.2f} seconds.") |
|
|
|
with Timer() as timer: |
|
gen_settings = ExLlamaV2Sampler.Settings() |
|
gen_settings.temperature = args.temperature |
|
gen_settings.top_p = args.top_p |
|
|
|
job = ExLlamaV2DynamicJob( |
|
input_ids=input_ids, |
|
max_new_tokens=config.max_seq_len - input_ids.shape[-1], |
|
gen_settings=gen_settings, |
|
stop_conditions=["<|SPEECH_GENERATION_END|>"], |
|
) |
|
|
|
generator.enqueue(job) |
|
output = [] |
|
|
|
while generator.num_remaining_jobs(): |
|
for result in generator.iterate(): |
|
if result.get("stage") == "streaming": |
|
text = result.get("text") |
|
|
|
if text: |
|
output.append(text) |
|
|
|
if args.debug: |
|
print(text, end="", flush=True) |
|
|
|
if result.get("eos"): |
|
generator.clear_queue() |
|
|
|
if args.debug: |
|
print() |
|
|
|
print(f"Generated {len(output)} tokens in {timer.interval:.2f} seconds.") |
|
|
|
with Timer() as timer: |
|
output = [int(o[4:-2]) for o in output] |
|
output = torch.tensor([[output]]).cuda() |
|
output = vocoder.decode_code(output) |
|
output = output[0, 0, :] |
|
output = output.unsqueeze(0).cpu() |
|
torchaudio.save(args.output, output, args.sample_rate) |
|
|
|
print(f"Decoded audio in {timer.interval:.2f} seconds.") |
|
``` |