xxyyy123 commited on
Commit
fda50b3
·
verified ·
1 Parent(s): 8878aaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py CHANGED
@@ -8,6 +8,42 @@ from transformers import AutoModelForCausalLM
8
  from transformers import TextIteratorStreamer
9
  from threading import Thread
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model_name = 'AIDC-AI/Ovis2-16B'
12
 
13
  # load model
 
8
  from transformers import TextIteratorStreamer
9
  from threading import Thread
10
 
11
+ import importlib.metadata
12
+ from importlib import import_module
13
+ from transformers.utils import is_flash_attn_2_available
14
+ from packaging import version
15
+
16
+ def check_flash_attention_2_requirements():
17
+ # 检查 Flash Attention 2 是否可用
18
+ flash_attn_2_available = is_flash_attn_2_available()
19
+ if not flash_attn_2_available:
20
+ raise ImportError("Flash Attention 2 is not available.")
21
+
22
+ # 获取已安装的 flash_attn 版本
23
+ try:
24
+ installed_version = importlib.metadata.version("flash_attn")
25
+ except importlib.metadata.PackageNotFoundError:
26
+ raise ImportError("flash_attn package is not installed.")
27
+
28
+ # 解析已安装的版本和所需的最低版本
29
+ parsed_installed_version = version.parse(installed_version)
30
+ required_version = version.parse("2.6.3")
31
+
32
+ # 检查版本是否满足要求
33
+ if parsed_installed_version < required_version:
34
+ raise ImportError(f"flash_attn version {installed_version} is installed, but version >= 2.6.3 is required.")
35
+
36
+ print("All requirements for Flash Attention 2 are met.")
37
+
38
+ # 使用 try-except 块来捕获和显示具体的错误
39
+ try:
40
+ check_flash_attention_2_requirements()
41
+ except ImportError as e:
42
+ print(f"Error: {e}")
43
+ print("Using `flash_attention_2` requires having `flash_attn>=2.6.3` installed.")
44
+ else:
45
+ print("Flash Attention 2 can be used.")
46
+
47
  model_name = 'AIDC-AI/Ovis2-16B'
48
 
49
  # load model