xxyyy123 commited on
Commit
a2c4296
·
verified ·
1 Parent(s): fda50b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -30
app.py CHANGED
@@ -9,38 +9,49 @@ from transformers import TextIteratorStreamer
9
  from threading import Thread
10
 
11
  import importlib.metadata
12
- from importlib import import_module
13
- from transformers.utils import is_flash_attn_2_available
14
  from packaging import version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def check_flash_attention_2_requirements():
17
- # 检查 Flash Attention 2 是否可用
18
- flash_attn_2_available = is_flash_attn_2_available()
19
- if not flash_attn_2_available:
20
- raise ImportError("Flash Attention 2 is not available.")
21
-
22
- # 获取已安装的 flash_attn 版本
23
- try:
24
- installed_version = importlib.metadata.version("flash_attn")
25
- except importlib.metadata.PackageNotFoundError:
26
- raise ImportError("flash_attn package is not installed.")
27
-
28
- # 解析已安装的版本和所需的最低版本
29
- parsed_installed_version = version.parse(installed_version)
30
- required_version = version.parse("2.6.3")
31
-
32
- # 检查版本是否满足要求
33
- if parsed_installed_version < required_version:
34
- raise ImportError(f"flash_attn version {installed_version} is installed, but version >= 2.6.3 is required.")
35
-
36
- print("All requirements for Flash Attention 2 are met.")
37
-
38
- # 使用 try-except 块来捕获和显示具体的错误
39
- try:
40
- check_flash_attention_2_requirements()
41
- except ImportError as e:
42
- print(f"Error: {e}")
43
- print("Using `flash_attention_2` requires having `flash_attn>=2.6.3` installed.")
44
  else:
45
  print("Flash Attention 2 can be used.")
46
 
 
9
  from threading import Thread
10
 
11
  import importlib.metadata
 
 
12
  from packaging import version
13
+ from transformers.utils import (
14
+ is_torch_available,
15
+ _is_package_available,
16
+ is_torch_mlu_available
17
+ )
18
+
19
+ def diagnose_flash_attn_2_availability():
20
+ if not is_torch_available():
21
+ return "PyTorch is not available."
22
+
23
+ if not _is_package_available("flash_attn"):
24
+ return "flash_attn package is not installed."
25
+
26
+ import torch
27
+
28
+ if not (torch.cuda.is_available() or is_torch_mlu_available()):
29
+ return "Neither CUDA nor MLU is available."
30
+
31
+ flash_attn_version = importlib.metadata.version("flash_attn")
32
+
33
+ if torch.version.cuda:
34
+ required_version = "2.1.0"
35
+ if version.parse(flash_attn_version) < version.parse(required_version):
36
+ return f"CUDA is available, but flash_attn version {flash_attn_version} is installed. Version >= {required_version} is required."
37
+ elif torch.version.hip:
38
+ required_version = "2.0.4"
39
+ if version.parse(flash_attn_version) < version.parse(required_version):
40
+ return f"HIP is available, but flash_attn version {flash_attn_version} is installed. Version >= {required_version} is required."
41
+ elif is_torch_mlu_available():
42
+ required_version = "2.3.3"
43
+ if version.parse(flash_attn_version) < version.parse(required_version):
44
+ return f"MLU is available, but flash_attn version {flash_attn_version} is installed. Version >= {required_version} is required."
45
+ else:
46
+ return "Unknown PyTorch backend."
47
+
48
+ return "All requirements for Flash Attention 2 are met."
49
 
50
+ # 使用诊断函数
51
+ result = diagnose_flash_attn_2_availability()
52
+ if result != "All requirements for Flash Attention 2 are met.":
53
+ print(f"Flash Attention 2 is not available: {result}")
54
+ print("Using `flash_attention_2` requires having the correct version of `flash_attn` installed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  else:
56
  print("Flash Attention 2 can be used.")
57