File size: 1,444 Bytes
fe51ab9
 
9c77dee
fe51ab9
 
 
 
 
 
 
fd9138f
fe51ab9
 
b07e0da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe51ab9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gradio as gr

from off_topic import OffTopicDetector

detector = OffTopicDetector("openai/clip-vit-base-patch32")

def validate(item_id: str, threshold: float):
    images, domain, probas, valid_probas, invalid_probas = detector.predict_item_probas(item_id)
    valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
    invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
    return f"<h2>Domain: {domain}</h2>", valid_images, invalid_images

with gr.Blocks() as demo:
    gr.Markdown("""
                # Off topic image detector
                This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain where it belongs to.
                """)
    item_id = gr.Textbox(label="Item ID")
    threshold = gr.Number(label="Threshold", value=0.5)
    submit = gr.Button("Submit")
    domain = gr.Markdown()
    valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
    invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
    submit.click(inputs=[item_id, threshold], outputs=[domain, valid, invalid], fn=validate)
    gr.Examples(
        examples=[["MLU449951849", 0.5], ["MLA1293465558", 0.5]],
        inputs=[item_id, threshold],
        outputs=[domain, valid, invalid],
        fn=validate,
        cache_examples=True,
    )

demo.launch()