|
#include "model_adapter.h" |
|
#include "otherarch/utils.h" |
|
|
|
#include "common.h" |
|
#include "sampling.h" |
|
#include "llama.h" |
|
|
|
#include <algorithm> |
|
#include <cmath> |
|
#include <cstdio> |
|
#include <fstream> |
|
#include <map> |
|
#include <regex> |
|
#include <string> |
|
#include <thread> |
|
#include <vector> |
|
|
|
#include "src/llama-context.h" |
|
|
|
#if defined(_MSC_VER) |
|
#pragma warning(disable: 4244 4267) |
|
#endif |
|
|
|
#ifndef M_PI |
|
#define M_PI 3.14159265358979323846 |
|
#endif |
|
|
|
enum TTS_VER |
|
{ |
|
TTS_VER_2, |
|
TTS_VER_3 |
|
}; |
|
|
|
struct wav_header { |
|
char riff[4] = {'R', 'I', 'F', 'F'}; |
|
uint32_t chunk_size; |
|
char wave[4] = {'W', 'A', 'V', 'E'}; |
|
char fmt[4] = {'f', 'm', 't', ' '}; |
|
uint32_t fmt_chunk_size = 16; |
|
uint16_t audio_format = 1; |
|
uint16_t num_channels = 1; |
|
uint32_t sample_rate; |
|
uint32_t byte_rate; |
|
uint16_t block_align; |
|
uint16_t bits_per_sample = 16; |
|
char data[4] = {'d', 'a', 't', 'a'}; |
|
uint32_t data_size; |
|
}; |
|
|
|
static std::string save_wav16_base64(const std::vector<float> &data, int sample_rate) { |
|
std::ostringstream oss; |
|
wav_header header; |
|
|
|
|
|
header.sample_rate = sample_rate; |
|
header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); |
|
header.block_align = header.num_channels * (header.bits_per_sample / 8); |
|
header.data_size = data.size() * (header.bits_per_sample / 8); |
|
header.chunk_size = 36 + header.data_size; |
|
|
|
|
|
oss.write(reinterpret_cast<const char*>(&header), sizeof(header)); |
|
|
|
|
|
for (const auto &sample : data) { |
|
int16_t pcm_sample = static_cast<int16_t>(std::clamp(sample * 32767.0, -32768.0, 32767.0)); |
|
oss.write(reinterpret_cast<const char*>(&pcm_sample), sizeof(pcm_sample)); |
|
} |
|
|
|
|
|
std::string wav_data = oss.str(); |
|
return kcpp_base64_encode(wav_data); |
|
} |
|
|
|
static void fill_hann_window(int length, bool periodic, float * output) { |
|
int offset = -1; |
|
if (periodic) { |
|
offset = 0; |
|
} |
|
for (int i = 0; i < length; i++) { |
|
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); |
|
} |
|
} |
|
|
|
|
|
static void twiddle(float * real, float * imag, int k, int N) { |
|
float angle = 2 * M_PI * k / N; |
|
*real = cos(angle); |
|
*imag = sin(angle); |
|
} |
|
|
|
static void irfft(int n, const float * inp_cplx, float * out_real) { |
|
int N = n / 2 + 1; |
|
|
|
std::vector<float> real_input(N); |
|
std::vector<float> imag_input(N); |
|
for (int i = 0; i < N; ++i) { |
|
real_input[i] = inp_cplx[2 * i]; |
|
imag_input[i] = inp_cplx[2 * i + 1]; |
|
} |
|
|
|
std::vector<float> real_output(n); |
|
std::vector<float> imag_output(n); |
|
|
|
for (int k = 0; k < n; ++k) { |
|
real_output[k] = 0.0f; |
|
imag_output[k] = 0.0f; |
|
for (int m = 0; m < N; ++m) { |
|
float twiddle_real; |
|
float twiddle_imag; |
|
|
|
twiddle(&twiddle_real, &twiddle_imag, k * m, n); |
|
|
|
real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; |
|
imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; |
|
} |
|
} |
|
|
|
for (int i = 0; i < n; ++i) { |
|
out_real[i] = real_output[i] / N; |
|
} |
|
} |
|
|
|
|
|
static void fold(const std::vector<float> & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float> & output) { |
|
int64_t output_height = n_out; |
|
int64_t kernel_w = n_win; |
|
int64_t stride_w = n_hop; |
|
int64_t width = n_out; |
|
|
|
output.resize(width, 0.0f); |
|
|
|
int64_t col_idx = 0; |
|
for (int64_t w_col = 0; w_col < width; ++w_col) { |
|
int64_t start = w_col * stride_w - n_pad; |
|
int64_t end = start + kernel_w; |
|
|
|
for (int64_t w_im = start; w_im < end; ++w_im) { |
|
if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { |
|
output[w_im] += data[col_idx]; |
|
} |
|
col_idx++; |
|
} |
|
} |
|
|
|
output.resize(n_out - 2 * n_pad); |
|
} |
|
|
|
|
|
static std::vector<float> embd_to_audio( |
|
const float * embd, |
|
const int n_codes, |
|
const int n_embd, |
|
const int n_thread) { |
|
|
|
const int n_fft = 1280; |
|
const int n_hop = 320; |
|
const int n_win = 1280; |
|
const int n_pad = (n_win - n_hop)/2; |
|
const int n_out = (n_codes - 1)*n_hop + n_win; |
|
|
|
std::vector<float> hann(n_fft); |
|
|
|
fill_hann_window(hann.size(), true, hann.data()); |
|
|
|
int n_spec = n_embd*n_codes; |
|
|
|
std::vector<float> E (n_spec); |
|
std::vector<float> S (n_spec); |
|
std::vector<float> ST(n_spec); |
|
|
|
for (int l = 0; l < n_codes; ++l) { |
|
for (int k = 0; k < n_embd; ++k) { |
|
E[k*n_codes + l] = embd[l*n_embd + k]; |
|
} |
|
} |
|
|
|
for (int k = 0; k < n_embd/2; ++k) { |
|
for (int l = 0; l < n_codes; ++l) { |
|
float mag = E[(k )*n_codes + l]; |
|
float phi = E[(k + n_embd/2)*n_codes + l]; |
|
|
|
mag = exp(mag); |
|
|
|
if (mag > 1e2) { |
|
mag = 1e2; |
|
} |
|
S[2*(k*n_codes + l) + 0] = mag*cosf(phi); |
|
S[2*(k*n_codes + l) + 1] = mag*sinf(phi); |
|
} |
|
} |
|
|
|
for (int l = 0; l < n_codes; ++l) { |
|
for (int k = 0; k < n_embd/2; ++k) { |
|
ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; |
|
ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; |
|
} |
|
} |
|
|
|
std::vector<float> res (n_codes*n_fft); |
|
std::vector<float> hann2(n_codes*n_fft); |
|
|
|
std::vector<std::thread> workers(n_thread); |
|
for (int i = 0; i < n_thread; ++i) { |
|
workers[i] = std::thread([&, i]() { |
|
for (int l = i; l < n_codes; l += n_thread) { |
|
irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); |
|
for (int j = 0; j < n_fft; ++j) { |
|
res [l*n_fft + j] *= hann[j]; |
|
hann2[l*n_fft + j] = hann[j] * hann[j]; |
|
} |
|
} |
|
}); |
|
} |
|
for (int i = 0; i < n_thread; ++i) { |
|
workers[i].join(); |
|
} |
|
|
|
std::vector<float> audio; |
|
std::vector<float> env; |
|
|
|
fold(res, n_out, n_win, n_hop, n_pad, audio); |
|
fold(hann2, n_out, n_win, n_hop, n_pad, env); |
|
|
|
for (size_t i = 0; i < audio.size(); ++i) { |
|
audio[i] /= env[i]; |
|
} |
|
|
|
return audio; |
|
} |
|
|
|
static const std::map<int, std::string> ones = { |
|
{0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}, |
|
{5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"}, |
|
{10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"}, |
|
{15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"} |
|
}; |
|
|
|
static const std::map<int, std::string> tens = { |
|
{2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, |
|
{6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"} |
|
}; |
|
|
|
|
|
static std::string convert_less_than_thousand(int num) { |
|
std::string result; |
|
|
|
if (num >= 100) { |
|
result += ones.at(num / 100) + " hundred "; |
|
num %= 100; |
|
} |
|
|
|
if (num >= 20) { |
|
result += tens.at(num / 10); |
|
if (num % 10 > 0) { |
|
result += "-" + ones.at(num % 10); |
|
} |
|
} else if (num > 0) { |
|
result += ones.at(num); |
|
} |
|
|
|
return result; |
|
} |
|
|
|
static std::string number_to_words(const std::string & number_str) { |
|
try { |
|
size_t decimal_pos = number_str.find('.'); |
|
std::string integer_part = number_str.substr(0, decimal_pos); |
|
|
|
int int_number = std::stoi(integer_part); |
|
std::string result; |
|
|
|
if (int_number == 0) { |
|
result = "zero"; |
|
} else { |
|
if (int_number >= 1000000000) { |
|
int billions = int_number / 1000000000; |
|
result += convert_less_than_thousand(billions) + " billion "; |
|
int_number %= 1000000000; |
|
} |
|
|
|
if (int_number >= 1000000) { |
|
int millions = int_number / 1000000; |
|
result += convert_less_than_thousand(millions) + " million "; |
|
int_number %= 1000000; |
|
} |
|
|
|
if (int_number >= 1000) { |
|
int thousands = int_number / 1000; |
|
result += convert_less_than_thousand(thousands) + " thousand "; |
|
int_number %= 1000; |
|
} |
|
|
|
if (int_number > 0) { |
|
result += convert_less_than_thousand(int_number); |
|
} |
|
} |
|
|
|
|
|
if (decimal_pos != std::string::npos) { |
|
result += " point"; |
|
std::string decimal_part = number_str.substr(decimal_pos + 1); |
|
for (char digit : decimal_part) { |
|
result += " " + ones.at(digit - '0'); |
|
} |
|
} |
|
|
|
return result; |
|
} catch (const std::exception& e) { |
|
|
|
return " "; |
|
} |
|
} |
|
|
|
static std::string replace_numbers_with_words(const std::string & input_text) { |
|
std::regex number_pattern(R"(\d+(\.\d+)?)"); |
|
std::string result; |
|
auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern); |
|
auto end = std::sregex_iterator(); |
|
|
|
size_t last_pos = 0; |
|
for (std::sregex_iterator i = it; i != end; ++i) { |
|
const std::smatch& match = *i; |
|
result.append(input_text, last_pos, match.position() - last_pos); |
|
result.append(number_to_words(match.str())); |
|
last_pos = match.position() + match.length(); |
|
} |
|
result.append(input_text, last_pos); |
|
|
|
return result; |
|
} |
|
|
|
static std::string process_text(const std::string & text, TTS_VER ver) { |
|
|
|
std::string processed_text = replace_numbers_with_words(text); |
|
|
|
std::transform(processed_text.begin(), processed_text.end(), |
|
processed_text.begin(), ::tolower); |
|
|
|
if(ver==TTS_VER_2) |
|
{ |
|
|
|
processed_text = std::regex_replace(processed_text, std::regex(R"(([,.!?])\1+)"), "$1"); |
|
|
|
processed_text = std::regex_replace(processed_text, std::regex(R"(([.,?!])([^\s]))"), "$1 $2"); |
|
std::regex special_chars(R"([\(\)\[\]\{\}\:-_/,\.\\])"); |
|
processed_text = std::regex_replace(processed_text, special_chars, " "); |
|
std::regex non_alpha(R"([^a-z\s])"); |
|
processed_text = std::regex_replace(processed_text, non_alpha, ""); |
|
std::regex multiple_spaces(R"(\s+)"); |
|
processed_text = std::regex_replace(processed_text, multiple_spaces, " "); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), ""); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>"); |
|
} else { |
|
std::regex special_chars(R"([\(\)\[\]\{\}\:-_/\\])"); |
|
processed_text = std::regex_replace(processed_text, special_chars, " "); |
|
std::regex non_alpha(R"([^a-z\s.,?!])"); |
|
processed_text = std::regex_replace(processed_text, non_alpha, ""); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s+)"), " "); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(([,.!?])\1+)"), "$1"); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s+([.,!?]))"), "$1"); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(([.,?!])([^\s]))"), "$1 $2"); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\,)"), "<|comma|>"); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\.)"), "<|period|>"); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\?)"), "<|question_mark|>"); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\!)"), "<|exclamation_mark|>"); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s+)"), " "); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), ""); |
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|space|>"); |
|
} |
|
|
|
return processed_text; |
|
} |
|
|
|
static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) { |
|
prompt.insert(prompt.end(), tokens.begin(), tokens.end()); |
|
} |
|
static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) { |
|
auto tmp = common_tokenize(vocab, txt, add_special, parse_special); |
|
prompt_add(prompt, tmp); |
|
} |
|
static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { |
|
prompt.clear(); |
|
prompt_add(prompt, vocab, "<|im_start|>\n<|text_start|>", true, true); |
|
} |
|
|
|
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string& str, TTS_VER ver) |
|
{ |
|
std::string delimiter = "<|text_sep|>"; |
|
if(ver==TTS_VER_3) |
|
{ |
|
delimiter = "<|space|>"; |
|
} |
|
|
|
std::vector<llama_token> result; |
|
size_t start = 0; |
|
size_t end = str.find(delimiter); |
|
|
|
while (end != std::string::npos) { |
|
std::string current_word = str.substr(start, end - start); |
|
auto tmp = common_tokenize(vocab, current_word, false, true); |
|
result.push_back(tmp[0]); |
|
start = end + delimiter.length(); |
|
end = str.find(delimiter, start); |
|
} |
|
|
|
|
|
std::string current_word = str.substr(start); |
|
if(current_word!="") |
|
{ |
|
auto tmp = common_tokenize(vocab, current_word, false, true); |
|
if(tmp.size()>0){ |
|
result.push_back(tmp[0]); |
|
} |
|
} |
|
|
|
return result; |
|
} |
|
|
|
std::string format_audiotokens(const std::string& input, TTS_VER ver) |
|
{ |
|
if (ver == TTS_VER_2) { |
|
|
|
return input; |
|
} else { |
|
std::string clean = std::regex_replace(input, std::regex(R"(<\|code_start\|>)"), ""); |
|
clean = std::regex_replace(clean, std::regex(R"(<\|code_end\|>)"), "<|space|>"); |
|
return clean; |
|
} |
|
} |
|
|
|
std::string trim_words(const std::string& input, const std::string& separator, size_t maxWords) { |
|
|
|
std::vector<std::string> words; |
|
size_t start = 0, end; |
|
while ((end = input.find(separator, start)) != std::string::npos) { |
|
std::string last = input.substr(start, end - start); |
|
if (last != "") { |
|
words.push_back(last); |
|
} |
|
start = end + separator.length(); |
|
} |
|
std::string last = input.substr(start); |
|
if(last!="") |
|
{ |
|
words.push_back(last); |
|
} |
|
|
|
|
|
if (words.size() > maxWords) { |
|
words.resize(maxWords); |
|
} |
|
|
|
|
|
std::ostringstream result; |
|
for (size_t i = 0; i < words.size(); ++i) { |
|
if (i > 0) result << separator; |
|
result << words[i]; |
|
} |
|
|
|
return result.str(); |
|
} |
|
|
|
static llama_context * ttc_ctx = nullptr; |
|
static llama_context * cts_ctx = nullptr; |
|
|
|
static TTS_VER ttsver = TTS_VER_2; |
|
static int ttsdebugmode = 0; |
|
static bool tts_is_quiet = false; |
|
static std::string ttsplatformenv, ttsdeviceenv, ttsvulkandeviceenv; |
|
static std::string last_generated_audio = ""; |
|
static std::string last_generation_settings_prompt = ""; |
|
static int last_generation_settings_speaker_seed; |
|
static int last_generation_settings_audio_seed; |
|
static std::vector<llama_token> last_speaker_codes; |
|
static int last_speaker_seed = -999; |
|
static int cts_offset = 151672; |
|
static int space_id = 151670; |
|
static int code_terminate_id = 151670; |
|
static int nthreads = 4; |
|
static int tts_max_len = 4096; |
|
|
|
int total_tts_gens = 0; |
|
|
|
bool ttstype_load_model(const tts_load_model_inputs inputs) |
|
{ |
|
tts_is_quiet = inputs.quiet; |
|
|
|
|
|
int cl_parseinfo = inputs.clblast_info; |
|
std::string usingclblast = "GGML_OPENCL_CONFIGURED="+std::to_string(cl_parseinfo>0?1:0); |
|
putenv((char*)usingclblast.c_str()); |
|
cl_parseinfo = cl_parseinfo%100; |
|
int platform = cl_parseinfo/10; |
|
int devices = cl_parseinfo%10; |
|
ttsplatformenv = "GGML_OPENCL_PLATFORM="+std::to_string(platform); |
|
ttsdeviceenv = "GGML_OPENCL_DEVICE="+std::to_string(devices); |
|
putenv((char*)ttsplatformenv.c_str()); |
|
putenv((char*)ttsdeviceenv.c_str()); |
|
std::string vulkan_info_raw = inputs.vulkan_info; |
|
std::string vulkan_info_str = ""; |
|
for (size_t i = 0; i < vulkan_info_raw.length(); ++i) { |
|
vulkan_info_str += vulkan_info_raw[i]; |
|
if (i < vulkan_info_raw.length() - 1) { |
|
vulkan_info_str += ","; |
|
} |
|
} |
|
if(vulkan_info_str!="") |
|
{ |
|
ttsvulkandeviceenv = "GGML_VK_VISIBLE_DEVICES="+vulkan_info_str; |
|
putenv((char*)ttsvulkandeviceenv.c_str()); |
|
} |
|
|
|
llama_backend_init(); |
|
|
|
std::string modelfile_ttc = inputs.ttc_model_filename; |
|
std::string modelfile_cts = inputs.cts_model_filename; |
|
printf("\nLoading TTS Model, OuteTTS: %s \nWavTokenizer: %s \n",modelfile_ttc.c_str(),modelfile_cts.c_str()); |
|
|
|
ttsdebugmode = inputs.debugmode; |
|
|
|
|
|
llama_model_params tts_model_params = llama_model_default_params(); |
|
llama_context_params tts_ctx_params = llama_context_default_params(); |
|
|
|
nthreads = inputs.threads; |
|
|
|
tts_max_len = inputs.ttsmaxlen; |
|
|
|
tts_model_params.use_mmap = false; |
|
tts_model_params.use_mlock = false; |
|
tts_model_params.n_gpu_layers = inputs.gpulayers; |
|
tts_model_params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; |
|
tts_ctx_params.n_ctx = 8192; |
|
tts_ctx_params.logits_all = false; |
|
tts_ctx_params.offload_kqv = true; |
|
tts_ctx_params.n_batch = 8192; |
|
tts_ctx_params.n_ubatch = 512; |
|
tts_ctx_params.n_threads = nthreads; |
|
tts_ctx_params.n_threads_batch = nthreads; |
|
tts_ctx_params.flash_attn = inputs.flash_attention; |
|
|
|
llama_model * ttcmodel = llama_model_load_from_file(modelfile_ttc.c_str(), tts_model_params); |
|
ttc_ctx = llama_new_context_with_model(ttcmodel, tts_ctx_params); |
|
|
|
if (ttc_ctx == nullptr) { |
|
printf("\nTTS Load Error: Failed to initialize ttc context!\n"); |
|
return false; |
|
} |
|
|
|
llama_model * ctsmodel = llama_model_load_from_file(modelfile_cts.c_str(), tts_model_params); |
|
|
|
tts_ctx_params.embeddings = true; |
|
cts_ctx = llama_new_context_with_model(ctsmodel, tts_ctx_params); |
|
|
|
if (cts_ctx == nullptr) { |
|
printf("\nTTS Load Error: Failed to initialize cts context!\n"); |
|
return false; |
|
} |
|
|
|
std::vector<int> tmp = {1, 2, 3, 4}; |
|
llama_kv_cache_clear(ttc_ctx); |
|
auto er = llama_decode(ttc_ctx, llama_batch_get_one(tmp.data(), tmp.size())); |
|
if(er!=0) |
|
{ |
|
printf("\nTTS Eval returned nonzero: %d\n",er); |
|
return false; |
|
} |
|
|
|
const llama_vocab * ttcvocab = llama_model_get_vocab(ttcmodel); |
|
llama_tokens testoks = common_tokenize(ttcvocab,"<|space|>",false,true); |
|
if (testoks.size() == 1) { |
|
ttsver = TTS_VER_3; |
|
printf("\nUsing v0.3 mode"); |
|
|
|
space_id = testoks[0]; |
|
testoks = common_tokenize(ttcvocab,"<|audio_end|>",false,true); |
|
if (testoks.size() == 1) { |
|
code_terminate_id = testoks[0]; |
|
} |
|
} else { |
|
ttsver = TTS_VER_2; |
|
printf("\nUsing v0.2 mode"); |
|
} |
|
|
|
|
|
testoks = common_tokenize(ttcvocab,"<|0|>",false,true); |
|
if (testoks.size() == 1) { |
|
cts_offset = testoks[0]; |
|
} |
|
|
|
printf("\nTTS Load Complete.\n"); |
|
return true; |
|
} |
|
|
|
tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs) |
|
{ |
|
tts_generation_outputs output; |
|
|
|
if(ttc_ctx==nullptr || cts_ctx==nullptr) |
|
{ |
|
printf("\nWarning: KCPP TTS not initialized! Make sure both TTS and WavTokenizer models are loaded.\n"); |
|
output.data = ""; |
|
output.status = 0; |
|
return output; |
|
} |
|
|
|
std::vector<llama_token> codes; |
|
std::vector<llama_token> guide_tokens; |
|
const llama_model * model_ttc = &(ttc_ctx->model); |
|
const llama_vocab * ttcvocab = llama_model_get_vocab(model_ttc); |
|
const llama_model * model_cts = &(cts_ctx->model); |
|
const llama_vocab * ctsvocab = llama_model_get_vocab(model_cts); |
|
const int ttc_n_vocab = llama_vocab_n_tokens(ttcvocab); |
|
std::string prompt = inputs.prompt; |
|
const std::string sampletext = process_text("but that is what it is",ttsver); |
|
|
|
|
|
llama_kv_cache_clear(ttc_ctx); |
|
llama_kv_cache_clear(cts_ctx); |
|
std::vector<llama_token> prompt_inp; |
|
prompt_init(prompt_inp, ttcvocab); |
|
|
|
int speaker_seed = inputs.speaker_seed; |
|
int audio_seed = inputs.audio_seed; |
|
if (speaker_seed <= 0 || speaker_seed==0xFFFFFFFF) |
|
{ |
|
speaker_seed = (((uint32_t)time(NULL)) % 1000000u); |
|
} |
|
if (audio_seed <= 0 || audio_seed==0xFFFFFFFF) |
|
{ |
|
audio_seed = (((uint32_t)time(NULL)) % 1000000u); |
|
} |
|
if(ttsdebugmode==1 && !tts_is_quiet) |
|
{ |
|
printf("\nUsing Speaker Seed: %d", speaker_seed); |
|
printf("\nUsing Audio Seed: %d", audio_seed); |
|
} |
|
|
|
std::mt19937 tts_rng(audio_seed); |
|
std::mt19937 speaker_rng(speaker_seed); |
|
|
|
|
|
int n_decode = 0; |
|
int n_predict = 2048; |
|
bool next_token_uses_guide_token = true; |
|
|
|
|
|
std::string prompt_clean = process_text(prompt,ttsver); |
|
bool empty_check = (process_text(prompt,TTS_VER_2).size()==0); |
|
|
|
|
|
prompt_clean = trim_words(prompt_clean,(ttsver==TTS_VER_3?"<|space|>":"<|text_sep|>"),300); |
|
|
|
if(empty_check) |
|
{ |
|
|
|
if(!tts_is_quiet) |
|
{ |
|
printf("\nTTS sent empty input.\n"); |
|
last_generated_audio = ""; |
|
output.data = last_generated_audio.c_str(); |
|
output.status = 1; |
|
return output; |
|
} |
|
} |
|
|
|
double ttstime = 0; |
|
timer_start(); |
|
|
|
|
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nInput: %s\n", prompt_clean.c_str()); |
|
} |
|
|
|
llama_token newlineid = common_tokenize(ttcvocab,"\n",false,true)[0]; |
|
|
|
|
|
|
|
if(speaker_seed>0) |
|
{ |
|
|
|
if(last_speaker_seed==speaker_seed && !last_speaker_codes.empty()) |
|
{ |
|
|
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nReuse speaker ID=%d (%d tokens)...", last_speaker_seed, last_speaker_codes.size()); |
|
} |
|
} else if (speaker_seed>=1 && speaker_seed<=5){ |
|
std::string speaker = ""; |
|
switch(speaker_seed) |
|
{ |
|
case 1: |
|
speaker = format_audiotokens("but<|t_0.31|><|code_start|><|1023|><|1474|><|17|><|121|><|1362|><|744|><|438|><|1319|><|744|><|1419|><|1246|><|923|><|1338|><|406|><|939|><|975|><|1491|><|965|><|1212|><|248|><|794|><|464|><|830|><|code_end|>\nthat<|t_0.13|><|code_start|><|1578|><|1773|><|660|><|1074|><|221|><|1803|><|142|><|914|><|798|><|485|><|code_end|>\nis<|t_0.11|><|code_start|><|737|><|794|><|1288|><|182|><|895|><|1653|><|448|><|471|><|code_end|>\nwhat<|t_0.12|><|code_start|><|1734|><|1306|><|779|><|490|><|525|><|1028|><|37|><|1633|><|1353|><|code_end|>\nit<|t_0.09|><|code_start|><|1343|><|898|><|270|><|1035|><|94|><|1409|><|388|><|code_end|>\nis<|t_0.23|><|code_start|><|694|><|695|><|577|><|692|><|1047|><|388|><|28|><|905|><|1155|><|50|><|1629|><|1775|><|1711|><|1729|><|404|><|1027|><|344|><|code_end|>",ttsver); |
|
break; |
|
case 2: |
|
speaker = format_audiotokens("but<|t_0.45|><|code_start|><|920|><|1824|><|1138|><|1387|><|1096|><|1712|><|1642|><|810|><|1685|><|620|><|954|><|584|><|23|><|1467|><|509|><|659|><|1598|><|465|><|567|><|1440|><|3|><|476|><|740|><|288|><|419|><|1440|><|1477|><|254|><|25|><|811|><|882|><|476|><|246|><|246|><|code_end|>\nthat<|t_0.17|><|code_start|><|419|><|1690|><|208|><|1044|><|300|><|1100|><|375|><|1222|><|371|><|1045|><|637|><|1719|><|314|><|code_end|>\nis<|t_0.12|><|code_start|><|319|><|1131|><|794|><|1103|><|1296|><|1615|><|1587|><|233|><|863|><|code_end|>\nwhat<|t_0.16|><|code_start|><|793|><|902|><|391|><|946|><|437|><|95|><|1133|><|110|><|58|><|853|><|1283|><|449|><|code_end|>\nit<|t_0.12|><|code_start|><|774|><|239|><|974|><|213|><|1095|><|1612|><|101|><|1569|><|882|><|code_end|>\nis<|t_0.32|><|code_start|><|1131|><|529|><|1144|><|774|><|1114|><|483|><|693|><|648|><|1112|><|1470|><|1112|><|319|><|1294|><|1417|><|1660|><|729|><|1789|><|1413|><|1728|><|554|><|273|><|736|><|640|><|1549|><|code_end|>",ttsver); |
|
break; |
|
case 3: |
|
speaker = format_audiotokens("but<|t_0.21|><|code_start|><|348|><|1776|><|1620|><|1262|><|118|><|288|><|258|><|1407|><|1331|><|1102|><|664|><|1300|><|1647|><|1536|><|71|><|23|><|code_end|> \nthat<|t_0.19|><|code_start|><|3|><|1740|><|1253|><|1122|><|549|><|715|><|718|><|657|><|1136|><|1247|><|517|><|1333|><|815|><|634|><|code_end|>\nis<|t_0.12|><|code_start|><|1330|><|839|><|753|><|1826|><|1602|><|50|><|1441|><|889|><|948|><|code_end|>\nwhat<|t_0.16|><|code_start|><|899|><|869|><|250|><|894|><|876|><|1471|><|1308|><|1436|><|1328|><|1700|><|1425|><|1330|><|code_end|>\nit<|t_0.12|><|code_start|><|1027|><|1162|><|1344|><|1170|><|86|><|1562|><|1575|><|176|><|1186|><|code_end|>\nis<|t_0.25|><|code_start|><|361|><|1533|><|1697|><|903|><|333|><|1232|><|1337|><|1611|><|1196|><|0|><|1328|><|1245|><|1718|><|1635|><|1616|><|1599|><|1363|><|962|><|328|><|code_end|>",ttsver); |
|
break; |
|
case 4: |
|
speaker = format_audiotokens("but<|t_0.20|><|code_start|><|686|><|1288|><|1251|><|1428|><|481|><|702|><|1812|><|829|><|81|><|756|><|76|><|104|><|952|><|1723|><|1632|><|code_end|>\nthat<|t_0.20|><|code_start|><|1006|><|1067|><|1614|><|1810|><|887|><|43|><|1192|><|106|><|400|><|43|><|730|><|660|><|186|><|87|><|467|><|code_end|>\nis<|t_0.27|><|code_start|><|648|><|1625|><|9|><|685|><|243|><|106|><|996|><|990|><|228|><|809|><|1009|><|2|><|806|><|1325|><|1332|><|1766|><|202|><|725|><|416|><|822|><|code_end|>\nwhat<|t_0.36|><|code_start|><|1287|><|328|><|1241|><|1661|><|1651|><|1708|><|1740|><|1685|><|1715|><|1787|><|1381|><|197|><|1769|><|525|><|1000|><|234|><|364|><|115|><|212|><|632|><|1153|><|228|><|73|><|1002|><|1800|><|1277|><|1117|><|code_end|>\nit<|t_0.40|><|code_start|><|1830|><|1199|><|1282|><|1163|><|1195|><|1752|><|1092|><|1481|><|1003|><|513|><|1639|><|1805|><|1485|><|1645|><|195|><|1464|><|181|><|195|><|123|><|87|><|433|><|878|><|170|><|1265|><|375|><|1708|><|1739|><|1519|><|1185|><|1099|><|code_end|>\nis<|t_0.76|><|code_start|><|1748|><|1422|><|276|><|1337|><|1322|><|1519|><|1779|><|1067|><|1724|><|891|><|1205|><|1419|><|1144|><|1667|><|591|><|1003|><|1543|><|566|><|1390|><|426|><|1824|><|182|><|1138|><|52|><|129|><|1056|><|155|><|1056|><|1298|><|919|><|155|><|125|><|500|><|1022|><|571|><|315|><|400|><|100|><|617|><|295|><|757|><|324|><|592|><|1298|><|1310|><|57|><|876|><|1175|><|1353|><|1770|><|1649|><|1828|><|1637|><|362|><|1744|><|884|><|1027|><|code_end|>",ttsver); |
|
break; |
|
case 5: |
|
speaker = format_audiotokens("but<|t_0.68|><|code_start|><|1761|><|1164|><|1543|><|1677|><|1120|><|1634|><|1496|><|1639|><|1717|><|1306|><|1016|><|1713|><|976|><|1474|><|1817|><|976|><|1595|><|1255|><|584|><|1440|><|1121|><|287|><|91|><|44|><|246|><|160|><|1233|><|247|><|776|><|44|><|246|><|12|><|1352|><|866|><|168|><|71|><|246|><|246|><|804|><|933|><|168|><|193|><|44|><|1663|><|1097|><|411|><|1393|><|1326|><|21|><|342|><|118|><|code_end|>\nthat<|t_0.17|><|code_start|><|220|><|1750|><|1160|><|260|><|1738|><|300|><|291|><|989|><|147|><|1150|><|947|><|803|><|930|><|code_end|>\nis<|t_0.15|><|code_start|><|798|><|1632|><|412|><|1084|><|1166|><|1014|><|416|><|1637|><|415|><|1|><|1660|><|code_end|>\nwhat<|t_0.21|><|code_start|><|1412|><|707|><|572|><|1092|><|898|><|673|><|770|><|1787|><|994|><|983|><|1096|><|221|><|924|><|1323|><|1726|><|387|><|code_end|>\nit<|t_0.12|><|code_start|><|798|><|665|><|513|><|695|><|1410|><|337|><|237|><|1717|><|1353|><|code_end|>\nis<|t_0.24|><|code_start|><|1355|><|1084|><|65|><|1422|><|674|><|1280|><|940|><|1752|><|396|><|1431|><|1761|><|957|><|1440|><|634|><|333|><|1627|><|821|><|788|><|code_end|>",ttsver); |
|
break; |
|
} |
|
last_speaker_codes = common_tokenize(ttcvocab, speaker, false, true); |
|
last_speaker_seed = speaker_seed; |
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nSpecial ID=%d (%d tokens)...", last_speaker_seed, last_speaker_codes.size()); |
|
} |
|
} else { |
|
|
|
last_speaker_codes.clear(); |
|
guide_tokens = prepare_guide_tokens(ttcvocab,sampletext,ttsver); |
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nGuide Tokens (%d tokens):\n", guide_tokens.size()); |
|
const std::string inp_txt = common_detokenize(ttc_ctx, guide_tokens, true); |
|
printf("%s,", inp_txt.c_str()); |
|
printf("\n"); |
|
} |
|
prompt_add(prompt_inp, ttcvocab, sampletext, false, true); |
|
prompt_add(prompt_inp, ttcvocab, "<|text_end|>\n<|audio_start|>\n", false, true); |
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nPrepare new speaker (%d input tokens)...\n", prompt_inp.size()); |
|
print_tok_vec(prompt_inp); |
|
} |
|
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, false); |
|
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0); |
|
if (!evalok) { |
|
printf("\nError: TTS prompt batch processing failed\n"); |
|
output.data = ""; |
|
output.status = 0; |
|
return output; |
|
} |
|
|
|
while (n_decode <= n_predict) |
|
{ |
|
float * logits = llama_get_logits(ttc_ctx); |
|
|
|
|
|
const int topk = 20; |
|
const float top_p = 1.0f; |
|
const float temp = 1.2f; |
|
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector<int32_t>(),1.0,top_p,topk,temp,speaker_rng); |
|
|
|
|
|
if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id)) |
|
{ |
|
if(!guide_tokens.empty()) |
|
{ |
|
llama_token guide_token = guide_tokens[0]; |
|
guide_tokens.erase(guide_tokens.begin()); |
|
new_token_id = guide_token; |
|
} else { |
|
n_decode = n_predict; |
|
} |
|
} |
|
|
|
|
|
next_token_uses_guide_token = (new_token_id == newlineid); |
|
last_speaker_codes.push_back(new_token_id); |
|
|
|
|
|
if (llama_vocab_is_eog(ttcvocab, new_token_id) || n_decode >= n_predict) { |
|
break; |
|
} |
|
|
|
n_decode += 1; |
|
std::vector<llama_token> next = {new_token_id}; |
|
llama_batch batch = llama_batch_get_one(next.data(), next.size()); |
|
|
|
|
|
if (llama_decode(ttc_ctx, batch)) { |
|
printf("\nError: TTS code generation failed!\n"); |
|
output.data = ""; |
|
output.status = 0; |
|
return output; |
|
} |
|
} |
|
|
|
|
|
auto it = std::find(last_speaker_codes.rbegin(), last_speaker_codes.rend(), code_terminate_id); |
|
if (it != last_speaker_codes.rend()) { |
|
|
|
last_speaker_codes.erase(it.base(), last_speaker_codes.end()); |
|
if(ttsver==TTS_VER_3 && last_speaker_codes.size()>2) |
|
{ |
|
last_speaker_codes.pop_back(); |
|
last_speaker_codes.pop_back(); |
|
last_speaker_codes.push_back(space_id); |
|
} |
|
} |
|
last_speaker_seed = speaker_seed; |
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nNew speaker ID=%d created (%d tokens)...", last_speaker_seed, last_speaker_codes.size()); |
|
const std::string inp_txt = common_detokenize(ttc_ctx, last_speaker_codes, true); |
|
printf("\n%s\n", inp_txt.c_str()); |
|
} |
|
} |
|
guide_tokens.clear(); |
|
llama_kv_cache_clear(ttc_ctx); |
|
prompt_init(prompt_inp, ttcvocab); |
|
next_token_uses_guide_token = true; |
|
} |
|
|
|
|
|
guide_tokens = prepare_guide_tokens(ttcvocab,prompt_clean,ttsver); |
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nGuide Tokens (%d tokens):\n", guide_tokens.size()); |
|
const std::string inp_txt = common_detokenize(ttc_ctx, guide_tokens, true); |
|
printf("%s", inp_txt.c_str()); |
|
printf("\n"); |
|
} |
|
if(speaker_seed > 0) |
|
{ |
|
prompt_clean = sampletext + (ttsver==TTS_VER_3?"<|space|>":"<|text_sep|>") + prompt_clean; |
|
} |
|
prompt_add(prompt_inp, ttcvocab, prompt_clean, false, true); |
|
|
|
if(!tts_is_quiet) |
|
{ |
|
printf("\nTTS Processing (%d input tokens)...\n", prompt_inp.size()); |
|
} |
|
|
|
prompt_add(prompt_inp, ttcvocab, "<|text_end|>\n<|audio_start|>\n", false, true); |
|
|
|
if(!last_speaker_codes.empty() && speaker_seed > 0) |
|
{ |
|
prompt_add(prompt_inp, last_speaker_codes); |
|
prompt_add(prompt_inp, ttcvocab, "\n", false, true); |
|
} |
|
|
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
printf("\nDUMP TTS PROMPT (%d tokens):\n", prompt_inp.size()); |
|
print_tok_vec(prompt_inp); |
|
const std::string inp_txt = common_detokenize(ttc_ctx, prompt_inp, true); |
|
printf("\n%s\n", inp_txt.c_str()); |
|
} |
|
|
|
|
|
kcpp_embd_batch tts_batch = kcpp_embd_batch(prompt_inp, 0, false, false); |
|
|
|
auto evalok = (llama_decode(ttc_ctx, tts_batch.batch)==0); |
|
if (!evalok) { |
|
printf("\nError: TTS prompt batch processing failed\n"); |
|
output.data = ""; |
|
output.status = 0; |
|
return output; |
|
} |
|
|
|
|
|
n_decode = 0; |
|
n_predict = tts_max_len; |
|
|
|
while (n_decode <= n_predict) |
|
{ |
|
float * logits = llama_get_logits(ttc_ctx); |
|
|
|
|
|
const int topk = 4; |
|
const float temp = 0.75f; |
|
const float top_p = 1.0f; |
|
llama_token new_token_id = kcpp_quick_sample(logits,ttc_n_vocab,std::vector<int32_t>(),1.0,top_p,topk,temp,speaker_rng); |
|
|
|
|
|
if(next_token_uses_guide_token && !llama_vocab_is_control(ttcvocab, new_token_id) && !llama_vocab_is_eog(ttcvocab, new_token_id)) |
|
{ |
|
if(!guide_tokens.empty()) |
|
{ |
|
llama_token guide_token = guide_tokens[0]; |
|
guide_tokens.erase(guide_tokens.begin()); |
|
new_token_id = guide_token; |
|
} else { |
|
n_decode = n_predict; |
|
} |
|
} |
|
|
|
|
|
next_token_uses_guide_token = (new_token_id == newlineid); |
|
codes.push_back(new_token_id); |
|
|
|
|
|
if (llama_vocab_is_eog(ttcvocab, new_token_id) || n_decode >= n_predict) { |
|
break; |
|
} |
|
|
|
n_decode += 1; |
|
std::vector<llama_token> next = {new_token_id}; |
|
llama_batch batch = llama_batch_get_one(next.data(), next.size()); |
|
|
|
|
|
if (llama_decode(ttc_ctx, batch)) { |
|
printf("\nError: TTS code generation failed!\n"); |
|
output.data = ""; |
|
output.status = 0; |
|
return output; |
|
} |
|
if(!tts_is_quiet) |
|
{ |
|
printf("\rTTS Generating (%d outputs)", n_decode); |
|
} |
|
} |
|
|
|
if(!tts_is_quiet && ttsdebugmode==1) |
|
{ |
|
const std::string inp_txt = common_detokenize(ttc_ctx, codes, true); |
|
printf("\nGenerated %d Codes: '%s'\n",codes.size(), inp_txt.c_str()); |
|
} |
|
|
|
|
|
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < cts_offset || t > (cts_offset+4100); }), codes.end()); |
|
|
|
for (auto & token : codes) { |
|
token -= cts_offset; |
|
} |
|
|
|
const int n_codes = codes.size(); |
|
if(n_codes<=1) |
|
{ |
|
printf("\nWarning: No Audio Tokens Produced!\n"); |
|
last_generated_audio = ""; |
|
output.data = last_generated_audio.c_str(); |
|
output.status = 1; |
|
return output; |
|
} |
|
kcpp_embd_batch codebatch = kcpp_embd_batch(codes,0,false,true); |
|
printf("\nRunning Vocoder (%d AudioTokens)", codes.size()); |
|
|
|
if (llama_decode(cts_ctx, codebatch.batch) != 0) { |
|
printf("\nError: TTS vocoder generation failed!\n"); |
|
output.data = ""; |
|
output.status = 0; |
|
return output; |
|
} |
|
else |
|
{ |
|
|
|
const int n_embd = llama_model_n_embd(model_cts); |
|
const float * embd = llama_get_embeddings(cts_ctx); |
|
std::vector<float> audio = embd_to_audio(embd, n_codes, n_embd, nthreads); |
|
|
|
const int n_sr = 24000; |
|
const int t_sr = 24000; |
|
|
|
|
|
const int cutout = t_sr/4; |
|
|
|
|
|
|
|
if(audio.size()>cutout+16) |
|
{ |
|
for (int i = 0; i < cutout; ++i) { |
|
audio[i] = 0.0f; |
|
} |
|
|
|
for (int i = 0; i < cutout; ++i) { |
|
audio.push_back(0.0f); |
|
} |
|
} |
|
else |
|
{ |
|
printf("\nWarning: TTS vocoder generated nothing!\n"); |
|
last_generated_audio = ""; |
|
output.data = last_generated_audio.c_str(); |
|
output.status = 1; |
|
return output; |
|
} |
|
|
|
last_generated_audio = save_wav16_base64(audio, t_sr); |
|
ttstime = timer_check(); |
|
|
|
printf("\nTTS Generated %d audio tokens in %.2fs.\n",(int) codes.size(),ttstime); |
|
|
|
output.data = last_generated_audio.c_str(); |
|
output.status = 1; |
|
|
|
last_generation_settings_audio_seed = inputs.audio_seed; |
|
last_generation_settings_speaker_seed = inputs.speaker_seed; |
|
last_generation_settings_prompt = std::string(inputs.prompt); |
|
total_tts_gens += 1; |
|
|
|
return output; |
|
} |
|
} |
|
|