spoorthibhat commited on
Commit
462afd3
·
verified ·
1 Parent(s): 3fddf2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -72
app.py CHANGED
@@ -1,112 +1,132 @@
1
- import os
2
-
3
- os.system('pip install -q -e .')
4
- os.system('pip uninstall bitsandbytes')
5
- os.system('pip uninstall bitsandbytes-windows')
6
- os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl')
7
- import torch
8
- print(torch.cuda.is_available())
9
-
10
- print(os.system('python -m bitsandbytes'))
11
-
12
  import os
13
  import torch
14
  import warnings
15
- from accelerate import Accelerator
16
  import io
17
  from contextlib import redirect_stdout
18
- import gradio as gr
19
  from transformers import AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from llava.model.builder import load_pretrained_model
21
  from llava.mm_utils import get_model_name_from_path
22
  from llava.eval.run_llava import eval_model
23
 
24
- warnings.filterwarnings('ignore')
 
25
 
26
- # Initialize Accelerator
27
- accelerator = Accelerator(mixed_precision="fp16") # Use "fp16" for half-precision or "bf16" for bfloat16
 
 
 
 
 
 
 
28
 
29
- # Check GPU availability and define the device
30
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  print(f"Using device: {device}")
32
 
33
- # Clear GPU cache before loading the model
34
- print("Clearing GPU cache before model loading...")
35
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Define the model path
38
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
39
 
40
- # Load the model
41
- try:
42
- tokenizer, model, image_processor, context_len = load_pretrained_model(
43
- model_path=model_path,
44
- model_base=None,
45
- model_name=get_model_name_from_path(model_path)
46
- )
47
- # Enable Gradient Checkpointing
48
- model.gradient_checkpointing_enable()
49
-
50
- # Move the model to the correct device
51
- model.to(device)
52
-
53
- # Compile Model with PyTorch 2.0+
54
- print("Compiling the model with torch.compile()...")
55
- model = torch.compile(model, mode="max-autotune") # Optimized for both speed and memory
56
-
57
- # Prepare Model with Accelerator
58
- model = accelerator.prepare(model)
59
 
60
- print("Model successfully loaded and compiled!")
61
- except Exception as e:
62
- print(f"Error loading model: {e}")
63
- tokenizer, model, image_processor, context_len = None, None, None, None
64
-
65
- # Define the inference function
66
  def run_inference(image, question):
67
  if model is None:
68
  return "Model failed to load. Please check the logs."
69
-
70
- args = type('Args', (), {
71
- "model_path": model_path,
72
- "model_base": None,
73
- "image_file": image,
74
- "query": question,
75
- "conv_mode": None,
76
- "sep": ",",
77
- "temperature": 0,
78
- "top_p": None,
79
- "num_beams": 1,
80
- "max_new_tokens": 512
81
- })()
82
-
83
- # Capture the printed output of eval_model
84
- f = io.StringIO()
85
- with redirect_stdout(f):
86
- eval_model(args)
87
- output = f.getvalue()
88
- return output
 
 
 
 
89
 
90
  # Create the Gradio interface
91
  with gr.Blocks(theme=gr.themes.Monochrome()) as app:
92
  with gr.Column(scale=1):
93
  gr.Markdown("<center><h1>LLaVA-Med</h1></center>")
94
-
95
  with gr.Row():
96
  image = gr.Image(type="filepath", scale=2)
97
  question = gr.Textbox(placeholder="Enter a question", scale=3)
98
-
99
  with gr.Row():
100
  answer = gr.Textbox(placeholder="Answer pops up here", scale=1)
101
-
102
  with gr.Row():
103
  btn = gr.Button("Run Inference", scale=1)
104
-
105
- btn.click(fn=run_inference, inputs=[image, question], outputs=answer)
106
 
107
  # Launch the app
108
  if __name__ == "__main__":
109
  print("Clearing GPU cache before app launch...")
110
  torch.cuda.empty_cache()
111
-
112
  app.queue().launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import warnings
4
+ import gradio as gr
5
  import io
6
  from contextlib import redirect_stdout
7
+ from accelerate import Accelerator
8
  from transformers import AutoTokenizer
