spoorthibhat commited on
Commit
087cd4e
·
verified ·
1 Parent(s): d5e6700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -17
app.py CHANGED
@@ -31,23 +31,16 @@ print(f"Using device: {device}")
31
  # Define the model path
32
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
33
 
34
- # Load the model
35
- try:
36
-
37
- tokenizer, model, image_processor, context_len = load_pretrained_model(
38
- model_path=model_path,
39
- model_base=None,
40
- model_name=get_model_name_from_path(model_path)
41
- )
42
-
43
- # Move model to appropriate device
44
- model = model.to(device)
45
-
46
- # Enable gradient checkpointing to save memory
47
- model.gradient_checkpointing_enable()
48
- except Exception as e:
49
- print(f"Error loading model: {e}")
50
- tokenizer, model, image_processor, context_len = None, None, None, None
51
 
52
  # Define the inference function
53
  def run_inference(image, question):
 
31
  # Define the model path
32
  model_path = "Veda0718/llava-med-v1.5-mistral-7b-finetuned"
33
 
34
+ kwargs = {"device_map": "auto"}
35
+ kwargs['load_in_4bit'] = True
36
+ kwargs['quantization_config'] = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type='nf4'
41
+ )
42
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
43
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
 
 
 
 
 
 
 
44
 
45
  # Define the inference function
46
  def run_inference(image, question):