vincenthugging commited on
Commit
a5a2048
·
1 Parent(s): 893d32f

🔧 修复FlashAttention2错误和场景加载问题

Browse files

🐛 主要修复:
1. FlashAttention2 兼容性问题
- 添加自动回退机制:优先尝试 FlashAttention2,失败则使用标准SDPA
- 修改默认注意力实现为 'sdpa',适配HF Spaces环境
- 添加友好的日志提示,明确使用的注意力机制

2. 场景下拉菜单问题
- 使用预定义场景列表,确保界面稳定加载
- 修复动态获取场景时的竞态条件问题
- 特殊处理'默认示例'选项,直接调用默认音频加载

✨ 改进:
- 增强错误处理和用户反馈
- 优化场景加载逻辑,支持多种场景类型
- 确保即使在场景文件缺失的情况下也能正常工作

🎯 效果:
- 解决 ImportError: FlashAttention2 安装问题
- 修复场景选择器显示异常
- 提升 Space 在 CPU/GPU 环境下的兼容性
- 确保所有预设场景都能正确加载

Files changed (3) hide show
  1. __pycache__/app.cpython-313.pyc +0 -0
  2. app.py +23 -5
  3. generation_utils.py +8 -2
__pycache__/app.cpython-313.pyc ADDED
Binary file (25.7 kB). View file
 
app.py CHANGED
@@ -433,13 +433,21 @@ def create_space_ui() -> gr.Blocks:
433
 
434
  with gr.Group():
435
  gr.Markdown("### 🚀 快速操作")
436
- # 获取场景选项,设置第一个为默认值
437
- scenario_choices = list(get_scenario_examples().keys())
438
- default_scenario = scenario_choices[0] if scenario_choices else None
 
 
 
 
 
 
 
 
439
 
440
  scenario_dropdown = gr.Dropdown(
441
- choices=scenario_choices,
442
- value=default_scenario,
443
  label="🎭 选择场景",
444
  info="选择一个预设场景,自动填充对话文本和参考音频"
445
  )
@@ -514,6 +522,16 @@ def create_space_ui() -> gr.Blocks:
514
  gr.Warning("⚠️ 请先选择一个场景")
515
  return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
516
 
 
 
 
 
 
 
 
 
 
 
517
  scenarios = get_scenario_examples()
518
  if name not in scenarios:
519
  gr.Error(f"❌ 场景不存在: {name}")
 
433
 
434
  with gr.Group():
435
  gr.Markdown("### 🚀 快速操作")
436
+
437
+ # 预定义场景选项,确保界面稳定
438
+ predefined_scenarios = [
439
+ "🎧 默认示例",
440
+ "🤖 科技播客 - AI发展趋势",
441
+ "📚 教育播客 - 高效学习方法",
442
+ "🍜 生活播客 - 美食文化探索",
443
+ "💼 商业播客 - 创业经验分享",
444
+ "🏃 健康播客 - 运动健身指南",
445
+ "🧠 心理播客 - 情绪管理技巧"
446
+ ]
447
 
448
  scenario_dropdown = gr.Dropdown(
449
+ choices=predefined_scenarios,
450
+ value=predefined_scenarios[0],
451
  label="🎭 选择场景",
452
  info="选择一个预设场景,自动填充对话文本和参考音频"
453
  )
 
522
  gr.Warning("⚠️ 请先选择一个场景")
523
  return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
524
 
525
+ # 处理默认示例的特殊情况
526
+ if name == "🎧 默认示例":
527
+ try:
528
+ result = load_default_audio()
529
+ gr.Info("✅ 成功加载默认示例")
530
+ return result
531
+ except Exception as e:
532
+ gr.Error(f"❌ 加载默认示例时出错: {str(e)}")
533
+ return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
534
+
535
  scenarios = get_scenario_examples()
536
  if name not in scenarios:
537
  gr.Error(f"❌ 场景不存在: {name}")
generation_utils.py CHANGED
@@ -12,10 +12,16 @@ from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
12
  MAX_CHANNELS = 8
13
  SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
14
 
15
- def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"):
16
  tokenizer = AutoTokenizer.from_pretrained(model_path)
17
 
18
- model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch_dtype, attn_implementation=attn_implementation)
 
 
 
 
 
 
19
 
20
  spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
21
 
 
12
  MAX_CHANNELS = 8
13
  SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
14
 
15
+ def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa"):
16
  tokenizer = AutoTokenizer.from_pretrained(model_path)
17
 
18
+ # 尝试使用 FlashAttention2,失败则回退到标准实现
19
+ try:
20
+ model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
21
+ print("✅ 使用 FlashAttention2")
22
+ except ImportError:
23
+ print("⚠️ FlashAttention2 不可用,使用标准注意力机制")
24
+ model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch_dtype, attn_implementation=attn_implementation)
25
 
26
  spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
27