ShreyMehra commited on
Commit
ac752a4
·
unverified ·
1 Parent(s): 306e386

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -3,6 +3,7 @@ from PIL import Image #Image Processing
3
  import numpy as np #Image Processing
4
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
5
  import torch
 
6
 
7
 
8
  #title
@@ -16,8 +17,12 @@ image = st.file_uploader(label = "Upload your image here",type=['png','jpg','jpe
16
 
17
  @st.cache
18
  def load_model():
 
 
 
 
 
19
  processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
20
- model = Blip2ForConditionalGeneration.from_pretrained("Shrey23/Image-Captioning", device_map="auto", )
21
  return processor, model
22
 
23
  processor, model = load_model() #load model
 
3
  import numpy as np #Image Processing
4
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
5
  import torch
6
+ from peft import PeftModel, PeftConfig
7
 
8
 
9
  #title
 
17
 
18
  @st.cache
19
  def load_model():
20
+ peft_model_id = "Shrey23/Image-Captioning"
21
+ config = PeftConfig.from_pretrained(peft_model_id)
22
+ model = Blip2ForConditionalGeneration.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map="auto")
23
+ model = PeftModel.from_pretrained(model, peft_model_id)
24
+
25
  processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
 
26
  return processor, model
27
 
28
  processor, model = load_model() #load model