Spaces:
Runtime error
Runtime error
File size: 4,370 Bytes
ce7fa25 fe51ab9 ce7fa25 fe51ab9 75db47e 0d081dc 75db47e fe51ab9 ce7fa25 297a61e fe51ab9 2e8ba25 fe51ab9 ce7fa25 385ab0d 31e18ef ce7fa25 404d6dc ce7fa25 75db47e fe51ab9 b07e0da 2e8ba25 de37095 ce7fa25 2fb60a4 ce7fa25 27ec8e6 ce7fa25 385ab0d ce7fa25 27ec8e6 ce7fa25 27ec8e6 ce7fa25 fe51ab9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import Optional
import gradio as gr
from PIL import Image
from off_topic import OffTopicDetector, Translator
translator = Translator("Helsinki-NLP/opus-mt-roa-en")
detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator)
def validate_item(item_id: str, use_title: bool, threshold: float):
images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title)
valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
return f"## Domain: {domain}", valid_images, invalid_images
def validate_images(img_url_1, img_url_2, img_url_3, domain: str, title: str, threshold: float):
img_urls = [url for url in [img_url_1, img_url_2, img_url_3] if url != ""]
site, domain = domain.split("-")
domain_text = domain.replace("_", " ").lower()
if title == "":
title = None
images, output = detector.predict_probas_url(img_urls, domain_text, site, title)
probas, valid_probas, invalid_probas = output
valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
return f"## Domain: {domain}", valid_images, invalid_images
with gr.Blocks() as demo:
gr.Markdown("""
# Off topic image detector
### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
Input an item ID or select one of the preloaded examples below.""")
with gr.Tab("From item_id"):
with gr.Row():
item_id = gr.Textbox(label="Item ID")
with gr.Column():
use_title = gr.Checkbox(label="Use translated item title", value=True)
threshold = gr.Number(label="Threshold", value=0.25, precision=2)
submit = gr.Button("Submit")
gr.HTML("<hr>")
domain = gr.Markdown()
valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
gr.HTML("<hr>")
invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
submit.click(inputs=[item_id, use_title, threshold], outputs=[domain, valid, invalid], fn=validate_item)
gr.HTML("<hr>")
gr.Examples(
examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
inputs=[item_id, use_title, threshold],
outputs=[domain, valid, invalid],
fn=validate_item,
cache_examples=True,
)
with gr.Tab("From image urls"):
with gr.Row():
with gr.Column():
img_url_1 = gr.Textbox(label="Picture URL")
img_url_2 = gr.Textbox(label="Picture URL")
img_url_3 = gr.Textbox(label="Picture URL")
with gr.Column():
domain = gr.Textbox(label="Domain ID", placeholder="Required")
title = gr.Textbox(label="Item title", placeholder="Optional")
threshold = gr.Number(label="Threshold", value=0.25, precision=2)
submit = gr.Button("Submit")
gr.HTML("<hr>")
domain_output = gr.Markdown()
valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
gr.HTML("<hr>")
invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
submit.click(inputs=[img_url_1, img_url_2, img_url_3, domain, title, threshold], outputs=[domain_output, valid, invalid], fn=validate_images)
gr.HTML("<hr>")
#gr.Examples(
# examples=[["MLC572974424", True, 0.25], ["MLU449951849", True, 0.25], ["MLA1293465558", True, 0.25],
# ["MLB3184663685", True, 0.25], ["MLC1392230619", True, 0.25], ["MCO546152796", True, 0.25]],
# inputs=[item_id, use_title, threshold],
# outputs=[domain, valid, invalid],
# fn=validate,
# cache_examples=True,
#)
demo.launch()
|