SaiBrahmam commited on
Commit
5462c07
·
1 Parent(s): 4c155db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -41
app.py CHANGED
@@ -1,58 +1,45 @@
1
- # install requirements
2
- import requests
3
  from PIL import Image
 
4
  import torch
5
  from torchvision import transforms
6
  from torchvision.transforms.functional import InterpolationMode
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
- def load_demo_image(image_size, device, img_url):
12
- raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
13
 
14
- w, h = raw_image.size
15
  display(raw_image.resize((w//5,h//5)))
16
 
17
  transform = transforms.Compose([
18
- transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
19
  transforms.ToTensor(),
20
  transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
21
- ])
22
  image = transform(raw_image).unsqueeze(0).to(device)
23
  return image
24
 
25
- def generate_captions(image, model):
26
- # beam search
27
- #captions = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5, num_return_sequences=3)
28
- # nucleus sampling
29
- num_captions = 3
30
- captions = []
31
- for i in range(num_captions):
32
- caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
33
- captions.append(caption[0])
34
- for i, caption in enumerate(captions):
35
- print(f'caption {i+1}: {caption}')
36
-
37
- model_name = 'EleutherAI/gpt-neo-1.3B'
38
- tokenizer = AutoTokenizer.from_pretrained(model_name)
39
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
40
- model.eval()
41
-
42
- # Streamlit app code
43
- import streamlit as st
44
-
45
- st.title('Image Caption Generator')
46
-
47
- # Get user input
48
- img_url = st.text_input('Enter image URL', 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg')
49
-
50
- if img_url:
51
- # Load image
52
  image_size = 384
53
- image = load_demo_image(image_size, device, img_url)
54
-
55
- # Generate captions
56
- input_text = tokenizer.decode(tokenizer(image.tolist()[0]).input_ids)
57
- print(f'Input text: {input_text}')
58
- generate_captions(image, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
 
2
  from PIL import Image
3
+ import requests
4
  import torch
5
  from torchvision import transforms
6
  from torchvision.transforms.functional import InterpolationMode
7
+ from models.blip import blip_decoder
8
 
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
+ def load_image(image_url, image_size, device):
12
+ raw_image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
13
 
14
+ w,h = raw_image.size
15
  display(raw_image.resize((w//5,h//5)))
16
 
17
  transform = transforms.Compose([
18
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
19
  transforms.ToTensor(),
20
  transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
21
+ ])
22
  image = transform(raw_image).unsqueeze(0).to(device)
23
  return image
24
 
25
+ def generate_caption(image_url):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  image_size = 384
27
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
28
+ model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
29
+ model.eval()
30
+ model = model.to(device)
31
+ image = load_image(image_url, image_size, device)
32
+ with torch.no_grad():
33
+ captions = []
34
+ for i in range(3):
35
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
36
+ captions.append(caption[0])
37
+ return captions
38
+
39
+ st.title("Image Caption Generator")
40
+
41
+ image_url = st.text_input("Enter the image URL:")
42
+ if image_url:
43
+ captions = generate_caption(image_url)
44
+ for i, caption in enumerate(captions):
45
+ st.write(f'caption {i+1}: {caption}')