jinyao commited on
Commit
cbf3611
·
1 Parent(s): 56135b4

initialization

Browse files
Files changed (2) hide show
  1. app.py +125 -0
  2. requirements.txt +257 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from anomalib.engine import Engine
3
+ from pathlib import Path
4
+
5
+ # Import all possible model classes
6
+ from anomalib.models import (
7
+ Cfa,
8
+ Cflow,
9
+ Csflow,
10
+ Dfkde,
11
+ Dfm,
12
+ Draem,
13
+ Dsr,
14
+ EfficientAd,
15
+ Fastflow,
16
+ Fre,
17
+ Ganomaly,
18
+ Padim,
19
+ Patchcore,
20
+ ReverseDistillation,
21
+ Rkde,
22
+ Stfpm,
23
+ Uflow,
24
+ AiVad,
25
+ WinClip,
26
+ )
27
+
28
+ # Mapping model filename prefixes to corresponding classes
29
+ model_mapping = {
30
+ "Cfa": Cfa,
31
+ "Cflow": Cflow,
32
+ "Csflow": Csflow,
33
+ "Dfkde": Dfkde,
34
+ "Dfm": Dfm,
35
+ "Draem": Draem,
36
+ "Dsr": Dsr,
37
+ "EfficientAd": EfficientAd,
38
+ "Fastflow": Fastflow,
39
+ "Fre": Fre,
40
+ "Ganomaly": Ganomaly,
41
+ "Padim": Padim,
42
+ "Patchcore": Patchcore,
43
+ "ReverseDistillation": ReverseDistillation,
44
+ "Rkde": Rkde,
45
+ "Stfpm": Stfpm,
46
+ "Uflow": Uflow,
47
+ "AiVad": AiVad,
48
+ "WinClip": WinClip,
49
+ }
50
+
51
+ # Define the inference function
52
+ def predict(image_path, model_path):
53
+ # Initialize the engine
54
+ engine = Engine(
55
+ pixel_metrics="AUROC",
56
+ accelerator="auto",
57
+ devices=1,
58
+ logger=False,
59
+ )
60
+
61
+ # Get the model filename prefix to determine the model type
62
+ model_filename = Path(model_path).stem # Get the filename without extension
63
+ model_type = model_filename.split("_")[0] # Use the first part of the filename as the model type
64
+
65
+ # Select the corresponding model class based on the filename
66
+ model_class = model_mapping.get(model_type)
67
+ if model_class is None:
68
+ raise ValueError(f"Unknown model type: {model_type}. Please ensure the model file name is correct.")
69
+
70
+ # Initialize the model
71
+ model = model_class()
72
+
73
+ # Get the image filename
74
+ image_filename = Path(image_path).name
75
+
76
+ # Dynamically set the result save path, replacing "Padim" with the extracted model type
77
+ result_dir = Path(f"results/{model_type}/latest/images")
78
+ result_dir.mkdir(parents=True, exist_ok=True) # Create directory if it doesn't exist
79
+
80
+ # Perform inference
81
+ engine.predict(
82
+ data_path=image_path,
83
+ model=model,
84
+ ckpt_path=model_path,
85
+ )
86
+
87
+ result_path = result_dir / image_filename
88
+ return str(result_path)
89
+
90
+
91
+ # Function to clear input fields
92
+ def clear_inputs():
93
+ return None, None
94
+
95
+
96
+ # Define the Gradio interface
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown("# Inference/Prediction")
99
+
100
+ with gr.Row():
101
+ with gr.Column(scale=1):
102
+ image_input = gr.Image(label="Upload Image", type="filepath")
103
+ model_input = gr.File(label="Upload Model File")
104
+ with gr.Row():
105
+ predict_button = gr.Button("Run Inference")
106
+ clear_button = gr.Button("Clear Inputs")
107
+
108
+ with gr.Column(scale=3): # Increase the right column scale
109
+ output_image = gr.Image(label="Output Image", elem_id="output_image", width="100%", height=600) # Set height
110
+
111
+ # Click the inference button to run the prediction function and output the result
112
+ predict_button.click(
113
+ predict,
114
+ inputs=[image_input, model_input],
115
+ outputs=output_image
116
+ )
117
+
118
+ # Click the clear button to clear input fields
119
+ clear_button.click(
120
+ clear_inputs,
121
+ outputs=[image_input, model_input]
122
+ )
123
+
124
+ # Launch the Gradio app
125
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ about-time==4.2.1
2
+ absl-py==2.1.0
3
+ aiofiles==23.2.1
4
+ aiohttp==3.9.5
5
+ aiosignal==1.3.1
6
+ alive-progress==3.1.5
7
+ altair==5.3.0
8
+ annotated-types==0.7.0
9
+ anomalib==1.1.0
10
+ antlr4-python3-runtime==4.9.3
11
+ anyio==4.3.0
12
+ argon2-cffi==23.1.0
13
+ argon2-cffi-bindings==21.2.0
14
+ arrow==1.3.0
15
+ asttokens==2.4.1
16
+ async-lru==2.0.4
17
+ async-timeout==4.0.3
18
+ attrs==23.2.0
19
+ autograd==1.6.2
20
+ av==12.0.0
21
+ Babel==2.15.0
22
+ beautifulsoup4==4.12.3
23
+ bleach==6.1.0
24
+ cachetools==5.3.3
25
+ certifi==2024.2.2
26
+ cffi==1.16.0
27
+ cfgv==3.4.0
28
+ chardet==5.2.0
29
+ charset-normalizer==3.3.2
30
+ click==8.1.7
31
+ cma==3.2.2
32
+ colorama==0.4.6
33
+ comet-ml==3.42.1
34
+ comm==0.2.2
35
+ configobj==5.0.8
36
+ contourpy==1.2.1
37
+ coverage==7.5.1
38
+ cycler==0.12.1
39
+ debugpy==1.8.1
40
+ decorator==5.1.1
41
+ defusedxml==0.7.1
42
+ Deprecated==1.2.14
43
+ dill==0.3.8
44
+ distlib==0.3.8
45
+ dnspython==2.6.1
46
+ docker-pycreds==0.4.0
47
+ docstring_parser==0.16
48
+ dulwich==0.22.1
49
+ einops==0.8.0
50
+ email_validator==2.1.1
51
+ everett==3.1.0
52
+ exceptiongroup==1.2.1
53
+ execnet==2.1.1
54
+ executing==2.0.1
55
+ fastapi==0.111.0
56
+ fastapi-cli==0.0.4
57
+ fastjsonschema==2.19.1
58
+ ffmpy==0.3.2
59
+ filelock==3.14.0
60
+ fonttools==4.51.0
61
+ fqdn==1.5.1
62
+ FrEIA==0.2
63
+ frozenlist==1.4.1
64
+ fsspec==2024.5.0
65
+ ftfy==6.2.0
66
+ future==1.0.0
67
+ gitdb==4.0.11
68
+ GitPython==3.1.43
69
+ gradio==4.31.5
70
+ gradio_client==0.16.4
71
+ grapheme==0.6.0
72
+ grpcio==1.64.0
73
+ h11==0.14.0
74
+ httpcore==1.0.5
75
+ httptools==0.6.1
76
+ httpx==0.27.0
77
+ huggingface-hub==0.23.1
78
+ identify==2.5.36
79
+ idna==3.7
80
+ imageio==2.34.1
81
+ imgaug==0.4.0
82
+ importlib_resources==6.4.0
83
+ iniconfig==2.0.0
84
+ ipykernel==6.29.4
85
+ ipython==8.24.0
86
+ ipywidgets==8.1.2
87
+ isoduration==20.11.0
88
+ jedi==0.19.1
89
+ Jinja2==3.1.4
90
+ joblib==1.4.2
91
+ json5==0.9.25
92
+ jsonargparse==4.29.0
93
+ jsonpointer==2.4
94
+ jsonschema==4.22.0
95
+ jsonschema-specifications==2023.12.1
96
+ jstyleson==0.0.2
97
+ jupyter-events==0.10.0
98
+ jupyter-lsp==2.2.5
99
+ jupyter_client==8.6.2
100
+ jupyter_core==5.7.2
101
+ jupyter_server==2.14.0
102
+ jupyter_server_terminals==0.5.3
103
+ jupyterlab==4.2.1
104
+ jupyterlab_pygments==0.3.0
105
+ jupyterlab_server==2.27.2
106
+ jupyterlab_widgets==3.0.10
107
+ kiwisolver==1.4.5
108
+ kornia==0.6.9
109
+ lazy_loader==0.4
110
+ lightning==2.1.4
111
+ lightning-utilities==0.11.2
112
+ Markdown==3.6
113
+ markdown-it-py==3.0.0
114
+ MarkupSafe==2.1.5
115
+ matplotlib==3.9.0
116
+ matplotlib-inline==0.1.7
117
+ mdurl==0.1.2
118
+ mistune==3.0.2
119
+ mpmath==1.3.0
120
+ multidict==6.0.5
121
+ natsort==8.4.0
122
+ nbclient==0.10.0
123
+ nbconvert==7.16.4
124
+ nbformat==5.10.4
125
+ nest-asyncio==1.6.0
126
+ networkx==3.1
127
+ ninja==1.11.1.1
128
+ nncf==2.10.0
129
+ nodeenv==1.8.0
130
+ notebook==7.2.0
131
+ notebook_shim==0.2.4
132
+ numpy==1.26.4
133
+ omegaconf==2.3.0
134
+ onnx==1.16.1
135
+ open-clip-torch==2.24.0
136
+ opencv-python==4.9.0.80
137
+ openvino==2024.1.0
138
+ openvino-dev==2024.1.0
139
+ openvino-telemetry==2024.1.0
140
+ orjson==3.10.3
141
+ overrides==7.7.0
142
+ packaging==24.0
143
+ pandas==2.2.2
144
+ pandocfilters==1.5.1
145
+ parso==0.8.4
146
+ pathtools==0.1.2
147
+ pexpect==4.9.0
148
+ pillow==10.3.0
149
+ platformdirs==4.2.2
150
+ pluggy==1.5.0
151
+ pre-commit==3.7.1
152
+ prometheus_client==0.20.0
153
+ promise==2.3
154
+ prompt-toolkit==3.0.43
155
+ protobuf==3.20.3
156
+ psutil==5.9.8
157
+ ptyprocess==0.7.0
158
+ pure-eval==0.2.2
159
+ py-cpuinfo==9.0.0
160
+ pycparser==2.22
161
+ pydantic==2.7.1
162
+ pydantic_core==2.18.2
163
+ pydot==2.0.0
164
+ pydub==0.25.1
165
+ Pygments==2.18.0
166
+ pymoo==0.6.1.1
167
+ pyparsing==3.1.2
168
+ pyproject-api==1.6.1
169
+ pytest==8.2.1
170
+ pytest-cov==5.0.0
171
+ pytest-mock==3.14.0
172
+ pytest-sugar==1.0.0
173
+ pytest-xdist==3.6.1
174
+ python-box==6.1.0
175
+ python-dateutil==2.9.0.post0
176
+ python-dotenv==1.0.1
177
+ python-json-logger==2.0.7
178
+ python-multipart==0.0.9
179
+ pytorch-lightning==2.2.5
180
+ pytz==2024.1
181
+ PyYAML==6.0.1
182
+ pyzmq==26.0.3
183
+ referencing==0.35.1
184
+ regex==2024.5.15
185
+ requests==2.32.2
186
+ requests-toolbelt==1.0.0
187
+ rfc3339-validator==0.1.4
188
+ rfc3986-validator==0.1.1
189
+ rich==13.7.1
190
+ rich-argparse==1.4.0
191
+ rpds-py==0.18.1
192
+ ruff==0.4.5
193
+ scikit-image==0.23.2
194
+ scikit-learn==1.5.0
195
+ scipy==1.13.1
196
+ seaborn==0.13.2
197
+ semantic-version==2.10.0
198
+ Send2Trash==1.8.3
199
+ sentencepiece==0.2.0
200
+ sentry-sdk==2.3.1
201
+ setproctitle==1.3.3
202
+ shapely==2.0.4
203
+ shellingham==1.5.4
204
+ shortuuid==1.0.13
205
+ simplejson==3.19.2
206
+ six==1.16.0
207
+ smmap==5.0.1
208
+ sniffio==1.3.1
209
+ soupsieve==2.5
210
+ stack-data==0.6.3
211
+ starlette==0.37.2
212
+ sympy==1.12
213
+ tabulate==0.9.0
214
+ tensorboard==2.16.2
215
+ tensorboard-data-server==0.7.2
216
+ termcolor==2.4.0
217
+ terminado==0.18.1
218
+ threadpoolctl==3.5.0
219
+ tifffile==2024.5.22
220
+ timm==0.6.13
221
+ tinycss2==1.3.0
222
+ tomli==2.0.1
223
+ tomlkit==0.12.0
224
+ toolz==0.12.1
225
+ torch==2.1.2+cu121
226
+ torchmetrics==1.4.0.post0
227
+ torchvision==0.16.2+cu121
228
+ tornado==6.4
229
+ tox==4.15.0
230
+ tqdm==4.66.4
231
+ traitlets==5.14.3
232
+ triton==2.1.0
233
+ typer==0.12.3
234
+ types-python-dateutil==2.9.0.20240316
235
+ typeshed_client==2.5.1
236
+ typing_extensions==4.12.0
237
+ tzdata==2024.1
238
+ ujson==5.10.0
239
+ ultralytics==8.2.71
240
+ ultralytics-thop==2.0.0
241
+ uri-template==1.3.0
242
+ urllib3==2.2.1
243
+ uvicorn==0.29.0
244
+ uvloop==0.19.0
245
+ virtualenv==20.26.2
246
+ wandb==0.12.17
247
+ watchfiles==0.21.0
248
+ wcwidth==0.2.13
249
+ webcolors==1.13
250
+ webencodings==0.5.1
251
+ websocket-client==1.8.0
252
+ websockets==11.0.3
253
+ Werkzeug==3.0.3
254
+ widgetsnbextension==4.0.10
255
+ wrapt==1.16.0
256
+ wurlitzer==3.1.0
257
+ yarl==1.9.4