File size: 2,252 Bytes
d02f182
 
 
 
 
e7db843
d02f182
 
 
 
 
 
 
 
 
 
 
 
 
 
e7db843
 
 
70fc45e
e7db843
 
 
 
 
 
d345da7
 
d02f182
48aa533
d02f182
 
48aa533
bdff02d
48aa533
 
 
 
d02f182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
from gradio_client import Client

fuse_client = Client("https://noamrot-fusecap-image-captioning.hf.space/")
clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")
coca_client = Client("https://fffiloni-coca-clone.hf.space/")

def compare(image):
    
    ci_cap = clipi_client.predict(
		image,	# str (filepath or URL to image) in 'parameter_3' Image component
		"best",	# str in 'Select mode' Radio component
		2,	# int | float (numeric value between 2 and 24) in 'best mode max flavors' Slider component
		api_name="/clipi2"
    )
    
    fuse_cap = fuse_client.predict(
		image,	# str representing input in 'raw_image' Image component
		api_name="/predict"
    )
    
    coca_cap = coca_client.predict(
		image,	# filepath  in 'parameter_6' Image component
		"Nucleus sampling",	# Literal[Beam search, Nucleus sampling]  in 'Text Decoding Method' Radio component
		1,	# float (numeric value between 1.0 and 5.0) in 'Repeat Penalty (larger value prevents repetition)' Slider component
		0.5,	# float (numeric value between 0.0 and 1.0) in 'Top p (used with nucleus sampling)' Slider component
		5,	# float  in 'Minimum Sequence Length' Number component
		20,	# float  in 'Maximum Sequence Length (has to higher than Minimum)' Number component
	    api_name="/inference_caption"
    )

    print(f"coca: {coca_cap}")
    
    return ci_cap[0], coca_cap, fuse_cap


css = """
#col-container {max-width: 5810px; margin-left: auto; margin-right: auto;}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
        # Caption compare
        """)

        with gr.Row():
            with gr.Column():
                image_in = gr.Image(label="Image to caption", type="filepath")
                submit_btn = gr.Button("Compare !")
            with gr.Column():
                clip_int_out = gr.Textbox(label="Clip Interrogator")
                coca_out = gr.Textbox(label="CoCa")
                fuse_out = gr.Textbox(label="Fuse Cap")

    submit_btn.click(
        fn = compare,
        inputs = [
            image_in
        ],
        outputs = [
            clip_int_out,
            coca_out,
            fuse_out
        ]
    )

demo.queue().launch()