cartoon-caption / app.py
sonu
Add application file
eb3c81c
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
# Load model and processor
processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("sonukiller/git-base-cartoon")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def generate_caption(image):
"""
Generate a caption for the given image using the custom model
"""
# Preprocess the image
inputs = processor(images=image, return_tensors="pt").to(device)
# Generate caption
with torch.no_grad():
generated_ids = model.generate(
pixel_values=inputs.pixel_values,
max_length=50,
num_beams=4,
early_stopping=True
)
# Decode the generated ids to text
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
# Create Gradio interface
with gr.Blocks(title="Custom Image Captioning", css="footer {visibility: hidden}") as demo:
gr.Markdown("# Custom Image Captioning Model")
gr.Markdown("Upload an image and get a caption generated by a custom-trained model.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
caption_button = gr.Button("Generate Caption")
with gr.Column():
output_text = gr.Textbox(label="Generated Caption")
caption_button.click(
fn=generate_caption,
inputs=input_image,
outputs=output_text
)
gr.Examples(
examples=[
"examples/example1.png",
],
inputs=input_image,
outputs=output_text,
fn=generate_caption,
cache_examples=True,
)
# Launch the app
demo.launch()