rxavier commited on
Commit
ce7fa25
1 Parent(s): 23c3e28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -22
app.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  import gradio as gr
 
2
 
3
  from off_topic import OffTopicDetector, Translator
4
 
@@ -7,38 +10,74 @@ translator = Translator("Helsinki-NLP/opus-mt-roa-en")
7
  detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator)
8
 
9
 
10
- def validate(item_id: str, use_title: bool, threshold: float):
11
  images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
12
  valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
13
  invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
14
  return f"## Domain: {domain}", valid_images, invalid_images
15
 
 
 
 
 
 
 
 
 
 
16
 
17
  with gr.Blocks() as demo:
18
  gr.Markdown("""
19
  # Off topic image detector
20
  ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
21
  Input an item ID or select one of the preloaded examples below.""")
22
- with gr.Row():
23
- item_id = gr.Textbox(label="Item ID")
24
- with gr.Column():
25
- use_title = gr.Checkbox(label="Use translated item title", value=True)
26
- threshold = gr.Number(label="Threshold", value=0.25, precision=2)
27
- submit = gr.Button("Submit")
28
- gr.HTML("<hr>")
29
- domain = gr.Markdown()
30
- valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
31
- gr.HTML("<hr>")
32
- invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
33
- submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate)
34
- gr.HTML("<hr>")
35
- gr.Examples(
36
- examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
37
- ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
38
- inputs=[item_id, use_title, threshold],
39
- outputs=[domain, valid, invalid],
40
- fn=validate,
41
- cache_examples=True,
42
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  demo.launch()
 
1
+ from typing import Optional
2
+
3
  import gradio as gr
4
+ from PIL import Image
5
 
6
  from off_topic import OffTopicDetector, Translator
7
 
 
10
  detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator)
11
 
12
 
13
+ def validate_item(item_id: str, use_title: bool, threshold: float):
14
  images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
15
  valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
16
  invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
17
  return f"## Domain: {domain}", valid_images, invalid_images
18
 
19
+ def validate_images(img_url_1, img_url_2, img_url_3, domain: str, title: str, threshold: float):
20
+ img_urls = [url for url in [img_url_1, img_url_2, img_url_3] if url != ""]
21
+ if title == "":
22
+ title = None
23
+ images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_url(img_urls, domain, title)
24
+ valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
25
+ invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
26
+ return f"## Domain: {domain}", valid_images, invalid_images
27
+
28
 
29
  with gr.Blocks() as demo:
30
  gr.Markdown("""
31
  # Off topic image detector
32
  ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
33
  Input an item ID or select one of the preloaded examples below.""")
34
+ with gr.Tab("From item_id"):
35
+ with gr.Row():
36
+ item_id = gr.Textbox(label="Item ID")
37
+ with gr.Column():
38
+ use_title = gr.Checkbox(label="Use translated item title", value=True)
39
+ threshold = gr.Number(label="Threshold", value=0.25, precision=2)
40
+ submit = gr.Button("Submit")
41
+ gr.HTML("<hr>")
42
+ domain = gr.Markdown()
43
+ valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
44
+ gr.HTML("<hr>")
45
+ invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
46
+ submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate_item)
47
+ gr.HTML("<hr>")
48
+ gr.Examples(
49
+ examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
50
+ ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
51
+ inputs=[item_id, use_title, threshold],
52
+ outputs=[domain, valid, invalid],
53
+ fn=validate,
54
+ cache_examples=True,
55
+ )
56
+ with gr.Tab("From image urls"):
57
+ with gr.Row():
58
+ with gr.Column():
59
+ pic_url_1 = gr.Textbox(label="Picture URL")
60
+ pic_url_1 = gr.Textbox(label="Picture URL")
61
+ pic_url_1 = gr.Textbox(label="Picture URL")
62
+ with gr.Column():
63
+ domain = gr.Textbox(label="Domain name", placeholder="Required")
64
+ title = gr.Textbox(label="Item title", placeholder="Optional")
65
+ threshold = gr.Number(label="Threshold", value=0.25, precision=2)
66
+ submit = gr.Button("Submit")
67
+ gr.HTML("<hr>")
68
+ domain = gr.Markdown()
69
+ valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
70
+ gr.HTML("<hr>")
71
+ invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
72
+ submit.click(inputs=[pic_url_1, pic_url_2, pic_url_3, domain, title, threshold], outputs=[domain, valid, invalid], fn=validate_images)
73
+ gr.HTML("<hr>")
74
+ #gr.Examples(
75
+ # examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
76
+ # ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
77
+ # inputs=[item_id, use_title, threshold],
78
+ # outputs=[domain, valid, invalid],
79
+ # fn=validate,
80
+ # cache_examples=True,
81
+ #)
82
 
83
  demo.launch()