import gradio as gr from . import OffTopicDetector detector = OffTopicDetector("openai/clip-vit-base-patch32") def validate(item_id: str, threshold: float): images, domain, probas, valid_probas, invalid_probas = detector.predict_item_probas(item_id) valid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() >= threshold] invalid_images = [x for i, x in enumerate(images) if valid_probas[i].squeeze() < threshold] return valid_images, invalid_images with gr.Blocks() as demo: with gr.Tabs(): with gr.Tab("From Item ID"): item_id = gr.Textbox(label="Item ID") threshold = gr.Number(label="Threshold", value=0.5) submit = gr.Button("Submit") valid = gr.Gallery(label="Valid images").style(grid=[1, 2, 3], height="auto") invalid = gr.Gallery(label="Invalid images").style(grid=[1, 2, 3], height="auto") submit.click(inputs=[item_id, threshold], outputs=[valid, invalid], fn=validate) demo.launch()