spoorthibhat commited on
Commit
f8ba981
·
verified ·
1 Parent(s): 462afd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -96
app.py CHANGED
@@ -1,132 +1,90 @@
1
  import os
 
 
 
 
 
2
  import torch
3
- import warnings
4
- import gradio as gr
5
- import io
6
- from contextlib import redirect_stdout
7
- from accelerate import Accelerator
8
- from transformers import AutoTokenizer
9
 
10
- # Set memory-related environment variables
11
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
12
 
13
- # Suppress warnings and optimize CUDA
 
 
14
  warnings.filterwarnings('ignore')
15
- torch.backends.cudnn.benchmark = True
16
- torch.backends.cuda.matmul.allow_tf32 = True
17
-
18
- # Suppress specific pip install warnings
19
- os.system('pip install -q -e .')
20
- os.system('pip uninstall -y bitsandbytes')
21
- os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl')
22
 
23
- # Import LLaVA specific modules
 
 
 
24
  from llava.model.builder import load_pretrained_model
25
  from llava.mm_utils import get_model_name_from_path
26
  from llava.eval.run_llava import eval_model
27
 
28
- # Initialize Accelerator with lower precision
29
- accelerator = Accelerator(mixed_precision="fp16")
30
-
31
- # Device setup with more robust checking
32
- def get_optimal_device():
33
- if torch.cuda.is_available():
34
- # Find GPU with most free memory
35
- total_memory = [torch.cuda.get_device_properties(i).total_memory for i in range(torch.cuda.device_count())]
36
- free_memory = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
37
- best_gpu = free_memory.index(max(free_memory))
38
- return torch.device(f'cuda:{best_gpu}')
39
- return torch.device('cpu')
40
-
41
- device = get_optimal_device()
42
  print(f"Using device: {device}")
43
 
44
- # Model loading with memory optimizations
45
- def load_model_safely(model_path):
46
- try:
47
- # Clear GPU cache
48
- torch.cuda.empty_cache()
49
- if torch.cuda.is_available():
50
- torch.cuda.synchronize()
51
-
52
- # Load model with device mapping
53
- tokenizer, model, image_processor, context_len = load_pretrained_model(
54
- model_path=model_path,
55
- model_base=None,
56
- model_name=get_model_name_from_path(model_path),
57
- device_map="auto" # Automatic device distribution
58
- )
59
-
60
- # Enable memory-efficient techniques
61
- model.gradient_checkpointing_enable()
62
-
63
- # Move to device and prepare with accelerator
64
- model.to(device)
65
-
66
- # Optional: Compile with memory-aware mode
67
- try:
68
- model = torch.compile(model, mode="reduce-overhead")
69
- except Exception as compile_error:
70
- print(f"Model compilation failed: {compile_error}. Proceeding without compilation.")
71
-
72
- model = accelerator.prepare(model)
73
-
74
- return tokenizer, model, image_processor, context_len
75
-
76
- except Exception as e:
77
- print(f"Error loading model: {e}")
78
- return None, None, None, None
79
-
80
  # Define the model path
81
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
82
 
83
- # Load the model with safety checks
84
- tokenizer, model, image_processor, context_len = load_model_safely(model_path)
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Inference function with error handling
87
  def run_inference(image, question):
88
  if model is None:
89
  return "Model failed to load. Please check the logs."
90
 
91
- try:
92
- args = type('Args', (), {
93
- "model_path": model_path,
94
- "model_base": None,
95
- "image_file": image,
96
- "query": question,
97
- "conv_mode": None,
98
- "sep": ",",
99
- "temperature": 0,
100
- "top_p": None,
101
- "num_beams": 1,
102
- "max_new_tokens": 512
103
- })()
104
-
105
- # Capture the printed output of eval_model
106
- f = io.StringIO()
107
- with redirect_stdout(f):
108
- eval_model(args)
109
- output = f.getvalue()
110
- return output
111
-
112
- except Exception as e:
113
- return f"Inference error: {str(e)}"
114
 
