File size: 2,271 Bytes
e02f821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import BlipForConditionalGeneration, BlipProcessor
import time
import gradio as gr


def get_image_captioning_tab():
    salesforce_model_name = "Salesforce/blip-image-captioning-base"
    salesforce_model = BlipForConditionalGeneration.from_pretrained(salesforce_model_name)
    salesforce_processor = BlipProcessor.from_pretrained(salesforce_model_name)

    noamrot_model_name = "noamrot/FuseCap_Image_Captioning"
    noamrot_model = BlipForConditionalGeneration.from_pretrained(noamrot_model_name)
    noamrot_processor = BlipProcessor.from_pretrained(noamrot_model_name)

    model_map = {
        salesforce_model_name: (salesforce_model, salesforce_processor),
        noamrot_model_name: (noamrot_model, noamrot_processor)
    }

    def gradio_process(model_name, image, text):
        (model, processor) = model_map[model_name]
        start = time.time()
        inputs = processor(image, text, return_tensors="pt")
        out = model.generate(**inputs)
        result = processor.decode(out[0], skip_special_tokens=True)
        end = time.time()
        time_spent = end - start

        return [result, time_spent]
    
    with gr.Blocks() as image_captioning_tab:
        gr.Markdown("# Image Captioning")

        with gr.Row():
            with gr.Column():
                # Input components
                input_image = gr.Image(label="Upload Image", type="pil")
                input_text = gr.Textbox(label="Caption")
                model_selector = gr.Dropdown([salesforce_model_name, noamrot_model_name],
                                                label = "Select Model")

                # Process button
                process_btn = gr.Button("Generate caption")

            with gr.Column():
                # Output components
                elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
                output_text = gr.Textbox(label="Generated caption")

        # Connect the input components to the processing function
        process_btn.click(
            fn=gradio_process,
            inputs=[
                model_selector,
                input_image,
                input_text
            ],
            outputs=[output_text, elapsed_result]
        )

    return image_captioning_tab