File size: 3,286 Bytes
cbf3611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from anomalib.engine import Engine
from pathlib import Path

# Import all possible model classes
from anomalib.models import (
    Cfa,
    Cflow,
    Csflow,
    Dfkde,
    Dfm,
    Draem,
    Dsr,
    EfficientAd,
    Fastflow,
    Fre,
    Ganomaly,
    Padim,
    Patchcore,
    ReverseDistillation,
    Rkde,
    Stfpm,
    Uflow,
    AiVad,
    WinClip,
)

# Mapping model filename prefixes to corresponding classes
model_mapping = {
    "Cfa": Cfa,
    "Cflow": Cflow,
    "Csflow": Csflow,
    "Dfkde": Dfkde,
    "Dfm": Dfm,
    "Draem": Draem,
    "Dsr": Dsr,
    "EfficientAd": EfficientAd,
    "Fastflow": Fastflow,
    "Fre": Fre,
    "Ganomaly": Ganomaly,
    "Padim": Padim,
    "Patchcore": Patchcore,
    "ReverseDistillation": ReverseDistillation,
    "Rkde": Rkde,
    "Stfpm": Stfpm,
    "Uflow": Uflow,
    "AiVad": AiVad,
    "WinClip": WinClip,
}

# Define the inference function
def predict(image_path, model_path):
    # Initialize the engine
    engine = Engine(
        pixel_metrics="AUROC",
        accelerator="auto",
        devices=1,
        logger=False,
    )

    # Get the model filename prefix to determine the model type
    model_filename = Path(model_path).stem  # Get the filename without extension
    model_type = model_filename.split("_")[0]  # Use the first part of the filename as the model type

    # Select the corresponding model class based on the filename
    model_class = model_mapping.get(model_type)
    if model_class is None:
        raise ValueError(f"Unknown model type: {model_type}. Please ensure the model file name is correct.")

    # Initialize the model
    model = model_class()

    # Get the image filename
    image_filename = Path(image_path).name

    # Dynamically set the result save path, replacing "Padim" with the extracted model type
    result_dir = Path(f"results/{model_type}/latest/images")
    result_dir.mkdir(parents=True, exist_ok=True)  # Create directory if it doesn't exist

    # Perform inference
    engine.predict(
        data_path=image_path,
        model=model,
        ckpt_path=model_path,
    )

    result_path = result_dir / image_filename
    return str(result_path)


# Function to clear input fields
def clear_inputs():
    return None, None


# Define the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Inference/Prediction")

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(label="Upload Image", type="filepath")
            model_input = gr.File(label="Upload Model File")
            with gr.Row():
                predict_button = gr.Button("Run Inference")
                clear_button = gr.Button("Clear Inputs")

        with gr.Column(scale=3):  # Increase the right column scale
            output_image = gr.Image(label="Output Image", elem_id="output_image", width="100%", height=600)  # Set height

    # Click the inference button to run the prediction function and output the result
    predict_button.click(
        predict,
        inputs=[image_input, model_input],
        outputs=output_image
    )

    # Click the clear button to clear input fields
    clear_button.click(
        clear_inputs,
        outputs=[image_input, model_input]
    )

# Launch the Gradio app
demo.launch()