115
  # Create the Gradio interface
116
  with gr.Blocks(theme=gr.themes.Monochrome()) as app:
117
  with gr.Column(scale=1):
118
  gr.Markdown("<center><h1>LLaVA-Med</h1></center>")
 
119
  with gr.Row():
120
  image = gr.Image(type="filepath", scale=2)
121
  question = gr.Textbox(placeholder="Enter a question", scale=3)
 
122
  with gr.Row():
123
  answer = gr.Textbox(placeholder="Answer pops up here", scale=1)
 
124
  with gr.Row():
125
  btn = gr.Button("Run Inference", scale=1)
126
- btn.click(fn=run_inference, inputs=[image, question], outputs=answer)
 
127
 
128
  # Launch the app
129
  if __name__ == "__main__":
130
- print("Clearing GPU cache before app launch...")
131
- torch.cuda.empty_cache()
132
  app.queue().launch(debug=True)
 
1
  import os
2
+
3
+ os.system('pip install -q -e .')
4
+ os.system('pip uninstall bitsandbytes')
5
+ os.system('pip install bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl')
6
+
7
  import torch
8
+ print(torch.cuda.is_available())
 
 
 
 
 
9
 
10
+ print(os.system('python -m bitsandbytes'))
 
11
 
12
+ import os
13
+ import torch
14
+ import warnings
15
  warnings.filterwarnings('ignore')
 
 
 
 
 
 
 
16
 
17
+ import io
18
+ from contextlib import redirect_stdout
19
+ import gradio as gr
20
+ from transformers import AutoTokenizer
21
  from llava.model.builder import load_pretrained_model
22
  from llava.mm_utils import get_model_name_from_path
23
  from llava.eval.run_llava import eval_model
24
 
25
+ # Check CUDA availability with error handling
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
27
  print(f"Using device: {device}")
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Define the model path
30
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
31
 
32
+ # Load the model
33
+ try:
34
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
35
+ model_path=model_path,
36
+ model_base=None,
37
+ model_name=get_model_name_from_path(model_path)
38
+ )
39
+
40
+ # Move model to appropriate device
41
+ model = model.to(device)
42
+ except Exception as e:
43
+ print(f"Error loading model: {e}")
44
+ tokenizer, model, image_processor, context_len = None, None, None, None
45
 
46
+ # Define the inference function
47
  def run_inference(image, question):
48
  if model is None:
49
  return "Model failed to load. Please check the logs."
50
 
51
+ args = type('Args', (), {
52
+ "model_path": model_path,
53
+ "model_base": None,
54
+ "image_file": image,
55
+ "query": question,
56
+ "conv_mode": None,
57
+ "sep": ",",
58
+ "temperature": 0,
59
+ "top_p": None,
60
+ "num_beams": 1,
61
+ "max_new_tokens": 512
62
+ })()
63
+
64
+ # Capture the printed output of eval_model
65
+ f = io.StringIO()
66
+ with redirect_stdout(f):
67
+ eval_model(args)
68
+ output = f.getvalue()
69
+ return output
 
 
 
 
70
 
71
  # Create the Gradio interface
72
  with gr.Blocks(theme=gr.themes.Monochrome()) as app:
73
  with gr.Column(scale=1):
74
  gr.Markdown("<center><h1>LLaVA-Med</h1></center>")
75
+
76
  with gr.Row():
77
  image = gr.Image(type="filepath", scale=2)
78
  question = gr.Textbox(placeholder="Enter a question", scale=3)
79
+
80
  with gr.Row():
81
  answer = gr.Textbox(placeholder="Answer pops up here", scale=1)
82
+
83
  with gr.Row():
84
  btn = gr.Button("Run Inference", scale=1)
85
+
86
+ btn.click(fn=run_inference, inputs=[image, question], outputs=answer)
87
 
88
  # Launch the app
89
  if __name__ == "__main__":
 
 
90
  app.queue().launch(debug=True)