nishantguvvada's picture
Update app.py
da868a5
raw
history blame
1.63 kB
import streamlit as st
import tensorflow as tf
import numpy as np
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation")
image = Image.open('./title.jpg')
st.image(image)
st.write("""
# Multi-Modal Machine Learning
"""
)
file = st.file_uploader("Upload an image to generate captions!", type= ['png', 'jpg'])
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_step(image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
def on_click():
if file is None:
st.text("Please upload an image file")
else:
predict_step(file)
st.button('Generate', on_click=on_click)