import gradio as gr from anomalib.engine import Engine from pathlib import Path # Import all possible model classes from anomalib.models import ( Cfa, Cflow, Csflow, Dfkde, Dfm, Draem, Dsr, EfficientAd, Fastflow, Fre, Ganomaly, Padim, Patchcore, ReverseDistillation, Rkde, Stfpm, Uflow, AiVad, WinClip, ) # Mapping model filename prefixes to corresponding classes model_mapping = { "Cfa": Cfa, "Cflow": Cflow, "Csflow": Csflow, "Dfkde": Dfkde, "Dfm": Dfm, "Draem": Draem, "Dsr": Dsr, "EfficientAd": EfficientAd, "Fastflow": Fastflow, "Fre": Fre, "Ganomaly": Ganomaly, "Padim": Padim, "Patchcore": Patchcore, "ReverseDistillation": ReverseDistillation, "Rkde": Rkde, "Stfpm": Stfpm, "Uflow": Uflow, "AiVad": AiVad, "WinClip": WinClip, } # Define the inference function def predict(image_path, model_path): # Initialize the engine engine = Engine( pixel_metrics="AUROC", accelerator="auto", devices=1, logger=False, ) # Get the model filename prefix to determine the model type model_filename = Path(model_path).stem # Get the filename without extension model_type = model_filename.split("_")[0] # Use the first part of the filename as the model type # Select the corresponding model class based on the filename model_class = model_mapping.get(model_type) if model_class is None: raise ValueError(f"Unknown model type: {model_type}. Please ensure the model file name is correct.") # Initialize the model model = model_class() # Get the image filename image_filename = Path(image_path).name # Dynamically set the result save path, replacing "Padim" with the extracted model type result_dir = Path(f"results/{model_type}/latest/images") result_dir.mkdir(parents=True, exist_ok=True) # Create directory if it doesn't exist # Perform inference engine.predict( data_path=image_path, model=model, ckpt_path=model_path, ) result_path = result_dir / image_filename return str(result_path) # Function to clear input fields def clear_inputs(): return None, None # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Inference/Prediction") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(label="Upload Image", type="filepath") model_input = gr.File(label="Upload Model File") with gr.Row(): predict_button = gr.Button("Run Inference") clear_button = gr.Button("Clear Inputs") with gr.Column(scale=3): # Increase the right column scale output_image = gr.Image(label="Output Image", elem_id="output_image", width="100%", height=600) # Set height # Click the inference button to run the prediction function and output the result predict_button.click( predict, inputs=[image_input, model_input], outputs=output_image ) # Click the clear button to clear input fields clear_button.click( clear_inputs, outputs=[image_input, model_input] ) # Launch the Gradio app demo.launch()