Prasi21 commited on
Commit
a660aaa
·
verified ·
1 Parent(s): a39ea6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -1,29 +1,24 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import Blip2ForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
4
- from peft import PeftModel, PeftConfig
5
-
6
- # Load the PEFT model configuration and quantization settings
7
- peft_model_id = "Prasi21/blip2-opt-2.7b-strep-throat-caption-adapters3"
8
- config = PeftConfig.from_pretrained(peft_model_id)
9
- config.base_model_name_or_path = "Prasi21/blip2-opt-2.7b-strep-throat-caption-adapters3"
10
-
11
- # Enable 8-bit quantization for more efficient loading
12
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
13
-
14
- # Load the base model with quantization
15
- model = Blip2ForConditionalGeneration.from_pretrained(
16
- config.base_model_name_or_path,
17
- quantization_config=quantization_config,
18
- device_map="auto"
19
- )
20
-
21
- # Load the fine-tuned PEFT model
22
- model = PeftModel.from_pretrained(model, peft_model_id)
23
 
24
  # Load the processor
25
  processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Define the prediction function
28
  def predict(image):
29
  # Preprocess the image
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
4
+ from peft import LoraConfig, get_peft_model, PeftModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Load the processor
7
  processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
8
 
9
+ # Load the base model from the original repository
10
+ base_model = Blip2ForConditionalGeneration.from_pretrained(
11
+ "ybelkada/blip2-opt-2.7b-fp16-sharded",
12
+ device_map="auto",
13
+ quantization_config=quantization_config
14
+ )
15
+
16
+ repo_id = "Prasi21/blip2-opt-2.7b-strep-throat-caption-adapters"
17
+
18
+ # Load the fine-tuned LoRA adapters from the Hugging Face Hub
19
+ model = PeftModel.from_pretrained(base_model, repo_id)
20
+
21
+
22
  # Define the prediction function
23
  def predict(image):
24
  # Preprocess the image