IADBE_Inference / app.py
jinyao
initialization
cbf3611
import gradio as gr
from anomalib.engine import Engine
from pathlib import Path
# Import all possible model classes
from anomalib.models import (
Cfa,
Cflow,
Csflow,
Dfkde,
Dfm,
Draem,
Dsr,
EfficientAd,
Fastflow,
Fre,
Ganomaly,
Padim,
Patchcore,
ReverseDistillation,
Rkde,
Stfpm,
Uflow,
AiVad,
WinClip,
)
# Mapping model filename prefixes to corresponding classes
model_mapping = {
"Cfa": Cfa,
"Cflow": Cflow,
"Csflow": Csflow,
"Dfkde": Dfkde,
"Dfm": Dfm,
"Draem": Draem,
"Dsr": Dsr,
"EfficientAd": EfficientAd,
"Fastflow": Fastflow,
"Fre": Fre,
"Ganomaly": Ganomaly,
"Padim": Padim,
"Patchcore": Patchcore,
"ReverseDistillation": ReverseDistillation,
"Rkde": Rkde,
"Stfpm": Stfpm,
"Uflow": Uflow,
"AiVad": AiVad,
"WinClip": WinClip,
}
# Define the inference function
def predict(image_path, model_path):
# Initialize the engine
engine = Engine(
pixel_metrics="AUROC",
accelerator="auto",
devices=1,
logger=False,
)
# Get the model filename prefix to determine the model type
model_filename = Path(model_path).stem # Get the filename without extension
model_type = model_filename.split("_")[0] # Use the first part of the filename as the model type
# Select the corresponding model class based on the filename
model_class = model_mapping.get(model_type)
if model_class is None:
raise ValueError(f"Unknown model type: {model_type}. Please ensure the model file name is correct.")
# Initialize the model
model = model_class()
# Get the image filename
image_filename = Path(image_path).name
# Dynamically set the result save path, replacing "Padim" with the extracted model type
result_dir = Path(f"results/{model_type}/latest/images")
result_dir.mkdir(parents=True, exist_ok=True) # Create directory if it doesn't exist
# Perform inference
engine.predict(
data_path=image_path,
model=model,
ckpt_path=model_path,
)
result_path = result_dir / image_filename
return str(result_path)
# Function to clear input fields
def clear_inputs():
return None, None
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Inference/Prediction")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload Image", type="filepath")
model_input = gr.File(label="Upload Model File")
with gr.Row():
predict_button = gr.Button("Run Inference")
clear_button = gr.Button("Clear Inputs")
with gr.Column(scale=3): # Increase the right column scale
output_image = gr.Image(label="Output Image", elem_id="output_image", width="100%", height=600) # Set height
# Click the inference button to run the prediction function and output the result
predict_button.click(
predict,
inputs=[image_input, model_input],
outputs=output_image
)
# Click the clear button to clear input fields
clear_button.click(
clear_inputs,
outputs=[image_input, model_input]
)
# Launch the Gradio app
demo.launch()