vincenthugging commited on
Commit
88ea080
·
1 Parent(s): dee5477

🔧 修复概率张量数值不稳定错误

Browse files

🐛 主要修复:
1. **数值稳定性问题**
- 恢复temperature为1.0,避免过低值导致softmax不稳定
- 恢复bfloat16精度,比float16更稳定
- 添加数值稳定的softmax计算逻辑

2. **采样过程保护**
- 在multinomial采样前检查inf/nan值
- 实现数值稳定的softmax: score - max(score)
- 添加概率clamp和重新归一化保护
- 异常时自动回退到argmax确定性采样

3. **重试机制**
- 捕获RuntimeError中的概率张量错误
- 自动切换到确定性生成重试
- 保护原始配置,确保错误恢复

4. **参数优化**
- 应用文档推荐的'轻松对话风格'参数组合
- temperature=1.0, top_k=50, top_p=0.9, repetition_penalty=1.1
- 添加epsilon和pad_token_id保护参数

🎯 根因分析:
- 原因1: temperature=0.7过低导致softmax数值溢出
- 原因2: float16精度不足引发累积误差
- 原因3: 采样过程缺乏异常值检查
- 原因4: 没有重试和兜底机制

✅ 解决效果:
- 消除probability tensor contains inf/nan错误
- 提供多层数值稳定性保护
- 确保在极端情况下仍能正常生成
- 保持生成质量和自然度

Files changed (3) hide show
  1. app.py +47 -17
  2. generation_utils.py +1 -1
  3. modeling_asteroid.py +15 -1
app.py CHANGED
@@ -243,19 +243,26 @@ def initialize_model():
243
  model = model.to(device)
244
  spt = spt.to(device)
245
 
246
- # 优化生成参数,提升速度和效率
247
  try:
248
  # 减少最大生成长度,提升速度
249
  model.generation_config.max_new_tokens = min(
250
  getattr(model.generation_config, "max_new_tokens", 2048), 2048
251
  )
252
- # 设置更高效的生成参数
 
253
  model.generation_config.do_sample = True
254
- model.generation_config.temperature = 0.7
255
- model.generation_config.top_p = 0.9
256
- model.generation_config.num_beams = 1 # 使用贪心搜索,更快
 
 
 
 
 
 
257
 
258
- print(f"🚀 优化生成参数: max_tokens={model.generation_config.max_new_tokens}, beams={model.generation_config.num_beams}")
259
  except Exception as e: # noqa: BLE001
260
  print(f"⚠️ 生成参数设置失败: {e}")
261
  pass