9
+
10
+ # Set memory-related environment variables
11
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
12
+
13
+ # Suppress warnings and optimize CUDA
14
+ warnings.filterwarnings('ignore')
15
+ torch.backends.cudnn.benchmark = True
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+
18
+ # Suppress specific pip install warnings
19
+ os.system('pip install -q -e .')
20
+ os.system('pip uninstall -y bitsandbytes')
21
+ os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl')
22
+
23
+ # Import LLaVA specific modules
24
  from llava.model.builder import load_pretrained_model
25
  from llava.mm_utils import get_model_name_from_path
26
  from llava.eval.run_llava import eval_model
27
 
28
+ # Initialize Accelerator with lower precision
29
+ accelerator = Accelerator(mixed_precision="fp16")
30
 
31
+ # Device setup with more robust checking
32
+ def get_optimal_device():
33
+ if torch.cuda.is_available():
34
+ # Find GPU with most free memory
35
+ total_memory = [torch.cuda.get_device_properties(i).total_memory for i in range(torch.cuda.device_count())]
36
+ free_memory = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
37
+ best_gpu = free_memory.index(max(free_memory))
38
+ return torch.device(f'cuda:{best_gpu}')
39
+ return torch.device('cpu')
40
 
41
+ device = get_optimal_device()
 
42
  print(f"Using device: {device}")
43
 
44
+ # Model loading with memory optimizations
45
+ def load_model_safely(model_path):
46
+ try:
47
+ # Clear GPU cache
48
+ torch.cuda.empty_cache()
49
+ if torch.cuda.is_available():
50
+ torch.cuda.synchronize()
51
+
52
+ # Load model with device mapping
53
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
54
+ model_path=model_path,
55
+ model_base=None,
56
+ model_name=get_model_name_from_path(model_path),
57
+ device_map="auto" # Automatic device distribution
58
+ )
59
+
60
+ # Enable memory-efficient techniques
61
+ model.gradient_checkpointing_enable()
62
+
63
+ # Move to device and prepare with accelerator
64
+ model.to(device)
65
+
66
+ # Optional: Compile with memory-aware mode
67
+ try:
68
+ model = torch.compile(model, mode="reduce-overhead")
69
+ except Exception as compile_error:
70
+ print(f"Model compilation failed: {compile_error}. Proceeding without compilation.")
71
+
72
+ model = accelerator.prepare(model)
73
+
74
+ return tokenizer, model, image_processor, context_len
75
+
76
+ except Exception as e:
77
+ print(f"Error loading model: {e}")
78
+ return None, None, None, None
79
 
80
  # Define the model path
81
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
82
 
83
+ # Load the model with safety checks
84
+ tokenizer, model, image_processor, context_len = load_model_safely(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # Inference function with error handling
 
 
 
 
 
87
  def run_inference(image, question):
88
  if model is None:
89
  return "Model failed to load. Please check the logs."
90
+
91
+ try:
92
+ args = type('Args', (), {
93
+ "model_path": model_path,
94
+ "model_base": None,
95
+ "image_file": image,
96
+ "query": question,
97
+ "conv_mode": None,
98
+ "sep": ",",
99
+ "temperature": 0,
100
+ "top_p": None,
101
+ "num_beams": 1,
102
+ "max_new_tokens": 512
103
+ })()
104
+
105
+ # Capture the printed output of eval_model
106
+ f = io.StringIO()
107
+ with redirect_stdout(f):
108
+ eval_model(args)
109
+ output = f.getvalue()
110
+ return output
111
+
112
+ except Exception as e:
113
+ return f"Inference error: {str(e)}"
114
 
115
  # Create the Gradio interface
116
  with gr.Blocks(theme=gr.themes.Monochrome()) as app:
117
  with gr.Column(scale=1):
118
  gr.Markdown("<center><h1>LLaVA-Med</h1></center>")
 
119
  with gr.Row():
120
  image = gr.Image(type="filepath", scale=2)
121
  question = gr.Textbox(placeholder="Enter a question", scale=3)
 
122
  with gr.Row():
123
  answer = gr.Textbox(placeholder="Answer pops up here", scale=1)
 
124
  with gr.Row():
125
  btn = gr.Button("Run Inference", scale=1)
126
+ btn.click(fn=run_inference, inputs=[image, question], outputs=answer)
 
127
 
128
  # Launch the app
129
  if __name__ == "__main__":
130
  print("Clearing GPU cache before app launch...")
131
  torch.cuda.empty_cache()
 
132
  app.queue().launch(debug=True)