NourFakih commited on
Commit
2f582b2
·
verified ·
1 Parent(s): ac70fb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -18
app.py CHANGED
@@ -1,35 +1,35 @@
1
  import streamlit as st
 
 
2
  from PIL import Image
3
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
4
- from datetime import datetime
5
- import pandas as pd
6
- import tempfile
7
- import base64
8
  import nltk
 
 
9
  import spacy
10
  from spacy.cli import download
11
- from streamlit_option_menu import option_menu
12
- import torch
13
-
14
- # Set device
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # Download necessary NLTK and spaCy data
18
  nltk.download('wordnet')
19
  nltk.download('omw-1.4')
20
  download("en_core_web_sm")
21
-
22
- # Load the models
23
  nlp = spacy.load("en_core_web_sm")
24
- model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-115k-12"
 
 
25
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
26
  feature_extractor = ViTImageProcessor.from_pretrained(model_name)
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
28
  tokenizer.pad_token = tokenizer.eos_token
 
29
  model.config.eos_token_id = tokenizer.eos_token_id
30
  model.config.decoder_start_token_id = tokenizer.bos_token_id
31
  model.config.pad_token_id = tokenizer.pad_token_id
32
- image_captioner = pipeline('image-to-text', model=model_name)
33
 
34
  model_sum_name = "google-t5/t5-base"
35
  tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base")
@@ -40,8 +40,10 @@ if 'captured_images' not in st.session_state:
40
  st.session_state.captured_images = []
41
 
42
  def generate_caption(image):
43
- caption = image_captioner(image)
44
- return caption[0]['generated_text']
 
 
45
 
46
  def get_synonyms(word):
47
  synonyms = set()
@@ -88,9 +90,9 @@ def page_webcam_capture():
88
 
89
  if img_file:
90
  img = Image.open(img_file)
91
- img_array = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
92
  caption = generate_caption(img)
93
- capture_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
94
  add_image_to_state(img_array, caption, capture_time)
95
  st.image(img, caption=f"Caption: {caption}")
96
 
 
1
  import streamlit as st
2
+ import cv2
3
+ import pandas as pd
4
  from PIL import Image
5
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
 
 
 
 
6
  import nltk
7
+ import tempfile
8
+ from nltk.corpus import wordnet
9
  import spacy
10
  from spacy.cli import download
11
+ import base64
12
+ import numpy as np
13
+ import datetime
 
 
14
 
15
+ # Download necessary NLP models
16
  nltk.download('wordnet')
17
  nltk.download('omw-1.4')
18
  download("en_core_web_sm")
 
 
19
  nlp = spacy.load("en_core_web_sm")
20
+
21
+ # Load the pre-trained models for image captioning and summarization
22
+ model_name = "NourFakih/Vit-GPT2-COCO2017Flickr-85k-09"
23
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
24
  feature_extractor = ViTImageProcessor.from_pretrained(model_name)
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+
27
+ # GPT2 only has bos/eos tokens but not decoder_start/pad tokens
28
  tokenizer.pad_token = tokenizer.eos_token
29
+ # update the model config
30
  model.config.eos_token_id = tokenizer.eos_token_id
31
  model.config.decoder_start_token_id = tokenizer.bos_token_id
32
  model.config.pad_token_id = tokenizer.pad_token_id
 
33
 
34
  model_sum_name = "google-t5/t5-base"
35
  tokenizer_sum = AutoTokenizer.from_pretrained("google-t5/t5-base")
 
40
  st.session_state.captured_images = []
41
 
42
  def generate_caption(image):
43
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
44
+ output_ids = model.generate(pixel_values)
45
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
46
+ return caption
47
 
48
  def get_synonyms(word):
49
  synonyms = set()
 
90
 
91
  if img_file:
92
  img = Image.open(img_file)
93
+ img_array = np.array(img)
94
  caption = generate_caption(img)
95
+ capture_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
96
  add_image_to_state(img_array, caption, capture_time)
97
  st.image(img, caption=f"Caption: {caption}")
98