@@ -305,17 +312,40 @@ def generate_dialogue_audio(
305
  single_text = speaker1_text or speaker2_text or ""
306
  item.update({"prompt_audio": single_audio, "prompt_text": single_text})
307
 
308
- # 执行合成
309
- actual_texts_data, audio_results = process_batch(
310
- batch_items=[item],
311
- tokenizer=tokenizer,
312
- model=model,
313
- spt=spt,
314
- device=device,
315
- system_prompt=SYSTEM_PROMPT,
316
- start_idx=0,
317
- use_normalize=use_normalize,
318
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  if not audio_results or audio_results[0] is None:
321
  return None, "❌ 音频生成失败"
 
243
  model = model.to(device)
244
  spt = spt.to(device)
245
 
246
+ # 设置稳定的生成参数,避免数值不稳定
247
  try:
248
  # 减少最大生成长度,提升速度
249
  model.generation_config.max_new_tokens = min(
250
  getattr(model.generation_config, "max_new_tokens", 2048), 2048
251
  )
252
+
253
+ # 使用文档推荐的"轻松对话风格"参数组合,确保数值稳定
254
  model.generation_config.do_sample = True
255
+ model.generation_config.temperature = 1.0 # 恢复默认值,避免数值不稳定
256
+ model.generation_config.top_k = 50 # 添加top_k限制
257
+ model.generation_config.top_p = 0.9 # 保持合理的nucleus采样
258
+ model.generation_config.repetition_penalty = 1.1 # 避免重复
259
+ model.generation_config.num_beams = 1 # 使用贪心搜索
260
+
261
+ # 添加数值稳定性保护
262
+ model.generation_config.epsilon = 1e-8 # 防止除零错误
263
+ model.generation_config.pad_token_id = model.config.eos_token_id
264
 
265
+ print(f"🚀 应用稳定生成参数: temp={model.generation_config.temperature}, top_k={model.generation_config.top_k}, top_p={model.generation_config.top_p}")
266
  except Exception as e: # noqa: BLE001
267
  print(f"⚠️ 生成参数设置失败: {e}")
268
  pass
 
312
  single_text = speaker1_text or speaker2_text or ""
313
  item.update({"prompt_audio": single_audio, "prompt_text": single_text})
314
 
315
+ # 执行合成,添加重试机制
316
+ try:
317
+ actual_texts_data, audio_results = process_batch(
318
+ batch_items=[item],
319
+ tokenizer=tokenizer,
320
+ model=model,
321
+ spt=spt,
322
+ device=device,
323
+ system_prompt=SYSTEM_PROMPT,
324
+ start_idx=0,
325
+ use_normalize=use_normalize,
326
+ )
327
+ except RuntimeError as e:
328
+ if "probability tensor contains" in str(e):
329
+ print("⚠️ 检测到数值不稳定,尝试使用确定性生成...")
330
+ # 临时切换到确定性生成
331
+ original_do_sample = model.generation_config.do_sample
332
+ model.generation_config.do_sample = False
333
+ try:
334
+ actual_texts_data, audio_results = process_batch(
335
+ batch_items=[item],
336
+ tokenizer=tokenizer,
337
+ model=model,
338
+ spt=spt,
339
+ device=device,
340
+ system_prompt=SYSTEM_PROMPT,
341
+ start_idx=0,
342
+ use_normalize=use_normalize,
343
+ )
344
+ finally:
345
+ # 恢复原设置
346
+ model.generation_config.do_sample = original_do_sample
347
+ else:
348
+ raise e
349
 
350
  if not audio_results or audio_results[0] is None:
351
  return None, "❌ 音频生成失败"
generation_utils.py CHANGED
@@ -12,7 +12,7 @@ from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
12
  MAX_CHANNELS = 8
13
  SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
14
 
15
- def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.float16, attn_implementation="sdpa"):
16
  tokenizer = AutoTokenizer.from_pretrained(model_path)
17
 
18
  # 尝试使用 FlashAttention2,失败则回退到标准实现
 
12
  MAX_CHANNELS = 8
13
  SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
14
 
15
+ def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa"):
16
  tokenizer = AutoTokenizer.from_pretrained(model_path)
17
 
18
  # 尝试使用 FlashAttention2,失败则回退到标准实现
modeling_asteroid.py CHANGED
@@ -137,7 +137,21 @@ class CustomMixin(GenerationMixin):
137
  next_tokens = []
138
  for i, channel_score in enumerate(next_token_scores):
139
  if do_samples[i]:
140
- channel_ntk = torch.multinomial(nn.functional.softmax(channel_score, dim=-1), num_samples=1).squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  elif not do_samples[i]:
142
  channel_ntk = torch.argmax(channel_score, dim=-1)
143
  next_tokens.append(channel_ntk)
 
137
  next_tokens = []
138
  for i, channel_score in enumerate(next_token_scores):
139
  if do_samples[i]:
140
+ # 添加数值稳定性保护
141
+ # 检查并处理异常值
142
+ if torch.isnan(channel_score).any() or torch.isinf(channel_score).any():
143
+ print(f"⚠️ 检测到异常值,使用argmax采样")
144
+ channel_ntk = torch.argmax(channel_score, dim=-1)
145
+ else:
146
+ # 数值稳定的softmax计算
147
+ channel_score_stable = channel_score - torch.max(channel_score, dim=-1, keepdim=True)[0]
148
+ probs = nn.functional.softmax(channel_score_stable, dim=-1)
149
+
150
+ # 确保概率值有效
151
+ probs = torch.clamp(probs, min=1e-8, max=1.0)
152
+ probs = probs / probs.sum(dim=-1, keepdim=True) # 重新归一化
153
+
154
+ channel_ntk = torch.multinomial(probs, num_samples=1).squeeze(1)
155
  elif not do_samples[i]:
156
  channel_ntk = torch.argmax(channel_score, dim=-1)
157
  next_tokens.append(channel_ntk)