from typing import Optional import gradio as gr from PIL import Image from off_topic import OffTopicDetector, Translator translator = Translator("Helsinki-NLP/opus-mt-roa-en") detector = OffTopicDetector("openai/clip-vit-base-patch32", image_size="V", translator=translator) def validate_item(item_id: str, use_title: bool, threshold: float): images, domain, probas, valid_probas, invalid_probas = detector.predict_probas_item(item_id, use_title=use_title) valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold] invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold] return f"## Domain: {domain}", valid_images, invalid_images def validate_images(img_url_1, img_url_2, img_url_3, domain: str, title: str, threshold: float): img_urls = [url for url in [img_url_1, img_url_2, img_url_3] if url != ""] site, domain = domain.split("-") domain_text = domain.replace("_", " ").lower() if title == "": title = None images, output = detector.predict_probas_url(img_urls, domain_text, site, title) probas, valid_probas, invalid_probas = output valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold] invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold] return f"## Domain: {domain}", valid_images, invalid_images with gr.Blocks() as demo: gr.Markdown(""" # Off topic image detector ### This app takes an item ID and classifies its pictures as valid/invalid depending on whether they relate to the domain in which it's been listed. Input an item ID or select one of the preloaded examples below.""") with gr.Tab("From item_id"): with gr.Row(): item_id = gr.Textbox(label="Item ID") with gr.Column(): use_title = gr.Checkbox(label="Use translated item title", value=True) threshold = gr.Number(label="Threshold", value=0.25, precision=2) submit = gr.Button("Submit") gr.HTML("