Commit
·
a5a2048
1
Parent(s):
893d32f
🔧 修复FlashAttention2错误和场景加载问题
Browse files🐛 主要修复:
1. FlashAttention2 兼容性问题
- 添加自动回退机制:优先尝试 FlashAttention2,失败则使用标准SDPA
- 修改默认注意力实现为 'sdpa',适配HF Spaces环境
- 添加友好的日志提示,明确使用的注意力机制
2. 场景下拉菜单问题
- 使用预定义场景列表,确保界面稳定加载
- 修复动态获取场景时的竞态条件问题
- 特殊处理'默认示例'选项,直接调用默认音频加载
✨ 改进:
- 增强错误处理和用户反馈
- 优化场景加载逻辑,支持多种场景类型
- 确保即使在场景文件缺失的情况下也能正常工作
🎯 效果:
- 解决 ImportError: FlashAttention2 安装问题
- 修复场景选择器显示异常
- 提升 Space 在 CPU/GPU 环境下的兼容性
- 确保所有预设场景都能正确加载
- __pycache__/app.cpython-313.pyc +0 -0
- app.py +23 -5
- generation_utils.py +8 -2
__pycache__/app.cpython-313.pyc
ADDED
Binary file (25.7 kB). View file
|
|
app.py
CHANGED
@@ -433,13 +433,21 @@ def create_space_ui() -> gr.Blocks:
|
|
433 |
|
434 |
with gr.Group():
|
435 |
gr.Markdown("### 🚀 快速操作")
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
scenario_dropdown = gr.Dropdown(
|
441 |
-
choices=
|
442 |
-
value=
|
443 |
label="🎭 选择场景",
|
444 |
info="选择一个预设场景,自动填充对话文本和参考音频"
|
445 |
)
|
@@ -514,6 +522,16 @@ def create_space_ui() -> gr.Blocks:
|
|
514 |
gr.Warning("⚠️ 请先选择一个场景")
|
515 |
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
scenarios = get_scenario_examples()
|
518 |
if name not in scenarios:
|
519 |
gr.Error(f"❌ 场景不存在: {name}")
|
|
|
433 |
|
434 |
with gr.Group():
|
435 |
gr.Markdown("### 🚀 快速操作")
|
436 |
+
|
437 |
+
# 预定义场景选项,确保界面稳定
|
438 |
+
predefined_scenarios = [
|
439 |
+
"🎧 默认示例",
|
440 |
+
"🤖 科技播客 - AI发展趋势",
|
441 |
+
"📚 教育播客 - 高效学习方法",
|
442 |
+
"🍜 生活播客 - 美食文化探索",
|
443 |
+
"💼 商业播客 - 创业经验分享",
|
444 |
+
"🏃 健康播客 - 运动健身指南",
|
445 |
+
"🧠 心理播客 - 情绪管理技巧"
|
446 |
+
]
|
447 |
|
448 |
scenario_dropdown = gr.Dropdown(
|
449 |
+
choices=predefined_scenarios,
|
450 |
+
value=predefined_scenarios[0],
|
451 |
label="🎭 选择场景",
|
452 |
info="选择一个预设场景,自动填充对话文本和参考音频"
|
453 |
)
|
|
|
522 |
gr.Warning("⚠️ 请先选择一个场景")
|
523 |
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
524 |
|
525 |
+
# 处理默认示例的特殊情况
|
526 |
+
if name == "🎧 默认示例":
|
527 |
+
try:
|
528 |
+
result = load_default_audio()
|
529 |
+
gr.Info("✅ 成功加载默认示例")
|
530 |
+
return result
|
531 |
+
except Exception as e:
|
532 |
+
gr.Error(f"❌ 加载默认示例时出错: {str(e)}")
|
533 |
+
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
534 |
+
|
535 |
scenarios = get_scenario_examples()
|
536 |
if name not in scenarios:
|
537 |
gr.Error(f"❌ 场景不存在: {name}")
|
generation_utils.py
CHANGED
@@ -12,10 +12,16 @@ from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
|
|
12 |
MAX_CHANNELS = 8
|
13 |
SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
|
14 |
|
15 |
-
def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.bfloat16, attn_implementation="
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
|
21 |
|
|
|
12 |
MAX_CHANNELS = 8
|
13 |
SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
|
14 |
|
15 |
+
def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa"):
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
17 |
|
18 |
+
# 尝试使用 FlashAttention2,失败则回退到标准实现
|
19 |
+
try:
|
20 |
+
model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
|
21 |
+
print("✅ 使用 FlashAttention2")
|
22 |
+
except ImportError:
|
23 |
+
print("⚠️ FlashAttention2 不可用,使用标准注意力机制")
|
24 |
+
model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch_dtype, attn_implementation=attn_implementation)
|
25 |
|
26 |
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
|
27 |
|