File size: 4,247 Bytes
ce7fa25
 
fe51ab9
ce7fa25
fe51ab9
75db47e
 
 
0d081dc
75db47e
fe51ab9
 
ce7fa25
297a61e
fe51ab9
 
2e8ba25
fe51ab9
ce7fa25
 
 
 
 
 
 
 
 
75db47e
fe51ab9
b07e0da
 
2e8ba25
de37095
ce7fa25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb60a4
ce7fa25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe51ab9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Optional

import gradio as gr
from PIL import Image

from off_topic import OffTopicDetector, Translator


translator = Translator("Helsinki-NLP/opus-mt-roa-en")
detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator)


def validate_item(item_id: str, use_title: bool, threshold: float):
    images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
    valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
    invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
    return f"## Domain: {domain}", valid_images, invalid_images

def validate_images(img_url_1, img_url_2, img_url_3, domain: str, title: str, threshold: float):
    img_urls = [url for url in [img_url_1, img_url_2, img_url_3] if url != ""]
    if title == "":
        title = None
    images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_url(img_urls, domain, title)
    valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
    invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
    return f"## Domain: {domain}", valid_images, invalid_images


with gr.Blocks() as demo:
    gr.Markdown("""
                # Off topic image detector
                ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
                Input an item ID or select one of the preloaded examples below.""")
    with gr.Tab("From item_id"):
        with gr.Row():
            item_id = gr.Textbox(label="Item ID")
            with gr.Column():
                use_title = gr.Checkbox(label="Use translated item title", value=True)
                threshold = gr.Number(label="Threshold", value=0.25, precision=2)
            submit = gr.Button("Submit")
        gr.HTML("<hr>")
        domain = gr.Markdown()
        valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
        gr.HTML("<hr>")
        invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
        submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate_item)
        gr.HTML("<hr>")
        gr.Examples(
            examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
                      ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
            inputs=[item_id, use_title, threshold],
            outputs=[domain, valid, invalid],
            fn=validate_item,
            cache_examples=True,
        )
    with gr.Tab("From image urls"):
        with gr.Row():
            with gr.Column():
                pic_url_1 = gr.Textbox(label="Picture URL")
                pic_url_1 = gr.Textbox(label="Picture URL")
                pic_url_1 = gr.Textbox(label="Picture URL")
            with gr.Column():
                domain = gr.Textbox(label="Domain name", placeholder="Required")
                title = gr.Textbox(label="Item title", placeholder="Optional")
                threshold = gr.Number(label="Threshold", value=0.25, precision=2)
            submit = gr.Button("Submit")
        gr.HTML("<hr>")
        domain = gr.Markdown()
        valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
        gr.HTML("<hr>")
        invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
        submit.click(inputs=[pic_url_1, pic_url_2, pic_url_3, domain, title, threshold], outputs=[domain, valid, invalid], fn=validate_images)
        gr.HTML("<hr>")
        #gr.Examples(
        #    examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
        #              ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
        #    inputs=[item_id, use_title, threshold],
        #    outputs=[domain, valid, invalid],
        #    fn=validate,
        #    cache_examples=True,
        #)

demo.launch()