Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -31,23 +31,16 @@ print(f"Using device: {device}")
|
|
31 |
# Define the model path
|
32 |
model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
model = model.to(device)
|
45 |
-
|
46 |
-
# Enable gradient checkpointing to save memory
|
47 |
-
model.gradient_checkpointing_enable()
|
48 |
-
except Exception as e:
|
49 |
-
print(f"Error loading model: {e}")
|
50 |
-
tokenizer, model, image_processor, context_len = None, None, None, None
|
51 |
|
52 |
# Define the inference function
|
53 |
def run_inference(image, question):
|
|
|
31 |
# Define the model path
|
32 |
model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
|
33 |
|
34 |
+
kwargs = {"device_map": "auto"}
|
35 |
+
kwargs['load_in_4bit'] = True
|
36 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
37 |
+
load_in_4bit=True,
|
38 |
+
bnb_4bit_compute_dtype=torch.float16,
|
39 |
+
bnb_4bit_use_double_quant=True,
|
40 |
+
bnb_4bit_quant_type='nf4'
|
41 |
+
)
|
42 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# Define the inference function
|
46 |
def run_inference(image, question):
|