SaiBrahmam commited on
Commit
fd16bc7
·
1 Parent(s): 5462c07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -32
app.py CHANGED
@@ -1,45 +1,60 @@
1
  import streamlit as st
2
- from PIL import Image
3
- import requests
4
  import torch
 
 
5
  from torchvision import transforms
6
  from torchvision.transforms.functional import InterpolationMode
 
7
  from models.blip import blip_decoder
8
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
- def load_image(image_url, image_size, device):
12
- raw_image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
13
-
 
14
  w,h = raw_image.size
15
- display(raw_image.resize((w//5,h//5)))
16
-
17
  transform = transforms.Compose([
18
- transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
19
  transforms.ToTensor(),
20
  transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
21
- ])
22
- image = transform(raw_image).unsqueeze(0).to(device)
23
- return image
24
-
25
- def generate_caption(image_url):
26
- image_size = 384
27
- model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
28
- model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
29
- model.eval()
30
- model = model.to(device)
31
- image = load_image(image_url, image_size, device)
32
- with torch.no_grad():
33
- captions = []
34
- for i in range(3):
35
- caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
36
- captions.append(caption[0])
37
- return captions
38
 
39
- st.title("Image Caption Generator")
40
-
41
- image_url = st.text_input("Enter the image URL:")
42
- if image_url:
43
- captions = generate_caption(image_url)
44
- for i, caption in enumerate(captions):
45
- st.write(f'caption {i+1}: {caption}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
2
  import torch
3
+ import requests
4
+ from PIL import Image
5
  from torchvision import transforms
6
  from torchvision.transforms.functional import InterpolationMode
7
+
8
  from models.blip import blip_decoder
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ @st.cache(show_spinner=False)
13
+ def load_demo_image(image_size, device):
14
+ img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
15
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
16
  w,h = raw_image.size
 
 
17
  transform = transforms.Compose([
18
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
19
  transforms.ToTensor(),
20
  transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
21
+ ])
22
+ image = transform(raw_image).unsqueeze(0).to(device)
23
+ return image, raw_image.resize((w//5,h//5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def main():
26
+ st.set_page_config(page_title="Image Captioning App")
27
+ st.title("Image Captioning App")
28
+ st.write("This app generates captions for images using a pre-trained model.")
29
+
30
+ # Load image
31
+ image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
32
+ if image_file is not None:
33
+ image = Image.open(image_file)
34
+ image_size = 384
35
+ transform = transforms.Compose([
36
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
39
+ ])
40
+ image = transform(image).unsqueeze(0).to(device)
41
+
42
+ # Generate captions
43
+ with torch.no_grad():
44
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
45
+ model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
46
+ model.eval()
47
+ model = model.to(device)
48
+ num_captions = 3
49
+ captions = []
50
+ for i in range(num_captions):
51
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
52
+ captions.append(caption[0])
53
+ for i, caption in enumerate(captions):
54
+ st.write(f'Caption {i+1}: {caption}')
55
+
56
+ # Display uploaded image
57
+ st.image(image_file, caption='Uploaded image', use_column_width=True)
58
+
59
+ if __name__ == "__main__":
60
+ main()