spoorthibhat commited on
Commit
5dfbd7d
·
verified ·
1 Parent(s): 025a03a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -7,6 +7,8 @@ print(torch.cuda.is_available())
7
 
8
  print(os.system('python -m bitsandbytes'))
9
 
 
 
10
  import warnings
11
  warnings.filterwarnings('ignore')
12
 
@@ -18,18 +20,32 @@ from llava.model.builder import load_pretrained_model
18
  from llava.mm_utils import get_model_name_from_path
19
  from llava.eval.run_llava import eval_model
20
 
 
 
 
 
21
  # Define the model path
22
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
23
 
24
  # Load the model
25
- tokenizer, model, image_processor, context_len = load_pretrained_model(
26
- model_path=model_path,
27
- model_base=None,
28
- model_name=get_model_name_from_path(model_path)
29
- )
 
 
 
 
 
 
 
30
 
31
  # Define the inference function
32
  def run_inference(image, question):
 
 
 
33
  args = type('Args', (), {
34
  "model_path": model_path,
35
  "model_base": None,
 
7
 
8
  print(os.system('python -m bitsandbytes'))
9
 
10
+ import os
11
+ import torch
12
  import warnings
13
  warnings.filterwarnings('ignore')
14
 
 
20
  from llava.mm_utils import get_model_name_from_path
21
  from llava.eval.run_llava import eval_model
22
 
23
+ # Check CUDA availability with error handling
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ print(f"Using device: {device}")
26
+
27
  # Define the model path
28
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
29
 
30
  # Load the model
31
+ try:
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
33
+ model_path=model_path,
34
+ model_base=None,
35
+ model_name=get_model_name_from_path(model_path)
36
+ )
37
+
38
+ # Move model to appropriate device
39
+ model = model.to(device)
40
+ except Exception as e:
41
+ print(f"Error loading model: {e}")
42
+ tokenizer, model, image_processor, context_len = None, None, None, None
43
 
44
  # Define the inference function
45
  def run_inference(image, question):
46
+ if model is None:
47
+ return "Model failed to load. Please check the logs."
48
+
49
  args = type('Args', (), {
50
  "model_path": model_path,
51
  "model_base": None,