Spaces:
Sleeping
Sleeping
#imported all required libraries | |
import streamlit as st | |
import torch | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from transformers import ViTFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel | |
#used a pretrained model hosted on huggingface | |
loc = "ydshieh/vit-gpt2-coco-en" | |
feature_extractor = ViTFeatureExtractor.from_pretrained(loc) | |
tokenizer = AutoTokenizer.from_pretrained(loc) | |
model = VisionEncoderDecoderModel.from_pretrained(loc) | |
model.eval() | |
#defined a function for prediction | |
def predict(image): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
output_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True).sequences | |
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
preds = [pred.strip() for pred in preds] | |
return preds | |
#defined a function for Streamlit App | |
def app(): | |
st.title("ImaginateAI") | |
st.write("ViT and GPT2 are used to generate Image Caption for the uploaded image. COCO Dataset was used for training. This image captioning model might have some biases that I couldn’t figure during testing") | |
st.write("Upload an image or paste a URL to get predicted captions.") | |
upload_option = st.selectbox("Choose an option:", ("Upload Image", "Paste URL")) | |
if upload_option == "Upload Image": | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
preds = predict(image) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
st.write("Predicted Caption:", preds) | |
elif upload_option == "Paste URL": | |
image_url = st.text_input("Enter Image URL") | |
if st.button("Submit") and image_url: | |
try: | |
response = requests.get(image_url, stream=True) | |
image = Image.open(BytesIO(response.content)) | |
preds = predict(image) | |
st.image(image, caption="Image from URL", use_column_width=True) | |
st.write("Predicted Caption:", preds) | |
except: | |
st.write("Error: Invalid URL or unable to fetch image.") | |
if __name__ == "__main__": | |
app() |