Stardust-minus commited on
Commit
f65fe2e
·
verified ·
1 Parent(s): 492fb71

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -41,3 +41,6 @@ examples/English.wav filter=lfs diff=lfs merge=lfs -text
41
  examples/French.wav filter=lfs diff=lfs merge=lfs -text
42
  examples/German.wav filter=lfs diff=lfs merge=lfs -text
43
  examples/Spanish.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
41
  examples/French.wav filter=lfs diff=lfs merge=lfs -text
42
  examples/German.wav filter=lfs diff=lfs merge=lfs -text
43
  examples/Spanish.wav filter=lfs diff=lfs merge=lfs -text
44
+ 022b2161-8f56-4432-a9ae-b4bd514e4821.mp3 filter=lfs diff=lfs merge=lfs -text
45
+ output.wav filter=lfs diff=lfs merge=lfs -text
46
+ ref.wav filter=lfs diff=lfs merge=lfs -text
022b2161-8f56-4432-a9ae-b4bd514e4821.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb37bcf53feb185993aeb8a7f9b96f055b60ed6d0d96fe5a6833db1c0efba0f0
3
+ size 647000
examples/Arabic.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a3c902c13fcf408c95353d91ab65f839d27584d8929c7345317956d1e9ea5bd
3
- size 131
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79baad393ddae4d975e0a1e04065fe18d655104b6dd3db1e035e28f391c4d78f
3
+ size 128
examples/English.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ed744820849c8f16e03cb68e45b7d7d4b8697476a162d50ffe2cd6612a621aa6
3
- size 131
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:295ab67b022169527d1b3d564df6163900e8d45e39e069890f8b7b912f0bda5d
3
+ size 128
examples/French.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dee830ddff631df6e0db0911a20099ddf6438a80d1da597536470ba36e2d645c
3
- size 131
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1db0708d546351aa9e757adad0f97f8376962f8fcbfd28dd7d574ec6929f3bb
3
+ size 128
examples/German.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc076529638f0a4bb8d19b509b7781372c26abadcc74a7dcbc5b72b6b1e680fd
3
- size 131
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c23b9798f9e0eb659d0d1ae7d98a74dbc728cd06899e5e84e0fcc519c4613e70
3
+ size 128
examples/Japanese.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ba2a2c07770cb6ab36a5aa6ee953c9914773368e223359e4710897d425a25402
3
  size 128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:277101318b7c174690280daea1402701e4e176abce625853c585a256b776d685
3
  size 128
examples/Korean.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:09c122b25a3ad99247179be77deeaa6ead7d93b40092347801948fea34797e48
3
  size 128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a91634e1c44008d2d9a01ff9d63f50551080c8102d6902c99e6ff00077e8d715
3
  size 128
examples/Nice English Ref.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b895ec0d49173630cf9253c70579888cde65129fbaeda167e3b4f91593715eca
3
  size 128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7846d8d9cf1f149a4f9fb454040561c44401f2a9878c565ab4781d346e8a9436
3
  size 128
examples/Spanish.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c22d63058f58f46c6a65b6ced8faa969f403b065e822a274342b520e8e20b65f
3
- size 131
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d434eb34ed3d103579f1e93a0cf87d84a9e0f70b19d08fdb5b32a4f6e40cc3e1
3
+ size 128
fish_speech/models/text2semantic/inference.py CHANGED
@@ -339,7 +339,7 @@ def generate_long(
339
  temperature: float = 0.8,
340
  compile: bool = False,
341
  iterative_prompt: bool = True,
342
- chunk_length: int = 150,
343
  prompt_text: Optional[str | list[str]] = None,
344
  prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
345
  ):
