File size: 4,370 Bytes
ce7fa25
 
fe51ab9
ce7fa25
fe51ab9
75db47e
 
 
0d081dc
75db47e
fe51ab9
 
ce7fa25
297a61e
fe51ab9
 
2e8ba25
fe51ab9
ce7fa25
 
385ab0d
31e18ef
ce7fa25
 
404d6dc
 
ce7fa25
 
 
 
75db47e
fe51ab9
b07e0da
 
2e8ba25
de37095
ce7fa25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb60a4
ce7fa25
 
 
 
 
27ec8e6
 
 
ce7fa25
385ab0d
ce7fa25
 
 
 
27ec8e6
ce7fa25
 
 
27ec8e6
ce7fa25
 
 
 
 
 
 
 
 
fe51ab9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Optional

import gradio as gr
from PIL import Image

from off_topic import OffTopicDetector, Translator


translator = Translator("Helsinki-NLP/opus-mt-roa-en")
detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator)


def validate_item(item_id: str, use_title: bool, threshold: float):
    images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
    valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
    invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
    return f"## Domain: {domain}", valid_images, invalid_images

def validate_images(img_url_1, img_url_2, img_url_3, domain: str, title: str, threshold: float):
    img_urls = [url for url in [img_url_1, img_url_2, img_url_3] if url != ""]
    site, domain = domain.split("-")
    domain_text = domain.replace("_", " ").lower()
    if title == "":
        title = None
    images, output = detector.predict_probas_url(img_urls, domain_text, site, title)
    probas, valid_probas, invalid_probas = output
    valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
    invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
    return f"## Domain: {domain}", valid_images, invalid_images


with gr.Blocks() as demo:
    gr.Markdown("""
                # Off topic image detector
                ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
                Input an item ID or select one of the preloaded examples below.""")
    with gr.Tab("From item_id"):
        with gr.Row():
            item_id = gr.Textbox(label="Item ID")
            with gr.Column():
                use_title = gr.Checkbox(label="Use translated item title", value=True)
                threshold = gr.Number(label="Threshold", value=0.25, precision=2)
            submit = gr.Button("Submit")
        gr.HTML("<hr>")
        domain = gr.Markdown()
        valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
        gr.HTML("<hr>")
        invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
        submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate_item)
        gr.HTML("<hr>")
        gr.Examples(
            examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
                      ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
            inputs=[item_id, use_title, threshold],
            outputs=[domain, valid, invalid],
            fn=validate_item,
            cache_examples=True,
        )
    with gr.Tab("From image urls"):
        with gr.Row():
            with gr.Column():
                img_url_1 = gr.Textbox(label="Picture URL")
                img_url_2 = gr.Textbox(label="Picture URL")
                img_url_3 = gr.Textbox(label="Picture URL")
            with gr.Column():
                domain = gr.Textbox(label="Domain ID", placeholder="Required")
                title = gr.Textbox(label="Item title", placeholder="Optional")
                threshold = gr.Number(label="Threshold", value=0.25, precision=2)
            submit = gr.Button("Submit")
        gr.HTML("<hr>")
        domain_output = gr.Markdown()
        valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
        gr.HTML("<hr>")
        invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
        submit.click(inputs=[img_url_1, img_url_2, img_url_3, domain, title, threshold], outputs=[domain_output, valid, invalid], fn=validate_images)
        gr.HTML("<hr>")
        #gr.Examples(
        #    examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
        #              ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
        #    inputs=[item_id, use_title, threshold],
        #    outputs=[domain, valid, invalid],
        #    fn=validate,
        #    cache_examples=True,
        #)

demo.launch()