import gradio as gr from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, Blip2ForConditionalGeneration, VisionEncoderDecoderModel import torch import open_clip from huggingface_hub import hf_hub_download # Load the Blip2 model preprocessor_blip2_8_bit = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b") model_blip2_8_bit = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", load_in_8bit=True) # Load the Blip base model preprocessor_blip_base = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model_blip_base = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") # Load the Blip large model preprocessor_blip_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") model_blip_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") # Load the GIT coco model preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco") model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") # Load the CLIP model model_oc_coca, _, transform_oc_coca = open_clip.create_model_and_transforms( model_name="coca_ViT-L-14", pretrained="mscoco_finetuned_laion2B-s13B-b90k" ) device = "cuda" if torch.cuda.is_available() else "cpu" # Transfer the models to the device model_blip2_8_bit.to(device) model_blip_base.to(device) model_blip_large.to(device) model_git_large_coco.to(device) model_oc_coca.to(device) def generate_caption( preprocessor, model, image, tokenizer=None, use_float_16=False, ): """ Generate captions for the given image. ----- Parameters preprocessor: AutoProcessor The preprocessor for the model. model: BlipForConditionalGeneration The model to use. image: PIL.Image The image to generate captions for. tokenizer: AutoTokenizer The tokenizer to use. If None, the default tokenizer for the model will be used. use_float_16: bool Whether to use float16 precision. This can speed up inference, but may lead to worse results. ----- Returns str The generated caption. """ inputs = preprocessor(image, return_tensors="pt").to(device) if use_float_16: inputs = inputs.to(torch.float16) generated_ids = model.generate( pixel_values=inputs.pixel_values, # attention_mask=inputs.attention_mask, max_length=32, use_cache=True, ) if tokenizer is None: generated_caption = preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0] else: generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_caption def generate_captions_clip( model, transform, image ): """ Generate captions for the given image using CLIP. ----- Parameters model: VisionEncoderDecoderModel The CLIP model to use. transform: Callable The transform to apply to the image before passing it to the model. image: PIL.Image The image to generate captions for. ----- Returns str The generated caption. """ img = transform(image).unsqueeze(0).to(device) with torch.no_grad(), torch.cuda.amp.autocast(): generated = model.generate(img, seq_len=32, do_sample=True, temperature=0.9) generated_caption = model.decode(generated[0].detach()).split("")[0].replace("", "") return generated_caption def generate_captions( image ): """ Generate captions for the given image. ----- Parameters image: PIL.Image The image to generate captions for. ----- Returns str The generated caption. """ # Generate captions for the image using the Blip2 model caption_blip2_8_bit = generate_caption(preprocessor_blip2_8_bit, model_blip2_8_bit, image, use_float_16=True).strip() # Generate captions for the image using the Blip base model caption_blip_base = generate_caption(preprocessor_blip_base, model_blip_base, image).strip() # Generate captions for the image using the Blip large model caption_blip_large = generate_caption(preprocessor_blip_large, model_blip_large, image).strip() # Generate captions for the image using the GIT coco model caption_git_large_coco = generate_caption(preprocessor_git_large_coco, model_git_large_coco, image).strip() # Generate captions for the image using the CLIP model caption_oc_coca = generate_captions_clip(model_oc_coca, transform_oc_coca, image).strip() return caption_blip2_8_bit, caption_blip_base, caption_blip_large, caption_git_large_coco, caption_oc_coca # Create the interface iface = gr.Interface( fn=generate_captions, # Define the inputs: Image, Slider for Max Length, Slider for Temperature inputs=[ gr.inputs.Image(label="Image"), gr.inputs.Slider(minimum=16, maximum=64, step=2, default=32, label="Max Length"), gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.0, label="Temperature"), ], # Define the outputs outputs=[ gr.outputs.Textbox(label="Blip2 8-bit"), gr.outputs.Textbox(label="Blip base"), gr.outputs.Textbox(label="Blip large"), gr.outputs.Textbox(label="GIT large coco"), gr.outputs.Textbox(label="CLIP"), ], title="Image Captioning", description="Generate captions for images using the Blip2 model, the Blip base model, the Blip large model, the GIT large coco model, and the CLIP model.", enable_queue=True, ) # Launch the interface iface.launch()