diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..aac877118c1a947e5c1a3b0d4c12f0ba80f2b39b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,5 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.txt.gz filter=lfs diff=lfs merge=lfs -text +weights/pretrained_all.pth filter=lfs diff=lfs merge=lfs -text +weights/pretrained_mvtec_colondb.pth filter=lfs diff=lfs merge=lfs -text +weights/pretrained_visa_clinicdb.pth filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ef40b4a9c445473cd6cf6409040ff13eebd2c948 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/result/ +/.idea/ +/__pycache__/ +/weights/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e32c0f629673371119307f7620dca662918c5eaa --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Yunkang Cao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 823edc95f950a3a64f4ce6e47055174986499e35..c102e1922c5ff21c1fda70d02a5b9fdd58e5e558 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,176 @@ ---- -title: AdaCLIP -emoji: 🌖 -colorFrom: pink -colorTo: pink -sdk: gradio -sdk_version: 4.38.1 -app_file: app.py -pinned: false -license: mit ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# AdaCLIP (Detecting Anomalies for Novel Categories) +[]() + +> [**ECCV 24**] [**AdaCLIP: Adapting CLIP with Hybrid Learnable Prompts for Zero-Shot Anomaly Detection**](). +> +> by [Yunkang Cao](https://caoyunkang.github.io/), [Jiangning Zhang](https://zhangzjn.github.io/), [Luca Frittoli](https://scholar.google.com/citations?user=cdML_XUAAAAJ), +> [Yuqi Cheng](https://scholar.google.com/citations?user=02BC-WgAAAAJ&hl=en), [Weiming Shen](https://scholar.google.com/citations?user=FuSHsx4AAAAJ&hl=en), [Giacomo Boracchi](https://boracchi.faculty.polimi.it/) +> + +## Introduction +Zero-shot anomaly detection (ZSAD) targets the identification of anomalies within images from arbitrary novel categories. +This study introduces AdaCLIP for the ZSAD task, leveraging a pre-trained vision-language model (VLM), CLIP. +AdaCLIP incorporates learnable prompts into CLIP and optimizes them through training on auxiliary annotated anomaly detection data. +Two types of learnable prompts are proposed: \textit{static} and \textit{dynamic}. Static prompts are shared across all images, serving to preliminarily adapt CLIP for ZSAD. +In contrast, dynamic prompts are generated for each test image, providing CLIP with dynamic adaptation capabilities. +The combination of static and dynamic prompts is referred to as hybrid prompts, and yields enhanced ZSAD performance. +Extensive experiments conducted across 14 real-world anomaly detection datasets from industrial and medical domains indicate that AdaCLIP outperforms other ZSAD methods and can generalize better to different categories and even domains. +Finally, our analysis highlights the importance of diverse auxiliary data and optimized prompts for enhanced generalization capacity. + +## Overview of AdaCLIP + + +## 🛠️ Getting Started + +### Installation +To set up the AdaCLIP environment, follow one of the methods below: + +- Clone this repo: + ```shell + git clone https://github.com/caoyunkang/AdaCLIP.git && cd AdaCLIP + ``` +- You can use our provided installation script for an automated setup:: + ```shell + sh install.sh + ``` +- If you prefer to construct the experimental environment manually, follow these steps: + ```shell + conda create -n AdaCLIP python=3.9.5 -y + conda activate AdaCLIP + pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html + pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4 + pip install gradio # Optional, for app + ``` +- Remember to update the dataset root in config.py according to your preference: + ```python + DATA_ROOT = '../datasets' # Original setting + ``` + +### Dataset Preparation +Please download our processed visual anomaly detection datasets to your `DATA_ROOT` as needed. + +#### Industrial Visual Anomaly Detection Datasets +Note: some links are still in processing... + +| Dataset | Google Drive | Baidu Drive | Task +|------------|------------------|------------------| ------------------| +| MVTec AD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | +| VisA | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | +| MPDD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | +| BTAD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | +| KSDD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | +| DAGM | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | +| DTD-Synthetic | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization | + + + + +#### Medical Visual Anomaly Detection Datasets +| Dataset | Google Drive | Baidu Drive | Task +|------------|------------------|------------------| ------------------| +| HeadCT | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection | +| BrainMRI | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection | +| Br35H | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection | +| ISIC | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization | +| ColonDB | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization | +| ClinicDB | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization | +| TN3K | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization | + +#### Custom Datasets +To use your custom dataset, follow these steps: + +1. Refer to the instructions in `./data_preprocess` to generate the JSON file for your dataset. +2. Use `./dataset/base_dataset.py` to construct your own dataset. + + +### Weight Preparation + +We offer various pre-trained weights on different auxiliary datasets. +Please download the pre-trained weights in `./weights`. + +| Pre-trained Datasets | Google Drive | Baidu Drive +|------------|------------------|------------------| +| MVTec AD & ClinicDB | [Google Drive](https://drive.google.com/file/d/1xVXANHGuJBRx59rqPRir7iqbkYzq45W0/view?usp=drive_link) | [Baidu Drive](链接) | +| VisA & ColonDB | [Google Drive](https://drive.google.com/file/d/1QGmPB0ByPZQ7FucvGODMSz7r5Ke5wx9W/view?usp=drive_link) | [Baidu Drive](链接) | +| All Datasets Mentioned Above | [Google Drive](https://drive.google.com/file/d/1Cgkfx3GAaSYnXPLolx-P7pFqYV0IVzZF/view?usp=drive_link) | [Baidu Drive](链接) | + + +### Train + +By default, we use MVTec AD & ClinicDB for training and VisA for validation: +```shell +CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data mvtec colondb --testing_data visa +``` + + +Alternatively, for evaluation on MVTec AD & ClinicDB, we use VisA & ColonDB for training and MVTec AD for validation. +```shell +CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec +``` +Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable. +It is recommended to run the training process multiple times and choose the best model based on performance +on the validation set as the final model. + + +To construct a robust ZSAD model for demonstration, we also train our AdaCLIP on all AD datasets mentioned above: +```shell +CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True \ +--training_data \ +br35h brain_mri btad clinicdb colondb \ +dagm dtd headct isic mpdd mvtec sdd tn3k visa \ +--testing_data mvtec +``` + +### Test + +Manually select the best models from the validation set and place them in the `weights/` directory. Then, run the following testing script: +```shell +sh test.sh +``` + +If you want to test on a single image, you can refer to `test_single_image.sh`: +```shell +CUDA_VISIBLE_DEVICES=0 python test.py --testing_model image --ckt_path weights/pretrained_all.pth --save_fig True \ + --image_path asset/img.png --class_name candle --save_name test.png +``` + +## Main Results + +Due to differences in versions utilized, the reported performance may vary slightly compared to the detection performance +with the provided pre-trained weights. Some categories may show higher performance while others may show lower. + + + + + +### :page_facing_up: Demo App + +To run the demo application, use the following command: + +```bash +python app.py +``` + + + +## 💘 Acknowledgements +Our work is largely inspired by the following projects. Thanks for their admiring contribution. + +- [VAND-APRIL-GAN](https://github.com/ByChelsea/VAND-APRIL-GAN) +- [AnomalyCLIP](https://github.com/zqhang/AnomalyCLIP) +- [SAA](https://github.com/caoyunkang/Segment-Any-Anomaly) + + +## Stargazers over time +[](https://starchart.cc/caoyunkang/AdaCLIP) + + +## Citation + +If you find this project helpful for your research, please consider citing the following BibTeX entry. + +```BibTex + + + +``` diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..07de436c1ee7d67c5c990c59d5f1aa99354d0f7d --- /dev/null +++ b/app.py @@ -0,0 +1,133 @@ +import gradio as gr +from PIL import Image, ImageDraw, ImageFont +import warnings +import os +os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' +import json +import os +import torch +from scipy.ndimage import gaussian_filter +import cv2 +from method import AdaCLIP_Trainer +import numpy as np + +############ Init Model +ckt_path1 = 'weights/pretrained_mvtec_colondb.pth' +ckt_path2 = "weights/pretrained_visa_clinicdb.pth" +ckt_path3 = 'weights/pretrained_all.pth' + +# Configurations +image_size = 518 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +# device = 'cpu' +model = "ViT-L-14-336" +prompting_depth = 4 +prompting_length = 5 +prompting_type = 'SD' +prompting_branch = 'VL' +use_hsf = True +k_clusters = 20 + +config_path = os.path.join('./model_configs', f'{model}.json') + +# Prepare model +with open(config_path, 'r') as f: + model_configs = json.load(f) + +# Set up the feature hierarchy +n_layers = model_configs['vision_cfg']['layers'] +substage = n_layers // 4 +features_list = [substage, substage * 2, substage * 3, substage * 4] + +model = AdaCLIP_Trainer( + backbone=model, + feat_list=features_list, + input_dim=model_configs['vision_cfg']['width'], + output_dim=model_configs['embed_dim'], + learning_rate=0., + device=device, + image_size=image_size, + prompting_depth=prompting_depth, + prompting_length=prompting_length, + prompting_branch=prompting_branch, + prompting_type=prompting_type, + use_hsf=use_hsf, + k_clusters=k_clusters +).to(device) + + +def process_image(image, text, options): + # Load the model based on selected options + if 'MVTec AD+Colondb' in options: + model.load(ckt_path1) + elif 'VisA+Clinicdb' in options: + model.load(ckt_path2) + elif 'All' in options: + model.load(ckt_path3) + else: + # Default to 'All' if no valid option is provided + model.load(ckt_path3) + print('Invalid option. Defaulting to All.') + + # Ensure image is in RGB mode + image = image.convert('RGB') + + # Convert PIL image to NumPy array + np_image = np.array(image) + + # Convert RGB to BGR for OpenCV + np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR) + np_image = cv2.resize(np_image, (image_size, image_size)) + # Preprocess the image and run the model + img_input = model.preprocess(image).unsqueeze(0) + img_input = img_input.to(model.device) + + with torch.no_grad(): + anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True) + + # Process anomaly map + anomaly_map = anomaly_map[0, :, :].cpu().numpy() + anomaly_score = anomaly_score[0].cpu().numpy() + anomaly_map = gaussian_filter(anomaly_map, sigma=4) + anomaly_map = (anomaly_map * 255).astype(np.uint8) + + # Apply color map and blend with original image + heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) + vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0) + + # Convert OpenCV image back to PIL image for Gradio + vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB)) + + return vis_map_pil, f'{anomaly_score:.3f}' + +# Define examples +examples = [ + ["asset/img.png", "candle", "MVTec AD+Colondb"], + ["asset/img2.png", "bottle", "VisA+Clinicdb"], + ["asset/img3.png", "button", "All"], +] + +# Gradio interface layout +demo = gr.Interface( + fn=process_image, + inputs=[ + gr.Image(type="pil", label="Upload Image"), + gr.Textbox(label="Class Name"), + gr.Radio(["MVTec AD+Colondb", + "VisA+Clinicdb", + "All"], + label="Pre-trained Datasets") + ], + outputs=[ + gr.Image(type="pil", label="Output Image"), + gr.Textbox(label="Anomaly Score"), + ], + examples=examples, + title="AdaCLIP -- Zero-shot Anomaly Detection", + description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection" +) + +# Launch the demo +demo.launch() +# demo.launch(server_name="0.0.0.0", server_port=10002) + diff --git a/asset/Fig_app.png b/asset/Fig_app.png new file mode 100644 index 0000000000000000000000000000000000000000..96df975fe78d631ebba6b2c3b80a7da6be1f17cd --- /dev/null +++ b/asset/Fig_app.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f71ab8be0e45353c1660526ff450754e82ddf4a2b7f18bb5a33ac3b704b0d76b +size 268551 diff --git a/asset/Fig_detection_results.png b/asset/Fig_detection_results.png new file mode 100644 index 0000000000000000000000000000000000000000..bf173edfe3151b4ca788ac70272622ac3971fab4 --- /dev/null +++ b/asset/Fig_detection_results.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c00bd303a99d981d964b12e981bd1f2954d469766839523e76f7d7162fbb24cb +size 363123 diff --git a/asset/Table_industrial.png b/asset/Table_industrial.png new file mode 100644 index 0000000000000000000000000000000000000000..e0327380503f2fc33dceaa67288884eeb7262b21 --- /dev/null +++ b/asset/Table_industrial.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5fa4d9ab1ff1b3ca90b45f4b92ee7b12a89e5327cb22621d4081fb5f160d3d68 +size 401841 diff --git a/asset/Table_medical.png b/asset/Table_medical.png new file mode 100644 index 0000000000000000000000000000000000000000..59970a6f6fec926d3327aeb10537c3cb8d4ddbfb --- /dev/null +++ b/asset/Table_medical.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2424190619dbbd134b943ef9e38a6523635ab0d279f2445da6bdd266d3dafac +size 291004 diff --git a/asset/framework.png b/asset/framework.png new file mode 100644 index 0000000000000000000000000000000000000000..75e65d4138d43f257f0689af81848b2f2634f9d8 --- /dev/null +++ b/asset/framework.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3804c7f5ae141257dbe5dd43cb20f4216a1061051fd8754d6f0c730dd085ad7d +size 439936 diff --git a/asset/img.png b/asset/img.png new file mode 100644 index 0000000000000000000000000000000000000000..f6f91212a115c0e035c608ecae4bde2d936eab83 --- /dev/null +++ b/asset/img.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3eaff97d07132f9b06998737b976d4a0e0a3a2168b40aee43aad6e62d040f87e +size 1421232 diff --git a/asset/img2.png b/asset/img2.png new file mode 100644 index 0000000000000000000000000000000000000000..942e8707b4b3fb2b7c2124130e3f4295102564f0 --- /dev/null +++ b/asset/img2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3918b94553a8922b3c16d064ef73e9062710b35639a949c56d926037e4c0d0a +size 547657 diff --git a/asset/img3.png b/asset/img3.png new file mode 100644 index 0000000000000000000000000000000000000000..08cafcbe6d08ebcc150d11b38760422caf2f91d5 --- /dev/null +++ b/asset/img3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9394757293585aa9de542f3e70025788e5a3e1ad5a1277a8648f8050f8d7e868 +size 624200 diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ab7fc36bc73e27374399253a7d13c7e0f526fb --- /dev/null +++ b/config.py @@ -0,0 +1 @@ +DATA_ROOT = '../datasets' \ No newline at end of file diff --git a/data_preprocess/br35h.py b/data_preprocess/br35h.py new file mode 100644 index 0000000000000000000000000000000000000000..53385ac65df05ad00b7d605e6571bc737f6c0071 --- /dev/null +++ b/data_preprocess/br35h.py @@ -0,0 +1,50 @@ +import os +import json +import random +from config import DATA_ROOT + +Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection') +class Br35hSolver(object): + CLSNAMES = [ + 'br35h', + ] + + def __init__(self, root=Br35h_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + img_names.sort() + + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = Br35hSolver(root=Br35h_ROOT) + runner.run() diff --git a/data_preprocess/brain_mri.py b/data_preprocess/brain_mri.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9a23493ac89df6a683e1c6d7e98c19c5bb1523 --- /dev/null +++ b/data_preprocess/brain_mri.py @@ -0,0 +1,51 @@ +import os +import json +import random +from config import DATA_ROOT + +BrainMRI_ROOT = os.path.join(DATA_ROOT, 'BrainMRI') + +class BrainMRISolver(object): + CLSNAMES = [ + 'brain_mri', + ] + + def __init__(self, root=BrainMRI_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + img_names.sort() + + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = BrainMRISolver(root=BrainMRI_ROOT) + runner.run() diff --git a/data_preprocess/btad.py b/data_preprocess/btad.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb2fa22ad741f86ecab8d9ad5623167688a428c --- /dev/null +++ b/data_preprocess/btad.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +BTAD_ROOT = os.path.join(DATA_ROOT, 'BTech_Dataset_transformed') + +class BTADSolver(object): + CLSNAMES = [ + '01', '02', '03', + ] + + def __init__(self, root=BTAD_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['ok'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = BTADSolver(root=BTAD_ROOT) + runner.run() diff --git a/data_preprocess/clinicdb.py b/data_preprocess/clinicdb.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2c20b7667e48eb7e21ae4ad6b5357d3906e09 --- /dev/null +++ b/data_preprocess/clinicdb.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +ClinicDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ClinicDB') + +class ClinicDBSolver(object): + CLSNAMES = [ + 'ClinicDB', + ] + + def __init__(self, root=ClinicDB_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = ClinicDBSolver(root=ClinicDB_ROOT) + runner.run() diff --git a/data_preprocess/colondb.py b/data_preprocess/colondb.py new file mode 100644 index 0000000000000000000000000000000000000000..2939394077108884d0f63646d90fa0df912c5984 --- /dev/null +++ b/data_preprocess/colondb.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +ColonDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ColonDB') + +class ColonDBSolver(object): + CLSNAMES = [ + 'ColonDB', + ] + + def __init__(self, root=ColonDB_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = ColonDBSolver(root=ColonDB_ROOT) + runner.run() diff --git a/data_preprocess/dagm-pre.py b/data_preprocess/dagm-pre.py new file mode 100644 index 0000000000000000000000000000000000000000..d06889c628fc085e2179350e2f8e17c38e3a1f14 --- /dev/null +++ b/data_preprocess/dagm-pre.py @@ -0,0 +1,82 @@ +import os +import numpy as np +from sklearn.model_selection import train_test_split +import cv2 +import argparse +from config import DATA_ROOT + +dataset_root = os.path.join(DATA_ROOT, 'DAGM2007') + +class_names = os.listdir(dataset_root) + + +for class_name in class_names: + states = os.listdir(os.path.join(dataset_root, class_name)) + for state in states: + images = list() + mask = list() + files = os.listdir(os.path.join(dataset_root, class_name,state)) + for f in files: + if 'PNG' in f[-3:]: + images.append(f) + files = os.listdir(os.path.join(dataset_root, class_name, state,'Label')) + for f in files: + if 'PNG' in f[-3:]: + mask.append(f) + normal_image_path_train = list() + normal_image_path_test = list() + normal_image_path = list() + abnormal_image_path = list() + abnormal_image_label = list() + for f in images: + id = f[-8:-4] + flag = 0 + for y in mask: + if id in y: + abnormal_image_path.append(f) + abnormal_image_label.append(y) + flag = 1 + break + if flag == 0: + normal_image_path.append(f) + + if len(abnormal_image_path) != len(abnormal_image_label): + raise ValueError + length = len(abnormal_image_path) + + normal_image_path_test = normal_image_path[:length] + normal_image_path_train = normal_image_path[length:] + + target_root = '../datasets/DAGM_anomaly_detection' + + train_root = os.path.join(target_root, class_name, 'train','good') + if not os.path.exists(train_root): + os.makedirs(train_root) + for f in normal_image_path_train: + image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) + cv2.imwrite(os.path.join(train_root,f), image_data) + + test_root = os.path.join(target_root, class_name, 'test','good') + if not os.path.exists(test_root): + os.makedirs(test_root) + for f in normal_image_path_test: + image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) + cv2.imwrite(os.path.join(test_root,f), image_data) + + test_root = os.path.join(target_root, class_name, 'test','defect') + if not os.path.exists(test_root): + os.makedirs(test_root) + for f in abnormal_image_path: + image_data = cv2.imread(os.path.join(dataset_root, class_name, state,f)) + cv2.imwrite(os.path.join(test_root,f), image_data) + + test_root = os.path.join(target_root, class_name, 'ground_truth','defect') + if not os.path.exists(test_root): + os.makedirs(test_root) + for f in mask: + image_data = cv2.imread(os.path.join(dataset_root, class_name, state,'Label',f)) + cv2.imwrite(os.path.join(test_root,f), image_data) + + + +print("Done") \ No newline at end of file diff --git a/data_preprocess/dagm.py b/data_preprocess/dagm.py new file mode 100644 index 0000000000000000000000000000000000000000..23aea77fc1a39006057fa8fa6a1d6dc630695ec7 --- /dev/null +++ b/data_preprocess/dagm.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +DAGM_ROOT = os.path.join(DATA_ROOT, 'DAGM_anomaly_detection') + +class DAGMSolver(object): + CLSNAMES = [ + 'Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6','Class7','Class8','Class9','Class10', + ] + + def __init__(self, root=DAGM_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = DAGMSolver(root=DAGM_ROOT) + runner.run() diff --git a/data_preprocess/dtd.py b/data_preprocess/dtd.py new file mode 100644 index 0000000000000000000000000000000000000000..a759416ce36ca6220d1c20d688f66bbb3b165386 --- /dev/null +++ b/data_preprocess/dtd.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +DTD_ROOT = os.path.join(DATA_ROOT, 'DTD-Synthetic') + +class DTDSolver(object): + CLSNAMES = [ + 'Blotchy_099', 'Fibrous_183', 'Marbled_078', 'Matted_069', 'Mesh_114','Perforated_037','Stratified_154','Woven_001','Woven_068','Woven_104','Woven_125','Woven_127', + ] + + def __init__(self, root=DTD_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = DTDSolver(root=DTD_ROOT) + runner.run() diff --git a/data_preprocess/endo.py b/data_preprocess/endo.py new file mode 100644 index 0000000000000000000000000000000000000000..4f44aa23ff1910331c52310e3e6d727a07e26a31 --- /dev/null +++ b/data_preprocess/endo.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +ENDO_ROOT = os.path.join(DATA_ROOT, 'EndoTect') + +class ENDOSolver(object): + CLSNAMES = [ + 'endo', + ] + + def __init__(self, root=ENDO_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = ENDOSolver(root=ENDO_ROOT) + runner.run() diff --git a/data_preprocess/headct-pre.py b/data_preprocess/headct-pre.py new file mode 100644 index 0000000000000000000000000000000000000000..3672393f198d982f47476a036dbcb570f73d0c84 --- /dev/null +++ b/data_preprocess/headct-pre.py @@ -0,0 +1,41 @@ +import os +import numpy as np +from sklearn.model_selection import train_test_split +import shutil +import argparse + +from config import DATA_ROOT + +dataset_root = os.path.join(DATA_ROOT, 'head_ct') + +label_file = os.path.join(dataset_root, 'labels.csv') + +data = np.loadtxt(label_file, dtype=int, delimiter=',', skiprows=1) + +fnames = data[:, 0] +label = data[:, 1] + +normal_fnames = fnames[label==0] +outlier_fnames = fnames[label==1] + + +target_root = '../datasets/HeadCT_anomaly_detection/headct' +train_root = os.path.join(target_root, 'train/good') +if not os.path.exists(train_root): + os.makedirs(train_root) + +test_normal_root = os.path.join(target_root, 'test/good') +if not os.path.exists(test_normal_root): + os.makedirs(test_normal_root) +for f in normal_fnames: + source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f)) + shutil.copy(source, test_normal_root) + +test_outlier_root = os.path.join(target_root, 'test/defect') +if not os.path.exists(test_outlier_root): + os.makedirs(test_outlier_root) +for f in outlier_fnames: + source = os.path.join(dataset_root, 'head_ct/', '{:0>3d}.png'.format(f)) + shutil.copy(source, test_outlier_root) + +print('Done') \ No newline at end of file diff --git a/data_preprocess/headct.py b/data_preprocess/headct.py new file mode 100644 index 0000000000000000000000000000000000000000..0e427dc111e678eb6d82ffc171ac829f5569af41 --- /dev/null +++ b/data_preprocess/headct.py @@ -0,0 +1,52 @@ +import os +import json +import random +# from dataset import MPDD_ROOT +# from dataset.mpdd import MPDD_ROOT + + +HEADCT_ROOT = '../datasets/HeadCT_anomaly_detection' +class HEADCTSolver(object): + CLSNAMES = [ + 'headct', + ] + + def __init__(self, root=HEADCT_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + img_names.sort() + + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = HEADCTSolver(root=HEADCT_ROOT) + runner.run() diff --git a/data_preprocess/isic.py b/data_preprocess/isic.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d23c037d1e72eb4d3a7eb4588f173a09be8ef2 --- /dev/null +++ b/data_preprocess/isic.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +ISIC_ROOT = os.path.join(DATA_ROOT, 'ISIC') + +class ISICSolver(object): + CLSNAMES = [ + 'isic', + ] + + def __init__(self, root=ISIC_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = ISICSolver(root=ISIC_ROOT) + runner.run() diff --git a/data_preprocess/mpdd.py b/data_preprocess/mpdd.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe1e6691856148ae5664b1fd45bb08f359c010f --- /dev/null +++ b/data_preprocess/mpdd.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +MPDD_ROOT = os.path.join(DATA_ROOT, 'MPDD') + +class MPDDSolver(object): + CLSNAMES = [ + 'bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate','tubes', + ] + + def __init__(self, root=MPDD_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = MPDDSolver(root=MPDD_ROOT) + runner.run() diff --git a/data_preprocess/mvtec.py b/data_preprocess/mvtec.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d5fe964dad6f360c8d9f4cfc7723008b5ae620 --- /dev/null +++ b/data_preprocess/mvtec.py @@ -0,0 +1,52 @@ +import os +import json +import random +from dataset import MVTEC_ROOT + +class MVTecSolver(object): + CLSNAMES = [ + 'bottle', 'cable', 'capsule', 'carpet', 'grid', + 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', + 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', + ] + + def __init__(self, root=MVTEC_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = MVTecSolver(root=MVTEC_ROOT) + runner.run() diff --git a/data_preprocess/sdd-pre.py b/data_preprocess/sdd-pre.py new file mode 100644 index 0000000000000000000000000000000000000000..da9dc6ed246a51b912bc3f6cf624307168b8ba03 --- /dev/null +++ b/data_preprocess/sdd-pre.py @@ -0,0 +1,75 @@ +import os +import numpy as np +from sklearn.model_selection import train_test_split +import cv2 +import argparse + +from config import DATA_ROOT + +dataset_root = os.path.join(DATA_ROOT, 'KolektorSDD') + +dirs = os.listdir(dataset_root) +normal_images = list() +normal_labels = list() +normal_fname = list() +outlier_images = list() +outlier_labels = list() +outlier_fname = list() +for d in dirs: + files = os.listdir(os.path.join(dataset_root, d)) + images = list() + for f in files: + if 'jpg' in f[-3:]: + images.append(f) + + for image in images: + split_images = list() + split_labels = list() + image_name = image.split('.')[0] + image_data = cv2.imread(os.path.join(dataset_root, d, image)) + label_data = cv2.imread(os.path.join(dataset_root, d, image_name + '_label.bmp')) + if image_data.shape != label_data.shape: + raise ValueError + image_length = image_data.shape[0] + split_images.append(image_data[:image_length // 3, :, :]) + split_images.append(image_data[image_length // 3:image_length * 2 // 3, :, :]) + split_images.append(image_data[image_length * 2 // 3:, :, :]) + split_labels.append(label_data[:image_length // 3, :, :]) + split_labels.append(label_data[image_length // 3:image_length * 2 // 3, :, :]) + split_labels.append(label_data[image_length * 2 // 3:, :, :]) + for i, (im, la) in enumerate(zip(split_images, split_labels)): + if np.max(la) != 0: + outlier_images.append(im) + outlier_labels.append(la) + outlier_fname.append(d + '_' + image_name + '_' + str(i)) + else: + normal_images.append(im) + normal_labels.append(la) + normal_fname.append(d + '_' + image_name + '_' + str(i)) + +normal_train, normal_test, normal_name_train, normal_name_test = train_test_split(normal_images, normal_fname, test_size=0.25, random_state=42) + +target_root = '../datasets/SDD_anomaly_detection/SDD' +train_root = os.path.join(target_root, 'train/good') +if not os.path.exists(train_root): + os.makedirs(train_root) +for image, name in zip(normal_train, normal_name_train): + cv2.imwrite(os.path.join(train_root, name + '.png'), image) + +test_root = os.path.join(target_root, 'test/good') +if not os.path.exists(test_root): + os.makedirs(test_root) +for image, name in zip(normal_test, normal_name_test): + cv2.imwrite(os.path.join(test_root, name + '.png'), image) + +defect_root = os.path.join(target_root, 'test/defect') +label_root = os.path.join(target_root, 'ground_truth/defect') +if not os.path.exists(defect_root): + os.makedirs(defect_root) +if not os.path.exists(label_root): + os.makedirs(label_root) +for image, label, name in zip(outlier_images, outlier_labels, outlier_fname): + cv2.imwrite(os.path.join(defect_root, name + '.png'), image) + cv2.imwrite(os.path.join(label_root, name + '_mask.png'), label) + +print("Done") \ No newline at end of file diff --git a/data_preprocess/sdd.py b/data_preprocess/sdd.py new file mode 100644 index 0000000000000000000000000000000000000000..a04f543c53c105e756cca79442023a7c5d1699e4 --- /dev/null +++ b/data_preprocess/sdd.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection') + +class SDDSolver(object): + CLSNAMES = [ + 'SDD', + ] + + def __init__(self, root=SDD_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = SDDSolver(root=SDD_ROOT) + runner.run() diff --git a/data_preprocess/tn3k.py b/data_preprocess/tn3k.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc08832244b10074f7b609bc97656aa90ee155a --- /dev/null +++ b/data_preprocess/tn3k.py @@ -0,0 +1,52 @@ +import os +import json +import random +from config import DATA_ROOT + +TN3K_ROOT = os.path.join(DATA_ROOT, 'TN3K') + +class TN3KSolver(object): + CLSNAMES = [ + 'tn3k', + ] + + def __init__(self, root=TN3K_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + info = dict(train={}, test={}) + for cls_name in self.CLSNAMES: + cls_dir = f'{self.root}/{cls_name}' + for phase in ['train', 'test']: + cls_info = [] + species = os.listdir(f'{cls_dir}/{phase}') + for specie in species: + is_abnormal = True if specie not in ['good'] else False + img_names = os.listdir(f'{cls_dir}/{phase}/{specie}') + mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None + img_names.sort() + mask_names.sort() if mask_names is not None else None + for idx, img_name in enumerate(img_names): + info_img = dict( + img_path=f'{cls_name}/{phase}/{specie}/{img_name}', + mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '', + cls_name=cls_name, + specie_name=specie, + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + + info[phase][cls_name] = cls_info + + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + +if __name__ == '__main__': + runner = TN3KSolver(root=TN3K_ROOT) + runner.run() diff --git a/data_preprocess/visa.py b/data_preprocess/visa.py new file mode 100644 index 0000000000000000000000000000000000000000..323614308022ae475d278efced9fe2aed02a63a4 --- /dev/null +++ b/data_preprocess/visa.py @@ -0,0 +1,52 @@ +import os +import json +import pandas as pd +import random +from dataset import VISA_ROOT + +class VisASolver(object): + CLSNAMES = [ + 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', + 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', + 'pcb4', 'pipe_fryum', + ] + + def __init__(self, root=VISA_ROOT, train_ratio=0.5): + self.root = root + self.meta_path = f'{root}/meta.json' + self.phases = ['train', 'test'] + self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0) + self.train_ratio = train_ratio + + def run(self): + self.generate_meta_info() + + def generate_meta_info(self): + columns = self.csv_data.columns # [object, split, label, image, mask] + info = {phase: {} for phase in self.phases} + for cls_name in self.CLSNAMES: + cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name] + for phase in self.phases: + cls_info = [] + cls_data_phase = cls_data[cls_data[columns[1]] == phase] + cls_data_phase.index = list(range(len(cls_data_phase))) + for idx in range(cls_data_phase.shape[0]): + data = cls_data_phase.loc[idx] + is_abnormal = True if data[2] == 'anomaly' else False + info_img = dict( + img_path=data[3], + mask_path=data[4] if is_abnormal else '', + cls_name=cls_name, + specie_name='', + anomaly=1 if is_abnormal else 0, + ) + cls_info.append(info_img) + info[phase][cls_name] = cls_info + with open(self.meta_path, 'w') as f: + f.write(json.dumps(info, indent=4) + "\n") + + + +if __name__ == '__main__': + runner = VisASolver(root=VISA_ROOT) + runner.run() diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..881cb60d70e3fa4de1b1db53a57da50281d709bd --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,68 @@ +from .mvtec import MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT +from .visa import VISA_CLS_NAMES, VisaDataset, VISA_ROOT +from .mpdd import MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT +from .btad import BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT +from .sdd import SDD_CLS_NAMES, SDDDataset, SDD_ROOT +from .dagm import DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT +from .dtd import DTD_CLS_NAMES,DTDDataset,DTD_ROOT +from .isic import ISIC_CLS_NAMES,ISICDataset,ISIC_ROOT +from .colondb import ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT +from .clinicdb import ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT +from .tn3k import TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT +from .headct import HEADCT_CLS_NAMES,HEADCTDataset,HEADCT_ROOT +from .brain_mri import BrainMRI_CLS_NAMES,BrainMRIDataset,BrainMRI_ROOT +from .br35h import Br35h_CLS_NAMES,Br35hDataset,Br35h_ROOT +from torch.utils.data import ConcatDataset + +dataset_dict = { + 'br35h': (Br35h_CLS_NAMES, Br35hDataset, Br35h_ROOT), + 'brain_mri': (BrainMRI_CLS_NAMES, BrainMRIDataset, BrainMRI_ROOT), + 'btad': (BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT), + 'clinicdb': (ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT), + 'colondb': (ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT), + 'dagm': (DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT), + 'dtd': (DTD_CLS_NAMES, DTDDataset, DTD_ROOT), + 'headct': (HEADCT_CLS_NAMES, HEADCTDataset, HEADCT_ROOT), + 'isic': (ISIC_CLS_NAMES, ISICDataset, ISIC_ROOT), + 'mpdd': (MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT), + 'mvtec': (MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT), + 'sdd': (SDD_CLS_NAMES, SDDDataset, SDD_ROOT), + 'tn3k': (TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT), + 'visa': (VISA_CLS_NAMES, VisaDataset, VISA_ROOT), +} + +def get_data(dataset_type_list, transform, target_transform, training): + if not isinstance(dataset_type_list, list): + dataset_type_list = [dataset_type_list] + + dataset_cls_names_list = [] + dataset_instance_list = [] + dataset_root_list = [] + for dataset_type in dataset_type_list: + if dataset_dict.get(dataset_type, ''): + dataset_cls_names, dataset_instance, dataset_root = dataset_dict[dataset_type] + dataset_instance = dataset_instance( + clsnames=dataset_cls_names, + transform=transform, + target_transform=target_transform, + training=training + ) + + dataset_cls_names_list.append(dataset_cls_names) + dataset_instance_list.append(dataset_instance) + dataset_root_list.append(dataset_root) + + else: + print(f'Only support {list(dataset_dict.keys())}, but entered {dataset_type}...') + raise NotImplementedError + + if len(dataset_type_list) > 1: + dataset_instance = ConcatDataset(dataset_instance_list) + dataset_cls_names = dataset_cls_names_list + dataset_root = dataset_root_list + else: + dataset_instance = dataset_instance_list[0] + dataset_cls_names = dataset_cls_names_list[0] + dataset_root = dataset_root_list[0] + + return dataset_cls_names, dataset_instance, dataset_root \ No newline at end of file diff --git a/dataset/__pycache__/__init__.cpython-39.pyc b/dataset/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5826b4eca907aeaeb8c255768b822510e22307d Binary files /dev/null and b/dataset/__pycache__/__init__.cpython-39.pyc differ diff --git a/dataset/__pycache__/br35h.cpython-39.pyc b/dataset/__pycache__/br35h.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bab8c2009806fd165fc4440d2616e7cf4ee1c792 Binary files /dev/null and b/dataset/__pycache__/br35h.cpython-39.pyc differ diff --git a/dataset/__pycache__/brain_mri.cpython-39.pyc b/dataset/__pycache__/brain_mri.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac199549040338b67847dfd6b8019041e9b2b2ef Binary files /dev/null and b/dataset/__pycache__/brain_mri.cpython-39.pyc differ diff --git a/dataset/__pycache__/btad.cpython-39.pyc b/dataset/__pycache__/btad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b807457785f8dd738f50a7a10f47f843895b568d Binary files /dev/null and b/dataset/__pycache__/btad.cpython-39.pyc differ diff --git a/dataset/__pycache__/clinicdb.cpython-39.pyc b/dataset/__pycache__/clinicdb.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3026d398c798b88a7a95801c172d1beca37ae7c0 Binary files /dev/null and b/dataset/__pycache__/clinicdb.cpython-39.pyc differ diff --git a/dataset/__pycache__/colondb.cpython-39.pyc b/dataset/__pycache__/colondb.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97225e60df365de8467297105be18dc38c3fa6a4 Binary files /dev/null and b/dataset/__pycache__/colondb.cpython-39.pyc differ diff --git a/dataset/__pycache__/dagm.cpython-39.pyc b/dataset/__pycache__/dagm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f287d6f437efcc86ef5c6ba0c5c402042d67b510 Binary files /dev/null and b/dataset/__pycache__/dagm.cpython-39.pyc differ diff --git a/dataset/__pycache__/dtd.cpython-39.pyc b/dataset/__pycache__/dtd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..179f7e257b0e93bb15213c833914cc5e70e03632 Binary files /dev/null and b/dataset/__pycache__/dtd.cpython-39.pyc differ diff --git a/dataset/__pycache__/headct.cpython-39.pyc b/dataset/__pycache__/headct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..683659dfff8e6dcd1bb9f0561638521dae6e4505 Binary files /dev/null and b/dataset/__pycache__/headct.cpython-39.pyc differ diff --git a/dataset/__pycache__/isic.cpython-39.pyc b/dataset/__pycache__/isic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..066fe8d419662da76c7df9b1dbb85c9c41691c77 Binary files /dev/null and b/dataset/__pycache__/isic.cpython-39.pyc differ diff --git a/dataset/__pycache__/mpdd.cpython-39.pyc b/dataset/__pycache__/mpdd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b9d3df589fb7f346ef4cd0a4809e7c6efac4498 Binary files /dev/null and b/dataset/__pycache__/mpdd.cpython-39.pyc differ diff --git a/dataset/__pycache__/mvtec.cpython-39.pyc b/dataset/__pycache__/mvtec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0607993c90b7508a55150951f2c1ed59a23f8e35 Binary files /dev/null and b/dataset/__pycache__/mvtec.cpython-39.pyc differ diff --git a/dataset/__pycache__/sdd.cpython-39.pyc b/dataset/__pycache__/sdd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aafa3c9144683ea58bfeb4d8ca5dfaad99f842ab Binary files /dev/null and b/dataset/__pycache__/sdd.cpython-39.pyc differ diff --git a/dataset/__pycache__/tn3k.cpython-39.pyc b/dataset/__pycache__/tn3k.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c1096267ca748afd679365dec62f7835e5e2023 Binary files /dev/null and b/dataset/__pycache__/tn3k.cpython-39.pyc differ diff --git a/dataset/__pycache__/visa.cpython-39.pyc b/dataset/__pycache__/visa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43f6c38bf6ebcf3897a2e93d9b9c226a7a85f49d Binary files /dev/null and b/dataset/__pycache__/visa.cpython-39.pyc differ diff --git a/dataset/base_dataset.py b/dataset/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f09178216a203d9cedafb0f143ce377cbc2fb861 --- /dev/null +++ b/dataset/base_dataset.py @@ -0,0 +1,138 @@ +""" +Base class for our zero-shot anomaly detection dataset +""" +import json +import os +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import cv2 +from config import DATA_ROOT + + +class DataSolver: + def __init__(self, root, clsnames): + self.root = root + self.clsnames = clsnames + self.path = os.path.join(root, 'meta.json') + + def run(self): + with open(self.path, 'r') as f: + info = json.load(f) + + info_required = dict(train={}, test={}) + for cls in self.clsnames: + for k in info.keys(): + info_required[k][cls] = info[k][cls] + + return info_required + + +class BaseDataset(data.Dataset): + def __init__(self, clsnames, transform, target_transform, root, aug_rate=0., training=True): + self.root = root + self.transform = transform + self.target_transform = target_transform + self.aug_rate = aug_rate + self.training = training + self.data_all = [] + self.cls_names = clsnames + + solver = DataSolver(root, clsnames) + meta_info = solver.run() + + self.meta_info = meta_info['test'] # Only utilize the test dataset for both training and testing + for cls_name in self.cls_names: + self.data_all.extend(self.meta_info[cls_name]) + + self.length = len(self.data_all) + + def __len__(self): + return self.length + + def combine_img(self, cls_name): + """ + From April-GAN: https://github.com/ByChelsea/VAND-APRIL-GAN + Here we combine four images into a single image for data augmentation. + """ + img_info = random.sample(self.meta_info[cls_name], 4) + + img_ls = [] + mask_ls = [] + + for data in img_info: + img_path = os.path.join(self.root, data['img_path']) + mask_path = os.path.join(self.root, data['mask_path']) + + img = Image.open(img_path).convert('RGB') + img_ls.append(img) + + if not data['anomaly']: + img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') + else: + img_mask = np.array(Image.open(mask_path).convert('L')) > 0 + img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') + + mask_ls.append(img_mask) + + # Image + image_width, image_height = img_ls[0].size + result_image = Image.new("RGB", (2 * image_width, 2 * image_height)) + for i, img in enumerate(img_ls): + row = i // 2 + col = i % 2 + x = col * image_width + y = row * image_height + result_image.paste(img, (x, y)) + + # Mask + result_mask = Image.new("L", (2 * image_width, 2 * image_height)) + for i, img in enumerate(mask_ls): + row = i // 2 + col = i % 2 + x = col * image_width + y = row * image_height + result_mask.paste(img, (x, y)) + + return result_image, result_mask + + def __getitem__(self, index): + data = self.data_all[index] + img_path = os.path.join(self.root, data['img_path']) + mask_path = os.path.join(self.root, data['mask_path']) + cls_name = data['cls_name'] + anomaly = data['anomaly'] + random_number = random.random() + + if self.training and random_number < self.aug_rate: + img, img_mask = self.combine_img(cls_name) + else: + if img_path.endswith('.tif'): + img = cv2.imread(img_path) + img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + else: + img = Image.open(img_path).convert('RGB') + if anomaly == 0: + img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') + else: + if data['mask_path']: + img_mask = np.array(Image.open(mask_path).convert('L')) > 0 + img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L') + else: + img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L') + # Transforms + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None and img_mask is not None: + img_mask = self.target_transform(img_mask) + if img_mask is None: + img_mask = [] + + return { + 'img': img, + 'img_mask': img_mask, + 'cls_name': cls_name, + 'anomaly': anomaly, + 'img_path': img_path + } diff --git a/dataset/br35h.py b/dataset/br35h.py new file mode 100644 index 0000000000000000000000000000000000000000..9e6b8ebcd0a9871937ddac17064c1bfa4e25acb5 --- /dev/null +++ b/dataset/br35h.py @@ -0,0 +1,18 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://www.kaggle.com/datasets/ahmedhamada0/brain-tumor-detection''' + +Br35h_CLS_NAMES = [ + 'br35h', +] +Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection') + +class Br35hDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=Br35h_CLS_NAMES, aug_rate=0.0, root=Br35h_ROOT, training=True): + super(Br35hDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + diff --git a/dataset/brain_mri.py b/dataset/brain_mri.py new file mode 100644 index 0000000000000000000000000000000000000000..176f656d0b3da3a9b4a3c9f06914559c2ac605ef --- /dev/null +++ b/dataset/brain_mri.py @@ -0,0 +1,16 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://www.kaggle.com/datasets/navoneel/brain-mri-images-for-brain-tumor-detection''' +BrainMRI_CLS_NAMES = [ + 'brain_mri', +] +BrainMRI_ROOT = os.path.join(DATA_ROOT, 'BrainMRI') + +class BrainMRIDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=BrainMRI_CLS_NAMES, aug_rate=0.0, root=BrainMRI_ROOT, training=True): + super(BrainMRIDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) diff --git a/dataset/btad.py b/dataset/btad.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8858fc686d9323e19b71418535fa433020d30f --- /dev/null +++ b/dataset/btad.py @@ -0,0 +1,16 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://avires.dimi.uniud.it/papers/btad/btad.zip''' +BTAD_CLS_NAMES = [ + '01', '02', '03', +] +BTAD_ROOT = os.path.join(DATA_ROOT, 'BTech_Dataset_transformed') + +class BTADDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=BTAD_CLS_NAMES, aug_rate=0.0, root=BTAD_ROOT, training=True): + super(BTADDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) diff --git a/dataset/clinicdb.py b/dataset/clinicdb.py new file mode 100644 index 0000000000000000000000000000000000000000..6376dd67bac837a578ebf723df594986ca004443 --- /dev/null +++ b/dataset/clinicdb.py @@ -0,0 +1,16 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://paperswithcode.com/dataset/cvc-clinicdb''' +ClinicDB_CLS_NAMES = [ + 'ClinicDB', +] +ClinicDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ClinicDB') + +class ClinicDBDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=ClinicDB_CLS_NAMES, aug_rate=0.0, root=ClinicDB_ROOT, training=True): + super(ClinicDBDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) diff --git a/dataset/colondb.py b/dataset/colondb.py new file mode 100644 index 0000000000000000000000000000000000000000..ed70e981681642b1868769ce5de5c1c15f7440a0 --- /dev/null +++ b/dataset/colondb.py @@ -0,0 +1,18 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: http://mv.cvc.uab.es/projects/colon-qa/cvccolondb''' +ColonDB_CLS_NAMES = [ + 'ColonDB', +] +ColonDB_ROOT = os.path.join(DATA_ROOT, 'CVC-ColonDB') + +class ColonDBDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=ColonDB_CLS_NAMES, aug_rate=0.0, root=ColonDB_ROOT, training=True): + super(ColonDBDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + + diff --git a/dataset/dagm.py b/dataset/dagm.py new file mode 100644 index 0000000000000000000000000000000000000000..1b06b31fbd88ecfafe7cbc10a99efcf90b141dc5 --- /dev/null +++ b/dataset/dagm.py @@ -0,0 +1,16 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://hci.iwr.uni-heidelberg.de/content/weakly-supervised-learning-industrial-optical-inspection''' +DAGM_CLS_NAMES = [ + 'Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6','Class7','Class8','Class9','Class10', +] +DAGM_ROOT = os.path.join(DATA_ROOT, 'DAGM_anomaly_detection') + +class DAGMDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=DAGM_CLS_NAMES, aug_rate=0.0, root=DAGM_ROOT, training=True): + super(DAGMDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) diff --git a/dataset/dtd.py b/dataset/dtd.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2fc4eb276ab62e21e7f9adebaff530b89bc477 --- /dev/null +++ b/dataset/dtd.py @@ -0,0 +1,16 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1''' +DTD_CLS_NAMES = [ + 'Blotchy_099', 'Fibrous_183', 'Marbled_078', 'Matted_069', 'Mesh_114','Perforated_037','Stratified_154','Woven_001','Woven_068','Woven_104','Woven_125','Woven_127', +] +DTD_ROOT = os.path.join(DATA_ROOT, 'DTD-Synthetic') + +class DTDDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=DTD_CLS_NAMES, aug_rate=0.0, root=DTD_ROOT, training=True): + super(DTDDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) diff --git a/dataset/headct.py b/dataset/headct.py new file mode 100644 index 0000000000000000000000000000000000000000..f794ab7e8ac87c408088343e701a27f4e6d8630f --- /dev/null +++ b/dataset/headct.py @@ -0,0 +1,18 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://www.kaggle.com/datasets/felipekitamura/head-ct-hemorrhage''' +HEADCT_CLS_NAMES = [ + 'headct', +] +HEADCT_ROOT = os.path.join(DATA_ROOT, 'HeadCT_anomaly_detection') + +class HEADCTDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=HEADCT_CLS_NAMES, aug_rate=0.0, root=HEADCT_ROOT, training=True): + super(HEADCTDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + + diff --git a/dataset/isic.py b/dataset/isic.py new file mode 100644 index 0000000000000000000000000000000000000000..d86492b8d634f8243672f2baede947499816abf8 --- /dev/null +++ b/dataset/isic.py @@ -0,0 +1,18 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://challenge.isic-archive.com/data/''' +ISIC_CLS_NAMES = [ + 'isic', +] +ISIC_ROOT = os.path.join(DATA_ROOT, 'ISIC') + +class ISICDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=ISIC_CLS_NAMES, aug_rate=0.0, root=ISIC_ROOT, training=True): + super(ISICDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + + diff --git a/dataset/mpdd.py b/dataset/mpdd.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca3de11a8205c164f7de896e8f498d2469da7e3 --- /dev/null +++ b/dataset/mpdd.py @@ -0,0 +1,17 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://github.com/stepanje/MPDD''' +MPDD_CLS_NAMES = [ + 'bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate','tubes', +] +MPDD_ROOT = os.path.join(DATA_ROOT, 'MPDD') + +class MPDDDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=MPDD_CLS_NAMES, aug_rate=0.0, root=MPDD_ROOT, training=True): + super(MPDDDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + diff --git a/dataset/mvtec.py b/dataset/mvtec.py new file mode 100644 index 0000000000000000000000000000000000000000..44f0e160ac87cdf57f31f3ed27d9b0bcaf2c5bed --- /dev/null +++ b/dataset/mvtec.py @@ -0,0 +1,19 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://paperswithcode.com/dataset/mvtecad''' + +MVTEC_CLS_NAMES = [ + 'bottle', 'cable', 'capsule', 'carpet', 'grid', + 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', + 'tile', 'toothbrush', 'transistor', 'wood', 'zipper', +] +MVTEC_ROOT = os.path.join(DATA_ROOT, 'mvtec_anomaly_detection') + +class MVTecDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=MVTEC_CLS_NAMES, aug_rate=0.2, root=MVTEC_ROOT, training=True): + super(MVTecDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) diff --git a/dataset/sdd.py b/dataset/sdd.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac6cc938505b3c2a51b2bd7c0230df3aa107dab --- /dev/null +++ b/dataset/sdd.py @@ -0,0 +1,18 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://data.vicos.si/datasets/KSDD/KolektorSDD.zip''' +SDD_CLS_NAMES = [ + 'SDD', +] +SDD_ROOT = os.path.join(DATA_ROOT, 'SDD_anomaly_detection') + + +class SDDDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=SDD_CLS_NAMES, aug_rate=0.0, root=SDD_ROOT, training=True): + super(SDDDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + diff --git a/dataset/tn3k.py b/dataset/tn3k.py new file mode 100644 index 0000000000000000000000000000000000000000..7ab3e972ec83bee01913b587682f53f7f39cd142 --- /dev/null +++ b/dataset/tn3k.py @@ -0,0 +1,18 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://ieeexplore.ieee.org/document/9434087/references#references''' +TN3K_CLS_NAMES = [ + 'tn3k', +] +TN3K_ROOT = os.path.join(DATA_ROOT, 'TN3K') + +class TN3KDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=TN3K_CLS_NAMES, aug_rate=0.0, root=TN3K_ROOT, training=True): + super(TN3KDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + + diff --git a/dataset/visa.py b/dataset/visa.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc9aaa3097e8275d2fccddfb34ed580e446457f --- /dev/null +++ b/dataset/visa.py @@ -0,0 +1,20 @@ +import os +from .base_dataset import BaseDataset +from config import DATA_ROOT + +'''dataset source: https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar''' +VISA_CLS_NAMES = [ + 'candle', 'capsules', 'cashew', 'chewinggum', 'fryum', + 'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3', + 'pcb4', 'pipe_fryum', +] + +VISA_ROOT = os.path.join(DATA_ROOT, 'VisA_20220922') + +class VisaDataset(BaseDataset): + def __init__(self, transform, target_transform, clsnames=VISA_CLS_NAMES, aug_rate=0.0, root=VISA_ROOT, training=True): + super(VisaDataset, self).__init__( + clsnames=clsnames, transform=transform, target_transform=target_transform, + root=root, aug_rate=aug_rate, training=training + ) + diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..e50dca66b1a02486e4787bbb2b3489767b85d4c2 --- /dev/null +++ b/install.sh @@ -0,0 +1,15 @@ +# add dependencies +# python395_cuda113_pytorch1101 +# please change dataset root in ./config.py according to your specifications + +conda create -n AdaCLIP python=3.9.5 -y +conda activate AdaCLIP +pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html +pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4 +pip install gradio + + + + + + diff --git a/loss.py b/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8d62a45677e0fd0f7d2c04d4d94672fe5d93f956 --- /dev/null +++ b/loss.py @@ -0,0 +1,189 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from math import exp + +class FocalLoss(nn.Module): + """ + copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py + This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in + 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' + Focal_Loss= -1*alpha*(1-pt)*log(pt) + :param alpha: (tensor) 3D or 4D the scalar factor for this criterion + :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more + focus on hard misclassified example + :param smooth: (float,double) smooth value when cross entropy + :param balance_index: (int) balance class index, should be specific when alpha is float + :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. + """ + + def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): + super(FocalLoss, self).__init__() + self.apply_nonlin = apply_nonlin + self.alpha = alpha + self.gamma = gamma + self.balance_index = balance_index + self.smooth = smooth + self.size_average = size_average + + if self.smooth is not None: + if self.smooth < 0 or self.smooth > 1.0: + raise ValueError('smooth value should be in [0,1]') + + def forward(self, logit, target): + if self.apply_nonlin is not None: + logit = self.apply_nonlin(logit) + num_class = logit.shape[1] + + if logit.dim() > 2: + # N,C,d1,d2 -> N,C,m (m=d1*d2*...) + logit = logit.view(logit.size(0), logit.size(1), -1) + logit = logit.permute(0, 2, 1).contiguous() + logit = logit.view(-1, logit.size(-1)) + target = torch.squeeze(target, 1) + target = target.view(-1, 1) + alpha = self.alpha + + if alpha is None: + alpha = torch.ones(num_class, 1) + elif isinstance(alpha, (list, np.ndarray)): + assert len(alpha) == num_class + alpha = torch.FloatTensor(alpha).view(num_class, 1) + alpha = alpha / alpha.sum() + elif isinstance(alpha, float): + alpha = torch.ones(num_class, 1) + alpha = alpha * (1 - self.alpha) + alpha[self.balance_index] = self.alpha + + else: + raise TypeError('Not support alpha type') + + if alpha.device != logit.device: + alpha = alpha.to(logit.device) + + idx = target.cpu().long() + + one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() + one_hot_key = one_hot_key.scatter_(1, idx, 1) + if one_hot_key.device != logit.device: + one_hot_key = one_hot_key.to(logit.device) + + if self.smooth: + one_hot_key = torch.clamp( + one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth) + pt = (one_hot_key * logit).sum(1) + self.smooth + logpt = pt.log() + + gamma = self.gamma + + alpha = alpha[idx] + alpha = torch.squeeze(alpha) + loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt + + if self.size_average: + loss = loss.mean() + return loss + + +class BinaryDiceLoss(nn.Module): + def __init__(self): + super(BinaryDiceLoss, self).__init__() + + def forward(self, input, targets): + # 获取每个批次的大小 N + N = targets.size()[0] + # 平滑变量 + smooth = 1 + # 将宽高 reshape 到同一纬度 + input_flat = input.view(N, -1) + targets_flat = targets.view(N, -1) + + # 计算交集 + intersection = input_flat * targets_flat + N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth) + # 计算一个批次中平均每张图的损失 + loss = 1 - N_dice_eff.sum() / N + return loss + + + + +class ConADLoss(nn.Module): + """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + It also supports the unsupervised contrastive loss in SimCLR""" + def __init__(self, contrast_mode='all',random_anchors=10): + super(ConADLoss, self).__init__() + assert contrast_mode in ['all', 'mean', 'random'] + self.contrast_mode = contrast_mode + self.random_anchors = random_anchors + def forward(self, features, labels): + """Compute loss for model. If both `labels` and `mask` are None, + it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + + Args: + features: hidden vector of shape [bsz, C, ...]. + labels: ground truth of shape [bsz, 1, ...]., where 1 denotes to abnormal, and 0 denotes to normal + Returns: + A loss scalar. + """ + device = (torch.device('cuda') + if features.is_cuda + else torch.device('cpu')) + if len(features.shape) != len(labels.shape): + raise ValueError('`features` needs to have the same dimensions with labels') + + if len(features.shape) < 3: + raise ValueError('`features` needs to be [bsz, C, ...],' + 'at least 3 dimensions are required') + + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + labels = labels.view(labels.shape[0], labels.shape[1], -1) + + labels = labels.squeeze() + batch_size = features.shape[0] + + C = features.shape[1] + normal_feats = features[:, :, labels == 0] + abnormal_feats = features[:, :, labels == 1] + + normal_feats = normal_feats.permute((1, 0, 2)).contiguous().view(C, -1) + abnormal_feats = abnormal_feats.permute((1, 0, 2)).contiguous().view(C, -1) + + contrast_count = normal_feats.shape[1] + contrast_feature = normal_feats + + if self.contrast_mode == 'mean': + anchor_feature = torch.mean(normal_feats, dim=1) + anchor_feature = F.normalize(anchor_feature, dim=0, p=2) + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + elif self.contrast_mode == 'random': + dim_to_sample = 1 + num_samples = min(self.random_anchors, contrast_count) + permuted_indices = torch.randperm(normal_feats.size(dim_to_sample)).to(normal_feats.device) + selected_indices = permuted_indices[:num_samples] + anchor_feature = normal_feats.index_select(dim_to_sample, selected_indices) + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + # maximize similarity + anchor_dot_normal = torch.matmul(anchor_feature.T, normal_feats).mean() + + # minimize similarity + anchor_dot_abnormal = torch.matmul(anchor_feature.T, abnormal_feats).mean() + + loss = 0 + if normal_feats.shape[1] > 0: + loss -= anchor_dot_normal + if abnormal_feats.shape[1] > 0: + loss += anchor_dot_abnormal + + loss = torch.exp(loss) + + return loss diff --git a/method/__init__.py b/method/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f71b6f127b2fa66665b05d817a59acc5cfd0515 --- /dev/null +++ b/method/__init__.py @@ -0,0 +1,2 @@ +from .adaclip import AdaCLIP +from .trainer import AdaCLIP_Trainer \ No newline at end of file diff --git a/method/__pycache__/__init__.cpython-39.pyc b/method/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa53ec8b4a1a2607a3b2c07956ea589bbc668e1e Binary files /dev/null and b/method/__pycache__/__init__.cpython-39.pyc differ diff --git a/method/__pycache__/adaclip.cpython-39.pyc b/method/__pycache__/adaclip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41a4c03c170823cf9ccad9032d27811f9d4d626c Binary files /dev/null and b/method/__pycache__/adaclip.cpython-39.pyc differ diff --git a/method/__pycache__/clip_model.cpython-39.pyc b/method/__pycache__/clip_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..693c0004c2f95a435b52f398038b84c1b82c9326 Binary files /dev/null and b/method/__pycache__/clip_model.cpython-39.pyc differ diff --git a/method/__pycache__/custom_clip.cpython-39.pyc b/method/__pycache__/custom_clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa77d23eb1a9669441eebe3ad99cc0c122b9873 Binary files /dev/null and b/method/__pycache__/custom_clip.cpython-39.pyc differ diff --git a/method/__pycache__/simple_tokenizer.cpython-39.pyc b/method/__pycache__/simple_tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73226b9ef5d352a6ce72af6816d6e0c8515a6d9d Binary files /dev/null and b/method/__pycache__/simple_tokenizer.cpython-39.pyc differ diff --git a/method/__pycache__/tokenizer.cpython-39.pyc b/method/__pycache__/tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cad49a5cae4bc7d2a6096670b2f0d8350e5af57f Binary files /dev/null and b/method/__pycache__/tokenizer.cpython-39.pyc differ diff --git a/method/__pycache__/trainer.cpython-39.pyc b/method/__pycache__/trainer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de317be3f5386b412ca19af7f68c5196222a4860 Binary files /dev/null and b/method/__pycache__/trainer.cpython-39.pyc differ diff --git a/method/__pycache__/transformer.cpython-39.pyc b/method/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..408c51ea67b5cad826a5f5a086731c276234e1c9 Binary files /dev/null and b/method/__pycache__/transformer.cpython-39.pyc differ diff --git a/method/__pycache__/utils.cpython-39.pyc b/method/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c97e0a12775ae2d688ecab5baf43db849043003 Binary files /dev/null and b/method/__pycache__/utils.cpython-39.pyc differ diff --git a/method/adaclip.py b/method/adaclip.py new file mode 100644 index 0000000000000000000000000000000000000000..2240e7887dcd9517ad9fc575ca24adde25b251d0 --- /dev/null +++ b/method/adaclip.py @@ -0,0 +1,583 @@ +from typing import Union, List, Optional +import numpy as np +import torch +from pkg_resources import packaging +from torch import nn +from torch.nn import functional as F +from .clip_model import CLIP +from .simple_tokenizer import SimpleTokenizer as _Tokenizer +from sklearn.cluster import KMeans + +class ProjectLayer(nn.Module): + def __init__(self, input_dim, output_dim, num_replicas, stack=False, is_array=True): + super(ProjectLayer, self).__init__() + + self.head = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_replicas)]) + self.num_replicas = num_replicas + self.stack = stack + self.is_array = is_array + + def forward(self, tokens): + out_tokens = [] + for i in range(self.num_replicas): + if self.is_array: + temp = self.head[i](tokens[i][:, 1:, :]) # for ViT, we exclude the class token and only extract patch tokens here. + else: + temp = self.head[i](tokens) + + out_tokens.append(temp) + + if self.stack: + out_tokens = torch.stack(out_tokens, dim=1) + + return out_tokens + +class PromptLayer(nn.Module): + def __init__(self, channel, length, depth, is_text, prompting_type, enabled=True): + super(PromptLayer, self).__init__() + + self.channel = channel + self.length = length + self.depth = depth + self.is_text = is_text + self.enabled = enabled + + self.prompting_type = prompting_type + + if self.enabled: # only when enabled, the parameters should be constructed + if 'S' in prompting_type: # static prompts + # learnable + self.static_prompts = nn.ParameterList( + [nn.Parameter(torch.empty(self.length, self.channel)) + for _ in range(self.depth)]) + + for single_para in self.static_prompts: + nn.init.normal_(single_para, std=0.02) + + if 'D' in prompting_type: # dynamic prompts + self.dynamic_prompts = [0.] # place holder + + def set_dynamic_prompts(self, dynamic_prompts): + self.dynamic_prompts = dynamic_prompts + + def forward_text(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None): + if self.enabled: + length = self.length + + # only prompt the first J layers + if indx < self.depth: + if 'S' in self.prompting_type and 'D' in self.prompting_type: # both + static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) + textual_context = self.dynamic_prompts + static_prompts + elif 'S' in self.prompting_type: # static + static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) + textual_context = static_prompts + elif 'D' in self.prompting_type: # dynamic + textual_context = self.dynamic_prompts + else: + print('You should at least choose one type of prompts when the prompting branches are not none.') + raise NotImplementedError + + if indx == 0: # for the first layer + x = x + else: + if indx < self.depth: # replace with learnalbe tokens + prefix = x[:1, :, :] + suffix = x[1 + length:, :, :] + textual_context = textual_context.permute(1, 0, 2).half() + x = torch.cat([prefix, textual_context, suffix], dim=0) + else: # keep the same + x = x + else: + x = x + + x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask) + + return x, attn_tmp + + def forward_visual(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None): + if self.enabled: + length = self.length + + # only prompt the first J layers + if indx < self.depth: + if 'S' in self.prompting_type and 'D' in self.prompting_type: # both + static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) + visual_context = self.dynamic_prompts + static_prompts + elif 'S' in self.prompting_type: # static + static_prompts = self.static_prompts[indx].unsqueeze(0).expand(x.shape[1], -1, -1) + visual_context = static_prompts + elif 'D' in self.prompting_type: # dynamic + visual_context = self.dynamic_prompts + else: + print('You should at least choose one type of prompts when the prompting branches are not none.') + raise NotImplementedError + + + if indx == 0: # for the first layer + visual_context = visual_context.permute(1, 0, 2).half() + x = torch.cat([x, visual_context], dim=0) + else: + if indx < self.depth: # replace with learnalbe tokens + prefix = x[0:x.shape[0] - length, :, :] + visual_context = visual_context.permute(1, 0, 2).half() + x = torch.cat([prefix, visual_context], dim=0) + else: # keep the same + x = x + else: + x = x + + x, attn_tmp = resblock(q_x=x, k_x=k_x, v_x= v_x, attn_mask=attn_mask) + + if self.enabled: + tokens = x[:x.shape[0] - length, :, :] + else: + tokens = x + + return x, tokens, attn_tmp + + def forward(self, resblock, indx, x, k_x=None, v_x=None, attn_mask: Optional[torch.Tensor] = None): + if self.is_text: + return self.forward_text(resblock, indx, x, k_x, v_x, attn_mask) + else: + return self.forward_visual(resblock, indx, x, k_x, v_x, attn_mask) + + +class TextEmbebddingLayer(nn.Module): + def __init__(self, fixed): + super(TextEmbebddingLayer, self).__init__() + self.tokenizer = _Tokenizer() + self.ensemble_text_features = {} + self.prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', + '{} without defect', + '{} without damage'] + self.prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage'] + self.prompt_state = [self.prompt_normal, self.prompt_abnormal] + self.prompt_templates = ['a bad photo of a {}.', + 'a low resolution photo of the {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + ] + self.fixed = fixed + + def tokenize(self, texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[ + torch.IntTensor, torch.LongTensor]: + if isinstance(texts, str): + texts = [texts] + + sot_token = self.tokenizer.encoder["<|startoftext|>"] + eot_token = self.tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + ## TODO: text layeer is not compitable with multiple batches... + def forward(self, model, texts, device): + text_feature_list = [] + + for indx, text in enumerate(texts): + + if self.fixed: + if self.ensemble_text_features.get(text) is None: + text_features = self.encode_text(model, text, device) + self.ensemble_text_features[text] = text_features + else: + text_features = self.ensemble_text_features[text] + else: + text_features = self.encode_text(model, text, device) + self.ensemble_text_features[text] = text_features + + text_feature_list.append(text_features) + + text_features = torch.stack(text_feature_list, dim=0) + text_features = F.normalize(text_features, dim=1) + + return text_features + + def encode_text(self, model, text, device): + text_features = [] + for i in range(len(self.prompt_state)): + text = text.replace('-', ' ') + prompted_state = [state.format(text) for state in self.prompt_state[i]] + prompted_sentence = [] + for s in prompted_state: + for template in self.prompt_templates: + prompted_sentence.append(template.format(s)) + prompted_sentence = self.tokenize(prompted_sentence, context_length=77).to(device) + + class_embeddings = model.encode_text(prompted_sentence) + + class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) + class_embedding = class_embeddings.mean(dim=0) + class_embedding /= class_embedding.norm() + text_features.append(class_embedding) + + text_features = torch.stack(text_features, dim=1) + + return text_features + + +class HybridSemanticFusion(nn.Module): + def __init__(self, k_clusters): + super(HybridSemanticFusion, self).__init__() + self.k_clusters = k_clusters + self.n_aggregate_patch_tokens = k_clusters * 5 + self.cluster_performer = KMeans(n_clusters=self.k_clusters, n_init="auto") + + # @torch.no_grad() + def forward(self, patch_tokens: list, anomaly_maps: list): + anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1) + anomaly_map = torch.softmax(anomaly_map, dim=2)[:, :, 1] # B, L + + # extract most abnormal feats + selected_abnormal_tokens = [] + k = min(anomaly_map.shape[1], self.n_aggregate_patch_tokens) + top_k_indices = torch.topk(anomaly_map, k=k, dim=1).indices + for layer in range(len(patch_tokens)): + selected_tokens = patch_tokens[layer]. \ + gather(dim=1, index=top_k_indices.unsqueeze(-1). + expand(-1, -1, patch_tokens[layer].shape[-1])) + selected_abnormal_tokens.append(selected_tokens) + + # use kmeans to extract these centriods + # Stack the data_preprocess + stacked_data = torch.cat(selected_abnormal_tokens, dim=2) + + batch_cluster_centers = [] + # Perform K-Means clustering + for b in range(stacked_data.shape[0]): + cluster_labels = self.cluster_performer.fit_predict(stacked_data[b, :, :].detach().cpu().numpy()) + + # Initialize a list to store the cluster centers + cluster_centers = [] + + # Extract cluster centers for each cluster + for cluster_id in range(self.k_clusters): + collected_cluster_data = [] + for abnormal_tokens in selected_abnormal_tokens: + cluster_data = abnormal_tokens[b, :, :][cluster_labels == cluster_id] + collected_cluster_data.append(cluster_data) + collected_cluster_data = torch.cat(collected_cluster_data, dim=0) + cluster_center = torch.mean(collected_cluster_data, dim=0, keepdim=True) + cluster_centers.append(cluster_center) + + # Normalize the cluster centers + cluster_centers = torch.cat(cluster_centers, dim=0) + cluster_centers = torch.mean(cluster_centers, dim=0) + batch_cluster_centers.append(cluster_centers) + + batch_cluster_centers = torch.stack(batch_cluster_centers, dim=0) + batch_cluster_centers = F.normalize(batch_cluster_centers, dim=1) + + return batch_cluster_centers + + # # preprocess + # # compute the anomaly map + # anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1) + # anomaly_map = torch.softmax(anomaly_map, dim=2)[:, :, 1] # B, L + # + # # compute the average multi-hierarchy patch embeddings + # avg_patch_tokens = torch.mean(torch.stack(patch_tokens, dim=0), dim=0) # B, L, C + # + # # Initialize a list to store the centroids of clusters with the largest anomaly scores + # cluster_centroids = [] + # + # # loop across the batch size + # for b in range(avg_patch_tokens.shape[0]): + # # step1: group features into clusters + # cluster_labels = self.cluster_performer.fit_predict(avg_patch_tokens[b, :, :].detach().cpu().numpy()) + # + # # step2: compute the anomaly scores for individual clusters via the anomaly map + # # Convert cluster labels back to tensor + # cluster_labels = torch.tensor(cluster_labels).to(avg_patch_tokens.device) + # cluster_anomaly_scores = {} + # for label in torch.unique(cluster_labels): + # cluster_indices = torch.where(cluster_labels == label)[0] + # cluster_anomaly_scores[label.item()] = anomaly_map[b, cluster_indices].mean().item() + # + # # step3: select the cluster with the largest anomaly score and then compute its centroid by averaging the + # # corresponding avg_patch_tokens + # # Find the cluster with the largest anomaly score + # largest_anomaly_cluster = max(cluster_anomaly_scores, key=cluster_anomaly_scores.get) + # + # # Get the indices of the tokens belonging to the largest anomaly cluster + # largest_anomaly_cluster_indices = torch.where(cluster_labels == largest_anomaly_cluster)[0] + # + # # Compute the centroid of the largest anomaly cluster by averaging the corresponding avg_patch_tokens + # centroid = avg_patch_tokens[b, largest_anomaly_cluster_indices, :].mean(dim=0) + # + # # Append the centroid to the list of cluster centroids + # cluster_centroids.append(centroid) + # + # # Convert the list of centroids to a tensor + # cluster_centroids = torch.stack(cluster_centroids, dim=0) + # cluster_centroids = F.normalize(cluster_centroids, dim=1) + + # return cluster_centroids + +class AdaCLIP(nn.Module): + def __init__(self, freeze_clip: CLIP, text_channel: int, visual_channel: int, + prompting_length: int, prompting_depth: int, prompting_branch: str, prompting_type: str, + use_hsf: bool, k_clusters: int, + output_layers: list, device: str, image_size: int): + super(AdaCLIP, self).__init__() + self.freeze_clip = freeze_clip + + self.visual = self.freeze_clip.visual + self.transformer = self.freeze_clip.transformer + self.token_embedding = self.freeze_clip.token_embedding + self.positional_embedding = self.freeze_clip.positional_embedding + self.ln_final = self.freeze_clip.ln_final + self.text_projection = self.freeze_clip.text_projection + self.attn_mask = self.freeze_clip.attn_mask + + self.output_layers = output_layers + + self.prompting_branch = prompting_branch + self.prompting_type = prompting_type + self.prompting_depth = prompting_depth + self.prompting_length = prompting_length + self.use_hsf = use_hsf + self.k_clusters = k_clusters + + if 'L' in self.prompting_branch: + self.enable_text_prompt = True + else: + self.enable_text_prompt = False + + if 'V' in self.prompting_branch: + self.enable_visual_prompt = True + else: + self.enable_visual_prompt = False + + self.text_embedding_layer = TextEmbebddingLayer(fixed=(not self.enable_text_prompt)) + self.text_prompter = PromptLayer(text_channel, prompting_length, prompting_depth, is_text=True, + prompting_type=prompting_type, + enabled=self.enable_text_prompt) + self.visual_prompter = PromptLayer(visual_channel, prompting_length, prompting_depth, is_text=False, + prompting_type=prompting_type, + enabled=self.enable_visual_prompt) + + self.patch_token_layer = ProjectLayer( + visual_channel, + text_channel, + len(output_layers), stack=False, is_array=True + ) + + self.cls_token_layer = ProjectLayer( + text_channel, + text_channel, + 1, stack=False, is_array=False + ) + + if 'D' in self.prompting_type: # dynamic prompts + self.dynamic_visual_prompt_generator = ProjectLayer(text_channel, + visual_channel, + prompting_length, + stack=True, + is_array=False) + self.dynamic_text_prompt_generator = ProjectLayer(text_channel, + text_channel, + prompting_length, + stack=True, + is_array=False) + + if self.use_hsf: + self.HSF = HybridSemanticFusion(k_clusters) + + self.image_size = image_size + self.device = device + + def generate_and_set_dynamic_promtps(self, image): + with torch.no_grad(): + # extract image features + image_features, _ = self.visual.forward(image, self.output_layers) + + dynamic_visual_prompts = self.dynamic_visual_prompt_generator(image_features) + dynamic_text_prompts = self.dynamic_text_prompt_generator(image_features) + + self.visual_prompter.set_dynamic_prompts(dynamic_visual_prompts) + self.text_prompter.set_dynamic_prompts(dynamic_text_prompts) + + + def encode_image(self, image): + + x = image + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.visual.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], + self.visual.grid_size[0], + self.visual.patch_size[0], + self.visual.grid_size[1], + self.visual.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.visual.grid_size[0] * self.visual.grid_size[1], -1) + x = self.visual.patchnorm_pre_ln(x) + x = self.visual.conv1(x) + else: + x = self.visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.visual.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + + x = x + self.visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.visual.patch_dropout(x) + x = self.visual.ln_pre(x) + + patch_embedding = x + + x = x.permute(1, 0, 2) # NLD -> LND + + patch_tokens = [] + + for indx, r in enumerate(self.visual.transformer.resblocks): + x, tokens, attn_tmp = self.visual_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=None) + + if (indx + 1) in self.output_layers: + patch_tokens.append(tokens) + + x = x.permute(1, 0, 2) # LND -> NLD + patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD + + if self.visual.attn_pool is not None: + x = self.visual.attn_pool(x) + x = self.visual.ln_post(x) + pooled, tokens = self.visual._global_pool(x) + else: + pooled, tokens = self.visual._global_pool(x) + pooled = self.visual.ln_post(pooled) + + if self.visual.proj is not None: + pooled = pooled @ self.visual.proj + + return pooled, patch_tokens, patch_embedding + + def proj_visual_tokens(self, image_features, patch_tokens): + + # for patch tokens + proj_patch_tokens = self.patch_token_layer(patch_tokens) + for layer in range(len(proj_patch_tokens)): + proj_patch_tokens[layer] /= proj_patch_tokens[layer].norm(dim=-1, keepdim=True) + + # for cls tokens + proj_cls_tokens = self.cls_token_layer(image_features)[0] + proj_cls_tokens /= proj_cls_tokens.norm(dim=-1, keepdim=True) + + return proj_cls_tokens, proj_patch_tokens + + def encode_text(self, text): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + + for indx, r in enumerate(self.transformer.resblocks): + # add prompt here + x, attn_tmp = self.text_prompter(r, indx, x, k_x=None, v_x=None, attn_mask=self.attn_mask) + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x + + def visual_text_similarity(self, image_feature, patch_token, text_feature, aggregation): + anomaly_maps = [] + + for layer in range(len(patch_token)): + anomaly_map = (100.0 * patch_token[layer] @ text_feature) + anomaly_maps.append(anomaly_map) + + if self.use_hsf: + alpha = 0.2 + clustered_feature = self.HSF.forward(patch_token, anomaly_maps) + # aggregate the class token and the clustered features for more comprehensive information + cur_image_feature = alpha * clustered_feature + (1 - alpha) * image_feature + cur_image_feature = F.normalize(cur_image_feature, dim=1) + else: + cur_image_feature = image_feature + + anomaly_score = (100.0 * cur_image_feature.unsqueeze(1) @ text_feature) + anomaly_score = anomaly_score.squeeze(1) + anomaly_score = torch.softmax(anomaly_score, dim=1) + + # NOTE: this bilinear interpolation is not unreproducible and may occasionally lead to unstable ZSAD performance. + for i in range(len(anomaly_maps)): + B, L, C = anomaly_maps[i].shape + H = int(np.sqrt(L)) + anomaly_maps[i] = anomaly_maps[i].permute(0, 2, 1).view(B, 2, H, H) + anomaly_maps[i] = F.interpolate(anomaly_maps[i], size=self.image_size, mode='bilinear', align_corners=True) + + if aggregation: # in the test stage, we firstly aggregate logits from all hierarchies and then do the softmax normalization + anomaly_map = torch.mean(torch.stack(anomaly_maps, dim=1), dim=1) + anomaly_map = torch.softmax(anomaly_map, dim=1) + anomaly_map = (anomaly_map[:, 1:, :, :] + 1 - anomaly_map[:, 0:1, :, :]) / 2.0 + anomaly_score = anomaly_score[:, 1] + return anomaly_map, anomaly_score + else: # otherwise, we do the softmax normalization for individual hierarchies + for i in range(len(anomaly_maps)): + anomaly_maps[i] = torch.softmax(anomaly_maps[i], dim=1) + return anomaly_maps, anomaly_score + + def extract_feat(self, image, cls_name): + if 'D' in self.prompting_type: + self.generate_and_set_dynamic_promtps(image) # generate and set dynamic prompts for corresponding prompters + + if self.enable_visual_prompt: + image_features, patch_tokens, _ = self.encode_image(image) + else: + with torch.no_grad(): + image_features, patch_tokens, _ = self.encode_image(image) + + if self.enable_text_prompt: + text_features = self.text_embedding_layer(self, cls_name, self.device) + else: + with torch.no_grad(): + text_features = self.text_embedding_layer(self, cls_name, self.device) + + proj_cls_tokens, proj_patch_tokens = self.proj_visual_tokens(image_features, patch_tokens) + + return proj_cls_tokens, proj_patch_tokens, text_features + + @torch.cuda.amp.autocast() + def forward(self, image, cls_name, aggregation=True): + # extract features for images and texts + image_features, patch_tokens, text_features = self.extract_feat(image, cls_name) + anomaly_map, anomaly_score = self.visual_text_similarity(image_features, patch_tokens, text_features, aggregation) + + if aggregation: + anomaly_map = anomaly_map # tensor + anomaly_score = anomaly_score + anomaly_map = anomaly_map.squeeze(1) + + return anomaly_map, anomaly_score + else: + anomaly_maps = anomaly_map # list + anomaly_score = anomaly_score + + return anomaly_maps, anomaly_score + diff --git a/method/bpe_simple_vocab_16e6.txt.gz b/method/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/method/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/method/clip_model.py b/method/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..803ce8ae95819d0b2cf84aa6d20152c4470fac97 --- /dev/null +++ b/method/clip_model.py @@ -0,0 +1,412 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + output_tokens: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer + ) + + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, out_layers): + + x = image + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.visual.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], + self.visual.grid_size[0], + self.visual.patch_size[0], + self.visual.grid_size[1], + self.visual.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.visual.grid_size[0] * self.visual.grid_size[1], -1) + x = self.visual.patchnorm_pre_ln(x) + x = self.visual.conv1(x) + else: + x = self.visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + + # class embeddings and positional embeddings + x = torch.cat( + [self.visual.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + + x = x + self.visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.visual.patch_dropout(x) + x = self.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + patch_tokens = [] + + idx = 0 + for r in self.visual.transformer.resblocks: + idx += 1 + # add prompt here + x, attn_tmp = r(x, attn_mask=None) + if idx in out_layers: + patch_tokens.append(x) + + x = x.permute(1, 0, 2) # LND -> NLD + patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD + + if self.visual.attn_pool is not None: + x = self.visual.attn_pool(x) + x = self.visual.ln_post(x) + pooled, tokens = self.visual._global_pool(x) + else: + pooled, tokens = self.visual._global_pool(x) + pooled = self.visual.ln_post(pooled) + + if self.visual.proj is not None: + pooled = pooled @ self.visual.proj + + if self.visual.output_tokens: + return pooled, patch_tokens + + return pooled, patch_tokens + + def encode_text(self, text): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + + for r in self.visual.transformer.resblocks: + # add prompt here + + x, attn_tmp = r(x, attn_mask=self.attn_mask) + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x + + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic'): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/method/custom_clip.py b/method/custom_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..627c00c2dcdb9744c95cf90082cb74599bd9cdfb --- /dev/null +++ b/method/custom_clip.py @@ -0,0 +1,716 @@ +# This file is largely borrowed from open clip +import hashlib +import json +import logging +import os +import re +import urllib +import warnings +from copy import deepcopy +from dataclasses import dataclass, asdict +from functools import partial +from pathlib import Path +from typing import Any, Optional, Tuple +from typing import Dict, Union +from typing import List +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop +from tqdm import tqdm +from .clip_model import CLIP, convert_to_custom_text_state_dict, \ + resize_pos_embed +from .clip_model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .tokenizer import HFTokenizer, tokenize + +__version__ = '2.16.0' + +try: + from huggingface_hub import hf_hub_download + + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') +) + + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + + + +_PRETRAINED = { + "ViT-B-32": _VITB32, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + if 'openai' in model_pretrained.keys(): + tag = 'openai' + else: + tag = list(model_pretrained.keys())[0] + print('*' * 50) + print(f'Use pretrained model from {tag}...') + print('*' * 50) + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith( + expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target + + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn( + f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + cfg = get_pretrained_cfg(name, 'openai') + if cfg: + model_path = download_pretrained(cfg, cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_pretrained()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + # JIT -> Just In Time + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': + model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent.parent / f"./model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + img_size: int, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, +): + if model_name.count('ViT') < 1: + print('only support ViT model..') + raise NotImplementedError + + # in which means, we can also use old naming rules. + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + # our default version are borrowed from openai + assert pretrained and pretrained.lower() == 'openai', 'only support openai module.' + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model_cfg = model_cfg or get_model_config(model_name) + + model_cfg['vision_cfg']['image_size'] = img_size + cast_dtype = get_cast_dtype(precision) + + model_pre = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + state_dict = model_pre.state_dict() + + # to always output dict even if it is clip + if output_dict and hasattr(model_pre, "output_dict"): + model_pre.output_dict = True + + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + # mainly need to resize the position embeddings + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=True) + model.to(device=device) + if precision in ("fp16", "bf16"): + convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + # to always output dict even if it is clip + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_model_and_transforms( + model_name: str, + img_size: int, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, +): + ######### create the clip model + model = create_model( + model_name, + img_size, + pretrained, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + output_dict=output_dict, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess_train, preprocess_val diff --git a/method/simple_tokenizer.py b/method/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/method/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'</w>' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '</w>',) + pairs = get_pairs(word) + + if not pairs: + return token+'</w>' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') + return text diff --git a/method/tokenizer.py b/method/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/method/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'</w>' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['<start_of_text>', '<end_of_text>'] + else: + special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '</w>',) + pairs = get_pairs(word) + + if not pairs: + return token+'</w>' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<start_of_text>"] + eot_token = _tokenizer.encoder["<end_of_text>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/method/trainer.py b/method/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c9bd540f44d46df686140aacf44eaf0913466fbe --- /dev/null +++ b/method/trainer.py @@ -0,0 +1,225 @@ +import cv2 +import torchvision.transforms as transforms +from scipy.ndimage import gaussian_filter + +from loss import FocalLoss, BinaryDiceLoss +from tools import visualization, calculate_metric, calculate_average_metric +from .adaclip import * +from .custom_clip import create_model_and_transforms + + +class AdaCLIP_Trainer(nn.Module): + def __init__( + self, + # clip-related + backbone, feat_list, input_dim, output_dim, + + # learning-related + learning_rate, device, image_size, + + # model settings + prompting_depth=3, prompting_length=2, + prompting_branch='VL', prompting_type='SD', + use_hsf=True, k_clusters=20, + ): + + super(AdaCLIP_Trainer, self).__init__() + + self.device = device + self.feat_list = feat_list + self.image_size = image_size + self.prompting_branch = prompting_branch + self.prompting_type = prompting_type + + self.loss_focal = FocalLoss() + self.loss_dice = BinaryDiceLoss() + + ########### different model choices + freeze_clip, _, self.preprocess = create_model_and_transforms(backbone, image_size, + pretrained='openai') + freeze_clip = freeze_clip.to(device) + freeze_clip.eval() + + self.clip_model = AdaCLIP(freeze_clip=freeze_clip, + text_channel=output_dim, + visual_channel=input_dim, + prompting_length=prompting_length, + prompting_depth=prompting_depth, + prompting_branch=prompting_branch, + prompting_type=prompting_type, + use_hsf=use_hsf, + k_clusters=k_clusters, + output_layers=feat_list, + device=device, + image_size=image_size).to(device) + + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.CenterCrop(image_size), + transforms.ToTensor() + ]) + + self.preprocess.transforms[0] = transforms.Resize(size=(image_size, image_size), + interpolation=transforms.InterpolationMode.BICUBIC, + max_size=None) + + self.preprocess.transforms[1] = transforms.CenterCrop(size=(image_size, image_size)) + + # update parameters + self.learnable_paramter_list = [ + 'text_prompter', + 'visual_prompter', + 'patch_token_layer', + 'cls_token_layer', + 'dynamic_visual_prompt_generator', + 'dynamic_text_prompt_generator' + ] + + self.params_to_update = [] + for name, param in self.clip_model.named_parameters(): + # print(name) + for update_name in self.learnable_paramter_list: + if update_name in name: + # print(f'updated parameters--{name}: {update_name}') + self.params_to_update.append(param) + + # build the optimizer + self.optimizer = torch.optim.AdamW(self.params_to_update, lr=learning_rate, betas=(0.5, 0.999)) + + def save(self, path): + self.save_dict = {} + for param, value in self.state_dict().items(): + for update_name in self.learnable_paramter_list: + if update_name in param: + # print(f'{param}: {update_name}') + self.save_dict[param] = value + break + + torch.save(self.save_dict, path) + + def load(self, path): + self.load_state_dict(torch.load(path, map_location=self.device), strict=False) + + def train_one_batch(self, items): + image = items['img'].to(self.device) + cls_name = items['cls_name'] + + # pixel level + anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=False) + + if not isinstance(anomaly_map, list): + anomaly_map = [anomaly_map] + + # losses + gt = items['img_mask'].to(self.device) + gt = gt.squeeze() + + gt[gt > 0.5] = 1 + gt[gt <= 0.5] = 0 + + is_anomaly = items['anomaly'].to(self.device) + is_anomaly[is_anomaly > 0.5] = 1 + is_anomaly[is_anomaly <= 0.5] = 0 + loss = 0 + + # classification loss + classification_loss = self.loss_focal(anomaly_score, is_anomaly.unsqueeze(1)) + loss += classification_loss + + # seg loss + seg_loss = 0 + for am, in zip(anomaly_map): + seg_loss += (self.loss_focal(am, gt) + self.loss_dice(am[:, 1, :, :], gt) + + self.loss_dice(am[:, 0, :, :], 1-gt)) + + loss += seg_loss + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return loss + + def train_epoch(self, loader): + self.clip_model.train() + loss_list = [] + for items in loader: + loss = self.train_one_batch(items) + loss_list.append(loss.item()) + + return np.mean(loss_list) + + @torch.no_grad() + def evaluation(self, dataloader, obj_list, save_fig, save_fig_dir=None): + self.clip_model.eval() + + results = {} + results['cls_names'] = [] + results['imgs_gts'] = [] + results['anomaly_scores'] = [] + results['imgs_masks'] = [] + results['anomaly_maps'] = [] + results['imgs'] = [] + results['names'] = [] + + with torch.no_grad(), torch.cuda.amp.autocast(): + image_indx = 0 + for indx, items in enumerate(dataloader): + if save_fig: + path = items['img_path'] + for _path in path: + vis_image = cv2.resize(cv2.imread(_path), (self.image_size, self.image_size)) + results['imgs'].append(vis_image) + cls_name = items['cls_name'] + for _cls_name in cls_name: + image_indx += 1 + results['names'].append('{:}-{:03d}'.format(_cls_name, image_indx)) + + image = items['img'].to(self.device) + cls_name = items['cls_name'] + results['cls_names'].extend(cls_name) + gt_mask = items['img_mask'] + gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0 + + for _gt_mask in gt_mask: + results['imgs_masks'].append(_gt_mask.squeeze(0).numpy()) # px + + # pixel level + anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=True) + + anomaly_map = anomaly_map.cpu().numpy() + anomaly_score = anomaly_score.cpu().numpy() + + for _anomaly_map, _anomaly_score in zip(anomaly_map, anomaly_score): + _anomaly_map = gaussian_filter(_anomaly_map, sigma=4) + results['anomaly_maps'].append(_anomaly_map) + results['anomaly_scores'].append(_anomaly_score) + + is_anomaly = np.array(items['anomaly']) + for _is_anomaly in is_anomaly: + results['imgs_gts'].append(_is_anomaly) + + # visualization + if save_fig: + print('saving fig.....') + visualization.plot_sample_cv2( + results['names'], + results['imgs'], + {'AdaCLIP': results['anomaly_maps']}, + results['imgs_masks'], + save_fig_dir + ) + + metric_dict = dict() + for obj in obj_list: + metric_dict[obj] = dict() + + for obj in obj_list: + metric = calculate_metric(results, obj) + obj_full_name = f'{obj}' + metric_dict[obj_full_name] = metric + + metric_dict['Average'] = calculate_average_metric(metric_dict) + + return metric_dict + diff --git a/method/transformer.py b/method/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f66abbe368f440c6c0db5061c89b3497035b8c --- /dev/null +++ b/method/transformer.py @@ -0,0 +1,615 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple +import numpy as np + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + idx: int = 12, + ): + super().__init__() + + self.idx = idx + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask + ) + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = q_x + self.ls_1(tmp) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x, attn + + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, + idx=idx) + for idx in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9], + attn_mask: Optional[torch.Tensor] = None): + idx = 0 + out_tokens = [] + for r in self.resblocks: + idx += 1 + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x, attn_tmp = r(x, attn_mask=attn_mask) + if idx in out_layers: + out_tokens.append(x) + return x, out_tokens + + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, + bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor, out_layers: list): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], + self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, patch_tokens = self.transformer(x, out_layers) + x = x.permute(1, 0, 2) # LND -> NLD + patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD + # patch_tokens = patch_tokens.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + # patch_pooled, patch_tokens = self._global_pool(patch_tokens) + # tokens = self.ln_post(tokens) + + if self.proj is not None: + pooled = pooled @ self.proj + # patch_tokens = patch_tokens @ self.proj # 不知道能不能行 + # tokens = tokens @ self.proj + + if self.output_tokens: + return pooled, patch_tokens + + return pooled, patch_tokens + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + diff --git a/method/utils.py b/method/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673 --- /dev/null +++ b/method/utils.py @@ -0,0 +1,60 @@ +from itertools import repeat +import collections.abc + +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) diff --git a/model_configs/ViT-B-16.json b/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/model_configs/ViT-B-32.json b/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/model_configs/ViT-H-14.json b/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/model_configs/ViT-L-14-336.json b/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/model_configs/ViT-L-14.json b/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/model_configs/ViT-bigG-14.json b/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/model_configs/ViT-g-14.json b/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..b4986fbe9de5b6fc9af3b78a79b49bd3d49d04cd --- /dev/null +++ b/test.py @@ -0,0 +1,199 @@ +import warnings +warnings.filterwarnings("ignore", category=RuntimeWarning) +import os +os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' +from torch.utils.data import DataLoader +from tqdm import tqdm +import argparse +import json +import os +import torch +from scipy.ndimage import gaussian_filter +import cv2 + +# Importing from local modules +from tools import write2csv, setup_seed, Logger +from dataset import get_data, dataset_dict +from method import AdaCLIP_Trainer +from PIL import Image +import numpy as np + +setup_seed(111) + +def train(args): + assert os.path.isfile(args.ckt_path), f"Please check the path of pre-trained model, {args.ckt_path} is not valid." + + # Configurations + batch_size = args.batch_size + image_size = args.image_size + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + save_fig = args.save_fig + + # Logger + logger = Logger('log.txt') + + # Print basic information + for key, value in sorted(vars(args).items()): + logger.info(f'{key} = {value}') + + + config_path = os.path.join('./model_configs', f'{args.model}.json') + + # Prepare model + with open(config_path, 'r') as f: + model_configs = json.load(f) + + # Set up the feature hierarchy + n_layers = model_configs['vision_cfg']['layers'] + substage = n_layers // 4 + features_list = [substage, substage * 2, substage * 3, substage * 4] + + model = AdaCLIP_Trainer( + backbone=args.model, + feat_list=features_list, + input_dim=model_configs['vision_cfg']['width'], + output_dim=model_configs['embed_dim'], + learning_rate=0., + device=device, + image_size=image_size, + prompting_depth=args.prompting_depth, + prompting_length=args.prompting_length, + prompting_branch=args.prompting_branch, + prompting_type=args.prompting_type, + use_hsf=args.use_hsf, + k_clusters=args.k_clusters + ).to(device) + + model.load(args.ckt_path) + + if args.testing_model == 'dataset': + assert args.testing_data in dataset_dict.keys(), f"You entered {args.testing_data}, but we only support " \ + f"{dataset_dict.keys()}" + + save_root = args.save_path + csv_root = os.path.join(save_root, 'csvs') + image_root = os.path.join(save_root, 'images') + csv_path = os.path.join(csv_root, f'{args.testing_data}.csv') + image_dir = os.path.join(image_root, f'{args.testing_data}') + os.makedirs(image_dir, exist_ok=True) + + test_data_cls_names, test_data, test_data_root = get_data( + dataset_type_list=args.testing_data, + transform=model.preprocess, + target_transform=model.transform, + training=False) + + test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False) + save_fig_flag = save_fig + + metric_dict = model.evaluation( + test_dataloader, + test_data_cls_names, + save_fig_flag, + image_dir, + ) + + for tag, data in metric_dict.items(): + logger.info( + '{:>15} \t\tI-Auroc:{:.2f} \tI-F1:{:.2f} \tI-AP:{:.2f} \tP-Auroc:{:.2f} \tP-F1:{:.2f} \tP-AP:{:.2f}'. + format(tag, + data['auroc_im'], + data['f1_im'], + data['ap_im'], + data['auroc_px'], + data['f1_px'], + data['ap_px']) + ) + + + for k in metric_dict.keys(): + write2csv(metric_dict[k], test_data_cls_names, k, csv_path) + + elif args.testing_model == 'image': + assert os.path.isfile(args.image_path), f"Please verify the input image path: {args.image_path}" + ori_image = cv2.resize(cv2.imread(args.image_path), (args.image_size, args.image_size)) + pil_img = Image.open(args.image_path).convert('RGB') + + img_input = model.preprocess(pil_img).unsqueeze(0) + img_input = img_input.to(model.device) + + with torch.no_grad(): + anomaly_map, anomaly_score = model.clip_model(img_input, [args.class_name], aggregation=True) + + anomaly_map = anomaly_map[0, :, :] + anomaly_score = anomaly_score[0] + anomaly_map = anomaly_map.cpu().numpy() + anomaly_score = anomaly_score.cpu().numpy() + + anomaly_map = gaussian_filter(anomaly_map, sigma=4) + anomaly_map = anomaly_map * 255 + anomaly_map = anomaly_map.astype(np.uint8) + + heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) + vis_map = cv2.addWeighted(heat_map, 0.5, ori_image, 0.5, 0) + + vis_map = cv2.hconcat([ori_image, vis_map]) + save_path = os.path.join(args.save_path, args.save_name) + print(f"Anomaly detection results are saved in {save_path}, with an anomaly of {anomaly_score:.3f} ") + cv2.imwrite(save_path, vis_map) + +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + +if __name__ == '__main__': + parser = argparse.ArgumentParser("AdaCLIP", add_help=True) + + # Paths and configurations + parser.add_argument("--ckt_path", type=str, default='weights/pretrained_mvtec_colondb.pth', + help="Path to the pre-trained model (default: weights/pretrained_mvtec_colondb.pth)") + + parser.add_argument("--testing_model", type=str, default="dataset", choices=["dataset", "image"], + help="Model for testing (default: 'dataset')") + + # for the dataset model + parser.add_argument("--testing_data", type=str, default="visa", help="Dataset for testing (default: 'visa')") + + # for the image model + parser.add_argument("--image_path", type=str, default="asset/img.png", + help="Model for testing (default: 'asset/img.png')") + parser.add_argument("--class_name", type=str, default="candle", + help="The class name of the testing image (default: 'candle')") + parser.add_argument("--save_name", type=str, default="test.png", + help="Model for testing (default: 'dataset')") + + + parser.add_argument("--save_path", type=str, default='./workspaces', + help="Directory to save results (default: './workspaces')") + + parser.add_argument("--model", type=str, default="ViT-L-14-336", + choices=["ViT-B-16", "ViT-B-32", "ViT-L-14", "ViT-L-14-336"], + help="The CLIP model to be used (default: 'ViT-L-14-336')") + + parser.add_argument("--save_fig", type=str2bool, default=False, + help="Save figures for visualizations (default: False)") + + # Hyper-parameters + parser.add_argument("--batch_size", type=int, default=1, help="Batch size (default: 1)") + parser.add_argument("--image_size", type=int, default=518, help="Size of the input images (default: 518)") + + # Prompting parameters + parser.add_argument("--prompting_depth", type=int, default=4, help="Depth of prompting (default: 4)") + parser.add_argument("--prompting_length", type=int, default=5, help="Length of prompting (default: 5)") + parser.add_argument("--prompting_type", type=str, default='SD', choices=['', 'S', 'D', 'SD'], + help="Type of prompting. 'S' for Static, 'D' for Dynamic, 'SD' for both (default: 'SD')") + parser.add_argument("--prompting_branch", type=str, default='VL', choices=['', 'V', 'L', 'VL'], + help="Branch of prompting. 'V' for Visual, 'L' for Language, 'VL' for both (default: 'VL')") + + parser.add_argument("--use_hsf", type=str2bool, default=True, + help="Use HSF for aggregation. If False, original class embedding is used (default: True)") + parser.add_argument("--k_clusters", type=int, default=20, help="Number of clusters (default: 20)") + + args = parser.parse_args() + + if args.batch_size != 1: + raise NotImplementedError( + "Currently, only batch size of 1 is supported due to unresolved bugs. Please set --batch_size to 1.") + + train(args) + diff --git a/test.sh b/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..0a6999781cb42d7f87cd6b2fd193702c21e7e0bd --- /dev/null +++ b/test.sh @@ -0,0 +1,26 @@ +# pre-trained from MVTec and ColonDB +ckt_path="weights/pretrained_mvtec_colondb.pth" +gpu_id=0 + +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data br35h +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data brain_mri +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data btad +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data clinicdb +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data dagm +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data dtd +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data headct +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data isic +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data mpdd +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data sdd +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data tn3k +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data visa + +# pre-trained from Visa and Clinicdb +ckt_path="weights/pretrained_visa_clinicdb.pth" +gpu_id=0 + +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data colondb +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model dataset --ckt_path $ckt_path --save_fig True --testing_data mvtec + + + diff --git a/test_single_image.sh b/test_single_image.sh new file mode 100644 index 0000000000000000000000000000000000000000..7723a325fa47b202319c1693c31101b26e2fe974 --- /dev/null +++ b/test_single_image.sh @@ -0,0 +1,6 @@ +ckt_path="weights/pretrained_all.pth" +gpu_id=0 + +# demo: do zero-shot anomaly detection for a single image +CUDA_VISIBLE_DEVICES=$gpu_id python test.py --testing_model image --ckt_path $ckt_path --save_fig True \ + --image_path asset/img.png --class_name candle --save_name test.png \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6da160913889c6e8cec2654ee437a284aa7bb401 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,5 @@ +from .csv_tools import write2csv +from .logger import Logger, log_metrics +from .metrics import calculate_metric, calculate_average_metric +from .training_tools import setup_seed, setup_paths +from .visualization import plot_sample_cv2 \ No newline at end of file diff --git a/tools/__pycache__/__init__.cpython-39.pyc b/tools/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32b77dd7d0ede42e6f20793700a8cfca79bb557e Binary files /dev/null and b/tools/__pycache__/__init__.cpython-39.pyc differ diff --git a/tools/__pycache__/csv_tools.cpython-39.pyc b/tools/__pycache__/csv_tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36a6988ccaf8b3f53ad62d83c448cae289891e82 Binary files /dev/null and b/tools/__pycache__/csv_tools.cpython-39.pyc differ diff --git a/tools/__pycache__/logger.cpython-39.pyc b/tools/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a5f5bd9303e1bf7d4af5d63bf2139cbc8fbadc4 Binary files /dev/null and b/tools/__pycache__/logger.cpython-39.pyc differ diff --git a/tools/__pycache__/metrics.cpython-39.pyc b/tools/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ab33773e9fde9f213a6b61df1ded84998e78826 Binary files /dev/null and b/tools/__pycache__/metrics.cpython-39.pyc differ diff --git a/tools/__pycache__/training_tools.cpython-39.pyc b/tools/__pycache__/training_tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a124f410758eafb807725c0c30a3a50587bf1d6c Binary files /dev/null and b/tools/__pycache__/training_tools.cpython-39.pyc differ diff --git a/tools/__pycache__/visualization.cpython-39.pyc b/tools/__pycache__/visualization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2db2482f48ba34dc681980c9e5194ef795da91a Binary files /dev/null and b/tools/__pycache__/visualization.cpython-39.pyc differ diff --git a/tools/csv_tools.py b/tools/csv_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..515084889fc1462f4b444707a42188a3c2e79fac --- /dev/null +++ b/tools/csv_tools.py @@ -0,0 +1,28 @@ +import pandas as pd +import os + +def write2csv(results:dict, total_classes, cur_class, csv_path): + keys = list(results.keys()) + + if not os.path.exists(csv_path): + df_all = None + for class_name in total_classes: + r = dict() + for k in keys: + r[k] = 0.00 + df_temp = pd.DataFrame(r, index=[f'{class_name}']) + + if df_all is None: + df_all = df_temp + else: + df_all = pd.concat([df_all, df_temp], axis=0) + + df_all.to_csv(csv_path, header=True, float_format='%.2f') + + df = pd.read_csv(csv_path, index_col=0) + + for k in keys: + df.loc[f'{cur_class}', k] = results[k] + + df.to_csv(csv_path, header=True, float_format='%.2f') + diff --git a/tools/logger.py b/tools/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..abfbb4424d07fdf6f8e80b13cfc3a15cfb060209 --- /dev/null +++ b/tools/logger.py @@ -0,0 +1,74 @@ +import logging + +class Logger(object): + def __init__(self, txt_path): + root_logger = logging.getLogger() + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + root_logger.setLevel(logging.WARNING) + self.txt_path = txt_path + self.logger = logging.getLogger('train') + self.formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') + self.logger.setLevel(logging.INFO) + + def __console(self, level, message): + root_logger = logging.getLogger() + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + file_handler = logging.FileHandler(self.txt_path, mode='a') + console_handler = logging.StreamHandler() + + file_handler.setFormatter(self.formatter) + console_handler.setFormatter(self.formatter) + + self.logger.addHandler(file_handler) + self.logger.addHandler(console_handler) + + if level == 'info': + self.logger.info(message) + elif level == 'debug': + self.logger.debug(message) + elif level == 'warning': + self.logger.warning(message) + elif level == 'error': + self.logger.error(message) + + self.logger.removeHandler(file_handler) + self.logger.removeHandler(console_handler) + + file_handler.close() + + def debug(self, message): + self.__console('debug', message) + + def info(self, message): + self.__console('info', message) + + def warning(self, message): + self.__console('warning', message) + + def error(self, message): + self.__console('error', message) + +def log_metrics(metrics, logger, tensorboard_logger, epoch): + def log_single_class(data, tag): + logger.info( + '{:>15} \t\tI-Auroc:{:.2f} \tI-F1:{:.2f} \tI-AP:{:.2f} \tP-Auroc:{:.2f} \tP-F1:{:.2f} \tP-AP:{:.2f}'. + format(tag, + data['auroc_im'], + data['f1_im'], + data['ap_im'], + data['auroc_px'], + data['f1_px'], + data['ap_px']) + ) + # Adding scalar metrics to TensorBoard + for metric_name in ['auroc_im', 'f1_im', 'ap_im', 'auroc_px', 'f1_px', 'ap_px']: + tensorboard_logger.add_scalar(f'{tag}-{metric_name}', data[metric_name], epoch) + + for tag, data in metrics.items(): + log_single_class(data, tag) + + + diff --git a/tools/metrics.py b/tools/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..71dad0c8c9c6e867c809ee717999904939b84a97 --- /dev/null +++ b/tools/metrics.py @@ -0,0 +1,87 @@ +import numpy as np +from sklearn.metrics import auc, roc_auc_score, precision_recall_curve, average_precision_score + + +def rescale(x): + return (x - x.min()) / (x.max() - x.min()) + + +def is_one_class(gt: np.ndarray): + gt_ravel = gt.ravel() + return gt_ravel.sum() == 0 or gt_ravel.sum() == gt_ravel.shape[0] + + +def calculate_px_metrics(gt_px, pr_px): + if is_one_class(gt_px): # In case there are only normal pixels or no pixel-level labels + return 0, 0, 0 + + auroc_px = roc_auc_score(gt_px.ravel(), pr_px.ravel()) + precisions, recalls, _ = precision_recall_curve(gt_px.ravel(), pr_px.ravel()) + f1_scores = (2 * precisions * recalls) / (precisions + recalls) + f1_px = np.max(f1_scores[np.isfinite(f1_scores)]) + ap_px = average_precision_score(gt_px.ravel(), pr_px.ravel()) + + return auroc_px * 100, f1_px * 100, ap_px * 100 + + +def calculate_im_metrics(gt_im, pr_im): + if is_one_class(gt_im): # In case there are only normal samples or no image-level labels + return 0, 0, 0 + + auroc_im = roc_auc_score(gt_im.ravel(), pr_im.ravel()) + precisions, recalls, _ = precision_recall_curve(gt_im.ravel(), pr_im.ravel()) + f1_scores = (2 * precisions * recalls) / (precisions + recalls) + f1_im = np.max(f1_scores[np.isfinite(f1_scores)]) + ap_im = average_precision_score(gt_im, pr_im) + + return ap_im * 100, auroc_im * 100, f1_im * 100 + + +def calculate_average_metric(metrics: dict): + average = {} + for obj, metric in metrics.items(): + for k, v in metric.items(): + if k not in average: + average[k] = [] + average[k].append(v) + + for k, v in average.items(): + average[k] = np.mean(v) + + return average + + +def calculate_metric(results, obj): + gt_px = [] + pr_px = [] + + gt_im = [] + pr_im = [] + + for idx in range(len(results['cls_names'])): + if results['cls_names'][idx] == obj: + gt_px.append(results['imgs_masks'][idx]) + pr_px.append(results['anomaly_maps'][idx]) + + gt_im.append(results['imgs_gts'][idx]) + pr_im.append(results['anomaly_scores'][idx]) + + gt_px = np.array(gt_px) + pr_px = np.array(pr_px) + + gt_im = np.array(gt_im) + pr_im = np.array(pr_im) + + auroc_px, f1_px, ap_px = calculate_px_metrics(gt_px, pr_px) + ap_im, auroc_im, f1_im = calculate_im_metrics(gt_im, pr_im) + + metric = { + 'auroc_px': auroc_px, + 'auroc_im': auroc_im, + 'f1_px': f1_px, + 'f1_im': f1_im, + 'ap_px': ap_px, + 'ap_im': ap_im, + } + + return metric diff --git a/tools/training_tools.py b/tools/training_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2fb2456b92ac78633792d4e3e61d1f1dd66d95 --- /dev/null +++ b/tools/training_tools.py @@ -0,0 +1,67 @@ +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter +import os +import random +import torch +import numpy as np + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def setup_paths(args): + save_root = args.save_path + model_root = os.path.join(save_root, 'models') + log_root = os.path.join(save_root, 'logs') + csv_root = os.path.join(save_root, 'csvs') + image_root = os.path.join(save_root, 'images') + tensorboard_root = os.path.join(save_root, 'tensorboard') + + os.makedirs(model_root, exist_ok=True) + os.makedirs(log_root, exist_ok=True) + os.makedirs(csv_root, exist_ok=True) + os.makedirs(image_root, exist_ok=True) + os.makedirs(tensorboard_root, exist_ok=True) + + if args.use_hsf: + # prepare model name + model_name = f'{args.exp_indx}s-pretrained-{args.training_data}-{args.model}-' \ + f'{args.prompting_type}-{args.prompting_branch}-' \ + f'D{args.prompting_depth}-L{args.prompting_length}-HSF-K{args.k_clusters}' + else: + # prepare model name + model_name = f'{args.exp_indx}s-pretrained-{args.training_data}-{args.model}-' \ + f'{args.prompting_type}-{args.prompting_branch}-' \ + f'D{args.prompting_depth}-L{args.prompting_length}-WO-HSF' + + + # prepare model path + ckp_path = os.path.join(model_root, model_name) + + # prepare tensorboard dir + tensorboard_dir = os.path.join(tensorboard_root, f'{model_name}-{args.testing_data}') + if os.path.exists(tensorboard_dir): + import shutil + shutil.rmtree(tensorboard_dir) + tensorboard_logger = SummaryWriter(log_dir=tensorboard_dir) + + # prepare csv path + csv_path = os.path.join(csv_root, f'{model_name}-{args.testing_data}.csv') + + # prepare image path + image_dir = os.path.join(image_root, f'{model_name}-{args.testing_data}') + os.makedirs(image_dir, exist_ok=True) + + # prepare log path + log_path = os.path.join(log_root, f'{model_name}-{args.testing_data}.txt') + + return model_name, image_dir, csv_path, log_path, ckp_path, tensorboard_logger + + diff --git a/tools/visualization.py b/tools/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..899e413446f3ba3a91f0a80dc67bdf1816639b1f --- /dev/null +++ b/tools/visualization.py @@ -0,0 +1,117 @@ +import cv2 +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import os +import seaborn as sns + +## +from sklearn.manifold import TSNE +from sklearn.decomposition import PCA + +## +import matplotlib.ticker as mtick + + +def plot_sample_cv2(names, imgs, scores_: dict, gts, save_folder=None): + os.makedirs(save_folder, exist_ok=True) + + # get subplot number + total_number = len(imgs) + + scores = scores_.copy() + # normarlisze anomalies + for k, v in scores.items(): + max_value = np.max(v) + min_value = np.min(v) + + scores[k] = (scores[k] - min_value) / max_value * 255 + scores[k] = scores[k].astype(np.uint8) + # draw gts + mask_imgs = [] + for idx in range(total_number): + gts_ = gts[idx] + mask_imgs_ = imgs[idx].copy() + mask_imgs_[gts_ > 0.5] = (0, 0, 255) + mask_imgs.append(mask_imgs_) + + # save imgs + for idx in range(total_number): + + cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_ori.jpg'), imgs[idx]) + cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_gt.jpg'), mask_imgs[idx]) + + for key in scores: + heat_map = cv2.applyColorMap(scores[key][idx], cv2.COLORMAP_JET) + visz_map = cv2.addWeighted(heat_map, 0.5, imgs[idx], 0.5, 0) + cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_{key}.jpg'), + visz_map) + + + + +def plot_feat_cv2(names, feat, save_folder=None): + # get subplot number + total_number = len(feat) + + # save imgs + for idx in range(total_number): + feat[idx] = cv2.resize(feat[idx], (256, 256), interpolation=cv2.INTER_NEAREST) + cv2.imwrite(os.path.join(save_folder, f'{names[idx]}_feat.jpg'), feat[idx]) + + + +valid_feature_visualization_methods = ['TSNE', 'PCA'] + +def visualize_feature(features, labels, legends, n_components=3, method='TSNE'): + assert method in valid_feature_visualization_methods + assert n_components in [2, 3] + + if method == 'TSNE': + model = TSNE(n_components=n_components) + elif method == 'PCA': + model = PCA(n_components=n_components) + + else: + raise NotImplementedError + + feat_proj = model.fit_transform(features) + + if n_components == 2: + ax = scatter_2d(feat_proj, labels) + elif n_components == 3: + ax = scatter_3d(feat_proj, labels) + else: + raise NotImplementedError + + plt.legend(legends) + plt.axis('off') + + +def scatter_3d(feat_proj, label): + plt.clf() + ax1 = plt.axes(projection='3d') + + label_unique = np.unique(label) + + for l in label_unique: + ax1.scatter3D(feat_proj[label == l, 0], + feat_proj[label == l, 1], + feat_proj[label == l, 2], s=5) + + return ax1 + + +def scatter_2d(feat_proj, label): + plt.clf() + ax1 = plt.axes() + + label_unique = np.unique(label) + + for l in label_unique: + ax1.scatter(feat_proj[label == l, 0], + feat_proj[label == l, 1], s=5) + + return ax1 diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a9227ebf57e9a1904d3784b894bc73b5f80291ae --- /dev/null +++ b/train.py @@ -0,0 +1,184 @@ +import warnings +warnings.filterwarnings("ignore", category=RuntimeWarning) +import os +os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' +from torch.utils.data import DataLoader +from tqdm import tqdm +import argparse +import json +import os +import torch + +# Importing from local modules +from tools import write2csv, setup_paths, setup_seed, log_metrics, Logger +from dataset import get_data +from method import AdaCLIP_Trainer + +setup_seed(111) + +def train(args): + # Configurations + epochs = args.epoch + learning_rate = args.learning_rate + batch_size = args.batch_size + image_size = args.image_size + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + save_fig = args.save_fig + + # Set up paths + model_name, image_dir, csv_path, log_path, ckp_path, tensorboard_logger = setup_paths(args) + # Logger + logger = Logger(log_path) + + # Print basic information + for key, value in sorted(vars(args).items()): + logger.info(f'{key} = {value}') + + logger.info('Model name: {:}'.format(model_name)) + + config_path = os.path.join('./model_configs', f'{args.model}.json') + + # Prepare model + with open(config_path, 'r') as f: + model_configs = json.load(f) + + # Set up the feature hierarchy + n_layers = model_configs['vision_cfg']['layers'] + substage = n_layers // 4 + features_list = [substage, substage * 2, substage * 3, substage * 4] + + model = AdaCLIP_Trainer( + backbone=args.model, + feat_list=features_list, + input_dim=model_configs['vision_cfg']['width'], + output_dim=model_configs['embed_dim'], + learning_rate=learning_rate, + device=device, + image_size=image_size, + prompting_depth=args.prompting_depth, + prompting_length=args.prompting_length, + prompting_branch=args.prompting_branch, + prompting_type=args.prompting_type, + use_hsf=args.use_hsf, + k_clusters=args.k_clusters + ).to(device) + + train_data_cls_names, train_data, train_data_root = get_data( + dataset_type_list=args.training_data, + transform=model.preprocess, + target_transform=model.transform, + training=True) + + test_data_cls_names, test_data, test_data_root = get_data( + dataset_type_list=args.testing_data, + transform=model.preprocess, + target_transform=model.transform, + training=False) + + logger.info('Data Root: training, {:}; testing, {:}'.format(train_data_root, test_data_root)) + + train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) + test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False) + + # Typically, we use MVTec or VisA as the validation set. The best model from this validation + # process is then used for zero-shot anomaly detection on novel categories. + best_f1 = -1e1 + + for epoch in tqdm(range(epochs)): + loss = model.train_epoch(train_dataloader) + + # Logs + if (epoch + 1) % args.print_freq == 0: + logger.info('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, loss)) + tensorboard_logger.add_scalar('loss', loss, epoch) + + # Validation + if (epoch + 1) % args.valid_freq == 0 or (epoch == epochs - 1): + if epoch == epochs - 1: + save_fig_flag = save_fig + else: + save_fig_flag = False + + logger.info('=============================Testing ====================================') + metric_dict = model.evaluation( + test_dataloader, + test_data_cls_names, + save_fig_flag, + image_dir, + ) + + log_metrics( + metric_dict, + logger, + tensorboard_logger, + epoch + ) + + f1_px = metric_dict['Average']['f1_px'] + + # Save best + if f1_px > best_f1: + for k in metric_dict.keys(): + write2csv(metric_dict[k], test_data_cls_names, k, csv_path) + + ckp_path_best = ckp_path + '_best.pth' + model.save(ckp_path_best) + best_f1 = f1_px + + + +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + +if __name__ == '__main__': + parser = argparse.ArgumentParser("AdaCLIP", add_help=True) + + # Paths and configurations + parser.add_argument("--training_data", type=str, default=["mvtec", "colondb"], nargs='+', + help="Datasets for training (default: ['mvtec', 'colondb'])") + parser.add_argument("--testing_data", type=str, default="visa", help="Dataset for testing (default: 'visa')") + + parser.add_argument("--save_path", type=str, default='./workspaces', + help="Directory to save results (default: './workspaces')") + + parser.add_argument("--model", type=str, default="ViT-L-14-336", + choices=["ViT-B-16", "ViT-B-32", "ViT-L-14", "ViT-L-14-336"], + help="The CLIP model to be used (default: 'ViT-L-14-336')") + + parser.add_argument("--save_fig", type=str2bool, default=False, + help="Save figures for visualizations (default: False)") + parser.add_argument("--ckt_path", type=str, default='', help="Path to the pre-trained model (default: '')") + + # Hyper-parameters + parser.add_argument("--exp_indx", type=int, default=0, help="Index of the experiment (default: 0)") + parser.add_argument("--epoch", type=int, default=5, help="Number of epochs (default: 5)") + parser.add_argument("--learning_rate", type=float, default=0.01, help="Learning rate (default: 0.01)") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size (default: 1)") + + parser.add_argument("--image_size", type=int, default=518, help="Size of the input images (default: 518)") + parser.add_argument("--print_freq", type=int, default=1, help="Frequency of print statements (default: 1)") + parser.add_argument("--valid_freq", type=int, default=1, help="Frequency of validation (default: 1)") + + # Prompting parameters + parser.add_argument("--prompting_depth", type=int, default=4, help="Depth of prompting (default: 4)") + parser.add_argument("--prompting_length", type=int, default=5, help="Length of prompting (default: 5)") + parser.add_argument("--prompting_type", type=str, default='SD', choices=['', 'S', 'D', 'SD'], + help="Type of prompting. 'S' for Static, 'D' for Dynamic, 'SD' for both (default: 'SD')") + parser.add_argument("--prompting_branch", type=str, default='VL', choices=['', 'V', 'L', 'VL'], + help="Branch of prompting. 'V' for Visual, 'L' for Language, 'VL' for both (default: 'VL')") + + parser.add_argument("--use_hsf", type=str2bool, default=True, + help="Use HSF for aggregation. If False, original class embedding is used (default: True)") + parser.add_argument("--k_clusters", type=int, default=20, help="Number of clusters (default: 20)") + + args = parser.parse_args() + + train(args) + + if args.batch_size != 1: + raise NotImplementedError( + "Currently, only batch size of 1 is supported due to unresolved bugs. Please set --batch_size to 1.") + + train(args) + diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..e29cdc774576c3c50e8efef8b564b25de65ffefc --- /dev/null +++ b/train.sh @@ -0,0 +1,19 @@ +gpu_id=0 + +# Note: Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable. +# It is recommended to run the training process multiple times and choose the best model based on performance +# on the validation set as the final model. + +# pre-trained on MVtec and colondb +CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True --training_data mvtec colondb --testing_data visa + +# pre-trained on Visa and Clinicdb +CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec + +# This model is pre-trained on all available data to create a powerful Zero-Shot Anomaly Detection (ZSAD) model for demonstration purposes. +CUDA_VISIBLE_DEVICES=$gpu_id python train.py --save_fig True \ +--training_data \ +br35h brain_mri btad clinicdb colondb \ +dagm dtd headct isic mpdd mvtec sdd tn3k visa \ +--testing_data mvtec + diff --git a/weights/pretrained_all.pth b/weights/pretrained_all.pth new file mode 100644 index 0000000000000000000000000000000000000000..9c50f6392089ce48aeee1852d427cd5f814096fe --- /dev/null +++ b/weights/pretrained_all.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33e8d3db1cb4aab030866b8b70a46e10aa27ebf2c23b5463cb07f2574addd98c +size 42673907 diff --git a/weights/pretrained_mvtec_colondb.pth b/weights/pretrained_mvtec_colondb.pth new file mode 100644 index 0000000000000000000000000000000000000000..f18c1b331ac4352dbdf93ae350a0a4b6651522dc --- /dev/null +++ b/weights/pretrained_mvtec_colondb.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be51a42c052bd4cf060e54f503a1f5d0b2a3b899bc8dc2e243042f18b215427e +size 42673907 diff --git a/weights/pretrained_visa_clinicdb.pth b/weights/pretrained_visa_clinicdb.pth new file mode 100644 index 0000000000000000000000000000000000000000..2cc80e21556222f04d2f63b3413d2099fe3cf97d --- /dev/null +++ b/weights/pretrained_visa_clinicdb.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3deabbbaf1e412cfdfcb42923a500b986f4b9ee96ccbc7a735d89dbc87df44c8 +size 42673907