Update app.py
Browse files
app.py
CHANGED
@@ -9,38 +9,49 @@ from transformers import TextIteratorStreamer
|
|
9 |
from threading import Thread
|
10 |
|
11 |
import importlib.metadata
|
12 |
-
from importlib import import_module
|
13 |
-
from transformers.utils import is_flash_attn_2_available
|
14 |
from packaging import version
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
# 获取已安装的 flash_attn 版本
|
23 |
-
try:
|
24 |
-
installed_version = importlib.metadata.version("flash_attn")
|
25 |
-
except importlib.metadata.PackageNotFoundError:
|
26 |
-
raise ImportError("flash_attn package is not installed.")
|
27 |
-
|
28 |
-
# 解析已安装的版本和所需的最低版本
|
29 |
-
parsed_installed_version = version.parse(installed_version)
|
30 |
-
required_version = version.parse("2.6.3")
|
31 |
-
|
32 |
-
# 检查版本是否满足要求
|
33 |
-
if parsed_installed_version < required_version:
|
34 |
-
raise ImportError(f"flash_attn version {installed_version} is installed, but version >= 2.6.3 is required.")
|
35 |
-
|
36 |
-
print("All requirements for Flash Attention 2 are met.")
|
37 |
-
|
38 |
-
# 使用 try-except 块来捕获和显示具体的错误
|
39 |
-
try:
|
40 |
-
check_flash_attention_2_requirements()
|
41 |
-
except ImportError as e:
|
42 |
-
print(f"Error: {e}")
|
43 |
-
print("Using `flash_attention_2` requires having `flash_attn>=2.6.3` installed.")
|
44 |
else:
|
45 |
print("Flash Attention 2 can be used.")
|
46 |
|
|
|
9 |
from threading import Thread
|
10 |
|
11 |
import importlib.metadata
|
|
|
|
|
12 |
from packaging import version
|
13 |
+
from transformers.utils import (
|
14 |
+
is_torch_available,
|
15 |
+
_is_package_available,
|
16 |
+
is_torch_mlu_available
|
17 |
+
)
|
18 |
+
|
19 |
+
def diagnose_flash_attn_2_availability():
|
20 |
+
if not is_torch_available():
|
21 |
+
return "PyTorch is not available."
|
22 |
+
|
23 |
+
if not _is_package_available("flash_attn"):
|
24 |
+
return "flash_attn package is not installed."
|
25 |
+
|
26 |
+
import torch
|
27 |
+
|
28 |
+
if not (torch.cuda.is_available() or is_torch_mlu_available()):
|
29 |
+
return "Neither CUDA nor MLU is available."
|
30 |
+
|
31 |
+
flash_attn_version = importlib.metadata.version("flash_attn")
|
32 |
+
|
33 |
+
if torch.version.cuda:
|
34 |
+
required_version = "2.1.0"
|
35 |
+
if version.parse(flash_attn_version) < version.parse(required_version):
|
36 |
+
return f"CUDA is available, but flash_attn version {flash_attn_version} is installed. Version >= {required_version} is required."
|
37 |
+
elif torch.version.hip:
|
38 |
+
required_version = "2.0.4"
|
39 |
+
if version.parse(flash_attn_version) < version.parse(required_version):
|
40 |
+
return f"HIP is available, but flash_attn version {flash_attn_version} is installed. Version >= {required_version} is required."
|
41 |
+
elif is_torch_mlu_available():
|
42 |
+
required_version = "2.3.3"
|
43 |
+
if version.parse(flash_attn_version) < version.parse(required_version):
|
44 |
+
return f"MLU is available, but flash_attn version {flash_attn_version} is installed. Version >= {required_version} is required."
|
45 |
+
else:
|
46 |
+
return "Unknown PyTorch backend."
|
47 |
+
|
48 |
+
return "All requirements for Flash Attention 2 are met."
|
49 |
|
50 |
+
# 使用诊断函数
|
51 |
+
result = diagnose_flash_attn_2_availability()
|
52 |
+
if result != "All requirements for Flash Attention 2 are met.":
|
53 |
+
print(f"Flash Attention 2 is not available: {result}")
|
54 |
+
print("Using `flash_attention_2` requires having the correct version of `flash_attn` installed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
else:
|
56 |
print("Flash Attention 2 can be used.")
|
57 |
|