ShreyMehra commited on
Commit
3ac1bdf
·
unverified ·
1 Parent(s): 3542c9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -14,8 +14,6 @@ st.markdown("Link to the model - [Image-to-Caption-App on 🤗 Spaces](https://h
14
  #image uploader
15
  image = st.file_uploader(label = "Upload your image here",type=['png','jpg','jpeg'])
16
 
17
-
18
- @st.cache
19
  def load_model():
20
  peft_model_id = "Shrey23/Image-Captioning"
21
  config = PeftConfig.from_pretrained(peft_model_id)
@@ -24,8 +22,15 @@ def load_model():
24
 
25
  processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
26
  return processor, model
 
 
 
 
 
 
 
 
27
 
28
- processor, model = load_model() #load model
29
 
30
  if image is not None:
31
 
@@ -34,12 +39,12 @@ if image is not None:
34
 
35
  with st.spinner("🤖 AI is at Work! "):
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
- inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
38
  pixel_values = inputs.pixel_values
39
 
40
 
41
- generated_ids = model.generate(pixel_values=pixel_values, max_length=25)
42
- generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
 
44
  st.write(generated_caption)
45
 
 
14
  #image uploader
15
  image = st.file_uploader(label = "Upload your image here",type=['png','jpg','jpeg'])
16
 
 
 
17
  def load_model():
18
  peft_model_id = "Shrey23/Image-Captioning"
19
  config = PeftConfig.from_pretrained(peft_model_id)
 
22
 
23
  processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
24
  return processor, model
25
+
26
+
27
+ if "model" not in st.session_state:
28
+ processor, model = load_model() #load model
29
+ st.session_state.dict = {}
30
+ st.session_state.dict['processor'] = processor
31
+ st.session_state.dict['model'] = model
32
+
33
 
 
34
 
35
  if image is not None:
36
 
 
39
 
40
  with st.spinner("🤖 AI is at Work! "):
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ inputs = st.session_state.dict['processor'](images=image, return_tensors="pt").to(device, torch.float16)
43
  pixel_values = inputs.pixel_values
44
 
45
 
46
+ generated_ids = st.session_state.dict['model'].generate(pixel_values=pixel_values, max_length=25)
47
+ generated_caption = st.session_state.dict['processor'].batch_decode(generated_ids, skip_special_tokens=True)[0]
48
 
49
  st.write(generated_caption)
50