sounar commited on
Commit
b37e8c8
·
verified ·
1 Parent(s): e212182

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -3,7 +3,15 @@ import torch
3
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
5
  from PIL import Image
6
- from torchvision.transforms import ToTensor
 
 
 
 
 
 
 
 
7
 
8
  # Get API token from environment variable
9
  api_token = os.getenv("HF_TOKEN").strip()
@@ -16,15 +24,23 @@ bnb_config = BitsAndBytesConfig(
16
  bnb_4bit_compute_dtype=torch.float16
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Initialize model and tokenizer
20
  model = AutoModel.from_pretrained(
21
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
22
- quantization_config=bnb_config,
23
- device_map="auto",
24
- torch_dtype=torch.float16,
25
- trust_remote_code=True,
26
- attn_implementation="flash_attention_2",
27
- token=api_token
28
  )
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(
@@ -84,6 +100,11 @@ demo = gr.Interface(
84
 
85
  # Launch the Gradio app
86
  if __name__ == "__main__":
 
 
 
 
 
87
  demo.launch(
88
  share=True,
89
  server_name="0.0.0.0",
 
3
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
5
  from PIL import Image
6
+
7
+ # First, let's check if flash-attn is installed
8
+ try:
9
+ import flash_attn
10
+ FLASH_ATTN_AVAILABLE = True
11
+ except ImportError:
12
+ FLASH_ATTN_AVAILABLE = False
13
+ print("Flash Attention is not installed. Using default attention mechanism.")
14
+ print("To install Flash Attention, run: pip install flash-attn --no-build-isolation")
15
 
16
  # Get API token from environment variable
17
  api_token = os.getenv("HF_TOKEN").strip()
 
24
  bnb_4bit_compute_dtype=torch.float16
25
  )
26
 
27
+ # Initialize model with conditional Flash Attention
28
+ model_args = {
29
+ "quantization_config": bnb_config,
30
+ "device_map": "auto",
31
+ "torch_dtype": torch.float16,
32
+ "trust_remote_code": True,
33
+ "token": api_token
34
+ }
35
+
36
+ # Only add flash attention if available
37
+ if FLASH_ATTN_AVAILABLE:
38
+ model_args["attn_implementation"] = "flash_attention_2"
39
+
40
  # Initialize model and tokenizer
41
  model = AutoModel.from_pretrained(
42
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
43
+ **model_args
 
 
 
 
 
44
  )
45
 
46
  tokenizer = AutoTokenizer.from_pretrained(
 
100
 
101
  # Launch the Gradio app
102
  if __name__ == "__main__":
103
+ # Print installation instructions if Flash Attention is not available
104
+ if not FLASH_ATTN_AVAILABLE:
105
+ print("\nTo enable Flash Attention 2 for better performance, please install it using:")
106
+ print("pip install flash-attn --no-build-isolation")
107
+
108
  demo.launch(
109
  share=True,
110
  server_name="0.0.0.0",