Spaces:
Sleeping
Sleeping
import gradio as gr | |
from anomalib.engine import Engine | |
from pathlib import Path | |
# Import all possible model classes | |
from anomalib.models import ( | |
Cfa, | |
Cflow, | |
Csflow, | |
Dfkde, | |
Dfm, | |
Draem, | |
Dsr, | |
EfficientAd, | |
Fastflow, | |
Fre, | |
Ganomaly, | |
Padim, | |
Patchcore, | |
ReverseDistillation, | |
Rkde, | |
Stfpm, | |
Uflow, | |
AiVad, | |
WinClip, | |
) | |
# Mapping model filename prefixes to corresponding classes | |
model_mapping = { | |
"Cfa": Cfa, | |
"Cflow": Cflow, | |
"Csflow": Csflow, | |
"Dfkde": Dfkde, | |
"Dfm": Dfm, | |
"Draem": Draem, | |
"Dsr": Dsr, | |
"EfficientAd": EfficientAd, | |
"Fastflow": Fastflow, | |
"Fre": Fre, | |
"Ganomaly": Ganomaly, | |
"Padim": Padim, | |
"Patchcore": Patchcore, | |
"ReverseDistillation": ReverseDistillation, | |
"Rkde": Rkde, | |
"Stfpm": Stfpm, | |
"Uflow": Uflow, | |
"AiVad": AiVad, | |
"WinClip": WinClip, | |
} | |
# Define the inference function | |
def predict(image_path, model_path): | |
# Initialize the engine | |
engine = Engine( | |
pixel_metrics="AUROC", | |
accelerator="auto", | |
devices=1, | |
logger=False, | |
) | |
# Get the model filename prefix to determine the model type | |
model_filename = Path(model_path).stem # Get the filename without extension | |
model_type = model_filename.split("_")[0] # Use the first part of the filename as the model type | |
# Select the corresponding model class based on the filename | |
model_class = model_mapping.get(model_type) | |
if model_class is None: | |
raise ValueError(f"Unknown model type: {model_type}. Please ensure the model file name is correct.") | |
# Initialize the model | |
model = model_class() | |
# Get the image filename | |
image_filename = Path(image_path).name | |
# Dynamically set the result save path, replacing "Padim" with the extracted model type | |
result_dir = Path(f"results/{model_type}/latest/images") | |
result_dir.mkdir(parents=True, exist_ok=True) # Create directory if it doesn't exist | |
# Perform inference | |
engine.predict( | |
data_path=image_path, | |
model=model, | |
ckpt_path=model_path, | |
) | |
result_path = result_dir / image_filename | |
return str(result_path) | |
# Function to clear input fields | |
def clear_inputs(): | |
return None, None | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Inference/Prediction") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(label="Upload Image", type="filepath") | |
model_input = gr.File(label="Upload Model File") | |
with gr.Row(): | |
predict_button = gr.Button("Run Inference") | |
clear_button = gr.Button("Clear Inputs") | |
with gr.Column(scale=3): # Increase the right column scale | |
output_image = gr.Image(label="Output Image", elem_id="output_image", width="100%", height=600) # Set height | |
# Click the inference button to run the prediction function and output the result | |
predict_button.click( | |
predict, | |
inputs=[image_input, model_input], | |
outputs=output_image | |
) | |
# Click the clear button to clear input fields | |
clear_button.click( | |
clear_inputs, | |
outputs=[image_input, model_input] | |
) | |
# Launch the Gradio app | |
demo.launch() | |