import gradio as gr
from off_topic import OffTopicDetector
detector = OffTopicDetector("openai/clip-vit-base-patch32")
def validate(item_id: str, threshold: float):
images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id)
valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold]
invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold]
return f"
Domain: {domain}
", valid_images, invalid_images
with gr.Blocks() as demo:
gr.Markdown("""
# Off topic image detector
## This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed.
Input an item ID or select one of the preloaded examples below.""")
item_id = gr.Textbox(label="Item ID")
threshold = gr.Number(label="Threshold", value=0.5)
submit = gr.Button("Submit")
gr.HTML("
")
domain = gr.Markdown()
valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto")
gr.HTML("
")
invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto")
submit.click(inputs=[item_id, threshold], outputs=[domain, valid, invalid], fn=validate)
gr.HTML("
")
gr.Examples(
examples=[["MLU449951849", 0.3], ["MLA1293465558", 0.3], ["MLB3184663685", 0.3], ["MLC1392230619", 0.3]],
inputs=[item_id, threshold],
outputs=[domain, valid, invalid],
fn=validate,
cache_examples=True,
)
demo.launch()