Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
# Load model and processor | |
processor = AutoProcessor.from_pretrained("microsoft/git-base") | |
model = AutoModelForCausalLM.from_pretrained("sonukiller/git-base-cartoon") | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
def generate_caption(image): | |
""" | |
Generate a caption for the given image using the custom model | |
""" | |
# Preprocess the image | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
# Generate caption | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
pixel_values=inputs.pixel_values, | |
max_length=50, | |
num_beams=4, | |
early_stopping=True | |
) | |
# Decode the generated ids to text | |
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_caption | |
# Create Gradio interface | |
with gr.Blocks(title="Custom Image Captioning", css="footer {visibility: hidden}") as demo: | |
gr.Markdown("# Custom Image Captioning Model") | |
gr.Markdown("Upload an image and get a caption generated by a custom-trained model.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
caption_button = gr.Button("Generate Caption") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Generated Caption") | |
caption_button.click( | |
fn=generate_caption, | |
inputs=input_image, | |
outputs=output_text | |
) | |
gr.Examples( | |
examples=[ | |
"examples/example1.png", | |
], | |
inputs=input_image, | |
outputs=output_text, | |
fn=generate_caption, | |
cache_examples=True, | |
) | |
# Launch the app | |
demo.launch() |