Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import T5Tokenizer, ViTFeatureExtractor | |
# Model loading and setting up the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = torch.load("model_vit_ai.pt", map_location=device) | |
model.to(device) | |
# Tokenizer and Feature Extractor | |
tokenizer = T5Tokenizer.from_pretrained('t5-base') | |
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') | |
# Define the image preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | |
]) | |
def preprocess_image(image): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
image = transform(image) | |
return image.unsqueeze(0) | |
def generate_caption(image): | |
model.eval() | |
with torch.no_grad(): | |
image_tensor = preprocess_image(image).to(device) | |
decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device) | |
for _ in range(50): | |
outputs = model(images=image_tensor, decoder_ids=decoder_input_ids) | |
next_token_logits = outputs.logits[:, -1, :] | |
next_token_id = next_token_logits.argmax(1, keepdim=True) | |
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1) | |
if torch.eq(next_token_id, tokenizer.eos_token_id).all(): | |
break | |
caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True) | |
return caption | |
sample_images = [ | |
"sample_image1.jpg", | |
"sample_image2.jpg", | |
"sample_image3.jpg" | |
] | |
# Define Gradio interface | |
interface = gr.Interface( | |
fn=generate_caption, | |
inputs=gr.inputs.Image(source="upload", tool='editor', type="numpy", label="Upload an image or take a photo"), | |
outputs='text', | |
examples=sample_images, | |
title="Image Captioning Model", | |
description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption." | |
) | |
# Run the interface | |
interface.launch(debug=True) | |