diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..aac877118c1a947e5c1a3b0d4c12f0ba80f2b39b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,35 +1,5 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ckpt filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.mlmodel filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-*.safetensors filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.txt.gz filter=lfs diff=lfs merge=lfs -text
+weights/pretrained_all.pth filter=lfs diff=lfs merge=lfs -text
+weights/pretrained_mvtec_colondb.pth filter=lfs diff=lfs merge=lfs -text
+weights/pretrained_visa_clinicdb.pth filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ef40b4a9c445473cd6cf6409040ff13eebd2c948
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+/result/
+/.idea/
+/__pycache__/
+/weights/
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e32c0f629673371119307f7620dca662918c5eaa
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Yunkang Cao
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 823edc95f950a3a64f4ce6e47055174986499e35..c102e1922c5ff21c1fda70d02a5b9fdd58e5e558 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,176 @@
----
-title: AdaCLIP
-emoji: 🌖
-colorFrom: pink
-colorTo: pink
-sdk: gradio
-sdk_version: 4.38.1
-app_file: app.py
-pinned: false
-license: mit
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# AdaCLIP (Detecting Anomalies for Novel Categories)
+[![HuggingFace Space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)]()
+
+> [**ECCV 24**] [**AdaCLIP: Adapting CLIP with Hybrid Learnable Prompts for Zero-Shot Anomaly Detection**]().
+>
+> by [Yunkang Cao](https://caoyunkang.github.io/), [Jiangning Zhang](https://zhangzjn.github.io/),  [Luca Frittoli](https://scholar.google.com/citations?user=cdML_XUAAAAJ), 
+> [Yuqi Cheng](https://scholar.google.com/citations?user=02BC-WgAAAAJ&hl=en), [Weiming Shen](https://scholar.google.com/citations?user=FuSHsx4AAAAJ&hl=en), [Giacomo Boracchi](https://boracchi.faculty.polimi.it/) 
+> 
+
+## Introduction 
+Zero-shot anomaly detection (ZSAD) targets the identification of anomalies within images from arbitrary novel categories. 
+This study introduces AdaCLIP for the ZSAD task, leveraging a pre-trained vision-language model (VLM), CLIP. 
+AdaCLIP incorporates learnable prompts into CLIP and optimizes them through training on auxiliary annotated anomaly detection data. 
+Two types of learnable prompts are proposed: \textit{static} and \textit{dynamic}. Static prompts are shared across all images, serving to preliminarily adapt CLIP for ZSAD. 
+In contrast, dynamic prompts are generated for each test image, providing CLIP with dynamic adaptation capabilities. 
+The combination of static and dynamic prompts is referred to as hybrid prompts, and yields enhanced ZSAD performance. 
+Extensive experiments conducted across 14 real-world anomaly detection datasets from industrial and medical domains indicate that AdaCLIP outperforms other ZSAD methods and can generalize better to different categories and even domains. 
+Finally, our analysis highlights the importance of diverse auxiliary data and optimized prompts for enhanced generalization capacity.
+
+## Overview of AdaCLIP
+![overview](asset/framework.png)
+
+## 🛠️ Getting Started
+
+### Installation
+To set up the AdaCLIP environment, follow one of the methods below:
+
+- Clone this repo:
+  ```shell
+  git clone https://github.com/caoyunkang/AdaCLIP.git && cd AdaCLIP
+  ```
+- You can use our provided installation script for an automated setup::
+  ```shell
+  sh install.sh
+  ```
+- If you prefer to construct the experimental environment manually, follow these steps:
+  ```shell
+  conda create -n AdaCLIP python=3.9.5 -y
+  conda activate AdaCLIP
+  pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
+  pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4
+  pip install gradio # Optional, for app 
+  ```
+- Remember to update the dataset root in config.py according to your preference:
+  ```python
+  DATA_ROOT = '../datasets' # Original setting
+  ```
+
+### Dataset Preparation 
+Please download our processed visual anomaly detection datasets to your `DATA_ROOT` as needed. 
+
+#### Industrial Visual Anomaly Detection Datasets
+Note: some links are still in processing...
+
+| Dataset | Google Drive | Baidu Drive | Task
+|------------|------------------|------------------| ------------------|
+| MVTec AD    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
+| VisA    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
+| MPDD    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | 
+| BTAD    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
+| KSDD    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
+| DAGM    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
+| DTD-Synthetic    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
+
+
+
+
+#### Medical Visual Anomaly Detection Datasets
+| Dataset | Google Drive | Baidu Drive | Task
+|------------|------------------|------------------|  ------------------|
+| HeadCT    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
+| BrainMRI    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
+| Br35H    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
+| ISIC    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
+| ColonDB    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
+| ClinicDB    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
+| TN3K    | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
+
+#### Custom Datasets
+To use your custom dataset, follow these steps:
+
+1. Refer to the instructions in `./data_preprocess` to generate the JSON file for your dataset.
+2. Use `./dataset/base_dataset.py` to construct your own dataset.
+
+
+### Weight Preparation
+
+We offer various pre-trained weights on different auxiliary datasets. 
+Please download the pre-trained weights in `./weights`.
+
+| Pre-trained Datasets | Google Drive | Baidu Drive 
+|------------|------------------|------------------|  
+| MVTec AD & ClinicDB    | [Google Drive](https://drive.google.com/file/d/1xVXANHGuJBRx59rqPRir7iqbkYzq45W0/view?usp=drive_link) | [Baidu Drive](链接) | 
+| VisA & ColonDB    | [Google Drive](https://drive.google.com/file/d/1QGmPB0ByPZQ7FucvGODMSz7r5Ke5wx9W/view?usp=drive_link) | [Baidu Drive](链接) | 
+| All Datasets Mentioned Above   | [Google Drive](https://drive.google.com/file/d/1Cgkfx3GAaSYnXPLolx-P7pFqYV0IVzZF/view?usp=drive_link) | [Baidu Drive](链接) |
+
+
+### Train
+
+By default, we use MVTec AD & ClinicDB for training and VisA for validation:
+```shell
+CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data mvtec colondb --testing_data visa
+```
+
+
+Alternatively, for evaluation on MVTec AD & ClinicDB, we use VisA & ColonDB for training and MVTec AD for validation.
+```shell
+CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec
+```
+Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable.
+It is recommended to run the training process multiple times and choose the best model based on performance
+on the validation set as the final model.
+
+
+To construct a robust ZSAD model for demonstration, we also train our AdaCLIP on all AD datasets mentioned above:
+```shell
+CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True \
+--training_data \
+br35h brain_mri btad clinicdb colondb \
+dagm dtd headct isic mpdd mvtec sdd tn3k visa \
+--testing_data mvtec
+```
+
+### Test
+
+Manually select the best models from the validation set and place them in the `weights/` directory. Then, run the following testing script:
+```shell
+sh test.sh
+```
+
+If you want to test on a single image, you can refer to `test_single_image.sh`:
+```shell
+CUDA_VISIBLE_DEVICES=0 python test.py --testing_model image --ckt_path weights/pretrained_all.pth --save_fig True \
+ --image_path asset/img.png --class_name candle --save_name test.png
+```
+
+## Main Results
+
+Due to differences in versions utilized, the reported performance may vary slightly compared to the detection performance 
+with the provided pre-trained weights. Some categories may show higher performance while others may show lower.
+
+![Table_industrial](./asset/Table_industrial.png)
+![Table_medical](./asset/Table_medical.png)
+![Fig_detection_results](./asset/Fig_detection_results.png)
+
+### :page_facing_up: Demo App
+
+To run the demo application, use the following command:
+
+```bash
+python app.py
+```
+
+![Demo](./asset/Fig_app.png)
+
+## 💘 Acknowledgements
+Our work is largely inspired by the following projects. Thanks for their admiring contribution.
+
+- [VAND-APRIL-GAN](https://github.com/ByChelsea/VAND-APRIL-GAN)
+- [AnomalyCLIP](https://github.com/zqhang/AnomalyCLIP)
+- [SAA](https://github.com/caoyunkang/Segment-Any-Anomaly)
+
+
+## Stargazers over time
+[![Stargazers over time](https://starchart.cc/caoyunkang/AdaCLIP.svg?variant=adaptive)](https://starchart.cc/caoyunkang/AdaCLIP)
+
+
+## Citation
+
+If you find this project helpful for your research, please consider citing the following BibTeX entry.
+
+```BibTex
+
+
+
+```
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..07de436c1ee7d67c5c990c59d5f1aa99354d0f7d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,133 @@
+import gradio as gr
+from PIL import Image, ImageDraw, ImageFont
+import warnings
+import os
+os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
+import json
+import os
+import torch
+from scipy.ndimage import gaussian_filter
+import cv2
+from method import AdaCLIP_Trainer
+import numpy as np
+
+############ Init Model
+ckt_path1 = 'weights/pretrained_mvtec_colondb.pth'
+ckt_path2 = "weights/pretrained_visa_clinicdb.pth"
+ckt_path3 = 'weights/pretrained_all.pth'
+
+# Configurations
+image_size = 518
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+# device = 'cpu'
+model = "ViT-L-14-336"
+prompting_depth = 4
+prompting_length = 5
+prompting_type = 'SD'
+prompting_branch = 'VL'
+use_hsf = True
+k_clusters = 20
+
+config_path = os.path.join('./model_configs', f'{model}.json')
+
+# Prepare model
+with open(config_path, 'r') as f:
+    model_configs = json.load(f)
+
+# Set up the feature hierarchy
+n_layers = model_configs['vision_cfg']['layers']
+substage = n_layers // 4
+features_list = [substage, substage * 2, substage * 3, substage * 4]
+
+model = AdaCLIP_Trainer(
+    backbone=model,
+    feat_list=features_list,
+    input_dim=model_configs['vision_cfg']['width'],
+    output_dim=model_configs['embed_dim'],
+    learning_rate=0.,
+    device=device,
+    image_size=image_size,
+    prompting_depth=prompting_depth,
+    prompting_length=prompting_length,
+    prompting_branch=prompting_branch,
+    prompting_type=prompting_type,
+    use_hsf=use_hsf,
+    k_clusters=k_clusters
+).to(device)
+
+
+def process_image(image, text, options):
+    # Load the model based on selected options
+    if 'MVTec AD+Colondb' in options:
+        model.load(ckt_path1)
+    elif 'VisA+Clinicdb' in options:
+        model.load(ckt_path2)
+    elif 'All' in options:
+        model.load(ckt_path3)
+    else:
+        # Default to 'All' if no valid option is provided
+        model.load(ckt_path3)
+        print('Invalid option. Defaulting to All.')
+
+    # Ensure image is in RGB mode
+    image = image.convert('RGB')
+
+    # Convert PIL image to NumPy array
+    np_image = np.array(image)
+
+    # Convert RGB to BGR for OpenCV
+    np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
+    np_image = cv2.resize(np_image, (image_size, image_size))
+    # Preprocess the image and run the model
+    img_input = model.preprocess(image).unsqueeze(0)
+    img_input = img_input.to(model.device)
+
+    with torch.no_grad():
+        anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True)
+
+    # Process anomaly map
+    anomaly_map = anomaly_map[0, :, :].cpu().numpy()
+    anomaly_score = anomaly_score[0].cpu().numpy()
+    anomaly_map = gaussian_filter(anomaly_map, sigma=4)
+    anomaly_map = (anomaly_map * 255).astype(np.uint8)
+
+    # Apply color map and blend with original image
+    heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
+    vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
+
+    # Convert OpenCV image back to PIL image for Gradio
+    vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB))
+
+    return vis_map_pil, f'{anomaly_score:.3f}'
+
+# Define examples
+examples = [
+    ["asset/img.png", "candle", "MVTec AD+Colondb"],
+    ["asset/img2.png", "bottle", "VisA+Clinicdb"],
+    ["asset/img3.png", "button", "All"],
+]
+
+# Gradio interface layout
+demo = gr.Interface(
+    fn=process_image,
+    inputs=[
+        gr.Image(type="pil", label="Upload Image"),
+        gr.Textbox(label="Class Name"),
+        gr.Radio(["MVTec AD+Colondb",
+                  "VisA+Clinicdb",
+                  "All"],
+        label="Pre-trained Datasets")
+    ],
+    outputs=[
+        gr.Image(type="pil", label="Output Image"),
+        gr.Textbox(label="Anomaly Score"),
+    ],
+    examples=examples,
+    title="AdaCLIP -- Zero-shot Anomaly Detection",
+    description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection"
+)
+
+# Launch the demo
+demo.launch()
+# demo.launch(server_name="0.0.0.0", server_port=10002)
+
diff --git a/asset/Fig_app.png b/asset/Fig_app.png
new file mode 100644
index 0000000000000000000000000000000000000000..96df975fe78d631ebba6b2c3b80a7da6be1f17cd
--- /dev/null
+++ b/asset/Fig_app.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f71ab8be0e45353c1660526ff450754e82ddf4a2b7f18bb5a33ac3b704b0d76b
+size 268551
diff --git a/asset/Fig_detection_results.png b/asset/Fig_detection_results.png
new file mode 100644
index 0000000000000000000000000000000000000000..bf173edfe3151b4ca788ac70272622ac3971fab4
--- /dev/null
+++ b/asset/Fig_detection_results.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c00bd303a99d981d964b12e981bd1f2954d469766839523e76f7d7162fbb24cb
+size 363123
diff --git a/asset/Table_industrial.png b/asset/Table_industrial.png
new file mode 100644
index 0000000000000000000000000000000000000000..e0327380503f2fc33dceaa67288884eeb7262b21
--- /dev/null
+++ b/asset/Table_industrial.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5fa4d9ab1ff1b3ca90b45f4b92ee7b12a89e5327cb22621d4081fb5f160d3d68
+size 401841
diff --git a/asset/Table_medical.png b/asset/Table_medical.png
new file mode 100644
index 0000000000000000000000000000000000000000..59970a6f6fec926d3327aeb10537c3cb8d4ddbfb
--- /dev/null
+++ b/asset/Table_medical.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d2424190619dbbd134b943ef9e38a6523635ab0d279f2445da6bdd266d3dafac
+size 291004
diff --git a/asset/framework.png b/asset/framework.png
new file mode 100644
index 0000000000000000000000000000000000000000..75e65d4138d43f257f0689af81848b2f2634f9d8
--- /dev/null
+++ b/asset/framework.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3804c7f5ae141257dbe5dd43cb20f4216a1061051fd8754d6f0c730dd085ad7d
+size 439936
diff --git a/asset/img.png b/asset/img.png
new file mode 100644
index 0000000000000000000000000000000000000000..f6f91212a115c0e035c608ecae4bde2d936eab83
--- /dev/null
+++ b/asset/img.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3eaff97d07132f9b06998737b976d4a0e0a3a2168b40aee43aad6e62d040f87e
+size 1421232
diff --git a/asset/img2.png b/asset/img2.png
new file mode 100644
index 0000000000000000000000000000000000000000..942e8707b4b3fb2b7c2124130e3f4295102564f0
--- /dev/null
+++ b/asset/img2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3918b94553a8922b3c16d064ef73e9062710b35639a949c56d926037e4c0d0a
+size 547657
diff --git a/asset/img3.png b/asset/img3.png
new file mode 100644
index 0000000000000000000000000000000000000000..08cafcbe6d08ebcc150d11b38760422caf2f91d5
--- /dev/null
+++ b/asset/img3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9394757293585aa9de542f3e70025788e5a3e1ad5a1277a8648f8050f8d7e868
+size 624200
diff --git a/config.py b/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ab7fc36bc73e27374399253a7d13c7e0f526fb
--- /dev/null
+++ b/config.py
@@ -0,0 +1 @@
+DATA_ROOT = '../datasets'
\ No newline at end of file
diff --git a/data_preprocess/br35h.py b/data_preprocess/br35h.py
new file mode 100644
index 0000000000000000000000000000000000000000..53385ac65df05ad00b7d605e6571bc737f6c0071
--- /dev/null
+++ b/data_preprocess/br35h.py
@@ -0,0 +1,50 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection')
+class Br35hSolver(object):
+    CLSNAMES = [
+        'br35h',
+    ]
+
+    def __init__(self, root=Br35h_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    img_names.sort()
+
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = Br35hSolver(root=Br35h_ROOT)
+    runner.run()
diff --git a/data_preprocess/brain_mri.py b/data_preprocess/brain_mri.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d9a23493ac89df6a683e1c6d7e98c19c5bb1523
--- /dev/null
+++ b/data_preprocess/brain_mri.py
@@ -0,0 +1,51 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+BrainMRI_ROOT = os.path.join(DATA_ROOT, 'BrainMRI')
+
+class BrainMRISolver(object):
+    CLSNAMES = [
+        'brain_mri',
+    ]
+
+    def __init__(self, root=BrainMRI_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    img_names.sort()
+
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = BrainMRISolver(root=BrainMRI_ROOT)
+    runner.run()
diff --git a/data_preprocess/btad.py b/data_preprocess/btad.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb2fa22ad741f86ecab8d9ad5623167688a428c
--- /dev/null
+++ b/data_preprocess/btad.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+BTAD_ROOT = os.path.join(DATA_ROOT, 'BTech_Dataset_transformed')
+
+class BTADSolver(object):
+    CLSNAMES = [
+        '01', '02', '03',
+    ]
+
+    def __init__(self, root=BTAD_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['ok'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = BTADSolver(root=BTAD_ROOT)
+    runner.run()
diff --git a/data_preprocess/clinicdb.py b/data_preprocess/clinicdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..69a2c20b7667e48eb7e21ae4ad6b5357d3906e09
--- /dev/null
+++ b/data_preprocess/clinicdb.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+ClinicDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ClinicDB')
+
+class ClinicDBSolver(object):
+    CLSNAMES = [
+        'ClinicDB',
+    ]
+
+    def __init__(self, root=ClinicDB_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = ClinicDBSolver(root=ClinicDB_ROOT)
+    runner.run()
diff --git a/data_preprocess/colondb.py b/data_preprocess/colondb.py
new file mode 100644
index 0000000000000000000000000000000000000000..2939394077108884d0f63646d90fa0df912c5984
--- /dev/null
+++ b/data_preprocess/colondb.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+ColonDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ColonDB')
+
+class ColonDBSolver(object):
+    CLSNAMES = [
+        'ColonDB',
+    ]
+
+    def __init__(self, root=ColonDB_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = ColonDBSolver(root=ColonDB_ROOT)
+    runner.run()
diff --git a/data_preprocess/dagm-pre.py b/data_preprocess/dagm-pre.py
new file mode 100644
index 0000000000000000000000000000000000000000..d06889c628fc085e2179350e2f8e17c38e3a1f14
--- /dev/null
+++ b/data_preprocess/dagm-pre.py
@@ -0,0 +1,82 @@
+import os
+import numpy as np
+from sklearn.model_selection import train_test_split
+import cv2
+import argparse
+from config import DATA_ROOT
+
+dataset_root = os.path.join(DATA_ROOT, 'DAGM2007')
+
+class_names = os.listdir(dataset_root)
+
+
+for class_name in class_names:
+    states = os.listdir(os.path.join(dataset_root, class_name))
+    for state in states:
+        images = list()
+        mask = list()
+        files = os.listdir(os.path.join(dataset_root, class_name,state))
+        for f in files:
+            if 'PNG' in f[-3:]:
+                images.append(f)
+        files = os.listdir(os.path.join(dataset_root, class_name, state,'Label'))
+        for f in files:
+            if 'PNG' in f[-3:]:
+                mask.append(f)
+        normal_image_path_train = list()
+        normal_image_path_test = list()
+        normal_image_path = list()
+        abnormal_image_path = list()
+        abnormal_image_label = list()
+        for f in images:
+            id = f[-8:-4]
+            flag = 0
+            for y in mask:
+                if id in y:
+                    abnormal_image_path.append(f)
+                    abnormal_image_label.append(y)
+                    flag = 1
+                    break
+            if flag == 0:
+                normal_image_path.append(f)
+
+        if len(abnormal_image_path) != len(abnormal_image_label):
+            raise ValueError
+        length = len(abnormal_image_path)
+
+        normal_image_path_test = normal_image_path[:length]
+        normal_image_path_train = normal_image_path[length:]
+
+        target_root = '../datasets/DAGM_anomaly_detection'
+
+        train_root = os.path.join(target_root, class_name, 'train','good')
+        if not os.path.exists(train_root):
+            os.makedirs(train_root)
+        for f in normal_image_path_train:
+            image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
+            cv2.imwrite(os.path.join(train_root,f), image_data)
+
+        test_root = os.path.join(target_root, class_name, 'test','good')
+        if not os.path.exists(test_root):
+            os.makedirs(test_root)
+        for f in normal_image_path_test:
+            image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
+            cv2.imwrite(os.path.join(test_root,f), image_data)
+
+        test_root = os.path.join(target_root, class_name, 'test','defect')
+        if not os.path.exists(test_root):
+            os.makedirs(test_root)
+        for f in abnormal_image_path:
+            image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f))
+            cv2.imwrite(os.path.join(test_root,f), image_data)
+
+        test_root = os.path.join(target_root, class_name, 'ground_truth','defect')
+        if not os.path.exists(test_root):
+            os.makedirs(test_root)
+        for f in mask:
+            image_data = cv2.imread(os.path.join(dataset_root, class_name, state,'Label',f))
+            cv2.imwrite(os.path.join(test_root,f), image_data)
+
+
+
+print("Done")
\ No newline at end of file
diff --git a/data_preprocess/dagm.py b/data_preprocess/dagm.py
new file mode 100644
index 0000000000000000000000000000000000000000..23aea77fc1a39006057fa8fa6a1d6dc630695ec7
--- /dev/null
+++ b/data_preprocess/dagm.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+DAGM_ROOT = os.path.join(DATA_ROOT, 'DAGM_anomaly_detection')
+
+class DAGMSolver(object):
+    CLSNAMES = [
+        'Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6','Class7','Class8','Class9','Class10',
+    ]
+
+    def __init__(self, root=DAGM_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = DAGMSolver(root=DAGM_ROOT)
+    runner.run()
diff --git a/data_preprocess/dtd.py b/data_preprocess/dtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a759416ce36ca6220d1c20d688f66bbb3b165386
--- /dev/null
+++ b/data_preprocess/dtd.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+DTD_ROOT = os.path.join(DATA_ROOT, 'DTD-Synthetic')
+
+class DTDSolver(object):
+    CLSNAMES = [
+        'Blotchy_099', 'Fibrous_183', 'Marbled_078', 'Matted_069', 'Mesh_114','Perforated_037','Stratified_154','Woven_001','Woven_068','Woven_104','Woven_125','Woven_127',
+    ]
+
+    def __init__(self, root=DTD_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = DTDSolver(root=DTD_ROOT)
+    runner.run()
diff --git a/data_preprocess/endo.py b/data_preprocess/endo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f44aa23ff1910331c52310e3e6d727a07e26a31
--- /dev/null
+++ b/data_preprocess/endo.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+ENDO_ROOT = os.path.join(DATA_ROOT, 'EndoTect')
+
+class ENDOSolver(object):
+    CLSNAMES = [
+        'endo',
+    ]
+
+    def __init__(self, root=ENDO_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = ENDOSolver(root=ENDO_ROOT)
+    runner.run()
diff --git a/data_preprocess/headct-pre.py b/data_preprocess/headct-pre.py
new file mode 100644
index 0000000000000000000000000000000000000000..3672393f198d982f47476a036dbcb570f73d0c84
--- /dev/null
+++ b/data_preprocess/headct-pre.py
@@ -0,0 +1,41 @@
+import os
+import numpy as np
+from sklearn.model_selection import train_test_split
+import shutil
+import argparse
+
+from config import DATA_ROOT
+
+dataset_root = os.path.join(DATA_ROOT, 'head_ct')
+
+label_file = os.path.join(dataset_root, 'labels.csv')
+
+data = np.loadtxt(label_file, dtype=int, delimiter=',', skiprows=1)
+
+fnames = data[:, 0]
+label = data[:, 1]
+
+normal_fnames = fnames[label==0]
+outlier_fnames = fnames[label==1]
+
+
+target_root = '../datasets/HeadCT_anomaly_detection/headct'
+train_root = os.path.join(target_root, 'train/good')
+if not os.path.exists(train_root):
+    os.makedirs(train_root)
+
+test_normal_root = os.path.join(target_root, 'test/good')
+if not os.path.exists(test_normal_root):
+    os.makedirs(test_normal_root)
+for f in normal_fnames:
+    source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f))
+    shutil.copy(source, test_normal_root)
+
+test_outlier_root = os.path.join(target_root, 'test/defect')
+if not os.path.exists(test_outlier_root):
+    os.makedirs(test_outlier_root)
+for f in outlier_fnames:
+    source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f))
+    shutil.copy(source, test_outlier_root)
+
+print('Done')
\ No newline at end of file
diff --git a/data_preprocess/headct.py b/data_preprocess/headct.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e427dc111e678eb6d82ffc171ac829f5569af41
--- /dev/null
+++ b/data_preprocess/headct.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+# from dataset import MPDD_ROOT
+# from  dataset.mpdd import MPDD_ROOT
+
+
+HEADCT_ROOT = '../datasets/HeadCT_anomaly_detection'
+class HEADCTSolver(object):
+    CLSNAMES = [
+        'headct',
+    ]
+
+    def __init__(self, root=HEADCT_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    img_names.sort()
+
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = HEADCTSolver(root=HEADCT_ROOT)
+    runner.run()
diff --git a/data_preprocess/isic.py b/data_preprocess/isic.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1d23c037d1e72eb4d3a7eb4588f173a09be8ef2
--- /dev/null
+++ b/data_preprocess/isic.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+ISIC_ROOT = os.path.join(DATA_ROOT, 'ISIC')
+
+class ISICSolver(object):
+    CLSNAMES = [
+        'isic',
+    ]
+
+    def __init__(self, root=ISIC_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = ISICSolver(root=ISIC_ROOT)
+    runner.run()
diff --git a/data_preprocess/mpdd.py b/data_preprocess/mpdd.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fe1e6691856148ae5664b1fd45bb08f359c010f
--- /dev/null
+++ b/data_preprocess/mpdd.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+MPDD_ROOT = os.path.join(DATA_ROOT, 'MPDD')
+
+class MPDDSolver(object):
+    CLSNAMES = [
+        'bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate','tubes',
+    ]
+
+    def __init__(self, root=MPDD_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = MPDDSolver(root=MPDD_ROOT)
+    runner.run()
diff --git a/data_preprocess/mvtec.py b/data_preprocess/mvtec.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4d5fe964dad6f360c8d9f4cfc7723008b5ae620
--- /dev/null
+++ b/data_preprocess/mvtec.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from dataset import MVTEC_ROOT
+
+class MVTecSolver(object):
+    CLSNAMES = [
+        'bottle', 'cable', 'capsule', 'carpet', 'grid',
+        'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
+        'tile', 'toothbrush', 'transistor', 'wood', 'zipper',
+    ]
+
+    def __init__(self, root=MVTEC_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = MVTecSolver(root=MVTEC_ROOT)
+    runner.run()
diff --git a/data_preprocess/sdd-pre.py b/data_preprocess/sdd-pre.py
new file mode 100644
index 0000000000000000000000000000000000000000..da9dc6ed246a51b912bc3f6cf624307168b8ba03
--- /dev/null
+++ b/data_preprocess/sdd-pre.py
@@ -0,0 +1,75 @@
+import os
+import numpy as np
+from sklearn.model_selection import train_test_split
+import cv2
+import argparse
+
+from config import DATA_ROOT
+
+dataset_root = os.path.join(DATA_ROOT, 'KolektorSDD')
+
+dirs = os.listdir(dataset_root)
+normal_images = list()
+normal_labels = list()
+normal_fname = list()
+outlier_images = list()
+outlier_labels = list()
+outlier_fname = list()
+for d in dirs:
+    files = os.listdir(os.path.join(dataset_root, d))
+    images = list()
+    for f in files:
+        if 'jpg' in f[-3:]:
+            images.append(f)
+
+    for image in images:
+        split_images = list()
+        split_labels = list()
+        image_name = image.split('.')[0]
+        image_data = cv2.imread(os.path.join(dataset_root, d, image))
+        label_data = cv2.imread(os.path.join(dataset_root, d, image_name + '_label.bmp'))
+        if image_data.shape != label_data.shape:
+            raise ValueError
+        image_length = image_data.shape[0]
+        split_images.append(image_data[:image_length // 3, :, :])
+        split_images.append(image_data[image_length // 3:image_length * 2 // 3, :, :])
+        split_images.append(image_data[image_length * 2 // 3:, :, :])
+        split_labels.append(label_data[:image_length // 3, :, :])
+        split_labels.append(label_data[image_length // 3:image_length * 2 // 3, :, :])
+        split_labels.append(label_data[image_length * 2 // 3:, :, :])
+        for i, (im, la) in enumerate(zip(split_images, split_labels)):
+            if np.max(la) != 0:
+                outlier_images.append(im)
+                outlier_labels.append(la)
+                outlier_fname.append(d + '_' + image_name + '_' + str(i))
+            else:
+                normal_images.append(im)
+                normal_labels.append(la)
+                normal_fname.append(d + '_' + image_name + '_' + str(i))
+
+normal_train, normal_test, normal_name_train, normal_name_test = train_test_split(normal_images, normal_fname, test_size=0.25, random_state=42)
+
+target_root = '../datasets/SDD_anomaly_detection/SDD'
+train_root = os.path.join(target_root, 'train/good')
+if not os.path.exists(train_root):
+    os.makedirs(train_root)
+for image, name in zip(normal_train, normal_name_train):
+    cv2.imwrite(os.path.join(train_root, name + '.png'), image)
+
+test_root = os.path.join(target_root, 'test/good')
+if not os.path.exists(test_root):
+    os.makedirs(test_root)
+for image, name in zip(normal_test, normal_name_test):
+    cv2.imwrite(os.path.join(test_root, name + '.png'), image)
+
+defect_root = os.path.join(target_root, 'test/defect')
+label_root = os.path.join(target_root, 'ground_truth/defect')
+if not os.path.exists(defect_root):
+    os.makedirs(defect_root)
+if not os.path.exists(label_root):
+    os.makedirs(label_root)
+for image, label, name in zip(outlier_images, outlier_labels, outlier_fname):
+    cv2.imwrite(os.path.join(defect_root, name + '.png'), image)
+    cv2.imwrite(os.path.join(label_root, name + '_mask.png'), label)
+
+print("Done")
\ No newline at end of file
diff --git a/data_preprocess/sdd.py b/data_preprocess/sdd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a04f543c53c105e756cca79442023a7c5d1699e4
--- /dev/null
+++ b/data_preprocess/sdd.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection')
+
+class SDDSolver(object):
+    CLSNAMES = [
+        'SDD',
+    ]
+
+    def __init__(self, root=SDD_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = SDDSolver(root=SDD_ROOT)
+    runner.run()
diff --git a/data_preprocess/tn3k.py b/data_preprocess/tn3k.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cc08832244b10074f7b609bc97656aa90ee155a
--- /dev/null
+++ b/data_preprocess/tn3k.py
@@ -0,0 +1,52 @@
+import os
+import json
+import random
+from config import DATA_ROOT
+
+TN3K_ROOT = os.path.join(DATA_ROOT, 'TN3K')
+
+class TN3KSolver(object):
+    CLSNAMES = [
+        'tn3k',
+    ]
+
+    def __init__(self, root=TN3K_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        info = dict(train={}, test={})
+        for cls_name in self.CLSNAMES:
+            cls_dir = f'{self.root}/{cls_name}'
+            for phase in ['train', 'test']:
+                cls_info = []
+                species = os.listdir(f'{cls_dir}/{phase}')
+                for specie in species:
+                    is_abnormal = True if specie not in ['good'] else False
+                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
+                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
+                    img_names.sort()
+                    mask_names.sort() if mask_names is not None else None
+                    for idx, img_name in enumerate(img_names):
+                        info_img = dict(
+                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
+                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
+                            cls_name=cls_name,
+                            specie_name=specie,
+                            anomaly=1 if is_abnormal else 0,
+                        )
+                        cls_info.append(info_img)
+
+                info[phase][cls_name] = cls_info
+
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+if __name__ == '__main__':
+    runner = TN3KSolver(root=TN3K_ROOT)
+    runner.run()
diff --git a/data_preprocess/visa.py b/data_preprocess/visa.py
new file mode 100644
index 0000000000000000000000000000000000000000..323614308022ae475d278efced9fe2aed02a63a4
--- /dev/null
+++ b/data_preprocess/visa.py
@@ -0,0 +1,52 @@
+import os
+import json
+import pandas as pd
+import random
+from dataset import VISA_ROOT
+
+class VisASolver(object):
+    CLSNAMES = [
+        'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
+        'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
+        'pcb4', 'pipe_fryum',
+    ]
+
+    def __init__(self, root=VISA_ROOT, train_ratio=0.5):
+        self.root = root
+        self.meta_path = f'{root}/meta.json'
+        self.phases = ['train', 'test']
+        self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0)
+        self.train_ratio = train_ratio
+
+    def run(self):
+        self.generate_meta_info()
+
+    def generate_meta_info(self):
+        columns = self.csv_data.columns  # [object, split, label, image, mask]
+        info = {phase: {} for phase in self.phases}
+        for cls_name in self.CLSNAMES:
+            cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name]
+            for phase in self.phases:
+                cls_info = []
+                cls_data_phase = cls_data[cls_data[columns[1]] == phase]
+                cls_data_phase.index = list(range(len(cls_data_phase)))
+                for idx in range(cls_data_phase.shape[0]):
+                    data = cls_data_phase.loc[idx]
+                    is_abnormal = True if data[2] == 'anomaly' else False
+                    info_img = dict(
+                        img_path=data[3],
+                        mask_path=data[4] if is_abnormal else '',
+                        cls_name=cls_name,
+                        specie_name='',
+                        anomaly=1 if is_abnormal else 0,
+                    )
+                    cls_info.append(info_img)
+                info[phase][cls_name] = cls_info
+        with open(self.meta_path, 'w') as f:
+            f.write(json.dumps(info, indent=4) + "\n")
+
+
+
+if __name__ == '__main__':
+    runner = VisASolver(root=VISA_ROOT)
+    runner.run()
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..881cb60d70e3fa4de1b1db53a57da50281d709bd
--- /dev/null
+++ b/dataset/__init__.py
@@ -0,0 +1,68 @@
+from .mvtec import MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT
+from .visa import VISA_CLS_NAMES, VisaDataset, VISA_ROOT
+from .mpdd import MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT
+from .btad import BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT
+from .sdd import SDD_CLS_NAMES, SDDDataset, SDD_ROOT
+from .dagm import DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT
+from .dtd import DTD_CLS_NAMES,DTDDataset,DTD_ROOT
+from .isic import ISIC_CLS_NAMES,ISICDataset,ISIC_ROOT
+from .colondb import ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT
+from .clinicdb import ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT
+from .tn3k import TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT
+from .headct import HEADCT_CLS_NAMES,HEADCTDataset,HEADCT_ROOT
+from .brain_mri import BrainMRI_CLS_NAMES,BrainMRIDataset,BrainMRI_ROOT
+from .br35h import Br35h_CLS_NAMES,Br35hDataset,Br35h_ROOT
+from torch.utils.data import ConcatDataset
+
+dataset_dict = {
+    'br35h': (Br35h_CLS_NAMES, Br35hDataset, Br35h_ROOT),
+    'brain_mri': (BrainMRI_CLS_NAMES, BrainMRIDataset, BrainMRI_ROOT),
+    'btad': (BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT),
+    'clinicdb': (ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT),
+    'colondb': (ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT),
+    'dagm': (DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT),
+    'dtd': (DTD_CLS_NAMES, DTDDataset, DTD_ROOT),
+    'headct': (HEADCT_CLS_NAMES, HEADCTDataset, HEADCT_ROOT),
+    'isic': (ISIC_CLS_NAMES, ISICDataset, ISIC_ROOT),
+    'mpdd': (MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT),
+    'mvtec': (MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT),
+    'sdd': (SDD_CLS_NAMES, SDDDataset, SDD_ROOT),
+    'tn3k': (TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT),
+    'visa': (VISA_CLS_NAMES, VisaDataset, VISA_ROOT),
+}
+
+def get_data(dataset_type_list, transform, target_transform, training):
+    if not isinstance(dataset_type_list, list):
+        dataset_type_list = [dataset_type_list]
+
+    dataset_cls_names_list = []
+    dataset_instance_list = []
+    dataset_root_list = []
+    for dataset_type in dataset_type_list:
+        if dataset_dict.get(dataset_type, ''):
+            dataset_cls_names, dataset_instance, dataset_root = dataset_dict[dataset_type]
+            dataset_instance = dataset_instance(
+                clsnames=dataset_cls_names,
+                transform=transform,
+                target_transform=target_transform,
+                training=training
+            )
+
+            dataset_cls_names_list.append(dataset_cls_names)
+            dataset_instance_list.append(dataset_instance)
+            dataset_root_list.append(dataset_root)
+
+        else:
+            print(f'Only support {list(dataset_dict.keys())}, but entered {dataset_type}...')
+            raise NotImplementedError
+
+    if len(dataset_type_list) > 1:
+        dataset_instance = ConcatDataset(dataset_instance_list)
+        dataset_cls_names = dataset_cls_names_list
+        dataset_root = dataset_root_list
+    else:
+        dataset_instance = dataset_instance_list[0]
+        dataset_cls_names = dataset_cls_names_list[0]
+        dataset_root = dataset_root_list[0]
+
+    return dataset_cls_names, dataset_instance, dataset_root
\ No newline at end of file
diff --git a/dataset/__pycache__/__init__.cpython-39.pyc b/dataset/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5826b4eca907aeaeb8c255768b822510e22307d
Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-39.pyc differ
diff --git a/dataset/__pycache__/br35h.cpython-39.pyc b/dataset/__pycache__/br35h.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bab8c2009806fd165fc4440d2616e7cf4ee1c792
Binary files /dev/null and b/dataset/__pycache__/br35h.cpython-39.pyc differ
diff --git a/dataset/__pycache__/brain_mri.cpython-39.pyc b/dataset/__pycache__/brain_mri.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac199549040338b67847dfd6b8019041e9b2b2ef
Binary files /dev/null and b/dataset/__pycache__/brain_mri.cpython-39.pyc differ
diff --git a/dataset/__pycache__/btad.cpython-39.pyc b/dataset/__pycache__/btad.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b807457785f8dd738f50a7a10f47f843895b568d
Binary files /dev/null and b/dataset/__pycache__/btad.cpython-39.pyc differ
diff --git a/dataset/__pycache__/clinicdb.cpython-39.pyc b/dataset/__pycache__/clinicdb.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3026d398c798b88a7a95801c172d1beca37ae7c0
Binary files /dev/null and b/dataset/__pycache__/clinicdb.cpython-39.pyc differ
diff --git a/dataset/__pycache__/colondb.cpython-39.pyc b/dataset/__pycache__/colondb.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97225e60df365de8467297105be18dc38c3fa6a4
Binary files /dev/null and b/dataset/__pycache__/colondb.cpython-39.pyc differ
diff --git a/dataset/__pycache__/dagm.cpython-39.pyc b/dataset/__pycache__/dagm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f287d6f437efcc86ef5c6ba0c5c402042d67b510
Binary files /dev/null and b/dataset/__pycache__/dagm.cpython-39.pyc differ
diff --git a/dataset/__pycache__/dtd.cpython-39.pyc b/dataset/__pycache__/dtd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..179f7e257b0e93bb15213c833914cc5e70e03632
Binary files /dev/null and b/dataset/__pycache__/dtd.cpython-39.pyc differ
diff --git a/dataset/__pycache__/headct.cpython-39.pyc b/dataset/__pycache__/headct.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..683659dfff8e6dcd1bb9f0561638521dae6e4505
Binary files /dev/null and b/dataset/__pycache__/headct.cpython-39.pyc differ
diff --git a/dataset/__pycache__/isic.cpython-39.pyc b/dataset/__pycache__/isic.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..066fe8d419662da76c7df9b1dbb85c9c41691c77
Binary files /dev/null and b/dataset/__pycache__/isic.cpython-39.pyc differ
diff --git a/dataset/__pycache__/mpdd.cpython-39.pyc b/dataset/__pycache__/mpdd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b9d3df589fb7f346ef4cd0a4809e7c6efac4498
Binary files /dev/null and b/dataset/__pycache__/mpdd.cpython-39.pyc differ
diff --git a/dataset/__pycache__/mvtec.cpython-39.pyc b/dataset/__pycache__/mvtec.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0607993c90b7508a55150951f2c1ed59a23f8e35
Binary files /dev/null and b/dataset/__pycache__/mvtec.cpython-39.pyc differ
diff --git a/dataset/__pycache__/sdd.cpython-39.pyc b/dataset/__pycache__/sdd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aafa3c9144683ea58bfeb4d8ca5dfaad99f842ab
Binary files /dev/null and b/dataset/__pycache__/sdd.cpython-39.pyc differ
diff --git a/dataset/__pycache__/tn3k.cpython-39.pyc b/dataset/__pycache__/tn3k.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c1096267ca748afd679365dec62f7835e5e2023
Binary files /dev/null and b/dataset/__pycache__/tn3k.cpython-39.pyc differ
diff --git a/dataset/__pycache__/visa.cpython-39.pyc b/dataset/__pycache__/visa.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43f6c38bf6ebcf3897a2e93d9b9c226a7a85f49d
Binary files /dev/null and b/dataset/__pycache__/visa.cpython-39.pyc differ
diff --git a/dataset/base_dataset.py b/dataset/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f09178216a203d9cedafb0f143ce377cbc2fb861
--- /dev/null
+++ b/dataset/base_dataset.py
@@ -0,0 +1,138 @@
+"""
+Base class for our zero-shot anomaly detection dataset
+"""
+import json
+import os
+import random
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import cv2
+from config import DATA_ROOT
+
+
+class DataSolver:
+    def __init__(self, root, clsnames):
+        self.root = root
+        self.clsnames = clsnames
+        self.path = os.path.join(root, 'meta.json')
+
+    def run(self):
+        with open(self.path, 'r') as f:
+            info = json.load(f)
+
+        info_required = dict(train={}, test={})
+        for cls in self.clsnames:
+            for k in info.keys():
+                info_required[k][cls] = info[k][cls]
+
+        return info_required
+
+
+class BaseDataset(data.Dataset):
+    def __init__(self, clsnames, transform, target_transform, root, aug_rate=0., training=True):
+        self.root = root
+        self.transform = transform
+        self.target_transform = target_transform
+        self.aug_rate = aug_rate
+        self.training = training
+        self.data_all = []
+        self.cls_names = clsnames
+
+        solver = DataSolver(root, clsnames)
+        meta_info = solver.run()
+
+        self.meta_info = meta_info['test']  # Only utilize the test dataset for both training and testing
+        for cls_name in self.cls_names:
+            self.data_all.extend(self.meta_info[cls_name])
+
+        self.length = len(self.data_all)
+
+    def __len__(self):
+        return self.length
+
+    def combine_img(self, cls_name):
+        """
+        From April-GAN: https://github.com/ByChelsea/VAND-APRIL-GAN
+        Here we combine four images into a single image for data augmentation.
+        """
+        img_info = random.sample(self.meta_info[cls_name], 4)
+
+        img_ls = []
+        mask_ls = []
+
+        for data in img_info:
+            img_path = os.path.join(self.root, data['img_path'])
+            mask_path = os.path.join(self.root, data['mask_path'])
+
+            img = Image.open(img_path).convert('RGB')
+            img_ls.append(img)
+
+            if not data['anomaly']:
+                img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
+            else:
+                img_mask = np.array(Image.open(mask_path).convert('L')) > 0
+                img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
+
+            mask_ls.append(img_mask)
+
+        # Image
+        image_width, image_height = img_ls[0].size
+        result_image = Image.new("RGB", (2 * image_width, 2 * image_height))
+        for i, img in enumerate(img_ls):
+            row = i // 2
+            col = i % 2
+            x = col * image_width
+            y = row * image_height
+            result_image.paste(img, (x, y))
+
+        # Mask
+        result_mask = Image.new("L", (2 * image_width, 2 * image_height))
+        for i, img in enumerate(mask_ls):
+            row = i // 2
+            col = i % 2
+            x = col * image_width
+            y = row * image_height
+            result_mask.paste(img, (x, y))
+
+        return result_image, result_mask
+
+    def __getitem__(self, index):
+        data = self.data_all[index]
+        img_path = os.path.join(self.root, data['img_path'])
+        mask_path = os.path.join(self.root, data['mask_path'])
+        cls_name = data['cls_name']
+        anomaly = data['anomaly']
+        random_number = random.random()
+
+        if self.training and random_number < self.aug_rate:
+            img, img_mask = self.combine_img(cls_name)
+        else:
+            if img_path.endswith('.tif'):
+                img = cv2.imread(img_path)
+                img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+            else:
+                img = Image.open(img_path).convert('RGB')
+            if anomaly == 0:
+                img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
+            else:
+                if data['mask_path']:
+                    img_mask = np.array(Image.open(mask_path).convert('L')) > 0
+                    img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
+                else:
+                    img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
+        # Transforms
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.target_transform is not None and img_mask is not None:
+            img_mask = self.target_transform(img_mask)
+        if img_mask is None:
+            img_mask = []
+
+        return {
+            'img': img,
+            'img_mask': img_mask,
+            'cls_name': cls_name,
+            'anomaly': anomaly,
+            'img_path': img_path
+        }
diff --git a/dataset/br35h.py b/dataset/br35h.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e6b8ebcd0a9871937ddac17064c1bfa4e25acb5
--- /dev/null
+++ b/dataset/br35h.py
@@ -0,0 +1,18 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection'''
+
+Br35h_CLS_NAMES = [
+    'br35h',
+]
+Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection')
+
+class Br35hDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=Br35h_CLS_NAMES, aug_rate=0.0, root=Br35h_ROOT, training=True):
+        super(Br35hDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
diff --git a/dataset/brain_mri.py b/dataset/brain_mri.py
new file mode 100644
index 0000000000000000000000000000000000000000..176f656d0b3da3a9b4a3c9f06914559c2ac605ef
--- /dev/null
+++ b/dataset/brain_mri.py
@@ -0,0 +1,16 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection'''
+BrainMRI_CLS_NAMES = [
+    'brain_mri',
+]
+BrainMRI_ROOT = os.path.join(DATA_ROOT, 'BrainMRI')
+
+class BrainMRIDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=BrainMRI_CLS_NAMES, aug_rate=0.0, root=BrainMRI_ROOT, training=True):
+        super(BrainMRIDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
diff --git a/dataset/btad.py b/dataset/btad.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec8858fc686d9323e19b71418535fa433020d30f
--- /dev/null
+++ b/dataset/btad.py
@@ -0,0 +1,16 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://avires.dimi.uniud.it/papers/btad/btad.zip'''
+BTAD_CLS_NAMES = [
+    '01', '02', '03',
+]
+BTAD_ROOT = os.path.join(DATA_ROOT, 'BTech_Dataset_transformed')
+
+class BTADDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=BTAD_CLS_NAMES, aug_rate=0.0, root=BTAD_ROOT, training=True):
+        super(BTADDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
diff --git a/dataset/clinicdb.py b/dataset/clinicdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..6376dd67bac837a578ebf723df594986ca004443
--- /dev/null
+++ b/dataset/clinicdb.py
@@ -0,0 +1,16 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://paperswithcode.com/dataset/cvc-clinicdb'''
+ClinicDB_CLS_NAMES = [
+    'ClinicDB',
+]
+ClinicDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ClinicDB')
+
+class ClinicDBDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=ClinicDB_CLS_NAMES, aug_rate=0.0, root=ClinicDB_ROOT, training=True):
+        super(ClinicDBDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
diff --git a/dataset/colondb.py b/dataset/colondb.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed70e981681642b1868769ce5de5c1c15f7440a0
--- /dev/null
+++ b/dataset/colondb.py
@@ -0,0 +1,18 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: http://mv.cvc.uab.es/projects/colon-qa/cvccolondb'''
+ColonDB_CLS_NAMES = [
+    'ColonDB',
+]
+ColonDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ColonDB')
+
+class ColonDBDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=ColonDB_CLS_NAMES, aug_rate=0.0, root=ColonDB_ROOT, training=True):
+        super(ColonDBDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
+
diff --git a/dataset/dagm.py b/dataset/dagm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b06b31fbd88ecfafe7cbc10a99efcf90b141dc5
--- /dev/null
+++ b/dataset/dagm.py
@@ -0,0 +1,16 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://hci.iwr.uni-heidelberg.de/content/weakly-supervised-learning-industrial-optical-inspection'''
+DAGM_CLS_NAMES = [
+    'Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6','Class7','Class8','Class9','Class10',
+]
+DAGM_ROOT = os.path.join(DATA_ROOT, 'DAGM_anomaly_detection')
+
+class DAGMDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=DAGM_CLS_NAMES, aug_rate=0.0, root=DAGM_ROOT, training=True):
+        super(DAGMDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
diff --git a/dataset/dtd.py b/dataset/dtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b2fc4eb276ab62e21e7f9adebaff530b89bc477
--- /dev/null
+++ b/dataset/dtd.py
@@ -0,0 +1,16 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1'''
+DTD_CLS_NAMES = [
+    'Blotchy_099', 'Fibrous_183', 'Marbled_078', 'Matted_069', 'Mesh_114','Perforated_037','Stratified_154','Woven_001','Woven_068','Woven_104','Woven_125','Woven_127',
+]
+DTD_ROOT = os.path.join(DATA_ROOT, 'DTD-Synthetic')
+
+class DTDDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=DTD_CLS_NAMES, aug_rate=0.0, root=DTD_ROOT, training=True):
+        super(DTDDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
diff --git a/dataset/headct.py b/dataset/headct.py
new file mode 100644
index 0000000000000000000000000000000000000000..f794ab7e8ac87c408088343e701a27f4e6d8630f
--- /dev/null
+++ b/dataset/headct.py
@@ -0,0 +1,18 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://www.kaggle.com/datasets/felipekitamura/head-ct-hemorrhage'''
+HEADCT_CLS_NAMES = [
+    'headct',
+]
+HEADCT_ROOT = os.path.join(DATA_ROOT, 'HeadCT_anomaly_detection')
+
+class HEADCTDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=HEADCT_CLS_NAMES, aug_rate=0.0, root=HEADCT_ROOT, training=True):
+        super(HEADCTDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
+
diff --git a/dataset/isic.py b/dataset/isic.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86492b8d634f8243672f2baede947499816abf8
--- /dev/null
+++ b/dataset/isic.py
@@ -0,0 +1,18 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://challenge.isic-archive.com/data/'''
+ISIC_CLS_NAMES = [
+    'isic',
+]
+ISIC_ROOT = os.path.join(DATA_ROOT, 'ISIC')
+
+class ISICDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=ISIC_CLS_NAMES, aug_rate=0.0, root=ISIC_ROOT, training=True):
+        super(ISICDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
+
diff --git a/dataset/mpdd.py b/dataset/mpdd.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ca3de11a8205c164f7de896e8f498d2469da7e3
--- /dev/null
+++ b/dataset/mpdd.py
@@ -0,0 +1,17 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://github.com/stepanje/MPDD'''
+MPDD_CLS_NAMES = [
+    'bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate','tubes',
+]
+MPDD_ROOT = os.path.join(DATA_ROOT, 'MPDD')
+
+class MPDDDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=MPDD_CLS_NAMES, aug_rate=0.0, root=MPDD_ROOT, training=True):
+        super(MPDDDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
diff --git a/dataset/mvtec.py b/dataset/mvtec.py
new file mode 100644
index 0000000000000000000000000000000000000000..44f0e160ac87cdf57f31f3ed27d9b0bcaf2c5bed
--- /dev/null
+++ b/dataset/mvtec.py
@@ -0,0 +1,19 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://paperswithcode.com/dataset/mvtecad'''
+
+MVTEC_CLS_NAMES = [
+    'bottle', 'cable', 'capsule', 'carpet', 'grid',
+    'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
+    'tile', 'toothbrush', 'transistor', 'wood', 'zipper',
+]
+MVTEC_ROOT = os.path.join(DATA_ROOT, 'mvtec_anomaly_detection')
+
+class MVTecDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=MVTEC_CLS_NAMES, aug_rate=0.2, root=MVTEC_ROOT, training=True):
+        super(MVTecDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
diff --git a/dataset/sdd.py b/dataset/sdd.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac6cc938505b3c2a51b2bd7c0230df3aa107dab
--- /dev/null
+++ b/dataset/sdd.py
@@ -0,0 +1,18 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://data.vicos.si/datasets/KSDD/KolektorSDD.zip'''
+SDD_CLS_NAMES = [
+    'SDD',
+]
+SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection')
+
+
+class SDDDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=SDD_CLS_NAMES, aug_rate=0.0, root=SDD_ROOT, training=True):
+        super(SDDDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
diff --git a/dataset/tn3k.py b/dataset/tn3k.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ab3e972ec83bee01913b587682f53f7f39cd142
--- /dev/null
+++ b/dataset/tn3k.py
@@ -0,0 +1,18 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://ieeexplore.ieee.org/document/9434087/references#references'''
+TN3K_CLS_NAMES = [
+    'tn3k',
+]
+TN3K_ROOT = os.path.join(DATA_ROOT, 'TN3K')
+
+class TN3KDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=TN3K_CLS_NAMES, aug_rate=0.0, root=TN3K_ROOT, training=True):
+        super(TN3KDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
+
diff --git a/dataset/visa.py b/dataset/visa.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cc9aaa3097e8275d2fccddfb34ed580e446457f
--- /dev/null
+++ b/dataset/visa.py
@@ -0,0 +1,20 @@
+import os
+from .base_dataset import BaseDataset
+from config import DATA_ROOT
+
+'''dataset source: https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar'''
+VISA_CLS_NAMES = [
+    'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
+    'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
+    'pcb4', 'pipe_fryum',
+]
+
+VISA_ROOT = os.path.join(DATA_ROOT, 'VisA_20220922')
+
+class VisaDataset(BaseDataset):
+    def __init__(self, transform, target_transform, clsnames=VISA_CLS_NAMES, aug_rate=0.0, root=VISA_ROOT, training=True):
+        super(VisaDataset, self).__init__(
+            clsnames=clsnames, transform=transform, target_transform=target_transform,
+            root=root, aug_rate=aug_rate, training=training
+        )
+
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e50dca66b1a02486e4787bbb2b3489767b85d4c2
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,15 @@
+# add dependencies
+# python395_cuda113_pytorch1101
+# please change dataset root in ./config.py according to your specifications
+
+conda create -n AdaCLIP python=3.9.5 -y
+conda activate AdaCLIP
+pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
+pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4
+pip install gradio
+
+
+
+
+
+
diff --git a/loss.py b/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d62a45677e0fd0f7d2c04d4d94672fe5d93f956
--- /dev/null
+++ b/loss.py
@@ -0,0 +1,189 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from math import exp
+
+class FocalLoss(nn.Module):
+    """
+    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
+    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
+    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
+        Focal_Loss= -1*alpha*(1-pt)*log(pt)
+    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
+    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
+                    focus on hard misclassified example
+    :param smooth: (float,double) smooth value when cross entropy
+    :param balance_index: (int) balance class index, should be specific when alpha is float
+    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
+    """
+
+    def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
+        super(FocalLoss, self).__init__()
+        self.apply_nonlin = apply_nonlin
+        self.alpha = alpha
+        self.gamma = gamma
+        self.balance_index = balance_index
+        self.smooth = smooth
+        self.size_average = size_average
+
+        if self.smooth is not None:
+            if self.smooth < 0 or self.smooth > 1.0:
+                raise ValueError('smooth value should be in [0,1]')
+
+    def forward(self, logit, target):
+        if self.apply_nonlin is not None:
+            logit = self.apply_nonlin(logit)
+        num_class = logit.shape[1]
+
+        if logit.dim() > 2:
+            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
+            logit = logit.view(logit.size(0), logit.size(1), -1)
+            logit = logit.permute(0, 2, 1).contiguous()
+            logit = logit.view(-1, logit.size(-1))
+        target = torch.squeeze(target, 1)
+        target = target.view(-1, 1)
+        alpha = self.alpha
+
+        if alpha is None:
+            alpha = torch.ones(num_class, 1)
+        elif isinstance(alpha, (list, np.ndarray)):
+            assert len(alpha) == num_class
+            alpha = torch.FloatTensor(alpha).view(num_class, 1)
+            alpha = alpha / alpha.sum()
+        elif isinstance(alpha, float):
+            alpha = torch.ones(num_class, 1)
+            alpha = alpha * (1 - self.alpha)
+            alpha[self.balance_index] = self.alpha
+
+        else:
+            raise TypeError('Not support alpha type')
+
+        if alpha.device != logit.device:
+            alpha = alpha.to(logit.device)
+
+        idx = target.cpu().long()
+
+        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
+        one_hot_key = one_hot_key.scatter_(1, idx, 1)
+        if one_hot_key.device != logit.device:
+            one_hot_key = one_hot_key.to(logit.device)
+
+        if self.smooth:
+            one_hot_key = torch.clamp(
+                one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
+        pt = (one_hot_key * logit).sum(1) + self.smooth
+        logpt = pt.log()
+
+        gamma = self.gamma
+
+        alpha = alpha[idx]
+        alpha = torch.squeeze(alpha)
+        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
+
+        if self.size_average:
+            loss = loss.mean()
+        return loss
+
+
+class BinaryDiceLoss(nn.Module):
+    def __init__(self):
+        super(BinaryDiceLoss, self).__init__()
+
+    def forward(self, input, targets):
+        # 获取每个批次的大小 N
+        N = targets.size()[0]
+        # 平滑变量
+        smooth = 1
+        # 将宽高 reshape 到同一纬度
+        input_flat = input.view(N, -1)
+        targets_flat = targets.view(N, -1)
+
+        # 计算交集
+        intersection = input_flat * targets_flat
+        N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
+        # 计算一个批次中平均每张图的损失
+        loss = 1 - N_dice_eff.sum() / N
+        return loss
+
+
+
+
+class ConADLoss(nn.Module):
+    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
+    It also supports the unsupervised contrastive loss in SimCLR"""
+    def __init__(self, contrast_mode='all',random_anchors=10):
+        super(ConADLoss, self).__init__()
+        assert contrast_mode in ['all', 'mean', 'random']
+        self.contrast_mode = contrast_mode
+        self.random_anchors = random_anchors
+    def forward(self, features, labels):
+        """Compute loss for model. If both `labels` and `mask` are None,
+        it degenerates to SimCLR unsupervised loss:
+        https://arxiv.org/pdf/2002.05709.pdf
+
+        Args:
+            features: hidden vector of shape [bsz, C, ...].
+            labels: ground truth of shape [bsz, 1, ...]., where 1 denotes to abnormal, and 0 denotes to normal
+        Returns:
+            A loss scalar.
+        """
+        device = (torch.device('cuda')
+                  if features.is_cuda
+                  else torch.device('cpu'))
+        if len(features.shape) != len(labels.shape):
+            raise ValueError('`features` needs to have the same dimensions with labels')
+
+        if len(features.shape) < 3:
+            raise ValueError('`features` needs to be [bsz, C, ...],'
+                             'at least 3 dimensions are required')
+
+        if len(features.shape) > 3:
+            features = features.view(features.shape[0], features.shape[1], -1)
+            labels = labels.view(labels.shape[0], labels.shape[1], -1)
+
+        labels = labels.squeeze()
+        batch_size = features.shape[0]
+
+        C = features.shape[1]
+        normal_feats = features[:, :, labels == 0]
+        abnormal_feats = features[:, :, labels == 1]
+
+        normal_feats = normal_feats.permute((1, 0, 2)).contiguous().view(C, -1)
+        abnormal_feats = abnormal_feats.permute((1, 0, 2)).contiguous().view(C, -1)
+
+        contrast_count = normal_feats.shape[1]
+        contrast_feature = normal_feats
+
+        if self.contrast_mode == 'mean':
+            anchor_feature = torch.mean(normal_feats, dim=1)
+            anchor_feature = F.normalize(anchor_feature, dim=0, p=2)
+            anchor_count = 1
+        elif self.contrast_mode == 'all':
+            anchor_feature = contrast_feature
+            anchor_count = contrast_count
+        elif self.contrast_mode == 'random':
+            dim_to_sample = 1
+            num_samples = min(self.random_anchors, contrast_count)
+            permuted_indices = torch.randperm(normal_feats.size(dim_to_sample)).to(normal_feats.device)
+            selected_indices = permuted_indices[:num_samples]
+            anchor_feature = normal_feats.index_select(dim_to_sample, selected_indices)
+        else:
+            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
+
+        # compute logits
+        # maximize similarity
+        anchor_dot_normal = torch.matmul(anchor_feature.T, normal_feats).mean()
+
+        # minimize similarity
+        anchor_dot_abnormal = torch.matmul(anchor_feature.T, abnormal_feats).mean()
+
+        loss = 0
+        if normal_feats.shape[1] > 0:
+            loss -= anchor_dot_normal
+        if abnormal_feats.shape[1] > 0:
+            loss += anchor_dot_abnormal
+
+        loss = torch.exp(loss)
+
+        return loss
diff --git a/method/__init__.py b/method/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f71b6f127b2fa66665b05d817a59acc5cfd0515
--- /dev/null
+++ b/method/__init__.py
@@ -0,0 +1,2 @@
+from .adaclip import AdaCLIP
+from .trainer import AdaCLIP_Trainer
\ No newline at end of file
diff --git a/method/__pycache__/__init__.cpython-39.pyc b/method/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa53ec8b4a1a2607a3b2c07956ea589bbc668e1e
Binary files /dev/null and b/method/__pycache__/__init__.cpython-39.pyc differ
diff --git a/method/__pycache__/adaclip.cpython-39.pyc b/method/__pycache__/adaclip.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..41a4c03c170823cf9ccad9032d27811f9d4d626c
Binary files /dev/null and b/method/__pycache__/adaclip.cpython-39.pyc differ
diff --git a/method/__pycache__/clip_model.cpython-39.pyc b/method/__pycache__/clip_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..693c0004c2f95a435b52f398038b84c1b82c9326
Binary files /dev/null and b/method/__pycache__/clip_model.cpython-39.pyc differ
diff --git a/method/__pycache__/custom_clip.cpython-39.pyc b/method/__pycache__/custom_clip.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afa77d23eb1a9669441eebe3ad99cc0c122b9873
Binary files /dev/null and b/method/__pycache__/custom_clip.cpython-39.pyc differ
diff --git a/method/__pycache__/simple_tokenizer.cpython-39.pyc b/method/__pycache__/simple_tokenizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73226b9ef5d352a6ce72af6816d6e0c8515a6d9d
Binary files /dev/null and b/method/__pycache__/simple_tokenizer.cpython-39.pyc differ
diff --git a/method/__pycache__/tokenizer.cpython-39.pyc b/method/__pycache__/tokenizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cad49a5cae4bc7d2a6096670b2f0d8350e5af57f
Binary files /dev/null and b/method/__pycache__/tokenizer.cpython-39.pyc differ
diff --git a/method/__pycache__/trainer.cpython-39.pyc b/method/__pycache__/trainer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de317be3f5386b412ca19af7f68c5196222a4860
Binary files /dev/null and b/method/__pycache__/trainer.cpython-39.pyc differ
diff --git a/method/__pycache__/transformer.cpython-39.pyc b/method/__pycache__/transformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..408c51ea67b5cad826a5f5a086731c276234e1c9
Binary files /dev/null and b/method/__pycache__/transformer.cpython-39.pyc differ
diff --git a/method/__pycache__/utils.cpython-39.pyc b/method/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c97e0a12775ae2d688ecab5baf43db849043003
Binary files /dev/null and b/method/__pycache__/utils.cpython-39.pyc differ
diff --git a/method/adaclip.py b/method/adaclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..2240e7887dcd9517ad9fc575ca24adde25b251d0
--- /dev/null
+++ b/method/adaclip.py
@@ -0,0 +1,583 @@
+from typing import Union, List, Optional
+import numpy as np
+import torch
+from pkg_resources import packaging
+from torch import nn
+from torch.nn import functional as F
+from .clip_model import CLIP
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+from sklearn.cluster import KMeans
+
+class ProjectLayer(nn.Module):
+    def __init__(self, input_dim, output_dim, num_replicas, stack=False, is_array=True):
+        super(ProjectLayer, self).__init__()
+
+        self.head = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_replicas)])
+        self.num_replicas = num_replicas
+        self.stack = stack
+        self.is_array = is_array
+
+    def forward(self, tokens):
+        out_tokens = []
+        for i in range(self.num_replicas):
+            if self.is_array:
+                temp = self.head[i](tokens[i][:, 1:, :]) # for ViT, we exclude the class token and only extract patch tokens here.
+            else:
+                temp = self.head[i](tokens)
+
+            out_tokens.append(temp)
+
+        if self.stack:
+            out_tokens = torch.stack(out_tokens, dim=1)
+
+        return out_tokens
+
+class PromptLayer(nn.Module):
+    def __init__(self, channel, length, depth, is_text, prompting_type, enabled=True):
+        super(PromptLayer, self).__init__()
+
+        self.channel = channel
+        self.length = length
+        self.depth = depth
+        self.is_text = is_text
+        self.enabled = enabled
+
+        self.prompting_type = prompting_type
+
+        if self.enabled: # only when enabled, the parameters should be constructed
+            if 'S' in prompting_type: # static prompts
+                # learnable
+                self.static_prompts = nn.ParameterList(
+                    [nn.Parameter(torch.empty(self.length, self.channel))
+                     for _ in range(self.depth)])
+
+                for single_para in self.static_prompts:
+                    nn.init.normal_(single_para, std=0.02)
+
+            if 'D' in prompting_type: # dynamic prompts
+                self.dynamic_prompts = [0.] # place holder
+
+    def set_dynamic_prompts(self, dynamic_prompts):
+        self.dynamic_prompts = dynamic_prompts
+
+    def forward_text(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None):
+        if self.enabled:
+            length = self.length
+
+            # only prompt the first J layers
+            if indx < self.depth:
+                if 'S' in self.prompting_type and 'D' in self.prompting_type: # both
+                    static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
+                    textual_context = self.dynamic_prompts + static_prompts
+                elif 'S' in self.prompting_type:  # static
+                    static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
+                    textual_context = static_prompts
+                elif 'D' in self.prompting_type:  # dynamic
+                    textual_context = self.dynamic_prompts
+                else:
+                    print('You should at least choose one type of prompts when the prompting branches are not none.')
+                    raise NotImplementedError
+
+            if indx == 0:  # for the first layer
+                x = x
+            else:
+                if indx < self.depth:  # replace with learnalbe tokens
+                    prefix = x[:1, :, :]
+                    suffix = x[1 + length:, :, :]
+                    textual_context = textual_context.permute(1, 0, 2).half()
+                    x = torch.cat([prefix, textual_context, suffix], dim=0)
+                else:  # keep the same
+                    x = x
+        else:
+            x = x
+
+        x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask)
+
+        return x, attn_tmp
+
+    def forward_visual(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None):
+        if self.enabled:
+            length = self.length
+
+            # only prompt the first J layers
+            if indx < self.depth:
+                if 'S' in self.prompting_type and 'D' in self.prompting_type: # both
+                    static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
+                    visual_context = self.dynamic_prompts + static_prompts
+                elif 'S' in self.prompting_type:  # static
+                    static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1)
+                    visual_context = static_prompts
+                elif 'D' in self.prompting_type:  # dynamic
+                    visual_context = self.dynamic_prompts
+                else:
+                    print('You should at least choose one type of prompts when the prompting branches are not none.')
+                    raise NotImplementedError
+
+
+            if indx == 0:  # for the first layer
+                visual_context = visual_context.permute(1, 0, 2).half()
+                x = torch.cat([x, visual_context], dim=0)
+            else:
+                if indx < self.depth:  # replace with learnalbe tokens
+                    prefix = x[0:x.shape[0] - length, :, :]
+                    visual_context = visual_context.permute(1, 0, 2).half()
+                    x = torch.cat([prefix, visual_context], dim=0)
+                else:  # keep the same
+                    x = x
+        else:
+            x = x
+
+        x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask)
+
+        if self.enabled:
+            tokens = x[:x.shape[0] - length, :, :]
+        else:
+            tokens = x
+
+        return x, tokens, attn_tmp
+
+    def forward(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None):
+        if self.is_text:
+            return self.forward_text(resblock, indx, x, k_x, v_x, attn_mask)
+        else:
+            return self.forward_visual(resblock, indx, x, k_x, v_x, attn_mask)
+
+
+class TextEmbebddingLayer(nn.Module):
+    def __init__(self, fixed):
+        super(TextEmbebddingLayer, self).__init__()
+        self.tokenizer = _Tokenizer()
+        self.ensemble_text_features = {}
+        self.prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw',
+                              '{} without defect',
+                              '{} without damage']
+        self.prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
+        self.prompt_state = [self.prompt_normal, self.prompt_abnormal]
+        self.prompt_templates = ['a bad photo of a {}.',
+                                 'a low resolution photo of the {}.',
+                                 'a bad photo of the {}.',
+                                 'a cropped photo of the {}.',
+                                 ]
+        self.fixed = fixed
+
+    def tokenize(self, texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[
+        torch.IntTensor, torch.LongTensor]:
+        if isinstance(texts, str):
+            texts = [texts]
+
+        sot_token = self.tokenizer.encoder["<|startoftext|>"]
+        eot_token = self.tokenizer.encoder["<|endoftext|>"]
+        all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
+        if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
+            result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+        else:
+            result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
+
+        for i, tokens in enumerate(all_tokens):
+            if len(tokens) > context_length:
+                if truncate:
+                    tokens = tokens[:context_length]
+                    tokens[-1] = eot_token
+                else:
+                    raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+            result[i, :len(tokens)] = torch.tensor(tokens)
+
+        return result
+
+    ## TODO: text layeer is not compitable with multiple batches...
+    def forward(self, model, texts, device):
+        text_feature_list = []
+
+        for indx, text in enumerate(texts):
+
+            if self.fixed:
+                if self.ensemble_text_features.get(text) is None:
+                    text_features = self.encode_text(model, text, device)
+                    self.ensemble_text_features[text] = text_features
+                else:
+                    text_features = self.ensemble_text_features[text]
+            else:
+                text_features = self.encode_text(model, text, device)
+                self.ensemble_text_features[text] = text_features
+
+            text_feature_list.append(text_features)
+
+        text_features = torch.stack(text_feature_list, dim=0)
+        text_features = F.normalize(text_features, dim=1)
+
+        return text_features
+
+    def encode_text(self, model, text, device):
+        text_features = []
+        for i in range(len(self.prompt_state)):
+            text = text.replace('-', ' ')
+            prompted_state = [state.format(text) for state in self.prompt_state[i]]
+            prompted_sentence = []
+            for s in prompted_state:
+                for template in self.prompt_templates:
+                    prompted_sentence.append(template.format(s))
+            prompted_sentence = self.tokenize(prompted_sentence, context_length=77).to(device)
+
+            class_embeddings = model.encode_text(prompted_sentence)
+
+            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+            class_embedding = class_embeddings.mean(dim=0)
+            class_embedding /= class_embedding.norm()
+            text_features.append(class_embedding)
+
+        text_features = torch.stack(text_features, dim=1)
+
+        return text_features
+
+
+class HybridSemanticFusion(nn.Module):
+    def __init__(self, k_clusters):
+        super(HybridSemanticFusion, self).__init__()
+        self.k_clusters = k_clusters
+        self.n_aggregate_patch_tokens = k_clusters * 5
+        self.cluster_performer = KMeans(n_clusters=self.k_clusters, n_init="auto")
+
+    # @torch.no_grad()
+    def forward(self, patch_tokens: list, anomaly_maps: list):
+        anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1)
+        anomaly_map = torch.softmax(anomaly_map, dim=2)[:, :, 1] # B, L
+
+        # extract most abnormal feats
+        selected_abnormal_tokens = []
+        k = min(anomaly_map.shape[1], self.n_aggregate_patch_tokens)
+        top_k_indices = torch.topk(anomaly_map, k=k, dim=1).indices
+        for layer in range(len(patch_tokens)):
+            selected_tokens = patch_tokens[layer]. \
+                gather(dim=1, index=top_k_indices.unsqueeze(-1).
+                       expand(-1, -1, patch_tokens[layer].shape[-1]))
+            selected_abnormal_tokens.append(selected_tokens)
+
+        # use kmeans to extract these centriods
+        # Stack the data_preprocess
+        stacked_data = torch.cat(selected_abnormal_tokens, dim=2)
+
+        batch_cluster_centers = []
+        # Perform K-Means clustering
+        for b in range(stacked_data.shape[0]):
+            cluster_labels = self.cluster_performer.fit_predict(stacked_data[b, :, :].detach().cpu().numpy())
+
+            # Initialize a list to store the cluster centers
+            cluster_centers = []
+
+            # Extract cluster centers for each cluster
+            for cluster_id in range(self.k_clusters):
+                collected_cluster_data = []
+                for abnormal_tokens in selected_abnormal_tokens:
+                    cluster_data = abnormal_tokens[b, :, :][cluster_labels == cluster_id]
+                    collected_cluster_data.append(cluster_data)
+                collected_cluster_data = torch.cat(collected_cluster_data, dim=0)
+                cluster_center = torch.mean(collected_cluster_data, dim=0, keepdim=True)
+                cluster_centers.append(cluster_center)
+
+            # Normalize the cluster centers
+            cluster_centers = torch.cat(cluster_centers, dim=0)
+            cluster_centers = torch.mean(cluster_centers, dim=0)
+            batch_cluster_centers.append(cluster_centers)
+
+        batch_cluster_centers = torch.stack(batch_cluster_centers, dim=0)
+        batch_cluster_centers = F.normalize(batch_cluster_centers, dim=1)
+
+        return batch_cluster_centers
+
+        # # preprocess
+        # # compute the anomaly map
+        # anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1)
+        # anomaly_map = torch.softmax(anomaly_map, dim=2)[:, :, 1] # B, L
+        #
+        # # compute the average multi-hierarchy patch embeddings
+        # avg_patch_tokens = torch.mean(torch.stack(patch_tokens, dim=0), dim=0) # B, L, C
+        #
+        # # Initialize a list to store the centroids of clusters with the largest anomaly scores
+        # cluster_centroids = []
+        #
+        # # loop across the batch size
+        # for b in range(avg_patch_tokens.shape[0]):
+        #     # step1: group features into clusters
+        #     cluster_labels = self.cluster_performer.fit_predict(avg_patch_tokens[b, :, :].detach().cpu().numpy())
+        #
+        #     # step2: compute the anomaly scores for individual clusters via the anomaly map
+        #     # Convert cluster labels back to tensor
+        #     cluster_labels = torch.tensor(cluster_labels).to(avg_patch_tokens.device)
+        #     cluster_anomaly_scores = {}
+        #     for label in torch.unique(cluster_labels):
+        #         cluster_indices = torch.where(cluster_labels == label)[0]
+        #         cluster_anomaly_scores[label.item()] = anomaly_map[b, cluster_indices].mean().item()
+        #
+        #     # step3: select the cluster with the largest anomaly score and then compute its centroid by averaging the
+        #     # corresponding avg_patch_tokens
+        #     # Find the cluster with the largest anomaly score
+        #     largest_anomaly_cluster = max(cluster_anomaly_scores, key=cluster_anomaly_scores.get)
+        #
+        #     # Get the indices of the tokens belonging to the largest anomaly cluster
+        #     largest_anomaly_cluster_indices = torch.where(cluster_labels == largest_anomaly_cluster)[0]
+        #
+        #     # Compute the centroid of the largest anomaly cluster by averaging the corresponding avg_patch_tokens
+        #     centroid = avg_patch_tokens[b, largest_anomaly_cluster_indices, :].mean(dim=0)
+        #
+        #     # Append the centroid to the list of cluster centroids
+        #     cluster_centroids.append(centroid)
+        #
+        # # Convert the list of centroids to a tensor
+        # cluster_centroids = torch.stack(cluster_centroids, dim=0)
+        # cluster_centroids = F.normalize(cluster_centroids, dim=1)
+
+        # return cluster_centroids
+
+class AdaCLIP(nn.Module):
+    def __init__(self, freeze_clip: CLIP, text_channel: int, visual_channel: int,
+                 prompting_length: int, prompting_depth: int, prompting_branch: str, prompting_type: str,
+                 use_hsf: bool, k_clusters: int,
+                 output_layers: list, device: str, image_size: int):
+        super(AdaCLIP, self).__init__()
+        self.freeze_clip = freeze_clip
+
+        self.visual = self.freeze_clip.visual
+        self.transformer = self.freeze_clip.transformer
+        self.token_embedding = self.freeze_clip.token_embedding
+        self.positional_embedding = self.freeze_clip.positional_embedding
+        self.ln_final = self.freeze_clip.ln_final
+        self.text_projection = self.freeze_clip.text_projection
+        self.attn_mask = self.freeze_clip.attn_mask
+
+        self.output_layers = output_layers
+
+        self.prompting_branch = prompting_branch
+        self.prompting_type = prompting_type
+        self.prompting_depth = prompting_depth
+        self.prompting_length = prompting_length
+        self.use_hsf = use_hsf
+        self.k_clusters = k_clusters
+
+        if 'L' in self.prompting_branch:
+            self.enable_text_prompt = True
+        else:
+            self.enable_text_prompt = False
+
+        if 'V' in self.prompting_branch:
+            self.enable_visual_prompt = True
+        else:
+            self.enable_visual_prompt = False
+
+        self.text_embedding_layer = TextEmbebddingLayer(fixed=(not self.enable_text_prompt))
+        self.text_prompter = PromptLayer(text_channel, prompting_length, prompting_depth, is_text=True,
+                                         prompting_type=prompting_type,
+                                         enabled=self.enable_text_prompt)
+        self.visual_prompter = PromptLayer(visual_channel, prompting_length, prompting_depth, is_text=False,
+                                           prompting_type=prompting_type,
+                                           enabled=self.enable_visual_prompt)
+
+        self.patch_token_layer = ProjectLayer(
+            visual_channel,
+            text_channel,
+            len(output_layers), stack=False, is_array=True
+        )
+
+        self.cls_token_layer = ProjectLayer(
+            text_channel,
+            text_channel,
+            1, stack=False, is_array=False
+        )
+
+        if 'D' in self.prompting_type: # dynamic prompts
+            self.dynamic_visual_prompt_generator = ProjectLayer(text_channel,
+                                                                visual_channel,
+                                                                prompting_length,
+                                                                stack=True,
+                                                                is_array=False)
+            self.dynamic_text_prompt_generator = ProjectLayer(text_channel,
+                                                              text_channel,
+                                                              prompting_length,
+                                                              stack=True,
+                                                              is_array=False)
+
+        if self.use_hsf:
+            self.HSF = HybridSemanticFusion(k_clusters)
+
+        self.image_size = image_size
+        self.device = device
+
+    def generate_and_set_dynamic_promtps(self, image):
+        with torch.no_grad():
+            # extract image features
+            image_features, _ = self.visual.forward(image, self.output_layers)
+
+        dynamic_visual_prompts = self.dynamic_visual_prompt_generator(image_features)
+        dynamic_text_prompts = self.dynamic_text_prompt_generator(image_features)
+
+        self.visual_prompter.set_dynamic_prompts(dynamic_visual_prompts)
+        self.text_prompter.set_dynamic_prompts(dynamic_text_prompts)
+
+
+    def encode_image(self, image):
+
+        x = image
+        # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+        if self.visual.input_patchnorm:
+            # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+            x = x.reshape(x.shape[0], x.shape[1],
+                          self.visual.grid_size[0],
+                          self.visual.patch_size[0],
+                          self.visual.grid_size[1],
+                          self.visual.patch_size[1])
+            x = x.permute(0, 2, 4, 1, 3, 5)
+            x = x.reshape(x.shape[0], self.visual.grid_size[0] * self.visual.grid_size[1], -1)
+            x = self.visual.patchnorm_pre_ln(x)
+            x = self.visual.conv1(x)
+        else:
+            x = self.visual.conv1(x)  # shape = [*, width, grid, grid]
+            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+        # class embeddings and positional embeddings
+        x = torch.cat(
+            [self.visual.class_embedding.to(x.dtype) +
+             torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+
+        x = x + self.visual.positional_embedding.to(x.dtype)
+
+        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+        x = self.visual.patch_dropout(x)
+        x = self.visual.ln_pre(x)
+
+        patch_embedding = x
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+
+        patch_tokens = []
+
+        for indx, r in enumerate(self.visual.transformer.resblocks):
+            x, tokens, attn_tmp = self.visual_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=None)
+
+            if (indx + 1) in self.output_layers:
+                patch_tokens.append(tokens)
+
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))]  # LND -> NLD
+
+        if self.visual.attn_pool is not None:
+            x = self.visual.attn_pool(x)
+            x = self.visual.ln_post(x)
+            pooled, tokens = self.visual._global_pool(x)
+        else:
+            pooled, tokens = self.visual._global_pool(x)
+            pooled = self.visual.ln_post(pooled)
+
+        if self.visual.proj is not None:
+            pooled = pooled @ self.visual.proj
+
+        return pooled, patch_tokens, patch_embedding
+
+    def proj_visual_tokens(self, image_features, patch_tokens):
+
+        # for patch tokens
+        proj_patch_tokens = self.patch_token_layer(patch_tokens)
+        for layer in range(len(proj_patch_tokens)):
+            proj_patch_tokens[layer] /= proj_patch_tokens[layer].norm(dim=-1, keepdim=True)
+
+        # for cls tokens
+        proj_cls_tokens = self.cls_token_layer(image_features)[0]
+        proj_cls_tokens /= proj_cls_tokens.norm(dim=-1, keepdim=True)
+
+        return proj_cls_tokens, proj_patch_tokens
+
+    def encode_text(self, text):
+        cast_dtype = self.transformer.get_cast_dtype()
+
+        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]
+
+        x = x + self.positional_embedding.to(cast_dtype)
+        x = x.permute(1, 0, 2)  # NLD -> LND
+
+        for indx, r in enumerate(self.transformer.resblocks):
+            # add prompt here
+            x, attn_tmp = self.text_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=self.attn_mask)
+
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
+
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+        return x
+
+    def visual_text_similarity(self, image_feature, patch_token, text_feature, aggregation):
+        anomaly_maps = []
+
+        for layer in range(len(patch_token)):
+            anomaly_map = (100.0 * patch_token[layer] @ text_feature)
+            anomaly_maps.append(anomaly_map)
+
+        if self.use_hsf:
+            alpha = 0.2
+            clustered_feature = self.HSF.forward(patch_token, anomaly_maps)
+            # aggregate the class token and the clustered features for more comprehensive information
+            cur_image_feature = alpha * clustered_feature + (1 - alpha) * image_feature
+            cur_image_feature = F.normalize(cur_image_feature, dim=1)
+        else:
+            cur_image_feature = image_feature
+
+        anomaly_score = (100.0 * cur_image_feature.unsqueeze(1) @ text_feature)
+        anomaly_score = anomaly_score.squeeze(1)
+        anomaly_score = torch.softmax(anomaly_score, dim=1)
+
+        # NOTE: this bilinear interpolation is not unreproducible and may occasionally lead to unstable ZSAD performance.
+        for i in range(len(anomaly_maps)):
+            B, L, C = anomaly_maps[i].shape
+            H = int(np.sqrt(L))
+            anomaly_maps[i] = anomaly_maps[i].permute(0, 2, 1).view(B, 2, H, H)
+            anomaly_maps[i] = F.interpolate(anomaly_maps[i], size=self.image_size, mode='bilinear', align_corners=True)
+
+        if aggregation: # in the test stage, we firstly aggregate logits from all hierarchies and then do the softmax normalization
+            anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1)
+            anomaly_map = torch.softmax(anomaly_map, dim=1)
+            anomaly_map = (anomaly_map[:, 1:, :, :] + 1 - anomaly_map[:, 0:1, :, :]) / 2.0
+            anomaly_score = anomaly_score[:, 1]
+            return anomaly_map, anomaly_score
+        else: # otherwise, we do the softmax normalization for individual hierarchies
+            for i in range(len(anomaly_maps)):
+                anomaly_maps[i] = torch.softmax(anomaly_maps[i], dim=1)
+            return anomaly_maps, anomaly_score
+
+    def extract_feat(self, image, cls_name):
+        if 'D' in self.prompting_type:
+            self.generate_and_set_dynamic_promtps(image) # generate and set dynamic prompts for corresponding prompters
+
+        if self.enable_visual_prompt:
+            image_features, patch_tokens, _ = self.encode_image(image)
+        else:
+            with torch.no_grad():
+                image_features, patch_tokens, _ = self.encode_image(image)
+
+        if self.enable_text_prompt:
+            text_features = self.text_embedding_layer(self, cls_name, self.device)
+        else:
+            with torch.no_grad():
+                text_features = self.text_embedding_layer(self, cls_name, self.device)
+
+        proj_cls_tokens, proj_patch_tokens = self.proj_visual_tokens(image_features, patch_tokens)
+
+        return proj_cls_tokens, proj_patch_tokens, text_features
+
+    @torch.cuda.amp.autocast()
+    def forward(self, image, cls_name, aggregation=True):
+        # extract features for images and texts
+        image_features, patch_tokens, text_features = self.extract_feat(image, cls_name)
+        anomaly_map, anomaly_score = self.visual_text_similarity(image_features, patch_tokens, text_features, aggregation)
+
+        if aggregation:
+            anomaly_map = anomaly_map # tensor
+            anomaly_score = anomaly_score
+            anomaly_map = anomaly_map.squeeze(1)
+
+            return anomaly_map, anomaly_score
+        else:
+            anomaly_maps = anomaly_map # list
+            anomaly_score = anomaly_score
+
+            return anomaly_maps, anomaly_score
+
diff --git a/method/bpe_simple_vocab_16e6.txt.gz b/method/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/method/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/method/clip_model.py b/method/clip_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..803ce8ae95819d0b2cf84aa6d20152c4470fac97
--- /dev/null
+++ b/method/clip_model.py
@@ -0,0 +1,412 @@
+""" CLIP Model
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+from dataclasses import dataclass
+import logging
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
+from .utils import to_2tuple
+
+@dataclass
+class CLIPVisionCfg:
+    layers: Union[Tuple[int, int, int, int], int] = 12
+    width: int = 768
+    head_width: int = 64
+    mlp_ratio: float = 4.0
+    patch_size: int = 16
+    image_size: Union[Tuple[int, int], int] = 224
+    ls_init_value: Optional[float] = None  # layer scale initial value
+    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
+    input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
+    global_average_pool: bool = False  # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
+    attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
+    n_queries: int = 256 # n_queries for attentional pooler
+    attn_pooler_heads: int = 8 # n heads for attentional_pooling
+    timm_model_name: str = None  # a valid model name overrides layers, width, patch_size
+    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model
+    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')
+    timm_proj_bias: bool = False  # enable bias final projection
+    timm_drop: float = 0.  # head dropout
+    timm_drop_path: Optional[float] = None  # backbone stochastic depth
+    output_tokens: bool = False
+
+
+@dataclass
+class CLIPTextCfg:
+    context_length: int = 77
+    vocab_size: int = 49408
+    width: int = 512
+    heads: int = 8
+    layers: int = 12
+    ls_init_value: Optional[float] = None  # layer scale initial value
+    hf_model_name: str = None
+    hf_tokenizer_name: str = None
+    hf_model_pretrained: bool = True
+    proj: str = 'mlp'
+    pooler_type: str = 'mean_pooler'
+    embed_cls: bool = False
+    pad_id: int = 0
+    output_tokens: bool = False
+
+
+def get_cast_dtype(precision: str):
+    cast_dtype = None
+    if precision == 'bf16':
+        cast_dtype = torch.bfloat16
+    elif precision == 'fp16':
+        cast_dtype = torch.float16
+    return cast_dtype
+
+
+def _build_vision_tower(
+        embed_dim: int,
+        vision_cfg: CLIPVisionCfg,
+        quick_gelu: bool = False,
+        cast_dtype: Optional[torch.dtype] = None,
+):
+    if isinstance(vision_cfg, dict):
+        vision_cfg = CLIPVisionCfg(**vision_cfg)
+
+    # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+    # memory efficient in recent PyTorch releases (>= 1.10).
+    # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+    act_layer = QuickGELU if quick_gelu else nn.GELU
+
+    vision_heads = vision_cfg.width // vision_cfg.head_width
+    norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+    visual = VisionTransformer(
+        image_size=vision_cfg.image_size,
+        patch_size=vision_cfg.patch_size,
+        width=vision_cfg.width,
+        layers=vision_cfg.layers,
+        heads=vision_heads,
+        mlp_ratio=vision_cfg.mlp_ratio,
+        ls_init_value=vision_cfg.ls_init_value,
+        patch_dropout=vision_cfg.patch_dropout,
+        input_patchnorm=vision_cfg.input_patchnorm,
+        global_average_pool=vision_cfg.global_average_pool,
+        attentional_pool=vision_cfg.attentional_pool,
+        n_queries=vision_cfg.n_queries,
+        attn_pooler_heads=vision_cfg.attn_pooler_heads,
+        output_tokens=vision_cfg.output_tokens,
+        output_dim=embed_dim,
+        act_layer=act_layer,
+        norm_layer=norm_layer
+    )
+
+    return visual
+
+
+def _build_text_tower(
+        embed_dim: int,
+        text_cfg: CLIPTextCfg,
+        quick_gelu: bool = False,
+        cast_dtype: Optional[torch.dtype] = None,
+):
+    if isinstance(text_cfg, dict):
+        text_cfg = CLIPTextCfg(**text_cfg)
+
+    act_layer = QuickGELU if quick_gelu else nn.GELU
+    norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+
+    text = TextTransformer(
+        context_length=text_cfg.context_length,
+        vocab_size=text_cfg.vocab_size,
+        width=text_cfg.width,
+        heads=text_cfg.heads,
+        layers=text_cfg.layers,
+        ls_init_value=text_cfg.ls_init_value,
+        output_dim=embed_dim,
+        embed_cls=text_cfg.embed_cls,
+        output_tokens=text_cfg.output_tokens,
+        pad_id=text_cfg.pad_id,
+        act_layer=act_layer,
+        norm_layer=norm_layer
+    )
+
+    return text
+
+
+class CLIP(nn.Module):
+    output_dict: torch.jit.Final[bool]
+
+    def __init__(
+            self,
+            embed_dim: int,
+            vision_cfg: CLIPVisionCfg,
+            text_cfg: CLIPTextCfg,
+            quick_gelu: bool = False,
+            cast_dtype: Optional[torch.dtype] = None,
+            output_dict: bool = False,
+    ):
+        super().__init__()
+        self.output_dict = output_dict
+        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+        self.transformer = text.transformer
+        self.vocab_size = text.vocab_size
+        self.token_embedding = text.token_embedding
+        self.positional_embedding = text.positional_embedding
+        self.ln_final = text.ln_final
+        self.text_projection = text.text_projection
+        self.register_buffer('attn_mask', text.attn_mask, persistent=False)
+
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+
+    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+    @torch.jit.ignore
+    def set_grad_checkpointing(self, enable=True):
+        self.visual.set_grad_checkpointing(enable)
+        self.transformer.grad_checkpointing = enable
+
+    def encode_image(self, image, out_layers):
+
+        x = image
+        # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+        if self.visual.input_patchnorm:
+            # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+            x = x.reshape(x.shape[0], x.shape[1],
+                          self.visual.grid_size[0],
+                          self.visual.patch_size[0],
+                          self.visual.grid_size[1],
+                          self.visual.patch_size[1])
+            x = x.permute(0, 2, 4, 1, 3, 5)
+            x = x.reshape(x.shape[0], self.visual.grid_size[0] * self.visual.grid_size[1], -1)
+            x = self.visual.patchnorm_pre_ln(x)
+            x = self.visual.conv1(x)
+        else:
+            x = self.visual.conv1(x)  # shape = [*, width, grid, grid]
+            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+
+        # class embeddings and positional embeddings
+        x = torch.cat(
+            [self.visual.class_embedding.to(x.dtype) +
+             torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+
+        x = x + self.visual.positional_embedding.to(x.dtype)
+
+        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+        x = self.visual.patch_dropout(x)
+        x = self.visual.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+
+        patch_tokens = []
+
+        idx = 0
+        for r in self.visual.transformer.resblocks:
+            idx += 1
+            # add prompt here
+            x, attn_tmp = r(x, attn_mask=None)
+            if idx in out_layers:
+                patch_tokens.append(x)
+
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))]  # LND -> NLD
+
+        if self.visual.attn_pool is not None:
+            x = self.visual.attn_pool(x)
+            x = self.visual.ln_post(x)
+            pooled, tokens = self.visual._global_pool(x)
+        else:
+            pooled, tokens = self.visual._global_pool(x)
+            pooled = self.visual.ln_post(pooled)
+
+        if self.visual.proj is not None:
+            pooled = pooled @ self.visual.proj
+
+        if self.visual.output_tokens:
+            return pooled, patch_tokens
+
+        return pooled, patch_tokens
+
+    def encode_text(self, text):
+        cast_dtype = self.transformer.get_cast_dtype()
+
+        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]
+
+        x = x + self.positional_embedding.to(cast_dtype)
+        x = x.permute(1, 0, 2)  # NLD -> LND
+
+        for r in self.visual.transformer.resblocks:
+            # add prompt here
+
+            x, attn_tmp = r(x, attn_mask=self.attn_mask)
+
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
+
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+        return x
+
+
+
+def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
+    """Convert applicable model parameters to low-precision (bf16 or fp16)"""
+
+    def _convert_weights(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.to(dtype)
+            if l.bias is not None:
+                l.bias.data = l.bias.data.to(dtype)
+
+        if isinstance(l, (nn.MultiheadAttention, Attention)):
+            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+                tensor = getattr(l, attr)
+                if tensor is not None:
+                    tensor.data = tensor.data.to(dtype)
+
+        for name in ["text_projection", "proj"]:
+            if hasattr(l, name):
+                attr = getattr(l, name)
+                if attr is not None:
+                    attr.data = attr.data.to(dtype)
+
+    model.apply(_convert_weights)
+
+
+convert_weights_to_fp16 = convert_weights_to_lp  # backwards compat
+
+
+# used to maintain checkpoint compatibility
+def convert_to_custom_text_state_dict(state_dict: dict):
+    if 'text_projection' in state_dict:
+        # old format state_dict, move text tower -> .text
+        new_state_dict = {}
+        for k, v in state_dict.items():
+            if any(k.startswith(p) for p in (
+                'text_projection',
+                'positional_embedding',
+                'token_embedding',
+                'transformer',
+                'ln_final',
+            )):
+                k = 'text.' + k
+            new_state_dict[k] = v
+        return new_state_dict
+    return state_dict
+
+
+def build_model_from_openai_state_dict(
+        state_dict: dict,
+        quick_gelu=True,
+        cast_dtype=torch.float16,
+):
+    vit = "visual.proj" in state_dict
+
+    if vit:
+        vision_width = state_dict["visual.conv1.weight"].shape[0]
+        vision_layers = len(
+            [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+        image_size = vision_patch_size * grid_size
+    else:
+        counts: list = [
+            len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+        vision_layers = tuple(counts)
+        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+        vision_patch_size = None
+        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+        image_size = output_width * 32
+
+    embed_dim = state_dict["text_projection"].shape[1]
+    context_length = state_dict["positional_embedding"].shape[0]
+    vocab_size = state_dict["token_embedding.weight"].shape[0]
+    transformer_width = state_dict["ln_final.weight"].shape[0]
+    transformer_heads = transformer_width // 64
+    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+    vision_cfg = CLIPVisionCfg(
+        layers=vision_layers,
+        width=vision_width,
+        patch_size=vision_patch_size,
+        image_size=image_size,
+    )
+    text_cfg = CLIPTextCfg(
+        context_length=context_length,
+        vocab_size=vocab_size,
+        width=transformer_width,
+        heads=transformer_heads,
+        layers=transformer_layers,
+    )
+    model = CLIP(
+        embed_dim,
+        vision_cfg=vision_cfg,
+        text_cfg=text_cfg,
+        quick_gelu=quick_gelu,  # OpenAI models were trained with QuickGELU
+        cast_dtype=cast_dtype,
+    )
+
+    for key in ["input_resolution", "context_length", "vocab_size"]:
+        state_dict.pop(key, None)
+
+    convert_weights_to_fp16(model)  # OpenAI state dicts are partially converted to float16
+    model.load_state_dict(state_dict)
+    return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device('cpu')):
+    model.eval()
+    image_size = model.visual.image_size
+    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
+    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
+    model = torch.jit.trace_module(
+        model,
+        inputs=dict(
+            forward=(example_images, example_text),
+            encode_text=(example_text,),
+            encode_image=(example_images,)
+        ))
+    model.visual.image_size = image_size
+    return model
+
+
+def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic'):
+    # Rescale the grid of position embeddings when loading from state_dict
+    old_pos_embed = state_dict.get('visual.positional_embedding', None)
+    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
+        return
+    grid_size = to_2tuple(model.visual.grid_size)
+    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)
+    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+    if new_seq_len == old_pos_embed.shape[0]:
+        return
+
+    if extra_tokens:
+        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+    else:
+        pos_emb_tok, pos_emb_img = None, old_pos_embed
+    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
+    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+    pos_emb_img = F.interpolate(
+        pos_emb_img,
+        size=grid_size,
+        mode=interpolation,
+        align_corners=False,
+    )
+    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+    if pos_emb_tok is not None:
+        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+    else:
+        new_pos_embed = pos_emb_img
+    state_dict['visual.positional_embedding'] = new_pos_embed
diff --git a/method/custom_clip.py b/method/custom_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..627c00c2dcdb9744c95cf90082cb74599bd9cdfb
--- /dev/null
+++ b/method/custom_clip.py
@@ -0,0 +1,716 @@
+# This file is largely borrowed from open clip
+import hashlib
+import json
+import logging
+import os
+import re
+import urllib
+import warnings
+from copy import deepcopy
+from dataclasses import dataclass, asdict
+from functools import partial
+from pathlib import Path
+from typing import Any, Optional, Tuple
+from typing import Dict, Union
+from typing import List
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as F
+from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
+    CenterCrop
+from tqdm import tqdm
+from .clip_model import CLIP, convert_to_custom_text_state_dict, \
+    resize_pos_embed
+from .clip_model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
+from .tokenizer import HFTokenizer, tokenize
+
+__version__ = '2.16.0'
+
+try:
+    from huggingface_hub import hf_hub_download
+
+    hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
+    _has_hf_hub = True
+except ImportError:
+    hf_hub_download = None
+    _has_hf_hub = False
+
+
+def _pcfg(url='', hf_hub='', mean=None, std=None):
+    return dict(
+        url=url,
+        hf_hub=hf_hub,
+        mean=mean,
+        std=std,
+    )
+
+
+_VITB32 = dict(
+    openai=_pcfg(
+        "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+    laion400m_e31=_pcfg(
+        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+    laion400m_e32=_pcfg(
+        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+    laion2b_e16=_pcfg(
+        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
+    laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
+)
+
+
+_VITB16 = dict(
+    openai=_pcfg(
+        "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
+    laion400m_e31=_pcfg(
+        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
+    laion400m_e32=_pcfg(
+        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
+    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
+)
+
+_VITL14 = dict(
+    openai=_pcfg(
+        "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
+    laion400m_e31=_pcfg(
+        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
+    laion400m_e32=_pcfg(
+        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
+    laion2b_s32b_b82k=_pcfg(
+        hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
+        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+)
+
+_VITL14_336 = dict(
+    openai=_pcfg(
+        "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
+)
+
+_VITH14 = dict(
+    laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
+)
+
+_VITg14 = dict(
+    laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
+    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
+)
+
+_VITbigG14 = dict(
+    laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
+)
+
+
+
+_PRETRAINED = {
+    "ViT-B-32": _VITB32,
+    "ViT-B-16": _VITB16,
+    "ViT-L-14": _VITL14,
+    "ViT-L-14-336": _VITL14_336,
+    "ViT-H-14": _VITH14,
+    "ViT-g-14": _VITg14,
+    "ViT-bigG-14": _VITbigG14,
+}
+
+
+def _clean_tag(tag: str):
+    # normalize pretrained tags
+    return tag.lower().replace('-', '_')
+
+
+def list_pretrained(as_str: bool = False):
+    """ returns list of pretrained models
+    Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+    """
+    return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
+
+
+def list_pretrained_models_by_tag(tag: str):
+    """ return all models having the specified pretrain tag """
+    models = []
+    tag = _clean_tag(tag)
+    for k in _PRETRAINED.keys():
+        if tag in _PRETRAINED[k]:
+            models.append(k)
+    return models
+
+
+def list_pretrained_tags_by_model(model: str):
+    """ return all pretrain tags for the specified model architecture """
+    tags = []
+    if model in _PRETRAINED:
+        tags.extend(_PRETRAINED[model].keys())
+    return tags
+
+
+def is_pretrained_cfg(model: str, tag: str):
+    if model not in _PRETRAINED:
+        return False
+    return _clean_tag(tag) in _PRETRAINED[model]
+
+
+def get_pretrained_cfg(model: str, tag: str):
+    if model not in _PRETRAINED:
+        return {}
+    model_pretrained = _PRETRAINED[model]
+    if 'openai' in model_pretrained.keys():
+        tag = 'openai'
+    else:
+        tag = list(model_pretrained.keys())[0]
+    print('*' * 50)
+    print(f'Use pretrained model from {tag}...')
+    print('*' * 50)
+    return model_pretrained.get(_clean_tag(tag), {})
+
+
+def get_pretrained_url(model: str, tag: str):
+    cfg = get_pretrained_cfg(model, _clean_tag(tag))
+    return cfg.get('url', '')
+
+
+def download_pretrained_from_url(
+        url: str,
+        cache_dir: Union[str, None] = None,
+):
+    if not cache_dir:
+        cache_dir = os.path.expanduser("~/.cache/clip")
+    os.makedirs(cache_dir, exist_ok=True)
+    filename = os.path.basename(url)
+
+    if 'openaipublic' in url:
+        expected_sha256 = url.split("/")[-2]
+    elif 'mlfoundations' in url:
+        expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
+    else:
+        expected_sha256 = ''
+
+    download_target = os.path.join(cache_dir, filename)
+
+    if os.path.exists(download_target) and not os.path.isfile(download_target):
+        raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+    if os.path.isfile(download_target):
+        if expected_sha256:
+            if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+                return download_target
+            else:
+                warnings.warn(
+                    f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+        else:
+            return download_target
+
+    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+        with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+            while True:
+                buffer = source.read(8192)
+                if not buffer:
+                    break
+
+                output.write(buffer)
+                loop.update(len(buffer))
+
+    if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(
+            expected_sha256):
+        raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+    return download_target
+
+
+def has_hf_hub(necessary=False):
+    if not _has_hf_hub and necessary:
+        # if no HF Hub module installed, and it is necessary to continue, raise error
+        raise RuntimeError(
+            'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
+    return _has_hf_hub
+
+
+def download_pretrained_from_hf(
+        model_id: str,
+        filename: str = 'open_clip_pytorch_model.bin',
+        revision=None,
+        cache_dir: Union[str, None] = None,
+):
+    has_hf_hub(True)
+    cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
+    return cached_file
+
+
+def download_pretrained(
+        cfg: Dict,
+        force_hf_hub: bool = False,
+        cache_dir: Union[str, None] = None,
+):
+    target = ''
+    if not cfg:
+        return target
+
+    download_url = cfg.get('url', '')
+    download_hf_hub = cfg.get('hf_hub', '')
+    if download_hf_hub and force_hf_hub:
+        # use HF hub even if url exists
+        download_url = ''
+
+    if download_url:
+        target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
+    elif download_hf_hub:
+        has_hf_hub(True)
+        # we assume the hf_hub entries in pretrained config combine model_id + filename in
+        # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
+        # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
+        model_id, filename = os.path.split(download_hf_hub)
+        if filename:
+            target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
+        else:
+            target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+
+    return target
+
+
+OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
+
+
+@dataclass
+class AugmentationCfg:
+    scale: Tuple[float, float] = (0.9, 1.0)
+    ratio: Optional[Tuple[float, float]] = None
+    color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
+    interpolation: Optional[str] = None
+    re_prob: Optional[float] = None
+    re_count: Optional[int] = None
+    use_timm: bool = False
+
+
+class ResizeMaxSize(nn.Module):
+
+    def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
+        super().__init__()
+        if not isinstance(max_size, int):
+            raise TypeError(f"Size should be int. Got {type(max_size)}")
+        self.max_size = max_size
+        self.interpolation = interpolation
+        self.fn = min if fn == 'min' else min
+        self.fill = fill
+
+    def forward(self, img):
+        if isinstance(img, torch.Tensor):
+            height, width = img.shape[:2]
+        else:
+            width, height = img.size
+        scale = self.max_size / float(max(height, width))
+        if scale != 1.0:
+            new_size = tuple(round(dim * scale) for dim in (height, width))
+            img = F.resize(img, new_size, self.interpolation)
+            pad_h = self.max_size - new_size[0]
+            pad_w = self.max_size - new_size[1]
+            img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
+        return img
+
+
+def _convert_to_rgb(image):
+    return image.convert('RGB')
+
+
+def image_transform(
+        image_size: int,
+        is_train: bool,
+        mean: Optional[Tuple[float, ...]] = None,
+        std: Optional[Tuple[float, ...]] = None,
+        resize_longest_max: bool = False,
+        fill_color: int = 0,
+        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+):
+    mean = mean or OPENAI_DATASET_MEAN
+    if not isinstance(mean, (list, tuple)):
+        mean = (mean,) * 3
+
+    std = std or OPENAI_DATASET_STD
+    if not isinstance(std, (list, tuple)):
+        std = (std,) * 3
+
+    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
+        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
+        image_size = image_size[0]
+
+    if isinstance(aug_cfg, dict):
+        aug_cfg = AugmentationCfg(**aug_cfg)
+    else:
+        aug_cfg = aug_cfg or AugmentationCfg()
+    normalize = Normalize(mean=mean, std=std)
+    if is_train:
+        aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
+        use_timm = aug_cfg_dict.pop('use_timm', False)
+        if use_timm:
+            from timm.data import create_transform  # timm can still be optional
+            if isinstance(image_size, (tuple, list)):
+                assert len(image_size) >= 2
+                input_size = (3,) + image_size[-2:]
+            else:
+                input_size = (3, image_size, image_size)
+            # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
+            aug_cfg_dict.setdefault('interpolation', 'random')
+            aug_cfg_dict.setdefault('color_jitter', None)  # disable by default
+            train_transform = create_transform(
+                input_size=input_size,
+                is_training=True,
+                hflip=0.,
+                mean=mean,
+                std=std,
+                re_mode='pixel',
+                **aug_cfg_dict,
+            )
+        else:
+            train_transform = Compose([
+                RandomResizedCrop(
+                    image_size,
+                    scale=aug_cfg_dict.pop('scale'),
+                    interpolation=InterpolationMode.BICUBIC,
+                ),
+                _convert_to_rgb,
+                ToTensor(),
+                normalize,
+            ])
+            if aug_cfg_dict:
+                warnings.warn(
+                    f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
+        return train_transform
+    else:
+        if resize_longest_max:
+            transforms = [
+                ResizeMaxSize(image_size, fill=fill_color)
+            ]
+        else:
+            transforms = [
+                Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+                CenterCrop(image_size),
+            ]
+        transforms.extend([
+            _convert_to_rgb,
+            ToTensor(),
+            normalize,
+        ])
+        return Compose(transforms)
+
+
+def list_openai_models() -> List[str]:
+    """Returns the names of available CLIP models"""
+    return list_pretrained_models_by_tag('openai')
+
+
+def load_openai_model(
+        name: str,
+        precision: Optional[str] = None,
+        device: Optional[Union[str, torch.device]] = None,
+        jit: bool = True,
+        cache_dir: Optional[str] = None,
+):
+    """Load a CLIP model
+
+    Parameters
+    ----------
+    name : str
+        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+    precision: str
+        Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
+    device : Union[str, torch.device]
+        The device to put the loaded model
+    jit : bool
+        Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+    cache_dir : Optional[str]
+        The directory to cache the downloaded model weights
+
+    Returns
+    -------
+    model : torch.nn.Module
+        The CLIP model
+    preprocess : Callable[[PIL.Image], torch.Tensor]
+        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+    """
+    if device is None:
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+    if precision is None:
+        precision = 'fp32' if device == 'cpu' else 'fp16'
+
+    cfg = get_pretrained_cfg(name, 'openai')
+    if cfg:
+        model_path = download_pretrained(cfg, cache_dir=cache_dir)
+    elif os.path.isfile(name):
+        model_path = name
+    else:
+        raise RuntimeError(f"Model {name} not found; available models = {list_pretrained()}")
+
+    try:
+        # loading JIT archive
+        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+        state_dict = None
+    except RuntimeError:
+        # loading saved state dict
+        if jit:
+            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+            jit = False
+        state_dict = torch.load(model_path, map_location="cpu")
+
+    # JIT -> Just In Time
+    if not jit:
+        # Build a non-jit model from the OpenAI jitted model state dict
+        cast_dtype = get_cast_dtype(precision)
+        try:
+            model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
+        except KeyError:
+            sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+            model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
+
+        # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
+        model = model.to(device)
+        if precision.startswith('amp') or precision == 'fp32':
+            model.float()
+        elif precision == 'bf16':
+            convert_weights_to_lp(model, dtype=torch.bfloat16)
+
+        return model
+
+    # patch the device names
+    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+    device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+    def patch_device(module):
+        try:
+            graphs = [module.graph] if hasattr(module, "graph") else []
+        except RuntimeError:
+            graphs = []
+
+        if hasattr(module, "forward1"):
+            graphs.append(module.forward1.graph)
+
+        for graph in graphs:
+            for node in graph.findAllNodes("prim::Constant"):
+                if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+                    node.copyAttributes(device_node)
+
+    model.apply(patch_device)
+    patch_device(model.encode_image)
+    patch_device(model.encode_text)
+
+    # patch dtype to float32 (typically for CPU)
+    if precision == 'fp32':
+        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+        float_node = float_input.node()
+
+        def patch_float(module):
+            try:
+                graphs = [module.graph] if hasattr(module, "graph") else []
+            except RuntimeError:
+                graphs = []
+
+            if hasattr(module, "forward1"):
+                graphs.append(module.forward1.graph)
+
+            for graph in graphs:
+                for node in graph.findAllNodes("aten::to"):
+                    inputs = list(node.inputs())
+                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
+                        if inputs[i].node()["value"] == 5:
+                            inputs[i].node().copyAttributes(float_node)
+
+        model.apply(patch_float)
+        patch_float(model.encode_image)
+        patch_float(model.encode_text)
+        model.float()
+
+    # ensure image_size attr available at consistent location for both jit and non-jit
+    model.visual.image_size = model.input_resolution.item()
+    return model
+
+
+HF_HUB_PREFIX = 'hf-hub:'
+_MODEL_CONFIG_PATHS = [Path(__file__).parent.parent / f"./model_configs/"]
+_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def _rescan_model_configs():
+    global _MODEL_CONFIGS
+
+    config_ext = ('.json',)
+    config_files = []
+    for config_path in _MODEL_CONFIG_PATHS:
+        if config_path.is_file() and config_path.suffix in config_ext:
+            config_files.append(config_path)
+        elif config_path.is_dir():
+            for ext in config_ext:
+                config_files.extend(config_path.glob(f'*{ext}'))
+
+    for cf in config_files:
+        with open(cf, 'r') as f:
+            model_cfg = json.load(f)
+            if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
+                _MODEL_CONFIGS[cf.stem] = model_cfg
+
+    _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
+
+
+_rescan_model_configs()  # initial populate of model config registry
+
+
+def list_models():
+    """ enumerate available model architectures based on config files """
+    return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+    """ add model config path or file and update registry """
+    if not isinstance(path, Path):
+        path = Path(path)
+    _MODEL_CONFIG_PATHS.append(path)
+    _rescan_model_configs()
+
+
+def get_model_config(model_name):
+    if model_name in _MODEL_CONFIGS:
+        return deepcopy(_MODEL_CONFIGS[model_name])
+    else:
+        return None
+
+
+def get_tokenizer(model_name):
+    if model_name.startswith(HF_HUB_PREFIX):
+        tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
+    else:
+        config = get_model_config(model_name)
+        tokenizer = HFTokenizer(
+            config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
+    return tokenizer
+
+
+def load_state_dict(checkpoint_path: str, map_location='cpu'):
+    checkpoint = torch.load(checkpoint_path, map_location=map_location)
+    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+        state_dict = checkpoint['state_dict']
+    else:
+        state_dict = checkpoint
+    if next(iter(state_dict.items()))[0].startswith('module'):
+        state_dict = {k[7:]: v for k, v in state_dict.items()}
+    return state_dict
+
+
+def load_checkpoint(model, checkpoint_path, strict=True):
+    state_dict = load_state_dict(checkpoint_path)
+    # detect old format and make compatible with new format
+    if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
+        state_dict = convert_to_custom_text_state_dict(state_dict)
+    resize_pos_embed(state_dict, model)
+    incompatible_keys = model.load_state_dict(state_dict, strict=strict)
+    return incompatible_keys
+
+
+def create_model(
+        model_name: str,
+        img_size: int,
+        pretrained: Optional[str] = None,
+        precision: str = 'fp32',
+        device: Union[str, torch.device] = 'cpu',
+        jit: bool = False,
+        cache_dir: Optional[str] = None,
+        output_dict: Optional[bool] = None,
+):
+    if model_name.count('ViT') < 1:
+        print('only support ViT model..')
+        raise NotImplementedError
+
+    # in which means, we can also use old naming rules.
+    model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names
+    checkpoint_path = None
+    pretrained_cfg = {}
+    model_cfg = None
+
+    if isinstance(device, str):
+        device = torch.device(device)
+
+    # our default version are borrowed from openai
+    assert pretrained and pretrained.lower() == 'openai', 'only support openai module.'
+    logging.info(f'Loading pretrained {model_name} from OpenAI.')
+    model_cfg = model_cfg or get_model_config(model_name)
+
+    model_cfg['vision_cfg']['image_size'] = img_size
+    cast_dtype = get_cast_dtype(precision)
+
+    model_pre = load_openai_model(
+        model_name,
+        precision=precision,
+        device=device,
+        jit=jit,
+        cache_dir=cache_dir,
+    )
+    state_dict = model_pre.state_dict()
+
+    # to always output dict even if it is clip
+    if output_dict and hasattr(model_pre, "output_dict"):
+        model_pre.output_dict = True
+
+    model = CLIP(**model_cfg, cast_dtype=cast_dtype)
+
+    # mainly need to resize the position embeddings
+    resize_pos_embed(state_dict, model)
+    incompatible_keys = model.load_state_dict(state_dict, strict=True)
+    model.to(device=device)
+    if precision in ("fp16", "bf16"):
+        convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
+
+    # set image / mean metadata from pretrained_cfg if available, or use default
+    model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
+    model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
+
+    # to always output dict even if it is clip
+    if output_dict and hasattr(model, "output_dict"):
+        model.output_dict = True
+
+    if jit:
+        model = torch.jit.script(model)
+
+    return model
+
+
+def create_model_and_transforms(
+        model_name: str,
+        img_size: int,
+        pretrained: Optional[str] = None,
+        precision: str = 'fp32',
+        device: Union[str, torch.device] = 'cpu',
+        jit: bool = False,
+        image_mean: Optional[Tuple[float, ...]] = None,
+        image_std: Optional[Tuple[float, ...]] = None,
+        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+        cache_dir: Optional[str] = None,
+        output_dict: Optional[bool] = None,
+):
+    ######### create the clip model
+    model = create_model(
+        model_name,
+        img_size,
+        pretrained,
+        precision=precision,
+        device=device,
+        jit=jit,
+        cache_dir=cache_dir,
+        output_dict=output_dict,
+    )
+
+    image_mean = image_mean or getattr(model.visual, 'image_mean', None)
+    image_std = image_std or getattr(model.visual, 'image_std', None)
+    preprocess_train = image_transform(
+        model.visual.image_size,
+        is_train=True,
+        mean=image_mean,
+        std=image_std,
+        aug_cfg=aug_cfg,
+    )
+    preprocess_val = image_transform(
+        model.visual.image_size,
+        is_train=False,
+        mean=image_mean,
+        std=image_std,
+    )
+
+    return model, preprocess_train, preprocess_val
diff --git a/method/simple_tokenizer.py b/method/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91
--- /dev/null
+++ b/method/simple_tokenizer.py
@@ -0,0 +1,132 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+        merges = merges[1:49152-256-2+1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v+'</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token+'</w>'
+
+        while True:
+            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word)-1 and word[i+1] == second:
+                    new_word.append(first+second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
+        return text
diff --git a/method/tokenizer.py b/method/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282
--- /dev/null
+++ b/method/tokenizer.py
@@ -0,0 +1,214 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+# https://stackoverflow.com/q/62691279
+import os
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a significant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+        merges = merges[1:49152-256-2+1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v+'</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        if not special_tokens:
+            special_tokens = ['<start_of_text>', '<end_of_text>']
+        else:
+            special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
+        vocab.extend(special_tokens)
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {t:t for t in special_tokens}
+        special = "|".join(special_tokens)
+        self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+        self.vocab_size = len(self.encoder)
+        self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token+'</w>'
+
+        while True:
+            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word)-1 and word[i+1] == second:
+                    new_word.append(first+second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
+        return text
+
+
+_tokenizer = SimpleTokenizer()
+
+def decode(output_ids: torch.Tensor):
+    output_ids = output_ids.cpu().numpy()
+    return _tokenizer.decode(output_ids)
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
+    """
+    Returns the tokenized representation of given input string(s)
+
+    Parameters
+    ----------
+    texts : Union[str, List[str]]
+        An input string or a list of input strings to tokenize
+    context_length : int
+        The context length to use; all CLIP models use 77 as the context length
+
+    Returns
+    -------
+    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+    """
+    if isinstance(texts, str):
+        texts = [texts]
+
+    sot_token = _tokenizer.encoder["<start_of_text>"]
+    eot_token = _tokenizer.encoder["<end_of_text>"]
+    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+    for i, tokens in enumerate(all_tokens):
+        if len(tokens) > context_length:
+            tokens = tokens[:context_length]  # Truncate
+            tokens[-1] = eot_token
+        result[i, :len(tokens)] = torch.tensor(tokens)
+
+    return result
+
+
+class HFTokenizer:
+    """HuggingFace tokenizer wrapper"""
+
+    def __init__(self, tokenizer_name: str):
+        from transformers import AutoTokenizer
+        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+    def save_pretrained(self, dest):
+        self.tokenizer.save_pretrained(dest)
+
+    def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
+        # same cleaning as for default tokenizer, except lowercasing
+        # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
+        if isinstance(texts, str):
+            texts = [texts]
+        texts = [whitespace_clean(basic_clean(text)) for text in texts]
+        input_ids = self.tokenizer(
+            texts,
+            return_tensors='pt',
+            max_length=context_length,
+            padding='max_length',
+            truncation=True,
+        ).input_ids
+        return input_ids
diff --git a/method/trainer.py b/method/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9bd540f44d46df686140aacf44eaf0913466fbe
--- /dev/null
+++ b/method/trainer.py
@@ -0,0 +1,225 @@
+import cv2
+import torchvision.transforms as transforms
+from scipy.ndimage import gaussian_filter
+
+from loss import FocalLoss, BinaryDiceLoss
+from tools import visualization, calculate_metric, calculate_average_metric
+from .adaclip import *
+from .custom_clip import create_model_and_transforms
+
+
+class AdaCLIP_Trainer(nn.Module):
+    def __init__(
+            self,
+            # clip-related
+            backbone, feat_list, input_dim, output_dim,
+
+            # learning-related
+            learning_rate, device, image_size,
+
+            # model settings
+            prompting_depth=3, prompting_length=2,
+            prompting_branch='VL', prompting_type='SD',
+            use_hsf=True, k_clusters=20,
+    ):
+
+        super(AdaCLIP_Trainer, self).__init__()
+
+        self.device = device
+        self.feat_list = feat_list
+        self.image_size = image_size
+        self.prompting_branch = prompting_branch
+        self.prompting_type = prompting_type
+
+        self.loss_focal = FocalLoss()
+        self.loss_dice = BinaryDiceLoss()
+
+        ########### different model choices
+        freeze_clip, _, self.preprocess = create_model_and_transforms(backbone, image_size,
+                                                                      pretrained='openai')
+        freeze_clip  = freeze_clip.to(device)
+        freeze_clip.eval()
+
+        self.clip_model = AdaCLIP(freeze_clip=freeze_clip,
+                                  text_channel=output_dim,
+                                  visual_channel=input_dim,
+                                  prompting_length=prompting_length,
+                                  prompting_depth=prompting_depth,
+                                  prompting_branch=prompting_branch,
+                                  prompting_type=prompting_type,
+                                  use_hsf=use_hsf,
+                                  k_clusters=k_clusters,
+                                  output_layers=feat_list,
+                                  device=device,
+                                  image_size=image_size).to(device)
+
+        self.transform = transforms.Compose([
+            transforms.Resize((image_size, image_size)),
+            transforms.CenterCrop(image_size),
+            transforms.ToTensor()
+        ])
+
+        self.preprocess.transforms[0] = transforms.Resize(size=(image_size, image_size),
+                                                          interpolation=transforms.InterpolationMode.BICUBIC,
+                                                          max_size=None)
+
+        self.preprocess.transforms[1] = transforms.CenterCrop(size=(image_size, image_size))
+
+        # update parameters
+        self.learnable_paramter_list = [
+            'text_prompter',
+            'visual_prompter',
+            'patch_token_layer',
+            'cls_token_layer',
+            'dynamic_visual_prompt_generator',
+            'dynamic_text_prompt_generator'
+        ]
+
+        self.params_to_update = []
+        for name, param in self.clip_model.named_parameters():
+            # print(name)
+            for update_name in self.learnable_paramter_list:
+                if update_name in name:
+                    # print(f'updated parameters--{name}: {update_name}')
+                    self.params_to_update.append(param)
+
+        # build the optimizer
+        self.optimizer = torch.optim.AdamW(self.params_to_update, lr=learning_rate, betas=(0.5, 0.999))
+
+    def save(self, path):
+        self.save_dict = {}
+        for param, value in self.state_dict().items():
+            for update_name in self.learnable_paramter_list:
+                if update_name in param:
+                    # print(f'{param}: {update_name}')
+                    self.save_dict[param] = value
+                    break
+
+        torch.save(self.save_dict, path)
+
+    def load(self, path):
+        self.load_state_dict(torch.load(path, map_location=self.device), strict=False)
+
+    def train_one_batch(self, items):
+        image = items['img'].to(self.device)
+        cls_name = items['cls_name']
+
+        # pixel level
+        anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=False)
+
+        if not isinstance(anomaly_map, list):
+            anomaly_map = [anomaly_map]
+
+        # losses
+        gt = items['img_mask'].to(self.device)
+        gt = gt.squeeze()
+
+        gt[gt > 0.5] = 1
+        gt[gt <= 0.5] = 0
+
+        is_anomaly = items['anomaly'].to(self.device)
+        is_anomaly[is_anomaly > 0.5] = 1
+        is_anomaly[is_anomaly <= 0.5] = 0
+        loss = 0
+
+        # classification loss
+        classification_loss = self.loss_focal(anomaly_score, is_anomaly.unsqueeze(1))
+        loss += classification_loss
+
+        # seg loss
+        seg_loss = 0
+        for am, in zip(anomaly_map):
+            seg_loss += (self.loss_focal(am, gt) + self.loss_dice(am[:, 1, :, :], gt) +
+                         self.loss_dice(am[:, 0, :, :], 1-gt))
+
+        loss += seg_loss
+
+        self.optimizer.zero_grad()
+        loss.backward()
+        self.optimizer.step()
+
+        return loss
+
+    def train_epoch(self, loader):
+        self.clip_model.train()
+        loss_list = []
+        for items in loader:
+            loss = self.train_one_batch(items)
+            loss_list.append(loss.item())
+
+        return np.mean(loss_list)
+
+    @torch.no_grad()
+    def evaluation(self, dataloader, obj_list, save_fig, save_fig_dir=None):
+        self.clip_model.eval()
+
+        results = {}
+        results['cls_names'] = []
+        results['imgs_gts'] = []
+        results['anomaly_scores'] = []
+        results['imgs_masks'] = []
+        results['anomaly_maps'] = []
+        results['imgs'] = []
+        results['names'] = []
+
+        with torch.no_grad(), torch.cuda.amp.autocast():
+            image_indx = 0
+            for indx, items in enumerate(dataloader):
+                if save_fig:
+                    path = items['img_path']
+                    for _path in path:
+                        vis_image = cv2.resize(cv2.imread(_path), (self.image_size, self.image_size))
+                        results['imgs'].append(vis_image)
+                    cls_name = items['cls_name']
+                    for _cls_name in cls_name:
+                        image_indx += 1
+                        results['names'].append('{:}-{:03d}'.format(_cls_name, image_indx))
+
+                image = items['img'].to(self.device)
+                cls_name = items['cls_name']
+                results['cls_names'].extend(cls_name)
+                gt_mask = items['img_mask']
+                gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
+
+                for _gt_mask in gt_mask:
+                    results['imgs_masks'].append(_gt_mask.squeeze(0).numpy())  # px
+
+                # pixel level
+                anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=True)
+
+                anomaly_map = anomaly_map.cpu().numpy()
+                anomaly_score = anomaly_score.cpu().numpy()
+
+                for _anomaly_map, _anomaly_score in zip(anomaly_map, anomaly_score):
+                    _anomaly_map = gaussian_filter(_anomaly_map, sigma=4)
+                    results['anomaly_maps'].append(_anomaly_map)
+                    results['anomaly_scores'].append(_anomaly_score)
+
+                is_anomaly = np.array(items['anomaly'])
+                for _is_anomaly in is_anomaly:
+                    results['imgs_gts'].append(_is_anomaly)
+
+        # visualization
+        if save_fig:
+            print('saving fig.....')
+            visualization.plot_sample_cv2(
+                results['names'],
+                results['imgs'],
+                {'AdaCLIP': results['anomaly_maps']},
+                results['imgs_masks'],
+                save_fig_dir
+            )
+
+        metric_dict = dict()
+        for obj in obj_list:
+            metric_dict[obj] = dict()
+
+        for obj in obj_list:
+            metric = calculate_metric(results, obj)
+            obj_full_name = f'{obj}'
+            metric_dict[obj_full_name] = metric
+
+        metric_dict['Average'] = calculate_average_metric(metric_dict)
+
+        return metric_dict
+
diff --git a/method/transformer.py b/method/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f66abbe368f440c6c0db5061c89b3497035b8c
--- /dev/null
+++ b/method/transformer.py
@@ -0,0 +1,615 @@
+from collections import OrderedDict
+import math
+from typing import Callable, Optional, Sequence, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
+
+from .utils import to_2tuple
+import numpy as np
+
+
+class LayerNormFp32(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
+        return x.to(orig_type)
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm (with cast back to input dtype)."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class LayerScale(nn.Module):
+    def __init__(self, dim, init_values=1e-5, inplace=False):
+        super().__init__()
+        self.inplace = inplace
+        self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+    def forward(self, x):
+        return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchDropout(nn.Module):
+    """
+    https://arxiv.org/abs/2212.00794
+    """
+
+    def __init__(self, prob, exclude_first_token=True):
+        super().__init__()
+        assert 0 <= prob < 1.
+        self.prob = prob
+        self.exclude_first_token = exclude_first_token  # exclude CLS token
+
+    def forward(self, x):
+        if not self.training or self.prob == 0.:
+            return x
+
+        if self.exclude_first_token:
+            cls_tokens, x = x[:, :1], x[:, 1:]
+        else:
+            cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+        batch = x.size()[0]
+        num_tokens = x.size()[1]
+
+        batch_indices = torch.arange(batch)
+        batch_indices = batch_indices[..., None]
+
+        keep_prob = 1 - self.prob
+        num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+        rand = torch.randn(batch, num_tokens)
+        patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+        x = x[batch_indices, patch_indices_keep]
+
+        if self.exclude_first_token:
+            x = torch.cat((cls_tokens, x), dim=1)
+
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+            self,
+            dim,
+            num_heads=8,
+            qkv_bias=True,
+            scaled_cosine=False,
+            scale_heads=False,
+            logit_scale_max=math.log(1. / 0.01),
+            attn_drop=0.,
+            proj_drop=0.
+    ):
+        super().__init__()
+        self.scaled_cosine = scaled_cosine
+        self.scale_heads = scale_heads
+        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+        self.num_heads = num_heads
+        self.head_dim = dim // num_heads
+        self.scale = self.head_dim ** -0.5
+        self.logit_scale_max = logit_scale_max
+
+        # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
+        self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
+        if qkv_bias:
+            self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
+        else:
+            self.in_proj_bias = None
+
+        if self.scaled_cosine:
+            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
+        else:
+            self.logit_scale = None
+        self.attn_drop = nn.Dropout(attn_drop)
+        if self.scale_heads:
+            self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
+        else:
+            self.head_scale = None
+        self.out_proj = nn.Linear(dim, dim)
+        self.out_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
+        L, N, C = x.shape
+        q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
+        q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+        k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+        v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+
+        if self.logit_scale is not None:
+            attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
+            logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
+            attn = attn.view(N, self.num_heads, L, L) * logit_scale
+            attn = attn.view(-1, L, L)
+        else:
+            q = q * self.scale
+            attn = torch.bmm(q, k.transpose(-1, -2))
+
+        if attn_mask is not None:
+            if attn_mask.dtype == torch.bool:
+                new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+                new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+                attn_mask = new_attn_mask
+            attn += attn_mask
+
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = torch.bmm(attn, v)
+        if self.head_scale is not None:
+            x = x.view(N, self.num_heads, L, C) * self.head_scale
+            x = x.view(-1, L, C)
+        x = x.transpose(0, 1).reshape(L, N, C)
+        x = self.out_proj(x)
+        x = self.out_drop(x)
+        return x
+
+
+class AttentionalPooler(nn.Module):
+    def __init__(
+            self,
+            d_model: int,
+            context_dim: int,
+            n_head: int = 8,
+            n_queries: int = 256,
+            norm_layer: Callable = LayerNorm
+    ):
+        super().__init__()
+        self.query = nn.Parameter(torch.randn(n_queries, d_model))
+        self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
+        self.ln_q = norm_layer(d_model)
+        self.ln_k = norm_layer(context_dim)
+
+    def forward(self, x: torch.Tensor):
+        x = self.ln_k(x).permute(1, 0, 2)  # NLD -> LND
+        N = x.shape[1]
+        q = self.ln_q(self.query)
+        out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
+        return out.permute(1, 0, 2)  # LND -> NLD
+
+    def _repeat(self, query, N: int):
+        return query.unsqueeze(1).repeat(1, N, 1)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(
+            self,
+            d_model: int,
+            n_head: int,
+            mlp_ratio: float = 4.0,
+            ls_init_value: float = None,
+            act_layer: Callable = nn.GELU,
+            norm_layer: Callable = LayerNorm,
+            is_cross_attention: bool = False,
+            idx: int = 12,
+    ):
+        super().__init__()
+
+        self.idx = idx
+
+        self.ln_1 = norm_layer(d_model)
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+        if is_cross_attention:
+            self.ln_1_kv = norm_layer(d_model)
+
+        self.ln_2 = norm_layer(d_model)
+        mlp_width = int(d_model * mlp_ratio)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, mlp_width)),
+            ("gelu", act_layer()),
+            ("c_proj", nn.Linear(mlp_width, d_model))
+        ]))
+        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+    def attention(
+            self,
+            q_x: torch.Tensor,
+            k_x: Optional[torch.Tensor] = None,
+            v_x: Optional[torch.Tensor] = None,
+            attn_mask: Optional[torch.Tensor] = None,
+    ):
+        k_x = k_x if k_x is not None else q_x
+        v_x = v_x if v_x is not None else q_x
+
+        attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
+        return self.attn(
+            q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask
+        )
+
+    def forward(
+            self,
+            q_x: torch.Tensor,
+            k_x: Optional[torch.Tensor] = None,
+            v_x: Optional[torch.Tensor] = None,
+            attn_mask: Optional[torch.Tensor] = None,
+    ):
+        k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
+        v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
+
+        tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
+        x = q_x + self.ls_1(tmp)
+        x = x + self.ls_2(self.mlp(self.ln_2(x)))
+        return x, attn
+
+
+
+class Transformer(nn.Module):
+    def __init__(
+            self,
+            width: int,
+            layers: int,
+            heads: int,
+            mlp_ratio: float = 4.0,
+            ls_init_value: float = None,
+            act_layer: Callable = nn.GELU,
+            norm_layer: Callable = LayerNorm,
+    ):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.grad_checkpointing = False
+
+        self.resblocks = nn.ModuleList([
+            ResidualAttentionBlock(
+                width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer,
+                idx=idx)
+            for idx in range(layers)
+        ])
+
+    def get_cast_dtype(self) -> torch.dtype:
+        return self.resblocks[0].mlp.c_fc.weight.dtype
+
+    def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9],
+                attn_mask: Optional[torch.Tensor] = None):
+        idx = 0
+        out_tokens = []
+        for r in self.resblocks:
+            idx += 1
+            if self.grad_checkpointing and not torch.jit.is_scripting():
+                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
+                x = checkpoint(r, x, None, None, attn_mask)
+            else:
+                x, attn_tmp = r(x, attn_mask=attn_mask)
+                if idx in out_layers:
+                    out_tokens.append(x)
+        return x, out_tokens
+
+
+
+class VisionTransformer(nn.Module):
+    output_tokens: torch.jit.Final[bool]
+
+    def __init__(
+            self,
+            image_size: int,
+            patch_size: int,
+            width: int,
+            layers: int,
+            heads: int,
+            mlp_ratio: float,
+            ls_init_value: float = None,
+            global_average_pool: bool = False,
+            attentional_pool: bool = False,
+            n_queries: int = 256,
+            attn_pooler_heads: int = 8,
+            output_dim: int = 512,
+            patch_dropout: float = 0.,
+            input_patchnorm: bool = False,
+            act_layer: Callable = nn.GELU,
+            norm_layer: Callable = LayerNorm,
+            output_tokens: bool = False,
+    ):
+        super().__init__()
+        self.output_tokens = output_tokens
+        image_height, image_width = self.image_size = to_2tuple(image_size)
+        patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
+        self.grid_size = (image_height // patch_height, image_width // patch_width)
+        self.output_dim = output_dim
+
+        # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
+        self.input_patchnorm = input_patchnorm
+
+        if input_patchnorm:
+            patch_input_dim = patch_height * patch_width * 3
+            self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
+            self.conv1 = nn.Linear(patch_input_dim, width)
+        else:
+            self.patchnorm_pre_ln = nn.Identity()
+            self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size,
+                                   bias=False)
+
+        # class embeddings and positional embeddings
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
+
+        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+        self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
+
+        self.ln_pre = norm_layer(width)
+
+        self.transformer = Transformer(
+            width,
+            layers,
+            heads,
+            mlp_ratio,
+            ls_init_value=ls_init_value,
+            act_layer=act_layer,
+            norm_layer=norm_layer,
+        )
+
+        self.global_average_pool = global_average_pool
+        if attentional_pool:
+            self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
+            self.ln_post = norm_layer(output_dim)
+            self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
+        else:
+            self.attn_pool = None
+            self.ln_post = norm_layer(width)
+            self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+        self.init_parameters()
+
+    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+        for param in self.parameters():
+            param.requires_grad = False
+
+        if unlocked_groups != 0:
+            groups = [
+                [
+                    self.conv1,
+                    self.class_embedding,
+                    self.positional_embedding,
+                    self.ln_pre,
+                ],
+                *self.transformer.resblocks[:-1],
+                [
+                    self.transformer.resblocks[-1],
+                    self.ln_post,
+                ],
+                self.proj,
+            ]
+
+            def _unlock(x):
+                if isinstance(x, Sequence):
+                    for g in x:
+                        _unlock(g)
+                else:
+                    if isinstance(x, torch.nn.Parameter):
+                        x.requires_grad = True
+                    else:
+                        for p in x.parameters():
+                            p.requires_grad = True
+
+            _unlock(groups[-unlocked_groups:])
+
+    def init_parameters(self):
+        # FIXME OpenAI CLIP did not define an init for the VisualTransformer
+        # TODO experiment if default PyTorch init, below, or alternate init is best.
+
+        # nn.init.normal_(self.class_embedding, std=self.scale)
+        # nn.init.normal_(self.positional_embedding, std=self.scale)
+        #
+        # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+        # attn_std = self.transformer.width ** -0.5
+        # fc_std = (2 * self.transformer.width) ** -0.5
+        # for block in self.transformer.resblocks:
+        #     nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+        #     nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+        #     nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+        #     nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+        #
+        # if self.text_projection is not None:
+        #     nn.init.normal_(self.text_projection, std=self.scale)
+        pass
+
+    @torch.jit.ignore
+    def set_grad_checkpointing(self, enable=True):
+        self.transformer.grad_checkpointing = enable
+
+    def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.global_average_pool:
+            return x.mean(dim=1), x
+        else:
+            return x[:, 0], x[:, 1:]
+
+    def forward(self, x: torch.Tensor, out_layers: list):
+
+        # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+        if self.input_patchnorm:
+            # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+            x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1],
+                          self.patch_size[1])
+            x = x.permute(0, 2, 4, 1, 3, 5)
+            x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
+            x = self.patchnorm_pre_ln(x)
+            x = self.conv1(x)
+        else:
+            x = self.conv1(x)  # shape = [*, width, grid, grid]
+            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+        # class embeddings and positional embeddings
+        x = torch.cat(
+            [self.class_embedding.to(x.dtype) +
+             torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+
+        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+        x = self.patch_dropout(x)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x, patch_tokens = self.transformer(x, out_layers)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))]  # LND -> NLD
+        # patch_tokens = patch_tokens.permute(1, 0, 2)  # LND -> NLD
+
+        if self.attn_pool is not None:
+            x = self.attn_pool(x)
+            x = self.ln_post(x)
+            pooled, tokens = self._global_pool(x)
+        else:
+            pooled, tokens = self._global_pool(x)
+            pooled = self.ln_post(pooled)
+            # patch_pooled, patch_tokens = self._global_pool(patch_tokens)
+            # tokens = self.ln_post(tokens)
+
+        if self.proj is not None:
+            pooled = pooled @ self.proj
+            # patch_tokens = patch_tokens @ self.proj  # 不知道能不能行
+            # tokens = tokens @ self.proj
+
+        if self.output_tokens:
+            return pooled, patch_tokens
+
+        return pooled, patch_tokens
+
+
+class TextTransformer(nn.Module):
+    output_tokens: torch.jit.Final[bool]
+
+    def __init__(
+            self,
+            context_length: int = 77,
+            vocab_size: int = 49408,
+            width: int = 512,
+            heads: int = 8,
+            layers: int = 12,
+            ls_init_value: float = None,
+            output_dim: int = 512,
+            act_layer: Callable = nn.GELU,
+            norm_layer: Callable = LayerNorm,
+            embed_cls: bool = False,
+            pad_id: int = 0,
+            output_tokens: bool = False,
+    ):
+        super().__init__()
+        self.output_tokens = output_tokens
+        self.num_pos = self.context_length = context_length
+        self.vocab_size = vocab_size
+        self.width = width
+        self.output_dim = output_dim
+        self.heads = heads
+        self.pad_id = pad_id
+
+        self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+        if embed_cls:
+            self.cls_emb = nn.Parameter(torch.empty(width))
+            self.num_pos += 1
+        else:
+            self.cls_emb = None
+
+        self.token_embedding = nn.Embedding(vocab_size, width)
+        self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
+
+        self.transformer = Transformer(
+            width=width,
+            layers=layers,
+            heads=heads,
+            ls_init_value=ls_init_value,
+            act_layer=act_layer,
+            norm_layer=norm_layer,
+        )
+
+        self.ln_final = norm_layer(width)
+
+        self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
+
+        self.init_parameters()
+
+    def init_parameters(self):
+        nn.init.normal_(self.token_embedding.weight, std=0.02)
+        nn.init.normal_(self.positional_embedding, std=0.01)
+        if self.cls_emb is not None:
+            nn.init.normal_(self.cls_emb, std=0.01)
+
+        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+        attn_std = self.transformer.width ** -0.5
+        fc_std = (2 * self.transformer.width) ** -0.5
+        for block in self.transformer.resblocks:
+            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+        if self.text_projection is not None:
+            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+    @torch.jit.ignore
+    def set_grad_checkpointing(self, enable=True):
+        self.transformer.grad_checkpointing = enable
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.num_pos, self.num_pos)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    def build_cls_mask(self, text, cast_dtype: torch.dtype):
+        cls_mask = (text != self.pad_id).unsqueeze(1)
+        cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
+        additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
+        additive_mask.fill_(0)
+        additive_mask.masked_fill_(~cls_mask, float("-inf"))
+        additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
+        return additive_mask
+
+    def _repeat(self, t, N: int):
+        return t.reshape(1, 1, -1).repeat(N, 1, 1)
+
+    def forward(self, text):
+        cast_dtype = self.transformer.get_cast_dtype()
+        seq_len = text.shape[1]
+
+        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]
+        attn_mask = self.attn_mask
+        if self.cls_emb is not None:
+            seq_len += 1
+            x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
+            cls_mask = self.build_cls_mask(text, cast_dtype)
+            attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
+
+        x = x + self.positional_embedding[:seq_len].to(cast_dtype)
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        if self.cls_emb is not None:
+            pooled, tokens = x[:, -1], x[:, :-1]
+            pooled = self.ln_final(pooled)
+        else:
+            x = self.ln_final(x)
+            pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
+
+        if self.text_projection is not None:
+            pooled = pooled @ self.text_projection
+
+        if self.output_tokens:
+            return pooled, tokens
+
+        return pooled
+
diff --git a/method/utils.py b/method/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673
--- /dev/null
+++ b/method/utils.py
@@ -0,0 +1,60 @@
+from itertools import repeat
+import collections.abc
+
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=''):
+    """
+    Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+    itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+    returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+    Args:
+        module (torch.nn.Module): Any PyTorch module.
+        module_match (dict): Dictionary of full module names to freeze (all if empty)
+        name (str): Full module name (prefix)
+
+    Returns:
+        torch.nn.Module: Resulting module
+
+    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+    """
+    res = module
+    is_match = True
+    if module_match:
+        is_match = name in module_match
+    if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
+        res = FrozenBatchNorm2d(module.num_features)
+        res.num_features = module.num_features
+        res.affine = module.affine
+        if module.affine:
+            res.weight.data = module.weight.data.clone().detach()
+            res.bias.data = module.bias.data.clone().detach()
+        res.running_mean.data = module.running_mean.data
+        res.running_var.data = module.running_var.data
+        res.eps = module.eps
+    else:
+        for child_name, child in module.named_children():
+            full_child_name = '.'.join([name, child_name]) if name else child_name
+            new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+            if new_child is not child:
+                res.add_module(child_name, new_child)
+    return res
+
+
+# From PyTorch internals
+def _ntuple(n):
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+    return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = lambda n, x: _ntuple(n)(x)
diff --git a/model_configs/ViT-B-16.json b/model_configs/ViT-B-16.json
new file mode 100644
index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c
--- /dev/null
+++ b/model_configs/ViT-B-16.json
@@ -0,0 +1,16 @@
+{
+    "embed_dim": 512,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 12,
+        "width": 768,
+        "patch_size": 16
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/model_configs/ViT-B-32.json b/model_configs/ViT-B-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f
--- /dev/null
+++ b/model_configs/ViT-B-32.json
@@ -0,0 +1,16 @@
+{
+    "embed_dim": 512,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 12,
+        "width": 768,
+        "patch_size": 32
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 512,
+        "heads": 8,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/model_configs/ViT-H-14.json b/model_configs/ViT-H-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74
--- /dev/null
+++ b/model_configs/ViT-H-14.json
@@ -0,0 +1,17 @@
+{
+    "embed_dim": 1024,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 32,
+        "width": 1280,
+        "head_width": 80,
+        "patch_size": 14
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 1024,
+        "heads": 16,
+        "layers": 24
+    }
+}
\ No newline at end of file
diff --git a/model_configs/ViT-L-14-336.json b/model_configs/ViT-L-14-336.json
new file mode 100644
index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97
--- /dev/null
+++ b/model_configs/ViT-L-14-336.json
@@ -0,0 +1,16 @@
+{
+    "embed_dim": 768,
+    "vision_cfg": {
+        "image_size": 336,
+        "layers": 24,
+        "width": 1024,
+        "patch_size": 14
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 768,
+        "heads": 12,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/model_configs/ViT-L-14.json b/model_configs/ViT-L-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241
--- /dev/null
+++ b/model_configs/ViT-L-14.json
@@ -0,0 +1,16 @@
+{
+    "embed_dim": 768,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 24,
+        "width": 1024,
+        "patch_size": 14
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 768,
+        "heads": 12,
+        "layers": 12
+    }
+}
\ No newline at end of file
diff --git a/model_configs/ViT-bigG-14.json b/model_configs/ViT-bigG-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7
--- /dev/null
+++ b/model_configs/ViT-bigG-14.json
@@ -0,0 +1,18 @@
+{
+    "embed_dim": 1280,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 48,
+        "width": 1664,
+        "head_width": 104,
+        "mlp_ratio": 4.9231,
+        "patch_size": 14
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 1280,
+        "heads": 20,
+        "layers": 32
+    }
+}
\ No newline at end of file
diff --git a/model_configs/ViT-g-14.json b/model_configs/ViT-g-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091
--- /dev/null
+++ b/model_configs/ViT-g-14.json
@@ -0,0 +1,18 @@
+{
+    "embed_dim": 1024,
+    "vision_cfg": {
+        "image_size": 224,
+        "layers": 40,
+        "width": 1408,
+        "head_width": 88,
+        "mlp_ratio": 4.3637,
+        "patch_size": 14
+    },
+    "text_cfg": {
+        "context_length": 77,
+        "vocab_size": 49408,
+        "width": 1024,
+        "heads": 16,
+        "layers": 24
+    }
+}
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4986fbe9de5b6fc9af3b78a79b49bd3d49d04cd
--- /dev/null
+++ b/test.py
@@ -0,0 +1,199 @@
+import warnings
+warnings.filterwarnings("ignore", category=RuntimeWarning)
+import os
+os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+import argparse
+import json
+import os
+import torch
+from scipy.ndimage import gaussian_filter
+import cv2
+
+# Importing from local modules
+from tools import write2csv, setup_seed, Logger
+from dataset import get_data, dataset_dict
+from method import AdaCLIP_Trainer
+from PIL import Image
+import numpy as np
+
+setup_seed(111)
+
+def train(args):
+    assert os.path.isfile(args.ckt_path), f"Please check the path of pre-trained model, {args.ckt_path} is not valid."
+
+    # Configurations
+    batch_size = args.batch_size
+    image_size = args.image_size
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+    save_fig = args.save_fig
+
+    # Logger
+    logger = Logger('log.txt')
+
+    # Print basic information
+    for key, value in sorted(vars(args).items()):
+        logger.info(f'{key} = {value}')
+
+
+    config_path = os.path.join('./model_configs', f'{args.model}.json')
+
+    # Prepare model
+    with open(config_path, 'r') as f:
+        model_configs = json.load(f)
+
+    # Set up the feature hierarchy
+    n_layers = model_configs['vision_cfg']['layers']
+    substage = n_layers // 4
+    features_list = [substage, substage * 2, substage * 3, substage * 4]
+
+    model = AdaCLIP_Trainer(
+        backbone=args.model,
+        feat_list=features_list,
+        input_dim=model_configs['vision_cfg']['width'],
+        output_dim=model_configs['embed_dim'],
+        learning_rate=0.,
+        device=device,
+        image_size=image_size,
+        prompting_depth=args.prompting_depth,
+        prompting_length=args.prompting_length,
+        prompting_branch=args.prompting_branch,
+        prompting_type=args.prompting_type,
+        use_hsf=args.use_hsf,
+        k_clusters=args.k_clusters
+    ).to(device)
+
+    model.load(args.ckt_path)
+
+    if args.testing_model == 'dataset':
+        assert args.testing_data in dataset_dict.keys(), f"You entered {args.testing_data}, but we only support " \
+                                                         f"{dataset_dict.keys()}"
+
+        save_root = args.save_path
+        csv_root = os.path.join(save_root, 'csvs')
+        image_root = os.path.join(save_root, 'images')
+        csv_path = os.path.join(csv_root, f'{args.testing_data}.csv')
+        image_dir = os.path.join(image_root, f'{args.testing_data}')
+        os.makedirs(image_dir, exist_ok=True)
+
+        test_data_cls_names, test_data, test_data_root = get_data(
+            dataset_type_list=args.testing_data,
+            transform=model.preprocess,
+            target_transform=model.transform,
+            training=False)
+
+        test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
+        save_fig_flag = save_fig
+
+        metric_dict = model.evaluation(
+            test_dataloader,
+            test_data_cls_names,
+            save_fig_flag,
+            image_dir,
+        )
+
+        for tag, data in metric_dict.items():
+            logger.info(
+                '{:>15} \t\tI-Auroc:{:.2f} \tI-F1:{:.2f} \tI-AP:{:.2f} \tP-Auroc:{:.2f} \tP-F1:{:.2f} \tP-AP:{:.2f}'.
+                    format(tag,
+                           data['auroc_im'],
+                           data['f1_im'],
+                           data['ap_im'],
+                           data['auroc_px'],
+                           data['f1_px'],
+                           data['ap_px'])
+            )
+
+
+        for k in metric_dict.keys():
+            write2csv(metric_dict[k], test_data_cls_names, k, csv_path)
+
+    elif args.testing_model == 'image':
+        assert os.path.isfile(args.image_path), f"Please verify the input image path: {args.image_path}"
+        ori_image = cv2.resize(cv2.imread(args.image_path), (args.image_size, args.image_size))
+        pil_img = Image.open(args.image_path).convert('RGB')
+
+        img_input = model.preprocess(pil_img).unsqueeze(0)
+        img_input = img_input.to(model.device)
+
+        with torch.no_grad():
+            anomaly_map, anomaly_score = model.clip_model(img_input, [args.class_name], aggregation=True)
+
+        anomaly_map = anomaly_map[0, :, :]
+        anomaly_score = anomaly_score[0]
+        anomaly_map = anomaly_map.cpu().numpy()
+        anomaly_score = anomaly_score.cpu().numpy()
+
+        anomaly_map = gaussian_filter(anomaly_map, sigma=4)
+        anomaly_map = anomaly_map * 255
+        anomaly_map = anomaly_map.astype(np.uint8)
+
+        heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
+        vis_map = cv2.addWeighted(heat_map, 0.5, ori_image, 0.5, 0)
+
+        vis_map = cv2.hconcat([ori_image, vis_map])
+        save_path = os.path.join(args.save_path, args.save_name)
+        print(f"Anomaly detection results are saved in {save_path}, with an anomaly of {anomaly_score:.3f} ")
+        cv2.imwrite(save_path, vis_map)
+
+def str2bool(v):
+    return v.lower() in ("yes", "true", "t", "1")
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser("AdaCLIP", add_help=True)
+
+    # Paths and configurations
+    parser.add_argument("--ckt_path", type=str, default='weights/pretrained_mvtec_colondb.pth',
+                        help="Path to the pre-trained model (default: weights/pretrained_mvtec_colondb.pth)")
+
+    parser.add_argument("--testing_model", type=str, default="dataset", choices=["dataset", "image"],
+                        help="Model for testing (default: 'dataset')")
+
+    # for the dataset model
+    parser.add_argument("--testing_data", type=str, default="visa", help="Dataset for testing (default: 'visa')")
+
+    # for the image model
+    parser.add_argument("--image_path", type=str, default="asset/img.png",
+                        help="Model for testing (default: 'asset/img.png')")
+    parser.add_argument("--class_name", type=str, default="candle",
+                        help="The class name of the testing image (default: 'candle')")
+    parser.add_argument("--save_name", type=str, default="test.png",
+                        help="Model for testing (default: 'dataset')")
+
+
+    parser.add_argument("--save_path", type=str, default='./workspaces',
+                        help="Directory to save results (default: './workspaces')")
+
+    parser.add_argument("--model", type=str, default="ViT-L-14-336",
+                        choices=["ViT-B-16", "ViT-B-32", "ViT-L-14", "ViT-L-14-336"],
+                        help="The CLIP model to be used (default: 'ViT-L-14-336')")
+
+    parser.add_argument("--save_fig", type=str2bool, default=False,
+                        help="Save figures for visualizations (default: False)")
+
+    # Hyper-parameters
+    parser.add_argument("--batch_size", type=int, default=1, help="Batch size (default: 1)")
+    parser.add_argument("--image_size", type=int, default=518, help="Size of the input images (default: 518)")
+
+    # Prompting parameters
+    parser.add_argument("--prompting_depth", type=int, default=4, help="Depth of prompting (default: 4)")
+    parser.add_argument("--prompting_length", type=int, default=5, help="Length of prompting (default: 5)")
+    parser.add_argument("--prompting_type", type=str, default='SD', choices=['', 'S', 'D', 'SD'],
+                        help="Type of prompting. 'S' for Static, 'D' for Dynamic, 'SD' for both (default: 'SD')")
+    parser.add_argument("--prompting_branch", type=str, default='VL', choices=['', 'V', 'L', 'VL'],
+                        help="Branch of prompting. 'V' for Visual, 'L' for Language, 'VL' for both (default: 'VL')")
+
+    parser.add_argument("--use_hsf", type=str2bool, default=True,
+                        help="Use HSF for aggregation. If False, original class embedding is used (default: True)")
+    parser.add_argument("--k_clusters", type=int, default=20, help="Number of clusters (default: 20)")
+
+    args = parser.parse_args()
+
+    if args.batch_size != 1:
+        raise NotImplementedError(
+            "Currently, only batch size of 1 is supported due to unresolved bugs. Please set --batch_size to 1.")
+
+    train(args)
+
diff --git a/test.sh b/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0a6999781cb42d7f87cd6b2fd193702c21e7e0bd
--- /dev/null
+++ b/test.sh
@@ -0,0 +1,26 @@
+# pre-trained from MVTec and ColonDB
+ckt_path="weights/pretrained_mvtec_colondb.pth"
+gpu_id=0
+
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data br35h
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data brain_mri
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data btad
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data clinicdb
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data dagm
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data dtd
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data headct
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data isic
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data mpdd
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data sdd
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data tn3k
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data visa
+
+# pre-trained from Visa and Clinicdb
+ckt_path="weights/pretrained_visa_clinicdb.pth"
+gpu_id=0
+
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data colondb
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data mvtec
+
+
+
diff --git a/test_single_image.sh b/test_single_image.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7723a325fa47b202319c1693c31101b26e2fe974
--- /dev/null
+++ b/test_single_image.sh
@@ -0,0 +1,6 @@
+ckt_path="weights/pretrained_all.pth"
+gpu_id=0
+
+# demo: do zero-shot anomaly detection for a single image
+CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model image --ckt_path $ckt_path --save_fig True \
+ --image_path asset/img.png --class_name candle --save_name test.png
\ No newline at end of file
diff --git a/tools/__init__.py b/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6da160913889c6e8cec2654ee437a284aa7bb401
--- /dev/null
+++ b/tools/__init__.py
@@ -0,0 +1,5 @@
+from .csv_tools import write2csv
+from .logger import Logger, log_metrics
+from .metrics import calculate_metric, calculate_average_metric
+from .training_tools import setup_seed, setup_paths
+from .visualization import plot_sample_cv2
\ No newline at end of file
diff --git a/tools/__pycache__/__init__.cpython-39.pyc b/tools/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32b77dd7d0ede42e6f20793700a8cfca79bb557e
Binary files /dev/null and b/tools/__pycache__/__init__.cpython-39.pyc differ
diff --git a/tools/__pycache__/csv_tools.cpython-39.pyc b/tools/__pycache__/csv_tools.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..36a6988ccaf8b3f53ad62d83c448cae289891e82
Binary files /dev/null and b/tools/__pycache__/csv_tools.cpython-39.pyc differ
diff --git a/tools/__pycache__/logger.cpython-39.pyc b/tools/__pycache__/logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a5f5bd9303e1bf7d4af5d63bf2139cbc8fbadc4
Binary files /dev/null and b/tools/__pycache__/logger.cpython-39.pyc differ
diff --git a/tools/__pycache__/metrics.cpython-39.pyc b/tools/__pycache__/metrics.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ab33773e9fde9f213a6b61df1ded84998e78826
Binary files /dev/null and b/tools/__pycache__/metrics.cpython-39.pyc differ
diff --git a/tools/__pycache__/training_tools.cpython-39.pyc b/tools/__pycache__/training_tools.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a124f410758eafb807725c0c30a3a50587bf1d6c
Binary files /dev/null and b/tools/__pycache__/training_tools.cpython-39.pyc differ
diff --git a/tools/__pycache__/visualization.cpython-39.pyc b/tools/__pycache__/visualization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b2db2482f48ba34dc681980c9e5194ef795da91a
Binary files /dev/null and b/tools/__pycache__/visualization.cpython-39.pyc differ
diff --git a/tools/csv_tools.py b/tools/csv_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..515084889fc1462f4b444707a42188a3c2e79fac
--- /dev/null
+++ b/tools/csv_tools.py
@@ -0,0 +1,28 @@
+import pandas as pd
+import os
+
+def write2csv(results:dict, total_classes, cur_class,  csv_path):
+    keys = list(results.keys())
+
+    if not os.path.exists(csv_path):
+        df_all = None
+        for class_name in total_classes:
+            r = dict()
+            for k in keys:
+                r[k] = 0.00
+            df_temp = pd.DataFrame(r, index=[f'{class_name}'])
+
+            if df_all is None:
+                df_all = df_temp
+            else:
+                df_all = pd.concat([df_all, df_temp], axis=0)
+
+        df_all.to_csv(csv_path, header=True, float_format='%.2f')
+
+    df = pd.read_csv(csv_path, index_col=0)
+
+    for k in keys:
+        df.loc[f'{cur_class}', k] = results[k]
+
+    df.to_csv(csv_path, header=True, float_format='%.2f')
+
diff --git a/tools/logger.py b/tools/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..abfbb4424d07fdf6f8e80b13cfc3a15cfb060209
--- /dev/null
+++ b/tools/logger.py
@@ -0,0 +1,74 @@
+import logging
+
+class Logger(object):
+    def __init__(self, txt_path):
+        root_logger = logging.getLogger()
+        for handler in root_logger.handlers[:]:
+            root_logger.removeHandler(handler)
+        root_logger.setLevel(logging.WARNING)
+        self.txt_path = txt_path
+        self.logger = logging.getLogger('train')
+        self.formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
+        self.logger.setLevel(logging.INFO)
+
+    def __console(self, level, message):
+        root_logger = logging.getLogger()
+        for handler in root_logger.handlers[:]:
+            root_logger.removeHandler(handler)
+
+        file_handler = logging.FileHandler(self.txt_path, mode='a')
+        console_handler = logging.StreamHandler()
+
+        file_handler.setFormatter(self.formatter)
+        console_handler.setFormatter(self.formatter)
+
+        self.logger.addHandler(file_handler)
+        self.logger.addHandler(console_handler)
+
+        if level == 'info':
+            self.logger.info(message)
+        elif level == 'debug':
+            self.logger.debug(message)
+        elif level == 'warning':
+            self.logger.warning(message)
+        elif level == 'error':
+            self.logger.error(message)
+
+        self.logger.removeHandler(file_handler)
+        self.logger.removeHandler(console_handler)
+
+        file_handler.close()
+
+    def debug(self, message):
+        self.__console('debug', message)
+
+    def info(self, message):
+        self.__console('info', message)
+
+    def warning(self, message):
+        self.__console('warning', message)
+
+    def error(self, message):
+        self.__console('error', message)
+
+def log_metrics(metrics, logger, tensorboard_logger, epoch):
+    def log_single_class(data, tag):
+        logger.info(
+            '{:>15} \t\tI-Auroc:{:.2f} \tI-F1:{:.2f} \tI-AP:{:.2f} \tP-Auroc:{:.2f} \tP-F1:{:.2f} \tP-AP:{:.2f}'.
+            format(tag,
+                   data['auroc_im'],
+                   data['f1_im'],
+                   data['ap_im'],
+                   data['auroc_px'],
+                   data['f1_px'],
+                   data['ap_px'])
+        )
+        # Adding scalar metrics to TensorBoard
+        for metric_name in ['auroc_im', 'f1_im', 'ap_im', 'auroc_px', 'f1_px', 'ap_px']:
+            tensorboard_logger.add_scalar(f'{tag}-{metric_name}', data[metric_name], epoch)
+
+    for tag, data in metrics.items():
+        log_single_class(data, tag)
+
+
+
diff --git a/tools/metrics.py b/tools/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..71dad0c8c9c6e867c809ee717999904939b84a97
--- /dev/null
+++ b/tools/metrics.py
@@ -0,0 +1,87 @@
+import numpy as np
+from sklearn.metrics import auc, roc_auc_score, precision_recall_curve, average_precision_score
+
+
+def rescale(x):
+    return (x - x.min()) / (x.max() - x.min())
+
+
+def is_one_class(gt: np.ndarray):
+    gt_ravel = gt.ravel()
+    return gt_ravel.sum() == 0 or gt_ravel.sum() == gt_ravel.shape[0]
+
+
+def calculate_px_metrics(gt_px, pr_px):
+    if is_one_class(gt_px):  # In case there are only normal pixels or no pixel-level labels
+        return 0, 0, 0
+
+    auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel())
+    precisions, recalls, _ = precision_recall_curve(gt_px.ravel(), pr_px.ravel())
+    f1_scores = (2 * precisions * recalls) / (precisions + recalls)
+    f1_px = np.max(f1_scores[np.isfinite(f1_scores)])
+    ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel())
+
+    return auroc_px * 100, f1_px * 100, ap_px * 100
+
+
+def calculate_im_metrics(gt_im, pr_im):
+    if is_one_class(gt_im):  # In case there are only normal samples or no image-level labels
+        return 0, 0, 0
+
+    auroc_im = roc_auc_score(gt_im.ravel(), pr_im.ravel())
+    precisions, recalls, _ = precision_recall_curve(gt_im.ravel(), pr_im.ravel())
+    f1_scores = (2 * precisions * recalls) / (precisions + recalls)
+    f1_im = np.max(f1_scores[np.isfinite(f1_scores)])
+    ap_im = average_precision_score(gt_im, pr_im)
+
+    return ap_im * 100, auroc_im * 100, f1_im * 100
+
+
+def calculate_average_metric(metrics: dict):
+    average = {}
+    for obj, metric in metrics.items():
+        for k, v in metric.items():
+            if k not in average:
+                average[k] = []
+            average[k].append(v)
+
+    for k, v in average.items():
+        average[k] = np.mean(v)
+
+    return average
+
+
+def calculate_metric(results, obj):
+    gt_px = []
+    pr_px = []
+
+    gt_im = []
+    pr_im = []
+
+    for idx in range(len(results['cls_names'])):
+        if results['cls_names'][idx] == obj:
+            gt_px.append(results['imgs_masks'][idx])
+            pr_px.append(results['anomaly_maps'][idx])
+
+            gt_im.append(results['imgs_gts'][idx])
+            pr_im.append(results['anomaly_scores'][idx])
+
+    gt_px = np.array(gt_px)
+    pr_px = np.array(pr_px)
+
+    gt_im = np.array(gt_im)
+    pr_im = np.array(pr_im)
+
+    auroc_px, f1_px, ap_px = calculate_px_metrics(gt_px, pr_px)
+    ap_im, auroc_im, f1_im = calculate_im_metrics(gt_im, pr_im)
+
+    metric = {
+        'auroc_px': auroc_px,
+        'auroc_im': auroc_im,
+        'f1_px': f1_px,
+        'f1_im': f1_im,
+        'ap_px': ap_px,
+        'ap_im': ap_im,
+    }
+
+    return metric
diff --git a/tools/training_tools.py b/tools/training_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c2fb2456b92ac78633792d4e3e61d1f1dd66d95
--- /dev/null
+++ b/tools/training_tools.py
@@ -0,0 +1,67 @@
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+import os
+import random
+import torch
+import numpy as np
+
+
+def setup_seed(seed):
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
+
+
+def setup_paths(args):
+    save_root = args.save_path
+    model_root = os.path.join(save_root, 'models')
+    log_root = os.path.join(save_root, 'logs')
+    csv_root = os.path.join(save_root, 'csvs')
+    image_root = os.path.join(save_root, 'images')
+    tensorboard_root = os.path.join(save_root, 'tensorboard')
+
+    os.makedirs(model_root, exist_ok=True)
+    os.makedirs(log_root, exist_ok=True)
+    os.makedirs(csv_root, exist_ok=True)
+    os.makedirs(image_root, exist_ok=True)
+    os.makedirs(tensorboard_root, exist_ok=True)
+
+    if args.use_hsf:
+        # prepare model name
+        model_name = f'{args.exp_indx}s-pretrained-{args.training_data}-{args.model}-' \
+                     f'{args.prompting_type}-{args.prompting_branch}-' \
+                     f'D{args.prompting_depth}-L{args.prompting_length}-HSF-K{args.k_clusters}'
+    else:
+        # prepare model name
+        model_name = f'{args.exp_indx}s-pretrained-{args.training_data}-{args.model}-' \
+                     f'{args.prompting_type}-{args.prompting_branch}-' \
+                     f'D{args.prompting_depth}-L{args.prompting_length}-WO-HSF'
+
+
+    # prepare model path
+    ckp_path = os.path.join(model_root, model_name)
+
+    # prepare tensorboard dir
+    tensorboard_dir = os.path.join(tensorboard_root, f'{model_name}-{args.testing_data}')
+    if os.path.exists(tensorboard_dir):
+        import shutil
+        shutil.rmtree(tensorboard_dir)
+    tensorboard_logger = SummaryWriter(log_dir=tensorboard_dir)
+
+    # prepare csv path
+    csv_path = os.path.join(csv_root, f'{model_name}-{args.testing_data}.csv')
+
+    # prepare image path
+    image_dir = os.path.join(image_root, f'{model_name}-{args.testing_data}')
+    os.makedirs(image_dir, exist_ok=True)
+
+    # prepare log path
+    log_path = os.path.join(log_root, f'{model_name}-{args.testing_data}.txt')
+
+    return model_name, image_dir, csv_path, log_path, ckp_path, tensorboard_logger
+
+
diff --git a/tools/visualization.py b/tools/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..899e413446f3ba3a91f0a80dc67bdf1816639b1f
--- /dev/null
+++ b/tools/visualization.py
@@ -0,0 +1,117 @@
+import cv2
+import matplotlib
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import seaborn as sns
+
+##
+from sklearn.manifold import TSNE
+from sklearn.decomposition import PCA
+
+##
+import matplotlib.ticker as mtick
+
+
+def plot_sample_cv2(names, imgs, scores_: dict, gts, save_folder=None):
+    os.makedirs(save_folder, exist_ok=True)
+
+    # get subplot number
+    total_number = len(imgs)
+
+    scores = scores_.copy()
+    # normarlisze anomalies
+    for k, v in scores.items():
+        max_value = np.max(v)
+        min_value = np.min(v)
+
+        scores[k] = (scores[k] - min_value) / max_value * 255
+        scores[k] = scores[k].astype(np.uint8)
+    # draw gts
+    mask_imgs = []
+    for idx in range(total_number):
+        gts_ = gts[idx]
+        mask_imgs_ = imgs[idx].copy()
+        mask_imgs_[gts_ > 0.5] = (0, 0, 255)
+        mask_imgs.append(mask_imgs_)
+
+    # save imgs
+    for idx in range(total_number):
+
+        cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_ori.jpg'), imgs[idx])
+        cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_gt.jpg'), mask_imgs[idx])
+
+        for key in scores:
+            heat_map = cv2.applyColorMap(scores[key][idx], cv2.COLORMAP_JET)
+            visz_map = cv2.addWeighted(heat_map, 0.5, imgs[idx], 0.5, 0)
+            cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_{key}.jpg'),
+                        visz_map)
+
+
+
+
+def plot_feat_cv2(names, feat, save_folder=None):
+    # get subplot number
+    total_number = len(feat)
+
+    # save imgs
+    for idx in range(total_number):
+        feat[idx] = cv2.resize(feat[idx], (256, 256), interpolation=cv2.INTER_NEAREST)
+        cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_feat.jpg'), feat[idx])
+
+
+
+valid_feature_visualization_methods = ['TSNE', 'PCA']
+
+def visualize_feature(features, labels, legends, n_components=3, method='TSNE'):
+    assert method in valid_feature_visualization_methods
+    assert n_components in [2, 3]
+
+    if method == 'TSNE':
+        model = TSNE(n_components=n_components)
+    elif method == 'PCA':
+        model = PCA(n_components=n_components)
+
+    else:
+        raise NotImplementedError
+
+    feat_proj = model.fit_transform(features)
+
+    if n_components == 2:
+        ax = scatter_2d(feat_proj, labels)
+    elif n_components == 3:
+        ax = scatter_3d(feat_proj, labels)
+    else:
+        raise NotImplementedError
+
+    plt.legend(legends)
+    plt.axis('off')
+
+
+def scatter_3d(feat_proj, label):
+    plt.clf()
+    ax1 = plt.axes(projection='3d')
+
+    label_unique = np.unique(label)
+
+    for l in label_unique:
+        ax1.scatter3D(feat_proj[label == l, 0],
+                      feat_proj[label == l, 1],
+                      feat_proj[label == l, 2], s=5)
+
+    return ax1
+
+
+def scatter_2d(feat_proj, label):
+    plt.clf()
+    ax1 = plt.axes()
+
+    label_unique = np.unique(label)
+
+    for l in label_unique:
+        ax1.scatter(feat_proj[label == l, 0],
+                    feat_proj[label == l, 1], s=5)
+
+    return ax1
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9227ebf57e9a1904d3784b894bc73b5f80291ae
--- /dev/null
+++ b/train.py
@@ -0,0 +1,184 @@
+import warnings
+warnings.filterwarnings("ignore", category=RuntimeWarning)
+import os
+os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+import argparse
+import json
+import os
+import torch
+
+# Importing from local modules
+from tools import write2csv, setup_paths, setup_seed, log_metrics, Logger
+from dataset import get_data
+from method import AdaCLIP_Trainer
+
+setup_seed(111)
+
+def train(args):
+    # Configurations
+    epochs = args.epoch
+    learning_rate = args.learning_rate
+    batch_size = args.batch_size
+    image_size = args.image_size
+    device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+    save_fig = args.save_fig
+
+    # Set up paths
+    model_name, image_dir, csv_path, log_path, ckp_path, tensorboard_logger = setup_paths(args)
+    # Logger
+    logger = Logger(log_path)
+
+    # Print basic information
+    for key, value in sorted(vars(args).items()):
+        logger.info(f'{key} = {value}')
+
+    logger.info('Model name: {:}'.format(model_name))
+
+    config_path = os.path.join('./model_configs', f'{args.model}.json')
+
+    # Prepare model
+    with open(config_path, 'r') as f:
+        model_configs = json.load(f)
+
+    # Set up the feature hierarchy
+    n_layers = model_configs['vision_cfg']['layers']
+    substage = n_layers // 4
+    features_list = [substage, substage * 2, substage * 3, substage * 4]
+
+    model = AdaCLIP_Trainer(
+        backbone=args.model,
+        feat_list=features_list,
+        input_dim=model_configs['vision_cfg']['width'],
+        output_dim=model_configs['embed_dim'],
+        learning_rate=learning_rate,
+        device=device,
+        image_size=image_size,
+        prompting_depth=args.prompting_depth,
+        prompting_length=args.prompting_length,
+        prompting_branch=args.prompting_branch,
+        prompting_type=args.prompting_type,
+        use_hsf=args.use_hsf,
+        k_clusters=args.k_clusters
+    ).to(device)
+
+    train_data_cls_names, train_data, train_data_root = get_data(
+        dataset_type_list=args.training_data,
+        transform=model.preprocess,
+        target_transform=model.transform,
+        training=True)
+
+    test_data_cls_names, test_data, test_data_root = get_data(
+        dataset_type_list=args.testing_data,
+        transform=model.preprocess,
+        target_transform=model.transform,
+        training=False)
+
+    logger.info('Data Root: training, {:}; testing, {:}'.format(train_data_root, test_data_root))
+
+    train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
+    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
+
+    # Typically, we use MVTec or VisA as the validation set. The best model from this validation
+    # process is then used for zero-shot anomaly detection on novel categories.
+    best_f1 = -1e1
+
+    for epoch in tqdm(range(epochs)):
+        loss = model.train_epoch(train_dataloader)
+
+        # Logs
+        if (epoch + 1) % args.print_freq == 0:
+            logger.info('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, loss))
+            tensorboard_logger.add_scalar('loss', loss, epoch)
+
+        # Validation
+        if (epoch + 1) % args.valid_freq == 0 or (epoch == epochs - 1):
+            if epoch == epochs - 1:
+                save_fig_flag = save_fig
+            else:
+                save_fig_flag = False
+
+            logger.info('=============================Testing ====================================')
+            metric_dict = model.evaluation(
+                test_dataloader,
+                test_data_cls_names,
+                save_fig_flag,
+                image_dir,
+            )
+
+            log_metrics(
+                metric_dict,
+                logger,
+                tensorboard_logger,
+                epoch
+            )
+
+            f1_px = metric_dict['Average']['f1_px']
+
+            # Save best
+            if f1_px > best_f1:
+                for k in metric_dict.keys():
+                    write2csv(metric_dict[k], test_data_cls_names, k, csv_path)
+
+                ckp_path_best = ckp_path + '_best.pth'
+                model.save(ckp_path_best)
+                best_f1 = f1_px
+
+
+
+def str2bool(v):
+    return v.lower() in ("yes", "true", "t", "1")
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser("AdaCLIP", add_help=True)
+
+    # Paths and configurations
+    parser.add_argument("--training_data", type=str, default=["mvtec", "colondb"], nargs='+',
+                        help="Datasets for training (default: ['mvtec', 'colondb'])")
+    parser.add_argument("--testing_data", type=str, default="visa", help="Dataset for testing (default: 'visa')")
+
+    parser.add_argument("--save_path", type=str, default='./workspaces',
+                        help="Directory to save results (default: './workspaces')")
+
+    parser.add_argument("--model", type=str, default="ViT-L-14-336",
+                        choices=["ViT-B-16", "ViT-B-32", "ViT-L-14", "ViT-L-14-336"],
+                        help="The CLIP model to be used (default: 'ViT-L-14-336')")
+
+    parser.add_argument("--save_fig", type=str2bool, default=False,
+                        help="Save figures for visualizations (default: False)")
+    parser.add_argument("--ckt_path", type=str, default='', help="Path to the pre-trained model (default: '')")
+
+    # Hyper-parameters
+    parser.add_argument("--exp_indx", type=int, default=0, help="Index of the experiment (default: 0)")
+    parser.add_argument("--epoch", type=int, default=5, help="Number of epochs (default: 5)")
+    parser.add_argument("--learning_rate", type=float, default=0.01, help="Learning rate (default: 0.01)")
+    parser.add_argument("--batch_size", type=int, default=1, help="Batch size (default: 1)")
+
+    parser.add_argument("--image_size", type=int, default=518, help="Size of the input images (default: 518)")
+    parser.add_argument("--print_freq", type=int, default=1, help="Frequency of print statements (default: 1)")
+    parser.add_argument("--valid_freq", type=int, default=1, help="Frequency of validation (default: 1)")
+
+    # Prompting parameters
+    parser.add_argument("--prompting_depth", type=int, default=4, help="Depth of prompting (default: 4)")
+    parser.add_argument("--prompting_length", type=int, default=5, help="Length of prompting (default: 5)")
+    parser.add_argument("--prompting_type", type=str, default='SD', choices=['', 'S', 'D', 'SD'],
+                        help="Type of prompting. 'S' for Static, 'D' for Dynamic, 'SD' for both (default: 'SD')")
+    parser.add_argument("--prompting_branch", type=str, default='VL', choices=['', 'V', 'L', 'VL'],
+                        help="Branch of prompting. 'V' for Visual, 'L' for Language, 'VL' for both (default: 'VL')")
+
+    parser.add_argument("--use_hsf", type=str2bool, default=True,
+                        help="Use HSF for aggregation. If False, original class embedding is used (default: True)")
+    parser.add_argument("--k_clusters", type=int, default=20, help="Number of clusters (default: 20)")
+
+    args = parser.parse_args()
+
+    train(args)
+
+    if args.batch_size != 1:
+        raise NotImplementedError(
+            "Currently, only batch size of 1 is supported due to unresolved bugs. Please set --batch_size to 1.")
+
+    train(args)
+
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e29cdc774576c3c50e8efef8b564b25de65ffefc
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,19 @@
+gpu_id=0
+
+# Note: Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable.
+# It is recommended to run the training process multiple times and choose the best model based on performance
+# on the validation set as the final model.
+
+# pre-trained on MVtec and colondb
+CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True --training_data mvtec colondb --testing_data visa
+
+# pre-trained on Visa and Clinicdb
+CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec
+
+# This model is pre-trained on all available data to create a powerful Zero-Shot Anomaly Detection (ZSAD) model for demonstration purposes.
+CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True \
+--training_data \
+br35h brain_mri btad clinicdb colondb \
+dagm dtd headct isic mpdd mvtec sdd tn3k visa \
+--testing_data mvtec
+
diff --git a/weights/pretrained_all.pth b/weights/pretrained_all.pth
new file mode 100644
index 0000000000000000000000000000000000000000..9c50f6392089ce48aeee1852d427cd5f814096fe
--- /dev/null
+++ b/weights/pretrained_all.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33e8d3db1cb4aab030866b8b70a46e10aa27ebf2c23b5463cb07f2574addd98c
+size 42673907
diff --git a/weights/pretrained_mvtec_colondb.pth b/weights/pretrained_mvtec_colondb.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f18c1b331ac4352dbdf93ae350a0a4b6651522dc
--- /dev/null
+++ b/weights/pretrained_mvtec_colondb.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be51a42c052bd4cf060e54f503a1f5d0b2a3b899bc8dc2e243042f18b215427e
+size 42673907
diff --git a/weights/pretrained_visa_clinicdb.pth b/weights/pretrained_visa_clinicdb.pth
new file mode 100644
index 0000000000000000000000000000000000000000..2cc80e21556222f04d2f63b3413d2099fe3cf97d
--- /dev/null
+++ b/weights/pretrained_visa_clinicdb.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3deabbbaf1e412cfdfcb42923a500b986f4b9ee96ccbc7a735d89dbc87df44c8
+size 42673907