@@ -365,6 +365,24 @@ def generate_long(
365
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
366
  max_length = model.config.max_seq_len
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  if use_prompt:
369
  for t, c in zip(prompt_text, prompt_tokens):
370
  base_content_sequence.append(
@@ -385,7 +403,7 @@ def generate_long(
385
 
386
  encoded = []
387
  for text in texts:
388
- content_sequence = ContentSequence(modality=None)
389
  content_sequence.append(TextPart(text=text))
390
  encoded.append(
391
  content_sequence.encode_for_inference(
 
339
  temperature: float = 0.8,
340
  compile: bool = False,
341
  iterative_prompt: bool = True,
342
+ chunk_length: int = 512,
343
  prompt_text: Optional[str | list[str]] = None,
344
  prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
345
  ):
 
365
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
366
  max_length = model.config.max_seq_len
367
 
368
+ # if use_prompt:
369
+ # base_content_sequence.append(
370
+ # [
371
+ # TextPart(text=prompt_text[0]),
372
+ # VQPart(codes=prompt_tokens[0]),
373
+ # ],
374
+ # add_end=True,
375
+ # )
376
+
377
+ # for text in texts:
378
+ # content_sequence = ContentSequence(modality=None)
379
+ # base_content_sequence.append(
380
+ # [
381
+ # TextPart(text=text),
382
+ # ],
383
+ # add_end=True,
384
+ # )
385
+
386
  if use_prompt:
387
  for t, c in zip(prompt_text, prompt_tokens):
388
  base_content_sequence.append(
 
403
 
404
  encoded = []
405
  for text in texts:
406
+ content_sequence = ContentSequence(modality="text")
407
  content_sequence.append(TextPart(text=text))
408
  encoded.append(
409
  content_sequence.encode_for_inference(
fish_speech/models/text2semantic/llama.py CHANGED
@@ -48,7 +48,7 @@ class BaseModelArgs:
48
 
49
  # Codebook configs
50
  codebook_size: int = 160
51
- num_codebooks: int = 4
52
 
53
  # Gradient checkpointing
54
  use_gradient_checkpointing: bool = True
 
48
 
49
  # Codebook configs
50
  codebook_size: int = 160
51
+ num_codebooks: int = 9
52
 
53
  # Gradient checkpointing
54
  use_gradient_checkpointing: bool = True
generate_cli.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import queue
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import click
8
+ import torch
9
+ import soundfile as sf
10
+ from loguru import logger
11
+
12
+ from fish_speech.models.text2semantic.inference import (
13
+ CodebookSamplingParams,
14
+ SamplingParams,
15
+ generate_long,
16
+ launch_thread_safe_queue,
17
+ GenerateRequest,
18
+ WrappedGenerateResponse,
19
+ )
20
+ from fish_speech.models.text2semantic.llama import BaseTransformer
21
+ from fish_speech.models.dac.inference import load_model as load_decoder_model
22
+ from fish_speech.text import clean_text
23
+ from fish_speech.inference_engine.vq_manager import VQManager
24
+ from tools.api import load_audio
25
+
26
+
27
+ def load_llm_model(model_path: str, device: str, compile: bool = False):
28
+ """加载LLM模型"""
29
+ logger.info(f"Loading LLM model from {model_path}")
30
+ model = BaseTransformer.from_pretrained(
31
+ path=model_path,
32
+ load_weights=True,
33
+ )
34
+ model = model.to(device=device, dtype=torch.bfloat16)
35
+
36
+ if isinstance(model, model.__class__.__bases__[0].__subclasses__()[1]): # DualARTransformer
37
+ from fish_speech.models.text2semantic.inference import decode_one_token_ar as decode_one_token
38
+ logger.info("Using DualARTransformer")
39
+ else:
40
+ from fish_speech.models.text2semantic.inference import decode_one_token_naive as decode_one_token
41
+ logger.info("Using NaiveTransformer")
42
+
43
+ if compile:
44
+ logger.info("Compiling decode function...")
45
+ decode_one_token = torch.compile(
46
+ decode_one_token,
47
+ fullgraph=True,
48
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
49
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
50
+ )
51
+
52
+ return model.eval(), decode_one_token
53
+
54
+
55
+ def load_dac_model(config_name: str, checkpoint_path: str, device: str):
56
+ """加载DAC模型"""
57
+ logger.info(f"Loading DAC model from {checkpoint_path}")
58
+ model = load_decoder_model(
59
+ config_name=config_name,
60
+ checkpoint_path=checkpoint_path,
61
+ device=device,
62
+ )
63
+ return model
64
+
65
+
66
+ @click.command()
67
+ #@click.argument("text", type=str)
68
+ @click.option("--llm-model-path", type=str, required=True, help="Path to the LLM model")
69
+ @click.option("--dac-model-path", type=str, required=True, help="Path to the DAC model")
70
+ @click.option("--dac-config-name", type=str, default="modded_dac_vq", help="DAC model config name")
71
+ @click.option("--output-path", type=str, required=True, help="Path to save the output audio")
72
+ @click.option("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
73
+ @click.option("--max-new-tokens", type=int, default=4096, help="Maximum new tokens to generate")
74
+ @click.option("--chunk-length", type=int, default=1000, help="Chunk length for synthesis")
75
+ @click.option("--compile", is_flag=True, help="Whether to compile the model")
76
+ @click.option("--iterative-prompt", is_flag=True, help="Whether to use iterative prompt")
77
+ @click.option("--params-file", type=str, default="sampling_params_example.json", help="Path to JSON file containing sampling parameters")
78
+ @click.option(
79
+ "--ref-audio",
80
+ type=click.Path(path_type=Path, exists=True),
81
+ default="ref.wav",
82
+ help="参考音频文件路径,默认ref.wav"
83
+ )
84
+ def main(
85
+ #text: str,
86
+ llm_model_path: str,
87
+ dac_model_path: str,
88
+ dac_config_name: str,
89
+ output_path: str,
90
+ device: str,
91
+ max_new_tokens: int,
92
+ chunk_length: int,
93
+ compile: bool,
94
+ iterative_prompt: bool,
95
+ params_file: Optional[str],
96
+ ref_audio: Path,
97
+ ):
98
+ """生成语音,包括LLM生成token和DAC生成音频两个步骤"""
99
+
100
+ # 设置精度
101
+ precision = torch.half if torch.cuda.is_available() else torch.bfloat16
102
+
103
+ # 加载LLM模型(使用线程安全的队列)
104
+ logger.info("Loading LLM model...")
105
+ llama_queue = launch_thread_safe_queue(
106
+ checkpoint_path=llm_model_path,
107
+ device="cuda:0",
108
+ precision=precision,
109
+ compile=compile,
110
+ )
111
+ logger.info("LLM model loaded")
112
+
113
+ # 加载DAC模型
114
+ logger.info("Loading DAC model...")
115
+ dac_model = load_decoder_model(
116
+ config_name=dac_config_name,
117
+ checkpoint_path=dac_model_path,
118
+ device="cuda:1",
119
+ )
120
+ logger.info("DAC model loaded")
121
+
122
+ # 加载采样参数
123
+ if params_file:
124
+ with open(params_file, "r", encoding="utf-8") as f:
125
+ params_data = json.load(f)
126
+ text = params_data.get("text", "")
127
+
128
+ semantic_params = CodebookSamplingParams(**params_data.get("semantic", {}))
129
+ codebook_params = [
130
+ CodebookSamplingParams(**params) for params in params_data.get("codebooks", [])
131
+ ]
132
+ sampling_params = SamplingParams(
133
+ semantic=semantic_params,
134
+ codebooks=codebook_params,
135
+ )
136
+ else:
137
+ sampling_params = SamplingParams()
138
+
139
+ # 清理文本
140
+ text = clean_text(text)
141
+
142
+ # ���载参考音频
143
+ if not ref_audio.exists():
144
+ ref_audio_data, ref_sr = sf.read(ref_audio)
145
+ logger.info(f"Loaded reference audio: {ref_audio}, shape={ref_audio_data.shape}, sr={ref_sr}")
146
+ # 编码参考音频为prompt_tokens
147
+ vq_manager = VQManager()
148
+ vq_manager.decoder_model = dac_model
149
+ vq_manager.load_audio = load_audio
150
+ prompt_tokens = vq_manager.encode_reference(ref_audio, enable_reference_audio=True)
151
+ logger.info(f"Encoded reference audio to prompt_tokens, shape={prompt_tokens.shape if prompt_tokens is not None else None}")
152
+ else:
153
+ prompt_tokens = []
154
+ logger.warning(f"Reference audio {ref_audio} not found.")
155
+
156
+ # 生成语音
157
+ logger.info(f"Generating speech for text: {text}")
158
+ logger.info(f"Using sampling parameters: {sampling_params}")
159
+
160
+ output_path = Path(output_path)
161
+ if not output_path.suffix:
162
+ output_path = output_path.with_suffix('.wav')
163
+ output_path.parent.mkdir(parents=True, exist_ok=True)
164
+
165
+ # 创建响应队列
166
+ response_queue = queue.Queue()
167
+
168
+ # 准备请求
169
+ request = dict(
170
+ device=device,
171
+ max_new_tokens=max_new_tokens,
172
+ text=text,
173
+ sampling_params=sampling_params,
174
+ compile=compile,
175
+ iterative_prompt=iterative_prompt,
176
+ chunk_length=chunk_length,
177
+ prompt_text=[],
178
+ prompt_tokens=[prompt_tokens] if prompt_tokens is not None and len(prompt_tokens) else [],
179
+ #prompt_text=["Through the dense morning fog that rolled across the peaceful valley, the distant church bells chimed their melodic song, echoing off ancient stone walls and mingling with the gentle rustling of maple leaves in the cool breeze. Inside the cozy lakeside cottage, fresh bread baked in the old clay oven filled every corner with its rich, comforting aroma, while steam rose lazily from ceramic mugs of fresh-brewed coffee on the handcrafted pine table. The persistent rain finally gave way to brilliant sunshine, transforming ordinary dewdrops into countless sparkling diamonds scattered across the vibrant garden flowers."],
180
+ )
181
+
182
+ # 发送请求到LLM模型
183
+ llama_queue.put(GenerateRequest(request=request, response_queue=response_queue))
184
+
185
+ # 收集生成的token
186
+ all_tokens = []
187
+ while True:
188
+ wrapped_result: WrappedGenerateResponse = response_queue.get()
189
+
190
+ if wrapped_result.status == "error":
191
+ error = wrapped_result.response if isinstance(wrapped_result.response, Exception) else Exception("Unknown error")
192
+ logger.error(f"Error during generation: {error}")
193
+ break
194
+
195
+ result = wrapped_result.response
196
+ if result.action == "next":
197
+ break
198
+
199
+ all_tokens.append(result.codes)
200
+ logger.info(f"Generated chunk {len(all_tokens)}")
201
+
202
+ if not all_tokens:
203
+ logger.error("No tokens generated")
204
+ return
205
+
206
+ # 合并所有token
207
+ if len(all_tokens) > 1:
208
+ tokens = torch.cat(all_tokens, dim=1)
209
+ else:
210
+ tokens = all_tokens[0]
211
+
212
+ # 使用DAC模型生成音频
213
+ logger.info("Converting tokens to audio...")
214
+ feature_lengths = torch.tensor([tokens.shape[1]], device=device)
215
+ audio, _ = dac_model.decode(
216
+ indices=tokens[None].to("cuda:1"),
217
+ feature_lengths=feature_lengths.to("cuda:1")
218
+ )
219
+
220
+ # 保存音频
221
+ audio = audio[0, 0].detach().float().cpu().numpy()
222
+ sf.write(output_path, audio, dac_model.sample_rate)
223
+ logger.info(f"Saved audio to {output_path}")
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:043464cdadefbc6144a48155c38f69016d44a1d3eaab261a634719eb5d9162ee
3
+ size 888876
ref.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1acb52e60f3b9eaa66bd289a37a3da7c7b5c64511f42cd0bc8245b57122f354
3
+ size 3566670
sampling_params_example.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "text": "(excited, joyful tone) We're going to DISNEY WORLD! (squeal of delight) I've been saving for (emphasis) three years (breathless) and finally, FINALLY we can go! The look on your face right now is worth every extra shift I worked! (angry) After everything we've been through (break) I can't believe you would (emphasize) betray me like this. I gave you EVERYTHING! And now I'm left with nothing but memories and broken promises!",
3
+ "semantic": {
4
+ "temperature": 0.9,
5
+ "top_p": 0.9,
6
+ "repetition_penalty": 1.05
7
+ },
8
+ "codebooks": [
9
+ {
10
+ "temperature": 0.9,
11
+ "top_p": 0.9,
12
+ "repetition_penalty": 1.05
13
+ },
14
+ {
15
+ "temperature": 0.8,
16
+ "top_p": 0.8,
17
+ "repetition_penalty": 1.1
18
+ },
19
+ {
20
+ "temperature": 0.8,
21
+ "top_p": 0.8,
22
+ "repetition_penalty": 1.1
23
+ },
24
+ {
25
+ "temperature": 0.7,
26
+ "top_p": 0.7,
27
+ "repetition_penalty": 1.1
28
+ },
29
+ {
30
+ "temperature": 0.7,
31
+ "top_p": 0.7,
32
+ "repetition_penalty": 1.1
33
+ },
34
+ {
35
+ "temperature": 0.65,
36
+ "top_p": 0.65,
37
+ "repetition_penalty": 1.1
38
+ },
39
+ {
40
+ "temperature": 0.65,
41
+ "top_p": 0.65,
42
+ "repetition_penalty": 1.1
43
+ },
44
+ {
45
+ "temperature": 0.6,
46
+ "top_p": 0.6,
47
+ "repetition_penalty": 1.1
48
+ },
49
+ {
50
+ "temperature": 0.6,
51
+ "top_p": 0.4,
52
+ "repetition_penalty": 1.5
53
+ }
54
+ ]
55
+ }
tools/api.py CHANGED
@@ -136,7 +136,7 @@ async def other_exception_handler(exc: "Exception"):
136
 
137
 
138
  def load_audio(reference_audio, sr):
139
- if len(reference_audio) > 255 or not Path(reference_audio).exists():
140
  audio_data = reference_audio
141
  reference_audio = io.BytesIO(audio_data)
142
 
 
136
 
137
 
138
  def load_audio(reference_audio, sr):
139
+ if len(str(reference_audio)) > 255 or not Path(reference_audio).exists():
140
  audio_data = reference_audio
141
  reference_audio = io.BytesIO(audio_data)
142
 
tools/vqgan/inference.py CHANGED
@@ -14,7 +14,7 @@ from omegaconf import OmegaConf
14
  from tools.file import AUDIO_EXTENSIONS
15
 
16
  # register eval resolver
17
- OmegaConf.register_new_resolver("eval", eval)
18
 
19
 
20
  def load_model(config_name, checkpoint_path, device="cuda"):
 
14
  from tools.file import AUDIO_EXTENSIONS
15
 
16
  # register eval resolver
17
+ #OmegaConf.register_new_resolver("eval", eval)
18
 
19
 
20
  def load_model(config_name, checkpoint_path, device="cuda"):