import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM # Load model and processor processor = AutoProcessor.from_pretrained("microsoft/git-base") model = AutoModelForCausalLM.from_pretrained("sonukiller/git-base-cartoon") # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) def generate_caption(image): """ Generate a caption for the given image using the custom model """ # Preprocess the image inputs = processor(images=image, return_tensors="pt").to(device) # Generate caption with torch.no_grad(): generated_ids = model.generate( pixel_values=inputs.pixel_values, max_length=50, num_beams=4, early_stopping=True ) # Decode the generated ids to text generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_caption # Create Gradio interface with gr.Blocks(title="Custom Image Captioning", css="footer {visibility: hidden}") as demo: gr.Markdown("# Custom Image Captioning Model") gr.Markdown("Upload an image and get a caption generated by a custom-trained model.") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") caption_button = gr.Button("Generate Caption") with gr.Column(): output_text = gr.Textbox(label="Generated Caption") caption_button.click( fn=generate_caption, inputs=input_image, outputs=output_text ) gr.Examples( examples=[ "examples/example1.png", ], inputs=input_image, outputs=output_text, fn=generate_caption, cache_examples=True, ) # Launch the app demo.launch()