Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,8 @@ print(torch.cuda.is_available())
|
|
7 |
|
8 |
print(os.system('python -m bitsandbytes'))
|
9 |
|
|
|
|
|
10 |
import warnings
|
11 |
warnings.filterwarnings('ignore')
|
12 |
|
@@ -18,18 +20,32 @@ from llava.model.builder import load_pretrained_model
|
|
18 |
from llava.mm_utils import get_model_name_from_path
|
19 |
from llava.eval.run_llava import eval_model
|
20 |
|
|
|
|
|
|
|
|
|
21 |
# Define the model path
|
22 |
model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
|
23 |
|
24 |
# Load the model
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Define the inference function
|
32 |
def run_inference(image, question):
|
|
|
|
|
|
|
33 |
args = type('Args', (), {
|
34 |
"model_path": model_path,
|
35 |
"model_base": None,
|
|
|
7 |
|
8 |
print(os.system('python -m bitsandbytes'))
|
9 |
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
import warnings
|
13 |
warnings.filterwarnings('ignore')
|
14 |
|
|
|
20 |
from llava.mm_utils import get_model_name_from_path
|
21 |
from llava.eval.run_llava import eval_model
|
22 |
|
23 |
+
# Check CUDA availability with error handling
|
24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
print(f"Using device: {device}")
|
26 |
+
|
27 |
# Define the model path
|
28 |
model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
|
29 |
|
30 |
# Load the model
|
31 |
+
try:
|
32 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
33 |
+
model_path=model_path,
|
34 |
+
model_base=None,
|
35 |
+
model_name=get_model_name_from_path(model_path)
|
36 |
+
)
|
37 |
+
|
38 |
+
# Move model to appropriate device
|
39 |
+
model = model.to(device)
|
40 |
+
except Exception as e:
|
41 |
+
print(f"Error loading model: {e}")
|
42 |
+
tokenizer, model, image_processor, context_len = None, None, None, None
|
43 |
|
44 |
# Define the inference function
|
45 |
def run_inference(image, question):
|
46 |
+
if model is None:
|
47 |
+
return "Model failed to load. Please check the logs."
|
48 |
+
|
49 |
args = type('Args', (), {
|
50 |
"model_path": model_path,
|
51 |
"model_base": None,
|