diff --git a/.gitattributes b/.gitattributes
index 74adac1baf347119a1079bb2f694106b71d8eba1..31aa969e20b07dbc8bfb6f9f157312e0c1f32139 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
MiniGPT-Med-github/Med_examples_v2/5f4e8079-8225a5d2-1b0c3c46-4394a094-f285db0e.jpg filter=lfs diff=lfs merge=lfs -text
+Med_examples_v2/5f4e8079-8225a5d2-1b0c3c46-4394a094-f285db0e.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.1495.1517874291.249176.jpg b/Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.1495.1517874291.249176.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..63da4dc1d3a2764bb8c1308f98faae4e58e473b4
Binary files /dev/null and b/Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.1495.1517874291.249176.jpg differ
diff --git a/Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.16254.1517874395.786150.jpg b/Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.16254.1517874395.786150.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2549de7bde8bf9ef04e15326f7e239d2a8706c38
Binary files /dev/null and b/Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.16254.1517874395.786150.jpg differ
diff --git a/Med_examples_v2/1.2.840.113654.2.55.48339325922382839066544590341580673064.png b/Med_examples_v2/1.2.840.113654.2.55.48339325922382839066544590341580673064.png
new file mode 100644
index 0000000000000000000000000000000000000000..85a84a04f70609c18aa9b5fbfdb492216fb45ba8
Binary files /dev/null and b/Med_examples_v2/1.2.840.113654.2.55.48339325922382839066544590341580673064.png differ
diff --git a/Med_examples_v2/1.3.6.1.4.1.14519.5.2.1.7009.9004.242286124999058976921785904029.png b/Med_examples_v2/1.3.6.1.4.1.14519.5.2.1.7009.9004.242286124999058976921785904029.png
new file mode 100644
index 0000000000000000000000000000000000000000..f509928858520ac71deafd8a6ffd3fec1bc9796d
Binary files /dev/null and b/Med_examples_v2/1.3.6.1.4.1.14519.5.2.1.7009.9004.242286124999058976921785904029.png differ
diff --git a/Med_examples_v2/5f4e8079-8225a5d2-1b0c3c46-4394a094-f285db0e.jpg b/Med_examples_v2/5f4e8079-8225a5d2-1b0c3c46-4394a094-f285db0e.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5a5c59a1b99e27a81f8363e38fc1d341cb61c774
--- /dev/null
+++ b/Med_examples_v2/5f4e8079-8225a5d2-1b0c3c46-4394a094-f285db0e.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94a8259f7b596eb34fd18375913ec0b17d0ae1e2bdb56467a236aa5cc7557ec1
+size 1951301
diff --git a/Med_examples_v2/synpic33889.jpg b/Med_examples_v2/synpic33889.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2765fff4060a4898c51780f94514c6033fdfba5c
Binary files /dev/null and b/Med_examples_v2/synpic33889.jpg differ
diff --git a/Med_examples_v2/synpic50958.jpg b/Med_examples_v2/synpic50958.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f1c07287d09b70f0f986a8d7c1051ac68cc41e5d
Binary files /dev/null and b/Med_examples_v2/synpic50958.jpg differ
diff --git a/Med_examples_v2/synpic56061.jpg b/Med_examples_v2/synpic56061.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ebfa7ed14307aa1aabdf1b5f5831b01059447a2a
Binary files /dev/null and b/Med_examples_v2/synpic56061.jpg differ
diff --git a/Med_examples_v2/synpic58547.jpg b/Med_examples_v2/synpic58547.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..544fabb1833ddbad8595850f63f7eab742a4927d
Binary files /dev/null and b/Med_examples_v2/synpic58547.jpg differ
diff --git a/Med_examples_v2/synpic60423.jpg b/Med_examples_v2/synpic60423.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3eb94743d37f81525ae53ba910f572839fb8ff58
Binary files /dev/null and b/Med_examples_v2/synpic60423.jpg differ
diff --git a/Med_examples_v2/synpic676.jpg b/Med_examples_v2/synpic676.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..520e71a24df1f6620cef1bc552ca53267012f146
Binary files /dev/null and b/Med_examples_v2/synpic676.jpg differ
diff --git a/Med_examples_v2/xmlab149/source.jpg b/Med_examples_v2/xmlab149/source.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b2fcc3bac58e12687ac7acf977a6360f918d68b6
Binary files /dev/null and b/Med_examples_v2/xmlab149/source.jpg differ
diff --git a/Med_examples_v2/xmlab589/source.jpg b/Med_examples_v2/xmlab589/source.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..270d3b3ca35e05c388220c64ee4e203eee410b5c
Binary files /dev/null and b/Med_examples_v2/xmlab589/source.jpg differ
diff --git a/README.md b/README.md
index 8abfe6df19545236b25999189e0b91dd6b6ddb75..58761db7923134c2bae138f7baec96ae8b39e86d 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,51 @@
----
-title: MiniGPT Med
-emoji: 🌍
-colorFrom: pink
-colorTo: green
-sdk: gradio
-sdk_version: 4.41.0
-app_file: app.py
-pinned: false
-license: mit
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# MiniGPT-Med: Large Language Model as a General Interface for Radiology Diagnosis
+Asma Alkhaldi, Raneem Alnajim, Layan Alabdullatef, Rawan Alyahya, Jun Chen, Deyao Zhu, Ahmed Alsinan, Mohamed Elhoseiny
+
+*Saudi Data and Artificial Intelligence Authority (SDAIA) and King Abdullah University of Science and Technology (KAUST)*
+
+## Installation
+```
+git clone https://github.com/Vision-CAIR/MiniGPT-Med
+cd MiniGPT-Med
+conda env create -f environment.yml
+conda activate miniGPT-Med
+```
+
+## Download miniGPT-Med trained model weights
+
+* miniGPT-Med's weights [miniGPT-Med Model](https://drive.google.com/file/d/1kjGLk6s9LsBmXfLWQFCdlwF3aul08Cl8/view?usp=sharing)
+
+* Then modify line 8 at miniGPT-Med/eval_configs/minigptv2_eval.yaml to be the path of miniGPT-Med weight.
+
+## Prepare weight for LLMs
+
+### Llama2 Version
+
+```shell
+git clone https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
+```
+
+Then modify line 14 at miniGPT-Med/minigpt4/configs/models/minigpt_v2.yaml to be the path of Llama-2-13b-chat-hf.
+
+## Launching Demo Locally
+
+```
+python demo.py --cfg-path eval_configs/minigptv2_eval.yaml --gpu-id 0
+```
+
+## Dataset
+| Dataset | Images | json file|
+|---------|---------|----------|
+| MIMIC |[Download](https://physionet.org/content/mimiciii/1.4/) | [Download](https://drive.google.com/drive/folders/1nZhdfNoh7fkx7CWvf0_47_OLv3tA2m3o?usp=sharing) |
+| NLST |[Download](https://wiki.cancerimagingarchive.net/display/NLST)| [Downlaod](https://drive.google.com/drive/folders/1OKgMTaGLu_dWRuco6JipYzezw3oNwgaz?usp=sharing) |
+|SLAKE |[Downlaod](https://www.med-vqa.com/slake/) |[Download](https://drive.google.com/drive/folders/1vstjmfRbKahSAsi_b6FmTQiuolvgO8oC?usp=sharing)|
+|RSNA |[Downlaod](https://www.rsna.org/rsnai/ai-image-challenge/rsna-pneumonia-detection-challenge-2018) | [Download](https://drive.google.com/drive/folders/1wkXPvUNqda6jWAIduyiVJkS3Tx7P7td8?usp=sharing) |
+|Rad-VQA |[Downalod](https://osf.io/89kps/) |[Download](https://drive.google.com/drive/folders/1ING6Dodwk2DU_t4GHQYudNFMMg9OMfBQ?usp=sharing) |
+
+## Acknowledgement
+
+- MiniGPT-4
+- Lavis
+- Vicuna
+- Falcon
+- Llama 2
diff --git a/dcgm/bash/34649895/dcgm-gpu-stats-gpu202-02-r-34649895.out b/dcgm/bash/34649895/dcgm-gpu-stats-gpu202-02-r-34649895.out
new file mode 100644
index 0000000000000000000000000000000000000000..539d921bf117bc18e85836b2905a3ef8ce508772
--- /dev/null
+++ b/dcgm/bash/34649895/dcgm-gpu-stats-gpu202-02-r-34649895.out
@@ -0,0 +1,39 @@
+Successfully retrieved statistics for job: 34649895.
++------------------------------------------------------------------------------+
+| GPU ID: 0 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Tue Jul 9 09:29:46 2024 |
+| End Time | Wed Jul 10 09:30:32 2024 |
+| Total Execution Time (sec) | 86445.3 |
+| No. of Processes | 1 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 232291 |
+| Power Usage (Watts) | Avg: 65.6704, Max: 84.315, Min: 61.555 |
+| Max GPU Memory Used (bytes) | 10104078336 |
+| SM Clock (MHz) | Avg: 595, Max: 1155, Min: 210 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++-- Compute Process Utilization ---+-----------------------------------------+
+| PID | 1548651 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34673507/dcgm-gpu-stats-gpu201-23-l-34673507.out b/dcgm/bash/34673507/dcgm-gpu-stats-gpu201-23-l-34673507.out
new file mode 100644
index 0000000000000000000000000000000000000000..a93a98a71be9b14e54b59d82a202ae9102c0ffe3
--- /dev/null
+++ b/dcgm/bash/34673507/dcgm-gpu-stats-gpu201-23-l-34673507.out
@@ -0,0 +1,39 @@
+Successfully retrieved statistics for job: 34673507.
++------------------------------------------------------------------------------+
+| GPU ID: 1 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Fri Jul 12 11:48:45 2024 |
+| End Time | Sat Jul 13 11:49:39 2024 |
+| Total Execution Time (sec) | 86454.5 |
+| No. of Processes | 1 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 252136 |
+| Power Usage (Watts) | Avg: 69.7762, Max: 70.022, Min: 69.151 |
+| Max GPU Memory Used (bytes) | 10104078336 |
+| SM Clock (MHz) | Avg: 1157, Max: 1410, Min: 1155 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++-- Compute Process Utilization ---+-----------------------------------------+
+| PID | 2527521 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34676162/dcgm-gpu-stats-gpu201-23-l-34676162.out b/dcgm/bash/34676162/dcgm-gpu-stats-gpu201-23-l-34676162.out
new file mode 100644
index 0000000000000000000000000000000000000000..6c3de57d50c0fe21b5d28923efbdd45eac11b242
--- /dev/null
+++ b/dcgm/bash/34676162/dcgm-gpu-stats-gpu201-23-l-34676162.out
@@ -0,0 +1,39 @@
+Successfully retrieved statistics for job: 34676162.
++------------------------------------------------------------------------------+
+| GPU ID: 3 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Sun Jul 14 07:57:08 2024 |
+| End Time | Mon Jul 15 07:57:59 2024 |
+| Total Execution Time (sec) | 86450.6 |
+| No. of Processes | 1 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 249997 |
+| Power Usage (Watts) | Avg: 82.8167, Max: 86.615, Min: 70.491 |
+| Max GPU Memory Used (bytes) | 10104078336 |
+| SM Clock (MHz) | Avg: 1352, Max: 1410, Min: 1080 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++-- Compute Process Utilization ---+-----------------------------------------+
+| PID | 3048225 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34691276/dcgm-gpu-stats-gpu201-09-l-34691276.out b/dcgm/bash/34691276/dcgm-gpu-stats-gpu201-09-l-34691276.out
new file mode 100644
index 0000000000000000000000000000000000000000..ce61c3e3b1877129eba211f608254d0aea89681c
--- /dev/null
+++ b/dcgm/bash/34691276/dcgm-gpu-stats-gpu201-09-l-34691276.out
@@ -0,0 +1,42 @@
+Successfully retrieved statistics for job: 34691276.
++------------------------------------------------------------------------------+
+| GPU ID: 0 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Tue Jul 16 08:21:43 2024 |
+| End Time | Tue Jul 16 21:44:34 2024 |
+| Total Execution Time (sec) | 48170.9 |
+| No. of Processes | 2 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 222759 |
+| Power Usage (Watts) | Avg: 61.4158, Max: 61.683, Min: 61.349 |
+| Max GPU Memory Used (bytes) | 10806624256 |
+| SM Clock (MHz) | Avg: 210, Max: 225, Min: 210 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++-- Compute Process Utilization ---+-----------------------------------------+
+| PID | 1958147 |
+| Avg SM Utilization (%) | 1 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 2068287 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34709014/dcgm-gpu-stats-gpu109-16-l-34709014.out b/dcgm/bash/34709014/dcgm-gpu-stats-gpu109-16-l-34709014.out
new file mode 100644
index 0000000000000000000000000000000000000000..affa24b89d0a68b8a25e2a386917e68b037442be
--- /dev/null
+++ b/dcgm/bash/34709014/dcgm-gpu-stats-gpu109-16-l-34709014.out
@@ -0,0 +1,39 @@
+Successfully retrieved statistics for job: 34709014.
++------------------------------------------------------------------------------+
+| GPU ID: 3 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Thu Jul 18 07:54:11 2024 |
+| End Time | Fri Jul 19 07:55:07 2024 |
+| Total Execution Time (sec) | 86456.3 |
+| No. of Processes | 1 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 245376 |
+| Power Usage (Watts) | Avg: 67.8347, Max: 68.156, Min: 67.563 |
+| Max GPU Memory Used (bytes) | 10582228992 |
+| SM Clock (MHz) | Avg: 1161, Max: 1410, Min: 1155 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++-- Compute Process Utilization ---+-----------------------------------------+
+| PID | 4005887 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34721198/dcgm-gpu-stats-gpu203-23-r-34721198.out b/dcgm/bash/34721198/dcgm-gpu-stats-gpu203-23-r-34721198.out
new file mode 100644
index 0000000000000000000000000000000000000000..adfe7578ac4468b53712ec94df5cc6c3df09a1f2
--- /dev/null
+++ b/dcgm/bash/34721198/dcgm-gpu-stats-gpu203-23-r-34721198.out
@@ -0,0 +1,57 @@
+Successfully retrieved statistics for job: 34721198.
++------------------------------------------------------------------------------+
+| GPU ID: 1 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Fri Jul 19 21:34:44 2024 |
+| End Time | Sat Jul 20 00:01:06 2024 |
+| Total Execution Time (sec) | 8782.23 |
+| No. of Processes | 7 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 225540 |
+| Power Usage (Watts) | Avg: 75.9496, Max: 87.541, Min: 65.792 |
+| Max GPU Memory Used (bytes) | 13356761088 |
+| SM Clock (MHz) | Avg: 210, Max: 210, Min: 210 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++-- Compute Process Utilization ---+-----------------------------------------+
+| PID | 866309 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 866955 |
+| Avg SM Utilization (%) | 1 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 868076 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 868638 |
+| Avg SM Utilization (%) | 5 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 869519 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 871043 |
+| Avg SM Utilization (%) | 1 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 871322 |
+| Avg SM Utilization (%) | 0 |
+| Avg Memory Utilization (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34734121/dcgm-gpu-stats-gpu201-23-l-34734121.out b/dcgm/bash/34734121/dcgm-gpu-stats-gpu201-23-l-34734121.out
new file mode 100644
index 0000000000000000000000000000000000000000..f590de6c5a80aca2f3ad7355de4e11768bb2bbb7
--- /dev/null
+++ b/dcgm/bash/34734121/dcgm-gpu-stats-gpu201-23-l-34734121.out
@@ -0,0 +1,35 @@
+Successfully retrieved statistics for job: 34734121.
++------------------------------------------------------------------------------+
+| GPU ID: 3 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Tue Jul 23 11:47:49 2024 |
+| End Time | Tue Jul 23 13:47:51 2024 |
+| Total Execution Time (sec) | 7202.22 |
+| No. of Processes | 0 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 226384 |
+| Power Usage (Watts) | Avg: 62.6807, Max: 81.445, Min: 62.015 |
+| Max GPU Memory Used (bytes) | 0 |
+| SM Clock (MHz) | Avg: 220, Max: 1410, Min: 210 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34738689/dcgm-gpu-stats-gpu201-16-r-34738689.out b/dcgm/bash/34738689/dcgm-gpu-stats-gpu201-16-r-34738689.out
new file mode 100644
index 0000000000000000000000000000000000000000..e1898117f8075490cc7c504e8f85434f7cdee8fd
--- /dev/null
+++ b/dcgm/bash/34738689/dcgm-gpu-stats-gpu201-16-r-34738689.out
@@ -0,0 +1,35 @@
+Successfully retrieved statistics for job: 34738689.
++------------------------------------------------------------------------------+
+| GPU ID: 3 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Wed Jul 24 10:14:38 2024 |
+| End Time | Wed Jul 24 11:45:33 2024 |
+| Total Execution Time (sec) | 5454.69 |
+| No. of Processes | 0 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 232516 |
+| Power Usage (Watts) | Avg: 64.2532, Max: 64.329, Min: 63.938 |
+| Max GPU Memory Used (bytes) | 0 |
+| SM Clock (MHz) | Avg: 210, Max: 210, Min: 210 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/dcgm/bash/34757693/dcgm-gpu-stats-gpu202-16-r-34757693.out b/dcgm/bash/34757693/dcgm-gpu-stats-gpu202-16-r-34757693.out
new file mode 100644
index 0000000000000000000000000000000000000000..869cf67eecf50e74d06c54b41952a3e0e65f06ea
--- /dev/null
+++ b/dcgm/bash/34757693/dcgm-gpu-stats-gpu202-16-r-34757693.out
@@ -0,0 +1,42 @@
+Successfully retrieved statistics for job: 34757693.
++------------------------------------------------------------------------------+
+| GPU ID: 2 |
++====================================+=========================================+
+|----- Execution Stats ------------+-----------------------------------------|
+| Start Time | Thu Jul 25 15:38:16 2024 |
+| End Time | Thu Jul 25 17:08:59 2024 |
+| Total Execution Time (sec) | 5442.54 |
+| No. of Processes | 2 |
++----- Performance Stats ----------+-----------------------------------------+
+| Energy Consumed (Joules) | 214029 |
+| Power Usage (Watts) | Avg: 59.2012, Max: 67.659, Min: 59.026 |
+| Max GPU Memory Used (bytes) | 7616856064 |
+| SM Clock (MHz) | Avg: 243, Max: 1080, Min: 210 |
+| Memory Clock (MHz) | Avg: 1593, Max: 1593, Min: 1593 |
+| SM Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| Memory Utilization (%) | Avg: 0, Max: 0, Min: 0 |
+| PCIe Rx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
+| PCIe Tx Bandwidth (megabytes) | Avg: N/A, Max: N/A, Min: N/A |
++----- Event Stats ----------------+-----------------------------------------+
+| Single Bit ECC Errors | 0 |
+| Double Bit ECC Errors | 0 |
+| PCIe Replay Warnings | 0 |
+| Critical XID Errors | 0 |
++----- Slowdown Stats -------------+-----------------------------------------+
+| Due to - Power (%) | 0 |
+| - Thermal (%) | 0 |
+| - Reliability (%) | Not Supported |
+| - Board Limit (%) | Not Supported |
+| - Low Utilization (%) | Not Supported |
+| - Sync Boost (%) | 0 |
++-- Compute Process Utilization ---+-----------------------------------------+
+| PID | 1095606 |
+| Avg SM Utilization (%) | 3 |
+| Avg Memory Utilization (%) | 0 |
+| PID | 1096190 |
+| Avg SM Utilization (%) | 14 |
+| Avg Memory Utilization (%) | 2 |
++----- Overall Health -------------+-----------------------------------------+
+| Overall Health | Healthy |
++------------------------------------+-----------------------------------------+
+
diff --git a/demo_v2.py b/demo_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd6d3813aeca32d560c6137000e1d8d718c67c15
--- /dev/null
+++ b/demo_v2.py
@@ -0,0 +1,648 @@
+# python demo_v2.py --cfg-path eval_configs/minigptv2_eval.yaml --gpu-id 0
+
+import argparse
+import os
+import random
+from collections import defaultdict
+
+import cv2
+import re
+
+import numpy as np
+from PIL import Image
+import torch
+import html
+import gradio as gr
+
+import torchvision.transforms as T
+import torch.backends.cudnn as cudnn
+
+from minigpt4.common.config import Config
+
+from minigpt4.common.registry import registry
+from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
+
+# imports modules for registration
+from minigpt4.datasets.builders import *
+from minigpt4.models import *
+from minigpt4.processors import *
+from minigpt4.runners import *
+from minigpt4.tasks import *
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Demo")
+ parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
+ help="path to configuration file.")
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+random.seed(42)
+np.random.seed(42)
+torch.manual_seed(42)
+
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+print('Initializing Chat')
+args = parse_args()
+cfg = Config(args)
+
+device = 'cuda:{}'.format(args.gpu_id)
+
+model_config = cfg.model_cfg
+model_config.device_8bit = args.gpu_id
+model_cls = registry.get_model_class(model_config.arch)
+model = model_cls.from_config(model_config).to(device)
+bounding_box_size = 100
+
+vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
+vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+
+model = model.eval()
+
+CONV_VISION = Conversation(
+ system="",
+ roles=(r"[INST] ", r" [/INST]"),
+ messages=[],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="",
+)
+
+
+def extract_substrings(string):
+ # first check if there is no-finished bracket
+ index = string.rfind('}')
+ if index != -1:
+ string = string[:index + 1]
+
+ pattern = r'
(.*?)\}(?!<)'
+ matches = re.findall(pattern, string)
+ substrings = [match for match in matches]
+
+ return substrings
+
+
+def is_overlapping(rect1, rect2):
+ x1, y1, x2, y2 = rect1
+ x3, y3, x4, y4 = rect2
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
+
+
+def computeIoU(bbox1, bbox2):
+ x1, y1, x2, y2 = bbox1
+ x3, y3, x4, y4 = bbox2
+ intersection_x1 = max(x1, x3)
+ intersection_y1 = max(y1, y3)
+ intersection_x2 = min(x2, x4)
+ intersection_y2 = min(y2, y4)
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
+ union_area = bbox1_area + bbox2_area - intersection_area
+ iou = intersection_area / union_area
+ return iou
+
+
+def save_tmp_img(visual_img):
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
+ file_path = "/tmp/gradio" + file_name
+ visual_img.save(file_path)
+ return file_path
+
+
+def mask2bbox(mask):
+ if mask is None:
+ return ''
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
+ mask = np.array(mask)[:, :, 0]
+
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+
+ if rows.sum():
+ # Get the top, bottom, left, and right boundaries
+ rmin, rmax = np.where(rows)[0][[0, -1]]
+ cmin, cmax = np.where(cols)[0][[0, -1]]
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
+ else:
+ bbox = ''
+
+ return bbox
+
+
+def escape_markdown(text):
+ # List of Markdown special characters that need to be escaped
+ md_chars = ['<', '>']
+
+ # Escape each special character
+ for char in md_chars:
+ text = text.replace(char, '\\' + char)
+
+ return text
+
+
+def reverse_escape(text):
+ md_chars = ['\\<', '\\>']
+
+ for char in md_chars:
+ text = text.replace(char, char[1:])
+
+ return text
+
+
+colors = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (210, 210, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (114, 128, 250),
+ (0, 165, 255),
+ (0, 128, 0),
+ (144, 238, 144),
+ (238, 238, 175),
+ (255, 191, 0),
+ (0, 128, 0),
+ (226, 43, 138),
+ (255, 0, 255),
+ (0, 215, 255),
+]
+
+color_map = {
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
+ color_id, color in enumerate(colors)
+}
+
+used_colors = colors
+
+
+def visualize_all_bbox_together(image, generation):
+ if image is None:
+ return None, ''
+
+ generation = html.unescape(generation)
+ print('gen begin', generation)
+ image_width, image_height = image.size
+ image = image.resize([500, int(500 / image_width * image_height)])
+ image_width, image_height = image.size
+
+ string_list = extract_substrings(generation)
+ if string_list: # it is grounding or detection
+ mode = 'all'
+ entities = defaultdict(list)
+ i = 0
+ j = 0
+ for string in string_list:
+ try:
+ obj, string = string.split('
')
+ except ValueError:
+ print('wrong string: ', string)
+ continue
+ bbox_list = string.split('')
+ flag = False
+ for bbox_string in bbox_list:
+ integers = re.findall(r'-?\d+', bbox_string)
+ if len(integers) == 4:
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+
+ entities[obj].append([left, bottom, right, top])
+
+ j += 1
+ flag = True
+ if flag:
+ i += 1
+ else:
+ integers = re.findall(r'-?\d+', generation)
+
+ if len(integers) == 4: # it is refer
+ mode = 'single'
+
+ entities = list()
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
+ left = x0 / bounding_box_size * image_width
+ bottom = y0 / bounding_box_size * image_height
+ right = x1 / bounding_box_size * image_width
+ top = y1 / bounding_box_size * image_height
+ entities.append([left, bottom, right, top])
+ else:
+ # don't detect any valid bbox to visualize
+ return None, ''
+
+ if len(entities) == 0:
+ return None, ''
+
+ if isinstance(image, Image.Image):
+ image_h = image.height
+ image_w = image.width
+ image = np.array(image)
+
+ elif isinstance(image, str):
+ if os.path.exists(image):
+ pil_img = Image.open(image).convert("RGB")
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ image_h = pil_img.height
+ image_w = pil_img.width
+ else:
+ raise ValueError(f"invaild image path, {image}")
+ elif isinstance(image, torch.Tensor):
+
+ image_tensor = image.cpu()
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
+ pil_img = T.ToPILImage()(image_tensor)
+ image_h = pil_img.height
+ image_w = pil_img.width
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
+ else:
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
+
+ indices = list(range(len(entities)))
+
+ new_image = image.copy()
+
+ previous_bboxes = []
+ # size of text
+ text_size = 0.5
+ # thickness of text
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
+ box_line = 2
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
+ base_height = int(text_height * 0.675)
+ text_offset_original = text_height - base_height
+ text_spaces = 2
+
+ # num_bboxes = sum(len(x[-1]) for x in entities)
+ used_colors = colors # random.sample(colors, k=num_bboxes)
+
+ color_id = -1
+ for entity_idx, entity_name in enumerate(entities):
+ if mode == 'single' or mode == 'identify':
+ bboxes = entity_name
+ bboxes = [bboxes]
+ else:
+ bboxes = entities[entity_name]
+ color_id += 1
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
+ skip_flag = False
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
+
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
+
+ if mode == 'all':
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
+
+ x1 = orig_x1 - l_o
+ y1 = orig_y1 - l_o
+
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
+ x1 = orig_x1 + r_o
+
+ # add text background
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
+ text_line)
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
+
+ for prev_bbox in previous_bboxes:
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
+ prev_bbox['phrase'] == entity_name:
+ skip_flag = True
+ break
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
+
+ if text_bg_y2 >= image_h:
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
+ text_bg_y2 = image_h
+ y1 = image_h
+ break
+ if not skip_flag:
+ alpha = 0.5
+ for i in range(text_bg_y1, text_bg_y2):
+ for j in range(text_bg_x1, text_bg_x2):
+ if i < image_h and j < image_w:
+ if j < text_bg_x1 + 1.35 * c_width:
+ # original color
+ bg_color = color
+ else:
+ # white
+ bg_color = [255, 255, 255]
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
+ np.uint8)
+
+ cv2.putText(
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
+ )
+
+ previous_bboxes.append(
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
+
+ if mode == 'all':
+ def color_iterator(colors):
+ while True:
+ for color in colors:
+ yield color
+
+ color_gen = color_iterator(colors)
+
+ # Add colors to phrases and remove
+ def colored_phrases(match):
+ phrase = match.group(1)
+ color = next(color_gen)
+ return f'{phrase}'
+
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|', '', generation)
+ generation_colored = re.sub(r'(.*?)
', colored_phrases, generation)
+ else:
+ generation_colored = ''
+
+ pil_image = Image.fromarray(new_image)
+ return pil_image, generation_colored
+
+
+def gradio_reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
+ interactive=True), chat_state, img_list
+
+
+def image_upload_trigger(upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ upload_flag = 1
+ if img_list:
+ replace_flag = 1
+ return upload_flag, replace_flag
+
+
+def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
+ # set the upload flag to true when receive a new image.
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
+ upload_flag = 1
+ if img_list or replace_flag == 1:
+ replace_flag = 1
+
+ return upload_flag, replace_flag
+
+
+def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
+ if len(user_message) == 0:
+ text_box_show = 'Input should not be empty!'
+ else:
+ text_box_show = ''
+
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+ else:
+ mask = None
+
+ if '[identify]' in user_message:
+ # check if user provide bbox in the text input
+ integers = re.findall(r'-?\d+', user_message)
+ if len(integers) != 4: # no bbox in text
+ bbox = mask2bbox(mask)
+ user_message = user_message + bbox
+
+ if chat_state is None:
+ chat_state = CONV_VISION.copy()
+
+ if upload_flag:
+ if replace_flag:
+ chat_state = CONV_VISION.copy() # new image, reset everything
+ replace_flag = 0
+ chatbot = []
+ img_list = []
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
+ upload_flag = 0
+
+ chat.ask(user_message, chat_state)
+
+ chatbot = chatbot + [[user_message, None]]
+
+ if '[identify]' in user_message:
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
+ if visual_img is not None:
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[(file_path,), None]]
+
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
+
+
+def gradio_answer(chatbot, chat_state, img_list, temperature):
+ llm_message = chat.answer(conv=chat_state,
+ img_list=img_list,
+ temperature=temperature,
+ max_new_tokens=500,
+ max_length=2000)[0]
+ chatbot[-1][1] = llm_message
+ return chatbot, chat_state
+
+
+def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
+ if len(img_list) > 0:
+ if not isinstance(img_list[0], torch.Tensor):
+ chat.encode_img(img_list)
+ streamer = chat.stream_answer(conv=chat_state,
+ img_list=img_list,
+ temperature=temperature,
+ max_new_tokens=500,
+ max_length=2000)
+ output = ''
+ for new_output in streamer:
+ escapped = escape_markdown(new_output)
+ output += escapped
+ chatbot[-1][1] = output
+ yield chatbot, chat_state
+ chat_state.messages[-1][1] = ''
+ return chatbot, chat_state
+
+
+def gradio_visualize(chatbot, gr_img):
+ if isinstance(gr_img, dict):
+ gr_img, mask = gr_img['image'], gr_img['mask']
+
+ unescaped = reverse_escape(chatbot[-1][1])
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
+ if visual_img is not None:
+ if len(generation_color):
+ chatbot[-1][1] = generation_color
+ file_path = save_tmp_img(visual_img)
+ chatbot = chatbot + [[None, (file_path,)]]
+
+ return chatbot
+
+
+def gradio_taskselect(idx):
+ prompt_list = [
+ '',
+ '[grounding] describe this image in detail',
+ '[refer] ',
+ '[detection] ',
+ '[identify] what is this ',
+ '[vqa] '
+ ]
+ instruct_list = [
+ '**Hint:** Type in whatever you want',
+ '**Hint:** Send the command to generate a grounded image description',
+ '**Hint:** Type in a phrase about an object in the image and send the command',
+ '**Hint:** Type in a caption or phrase, and see object locations in the image',
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
+ '**Hint:** Send a question to get a short answer',
+ ]
+ return prompt_list[idx], instruct_list[idx]
+
+
+
+
+chat = Chat(model, vis_processor, device=device)
+
+title = """MiniGPT-Med Demo
"""
+description = 'Welcome to Our MiniGPT-Med Chatbot Demo!'
+# article = """![](https://img.shields.io/badge/Project-Page-Green)
![](https://img.shields.io/badge/Paper-PDF-red)
![](https://img.shields.io/badge/GitHub-Repo-blue)
![](https://img.shields.io/badge/YouTube-Video-red)
"""
+article = """![](https://img.shields.io/badge/Project-Page-Green)
"""
+
+introduction = '''
+For Abilities Involving Visual Grounding:
+1. Grounding: CLICK **Send** to generate a grounded image description.
+2. Refer: Input a referring object and CLICK **Send**.
+3. Detection: Write a caption or phrase, and CLICK **Send**.
+4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
+5. VQA: Input a visual question and CLICK **Send**.
+6. No Tag: Input whatever you want and CLICK **Send** without any tagging
+
+You can also simply chat in free form!
+'''
+
+text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
+ scale=8)
+with gr.Blocks() as demo:
+ gr.Markdown(title)
+ # gr.Markdown(description)
+ gr.Markdown(article)
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
+
+ temperature = gr.Slider(
+ minimum=0.1,
+ maximum=1.5,
+ value=0.6,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+
+ clear = gr.Button("Restart")
+
+ gr.Markdown(introduction)
+
+ with gr.Column():
+ chat_state = gr.State(value=None)
+ img_list = gr.State(value=[])
+ chatbot = gr.Chatbot(label='MiniGPT-Med')
+
+ dataset = gr.Dataset(
+ components=[gr.Textbox(visible=False)],
+ samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
+ type="index",
+ label='Task Shortcuts',
+ )
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
+ with gr.Row():
+ text_input.render()
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
+
+ upload_flag = gr.State(value=0)
+ replace_flag = gr.State(value=0)
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
+# [29, 44, 42, 56]
+ with gr.Row():
+ with gr.Column():
+ gr.Examples(examples=[
+ ["Med_examples_v2/xmlab149/source.jpg", "[identify] what is this {<56><16><84><58>}", upload_flag,
+ replace_flag, img_list],
+ ["Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.1495.1517874291.249176.jpg", "[detection] pneumonia", upload_flag, replace_flag, img_list],
+ ["Med_examples_v2/1.2.840.113654.2.55.48339325922382839066544590341580673064.png", "[refer] the nodule in the left lung", upload_flag, replace_flag,
+ img_list],
+ ["Med_examples_v2/xmlab589/source.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+ with gr.Column():
+ gr.Examples(examples=[
+ ["Med_examples_v2/synpic50958.jpg", "[vqa] What does the small white lesions in the aorta mean?",
+ upload_flag, replace_flag, img_list],
+ ["Med_examples_v2/5f4e8079-8225a5d2-1b0c3c46-4394a094-f285db0e.jpg", "Please provide a detailed description of the picture", upload_flag, replace_flag, img_list],
+ ["Med_examples_v2/1.2.276.0.7230010.3.1.4.8323329.16254.1517874395.786150.jpg", "Diagnose this image", upload_flag, replace_flag, img_list],
+ ["Med_examples_v2/synpic58547.jpg", "Could you describe the contents of this image for me?", upload_flag,
+ replace_flag, img_list],
+ ], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
+ outputs=[upload_flag, replace_flag])
+
+ dataset.click(
+ gradio_taskselect,
+ inputs=[dataset],
+ outputs=[text_input, task_inst],
+ show_progress="hidden",
+ postprocess=False,
+ queue=False,
+ )
+
+ text_input.submit(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ send.click(
+ gradio_ask,
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
+ ).success(
+ gradio_stream_answer,
+ [chatbot, chat_state, img_list, temperature],
+ [chatbot, chat_state]
+ ).success(
+ gradio_visualize,
+ [chatbot, image],
+ [chatbot],
+ queue=False,
+ )
+
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
+
+demo.launch(share=True, enable_queue=True)
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8e51a0962afa92bae05b21c1656cc84ba4729b82
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,35 @@
+name: miniGPT-Med
+channels:
+ - pytorch
+ - defaults
+ - anaconda
+dependencies:
+ - python=3.9
+ - cudatoolkit
+ - pip
+ - pip:
+ - torch==2.0.0
+ - torchaudio
+ - torchvision
+ - huggingface-hub==0.18.0
+ - matplotlib==3.7.0
+ - psutil==5.9.4
+ - iopath
+ - pyyaml==6.0
+ - regex==2022.10.31
+ - tokenizers==0.13.2
+ - tqdm==4.64.1
+ - transformers==4.30.0
+ - timm==0.6.13
+ - webdataset==0.2.48
+ - omegaconf==2.3.0
+ - opencv-python==4.7.0.72
+ - decord==0.6.0
+ - peft==0.2.0
+ - sentence-transformers
+ - gradio==3.47.1
+ - accelerate==0.20.3
+ - bitsandbytes==0.37.0
+ - scikit-image
+ - visual-genome
+ - wandb
diff --git a/eval_configs/minigptv2_benchmark_evaluation.yaml b/eval_configs/minigptv2_benchmark_evaluation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..277e92b1ce085b29d62c6650ae5034a0e152cc5c
--- /dev/null
+++ b/eval_configs/minigptv2_benchmark_evaluation.yaml
@@ -0,0 +1,69 @@
+model:
+ arch: minigpt_v2
+ model_type: pretrain
+ max_txt_len: 500
+ end_sym: ""
+ low_resource: False
+ prompt_template: '[INST] {} [/INST]'
+ llama_model: "/ibex/project/c2106/RadGPT/MiniGPT4-v2/llama-2-7b-chat-hf"
+ ckpt: "/ibex/project/c2106/RadGPT/MiniGPT-Med-github/miniGPT_Med_.pth"
+ lora_r: 64
+ lora_alpha: 16
+
+datasets:
+ cc_sbu_align:
+ vis_processor:
+ train:
+ name: "blip2_image_eval"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+evaluation_datasets:
+ rsna:
+ eval_file_path: miniGPT-Med/json_files/RSNA/full_RSNA_1024.json
+ img_path: miniGPT-Med/RSNA/RSNA-bbox-1024
+ max_new_tokens: 100
+ batch_size: 10
+
+ radvqa:
+ eval_file_path: /miniGPT-Med/json_files/vqa/full_radVQA.json
+ img_path: /miniGPT-Med/radVQA/VQA_RAD_Images
+ max_new_tokens: 300
+ batch_size: 10
+
+ mimic_cxr:
+ eval_file_path: /miniGPT-Med/json_files/mimic/MIMIC_test.json
+ img_path: /miniGPT-Med/mimic-cxr-dataset/image
+ max_new_tokens: 300
+ batch_size: 10
+
+ nlst:
+ eval_file_path: /miniGPT-Med/json_files/NLST/NLST_test.json
+ img_path: /miniGPT-Med/NLST/NLST_images
+ max_new_tokens: 100
+ batch_size: 10
+
+ detect_mimic:
+ eval_file_path: /miniGPT-Med/json_files/MIMIC-bbox/MIMIC-benchmarck.json
+ img_path: /miniGPT-Med/mimic-cxr-dataset/image
+ max_new_tokens: 100
+ batch_size: 10
+
+ SLAKE:
+ eval_file_path: /miniGPT-Med/json_files/SLAKE/grounding_test_SLAKE.json
+ img_path: /miniGPT-Med/SLAKE_images/imgs
+ max_new_tokens: 100
+ batch_size: 10
+
+
+run:
+ task: image_text_pretrain
+ name: minigptv2_evaluation
+ save_path: /miniGPT-Med/expermints
+
+
+
+
+
diff --git a/eval_configs/minigptv2_eval.yaml b/eval_configs/minigptv2_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6509af07ebd7fd057c0cc60d8521a134be43b8c9
--- /dev/null
+++ b/eval_configs/minigptv2_eval.yaml
@@ -0,0 +1,24 @@
+model:
+ arch: minigpt_v2
+ model_type: pretrain
+ max_txt_len: 500
+ end_sym: ""
+ low_resource: True
+ prompt_template: '[INST] {} [/INST]'
+ ckpt: "/ibex/project/c2106/RadGPT/MiniGPT-Med-github/miniGPT_Med_.pth"
+ lora_r: 64
+ lora_alpha: 16
+
+
+datasets:
+ cc_sbu_align:
+ vis_processor:
+ train:
+ name: "blip2_image_eval"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+run:
+ task: image_text_pretrain
diff --git a/eval_scripts/.DS_Store b/eval_scripts/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/eval_scripts/.DS_Store differ
diff --git a/eval_scripts/__pycache__/IoU.cpython-39.pyc b/eval_scripts/__pycache__/IoU.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d9c4f40c04bdc68652286efd5b3b0d4a00b1467
Binary files /dev/null and b/eval_scripts/__pycache__/IoU.cpython-39.pyc differ
diff --git a/eval_scripts/__pycache__/clean_json.cpython-39.pyc b/eval_scripts/__pycache__/clean_json.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b9bc37a2e0c36c81631f175e32adbd75d000dea
Binary files /dev/null and b/eval_scripts/__pycache__/clean_json.cpython-39.pyc differ
diff --git a/eval_scripts/__pycache__/metrics.cpython-39.pyc b/eval_scripts/__pycache__/metrics.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..625784ac75ad8dbc762d77a911d9c4e10db2682a
Binary files /dev/null and b/eval_scripts/__pycache__/metrics.cpython-39.pyc differ
diff --git a/eval_scripts/clean_json.py b/eval_scripts/clean_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..8297427ec36ef5f00c15da75d4fa7db24dd88e36
--- /dev/null
+++ b/eval_scripts/clean_json.py
@@ -0,0 +1,74 @@
+import json
+import re
+
+def clean_mimic_json(messy_json, cleaned_output):
+ with open(messy_json, 'r') as f:
+ messy_data = json.load(f)
+
+ clean_data = []
+ for image_id, captions in messy_data.items():
+ image_id_clean = image_id.split('.')[0]
+ caption_clean = ' '.join(captions)
+
+ clean_item = {
+ "image_id": image_id_clean,
+ "caption": caption_clean
+ }
+
+ clean_data.append(clean_item)
+
+ with open(cleaned_output, 'w') as outfile:
+ json.dump(clean_data, outfile, indent=2)
+
+
+def clean_vqa_json(messy_json, cleaned_output):
+ with open(messy_json, "r") as file:
+ messy_json = json.load(file)
+
+ organized_json = {}
+
+ for key, values in messy_json.items():
+ organized_json[key] = []
+ for value in values:
+ organized_json[key].append({
+ "question": value["question"],
+ "answer": value["answer"]
+ })
+
+ with open(cleaned_output, "w") as outfile:
+ json.dump(organized_json, outfile, indent=4)
+
+
+
+def clean_detection_json(messy_json, cleaned_output):
+
+ with open(messy_json, "r") as input_file:
+ input_json = json.load(input_file)
+
+ organized_data = []
+
+ for key, value in input_json.items():
+ if value and isinstance(value, list) and len(value) > 0:
+ caption = value[0]
+ objects_match = caption.split("")
+ if len(objects_match) == 2:
+ object_part = objects_match[1].split("
")[0].strip()
+ else:
+ object_part = ""
+
+ bbox_match = re.findall(r'<(\d+)>', caption)
+
+ if object_part and bbox_match and len(bbox_match) == 4:
+ key_part = key.split(".png")[0]
+ bbox_values = [float(val) for val in bbox_match]
+
+ organized_item = {
+ "key": key_part,
+ "objects": [object_part],
+ "bbox": [bbox_values],
+ }
+
+ organized_data.append(organized_item)
+
+ with open(cleaned_output, "w") as output_file:
+ json.dump(organized_data, output_file, indent=4)
\ No newline at end of file
diff --git a/eval_scripts/metrics.py b/eval_scripts/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..725639c0ef670f84824011c618ace3cd972fecef
--- /dev/null
+++ b/eval_scripts/metrics.py
@@ -0,0 +1,164 @@
+import sys
+sys.path.append('.')
+
+import json
+import pandas as pd
+import csv
+from sentence_transformers import SentenceTransformer, util
+from minigpt4.common.eval_utils import computeIoU
+
+# Load pre-trained BERT model
+model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
+
+
+# BERT similarity function will be utilized in the two following functions
+def compute_bert_similarity(prediction_caption, ground_truth_caption):
+ prediction_embedding = model.encode([prediction_caption])
+ ground_truth_embedding = model.encode([ground_truth_caption])
+ similarity = util.pytorch_cos_sim(prediction_embedding, ground_truth_embedding)[0][0].item()
+ return similarity
+
+
+def MIMIC_BERT_Sim(gt_pth, pred_pth, output_csv):
+ # Read the ground truth and prediction JSON files
+ with open(gt_pth, 'r') as f:
+ ground_truth_data = json.load(f)
+
+ with open(pred_pth, 'r') as f:
+ prediction_data = json.load(f)
+
+ # Create a list to store BERT similarity data
+ bert_similarity_data = []
+
+ # Initialize variables to calculate the average
+ total_similarity = 0
+ total_count = 0
+
+ # Iterate over each item in the prediction_data list
+ for item in prediction_data:
+ # Extract the image_id and corresponding prediction caption
+ image_id = item["image_id"]
+ prediction_caption = item["caption"]
+
+ # Search for the matching ground truth caption based on image_id
+ ground_truth_caption = None
+ for gt_item in ground_truth_data:
+ if gt_item["image_id"] == image_id:
+ ground_truth_caption = gt_item["caption"]
+ break
+
+ if ground_truth_caption is not None:
+ bert_similarity = compute_bert_similarity(prediction_caption, ground_truth_caption)
+ bert_similarity_data.append({"image_id": image_id, "BERT_score": bert_similarity})
+
+ total_similarity += bert_similarity
+ total_count += 1
+
+ average_similarity = total_similarity / total_count if total_count > 0 else 0
+
+ df = pd.DataFrame(bert_similarity_data)
+ df_sorted = df.sort_values(by="BERT_score", ascending=True)
+ df_sorted.to_csv(output_csv, index=False)
+
+ return average_similarity
+
+def VQA_BERT_Sim(gt_pth, pred_pth, output_csv):
+ # Load ground truth JSON file
+ with open(gt_pth, 'r') as file:
+ gt_data = json.load(file)
+
+ # Load prediction JSON file
+ with open(pred_pth, 'r') as file:
+ prediction_data = json.load(file)
+
+ gt_qa_pairs = {(entry['image_name'], entry['question']): entry['answer'] for entry in gt_data}
+
+ def convert_to_dict(data):
+ qa_dict = {}
+ for image_name, qa_list in data.items():
+ for qa in qa_list:
+ key = (image_name, qa['question'])
+ qa_dict[key] = qa['answer']
+ return qa_dict
+
+ pred_qa_dict = convert_to_dict(prediction_data)
+
+ # Compute BERT similarity and create a list of results
+ results = []
+
+ for key, gt_answer in gt_qa_pairs.items():
+ if key in pred_qa_dict:
+ pred_answer = pred_qa_dict[key]
+ gt_answer = str(gt_answer)
+ pred_answer = str(pred_answer)
+
+ # Compute BERT similarity
+ similarity_score = compute_bert_similarity(pred_answer, gt_answer)
+
+ # Append the result to the list
+ results.append({
+ "img_name": key[0],
+ "question": key[1],
+ "answer": pred_answer,
+ "BERT_score": similarity_score
+ })
+
+ average_similarity = sum(entry["BERT_score"] for entry in results) / len(results) if results else 0
+ df = pd.DataFrame(results)
+ df_sorted = df.sort_values(by="BERT_score", ascending=True)
+ df_sorted.to_csv(output_csv, index=False)
+ print(f"Average BERT similarity score: {average_similarity}")
+
+
+#################################
+##############IoU################
+#################################
+
+def preprocess_bbox(bbox, original_size, image_size):
+ x1 = int((bbox[0] / original_size) * image_size)
+ y1 = int((bbox[1] / original_size) * image_size)
+ x2 = int((bbox[2] / original_size) * image_size)
+ y2 = int((bbox[3] / original_size) * image_size)
+ return [x1, y1, x2, y2]
+
+def average_iou(gt_pth, pred_pth, original_size, image_size, dataset_name, csv_filename):
+ # Load ground truth
+ with open(gt_pth, 'r') as file:
+ ground_truth = json.load(file)
+
+ # Load predictions
+ with open(pred_pth, 'r') as file:
+ predictions = json.load(file)
+
+ iou_list = []
+
+ with open(csv_filename, 'w', newline='') as csvfile:
+ fieldnames = ['image_name', 'IoU']
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
+ writer.writeheader()
+
+ for gt_item in ground_truth:
+ gt_key = gt_item['key']
+ gt_bboxes = gt_item['bbox']
+ original_size = gt_item['height']
+ gt_processed_bboxes = [preprocess_bbox(bbox, original_size, image_size) for bbox in gt_bboxes]
+
+ for pred_item in predictions:
+ pred_key = pred_item['key'].replace(".png", "")
+
+ if gt_key == pred_key:
+ pred_bboxes = pred_item['bbox']
+ try:
+ for gt_bbox in gt_processed_bboxes:
+ for pred_bbox in pred_bboxes:
+ iou = computeIoU(gt_bbox, pred_bbox)
+ iou_list.append(iou)
+ writer.writerow({'image_name': gt_key, 'IoU': iou})
+ print(gt_key)
+ print(iou)
+ except Exception as e:
+ print("gt_bbox: ", gt_bbox)
+ print("gt_bbox: ", pred_bboxes)
+
+ # average_iou = sum(iou_list) / len(iou_list)
+ # print(f"Average IoU for dataset {dataset_name}: {average_iou:.4f}")
\ No newline at end of file
diff --git a/eval_scripts/model_evaluation.py b/eval_scripts/model_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..85dd6dc1deb602b21aab2e6ab5006f3222e86158
--- /dev/null
+++ b/eval_scripts/model_evaluation.py
@@ -0,0 +1,274 @@
+'''
+use this command in terminal to run the evaluation script
+torchrun --master-port 8888 --nproc_per_node 1 eval_scripts/model_evaluation.py --cfg-path eval_configs/minigptv2_benchmark_evaluation.yaml --dataset
+
+
+'''
+
+import sys
+sys.path.append('.')
+import os
+import re
+import json
+import argparse
+from collections import defaultdict
+import random
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+import torch
+from torch.utils.data import DataLoader
+from minigpt4.common.config import Config
+from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
+from minigpt4.conversation.conversation import CONV_VISION_minigptv2
+
+from minigpt4.datasets.datasets.mimic_cxr_dataset import evalMIMICDataset, evalDetectMimicDataset
+from minigpt4.datasets.datasets.radvqa_dataset import evalRadVQADataset
+from minigpt4.datasets.datasets.nlst_dataset import eval_NLST_Dataset
+from minigpt4.datasets.datasets.rsna_dataset import evalRSNADataset
+from minigpt4.datasets.datasets.SLAKE_dataset import evalSLAKEDataset
+#import cleaning classes
+from eval_scripts.clean_json import clean_mimic_json, clean_vqa_json, clean_detection_json
+from eval_scripts.metrics import MIMIC_BERT_Sim, VQA_BERT_Sim, average_iou
+
+def list_of_str(arg):
+ return list(map(str, arg.split(',')))
+
+parser = eval_parser()
+parser.add_argument("--dataset", type=list_of_str, help="dataset to evaluate")
+
+args = parser.parse_args()
+
+cfg = Config(args)
+
+
+model, vis_processor = init_model(args)
+model.eval()
+CONV_VISION = CONV_VISION_minigptv2
+conv_temp = CONV_VISION.copy()
+conv_temp.system = ""
+model.eval()
+save_path = cfg.run_cfg.save_path
+
+def process_mimic_dataset():
+ eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
+ img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
+ batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
+ max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
+
+ with open((eval_file_path), 'r') as f:
+ mimic = json.load(f)
+
+ data = evalMIMICDataset(mimic, vis_processor, img_path)
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
+ minigpt4_predict = defaultdict(list)
+
+ for images, questions, img_ids in tqdm(eval_dataloader):
+ texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
+ for answer, img_id, question in zip(answers, img_ids, questions):
+ minigpt4_predict[img_id].append(answer)
+
+ file_save_path = os.path.join(save_path,"MIMIC_inference_results_stage3.json")
+ with open(file_save_path,'w') as f:
+ json.dump(minigpt4_predict, f)
+ clean_mimic_json(file_save_path, file_save_path)
+
+ # csv file path to save the BERT results per each case
+ output_csv_path = '/miniGPT-Med/metric_results/bert_similarity_scores.csv'
+
+ # in MIMIC_BERT_Sim add the path of the ground_truth then the path of the inference result
+ average_similarity = MIMIC_BERT_Sim(eval_file_path, file_save_path, output_csv_path)
+ #print the average BERT_Sim
+ print("Average BERT Similarity:", average_similarity)
+
+def process_vqa_dataset():
+ eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
+ img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
+ batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
+ max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
+
+ with open((eval_file_path), 'r') as f:
+ radVQA = json.load(f)
+
+ data = evalRadVQADataset(radVQA, vis_processor, img_path)
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
+ minigpt4_predict = defaultdict(list)
+
+ for images, questions, img_ids in tqdm(eval_dataloader):
+ texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
+ for answer, img_id, question in zip(answers, img_ids, questions):
+ minigpt4_predict[img_id].append({"key":img_ids,"question": question.replace("[vqa]", "").strip() , "answer": answer})
+
+ file_save_path = os.path.join(save_path,"radVQA_inference_results.json")
+ output_csv_path = '/miniGPT-Med/BERT_Sim_results/vqa_bert_similarity_scores.csv'
+
+ with open(file_save_path,'w') as f:
+ json.dump(minigpt4_predict, f)
+
+ clean_vqa_json(file_save_path, file_save_path)
+ VQA_BERT_Sim(eval_file_path, file_save_path, output_csv_path)
+
+def process_nlst_dataset():
+ eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
+ img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
+ batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
+ max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
+
+ with open((eval_file_path), 'r') as f:
+ nlst = json.load(f)
+
+ data = eval_NLST_Dataset(nlst, vis_processor, img_path)
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
+ minigpt4_predict = defaultdict(list)
+ resamples = []
+
+ for images, questions, img_ids in tqdm(eval_dataloader):
+
+ texts = prepare_texts(questions, conv_temp)
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
+
+ for answer, img_id, question in zip(answers, img_ids, questions):
+
+ # answer = answer.replace("","").replace(" ","").strip()
+ pattern = r'\{<\d{1,2}><\d{1,2}><\d{1,2}><\d{1,2}>\}'
+ minigpt4_predict[img_id].append(answer)
+
+ file_save_path = os.path.join(save_path,"NLST_inference_result.json")
+ with open(file_save_path,'w') as f:
+ json.dump(minigpt4_predict, f)
+
+ csv_pth = os.path.join(save_path,"NLST_IoU_results.csv")
+ clean_detection_json(file_save_path,file_save_path)
+ average_iou(eval_file_path, file_save_path, 512, 100, "NLST", csv_pth)
+
+
+
+def process_rsna_dataset():
+ eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
+ print(eval_file_path)
+ img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
+ batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
+ max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
+ print("----config----")
+ with open((eval_file_path), 'r') as f:
+ nlst = json.load(f)
+
+ data = evalRSNADataset(nlst, vis_processor, img_path)
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
+ minigpt4_predict = defaultdict(list)
+ resamples = []
+
+ for images, questions, img_ids in tqdm(eval_dataloader):
+ texts = prepare_texts(questions, conv_temp)
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
+
+ for answer, img_id, question in zip(answers, img_ids, questions):
+
+ # answer = answer.replace("","").replace(" ","").strip()
+ pattern = r'\{<\d{1,2}><\d{1,2}><\d{1,2}><\d{1,2}>\}'
+ minigpt4_predict[img_id].append(answer)
+ print(img_id)
+ print(answer)
+
+ file_save_path = os.path.join(save_path,"RSNA_inference_result.json")
+ with open(file_save_path,'w') as f:
+ json.dump(minigpt4_predict, f)
+
+ csv_pth = os.path.join(save_path,"RSNA_IoU_results.csv")
+ clean_detection_json(file_save_path,file_save_path)
+ average_iou(eval_file_path, file_save_path, 1024, 100, "rsna", csv_pth)
+
+
+def process_detect_mimic():
+ eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
+ img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
+ batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
+ max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
+
+ with open((eval_file_path), 'r') as f:
+ nlst = json.load(f)
+
+ data = evalDetectMimicDataset(nlst, vis_processor, img_path)
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
+ minigpt4_predict = defaultdict(list)
+ resamples = []
+
+ for images, questions, img_ids in tqdm(eval_dataloader):
+
+ texts = prepare_texts(questions, conv_temp)
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
+
+ for answer, img_id, question in zip(answers, img_ids, questions):
+ pattern = r'\{<\d{1,2}><\d{1,2}><\d{1,2}><\d{1,2}>\}'
+ minigpt4_predict[img_id].append(answer)
+
+ file_save_path = os.path.join(save_path,"Detect_MIMIC_inference_result.json")
+ with open(file_save_path,'w') as f:
+ json.dump(minigpt4_predict, f)
+
+
+ csv_pth = os.path.join(save_path,"MIMIC_IoU_results.csv")
+ clean_detection_json(file_save_path,file_save_path)
+ average_iou(eval_file_path, file_save_path, "to be specified soon", 100, "MIMIC", csv_pth)
+
+
+
+def process_SLAKE_dataset():
+ eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
+ img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
+ batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
+ max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
+
+ with open((eval_file_path), 'r') as f:
+ SLAKE = json.load(f)
+
+ data = evalSLAKEDataset(SLAKE, vis_processor, img_path)
+ eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
+ minigpt4_predict = defaultdict(list)
+ resamples = []
+
+ for images, questions, img_ids in tqdm(eval_dataloader):
+
+ texts = prepare_texts(questions, conv_temp)
+ answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
+
+ for answer, img_id, question in zip(answers, img_ids, questions):
+
+ # answer = answer.replace("","").replace(" ","").strip()
+ pattern = r'\{<\d{1,2}><\d{1,2}><\d{1,2}><\d{1,2}>\}'
+ minigpt4_predict[img_id].append(answer)
+
+ file_save_path = os.path.join(save_path,"SLAKE_inference_result.json")
+ with open(file_save_path,'w') as f:
+ json.dump(minigpt4_predict, f)
+
+ csv_pth = os.path.join(save_path,"SLAKE_IoU_results.csv")
+ clean_detection_json(file_save_path,file_save_path)
+ average_iou(eval_file_path, file_save_path, 100, 100, "SLAKE", csv_pth)
+
+
+
+############################################################################
+for dataset in args.dataset:
+ if dataset == 'mimic_cxr':
+ process_mimic_dataset()
+
+ elif dataset == 'radvqa':
+ process_vqa_dataset()
+
+ elif dataset == 'nlst':
+ process_nlst_dataset()
+
+ elif dataset == 'rsna':
+ process_rsna_dataset()
+
+ elif dataset == 'detect_mimic':
+ process_detect_mimic()
+
+ elif dataset == 'SLAKE':
+ process_SLAKE_dataset()
+
+ else:
+ print(f"Dataset '{dataset}' is not supported.")
\ No newline at end of file
diff --git a/miniGPTV2.yml b/miniGPTV2.yml
new file mode 100644
index 0000000000000000000000000000000000000000..67aa4c59056b55a60bd4224e5ad9e4c8166d7374
--- /dev/null
+++ b/miniGPTV2.yml
@@ -0,0 +1,35 @@
+name: GPTv2
+channels:
+ - pytorch
+ - defaults
+ - anaconda
+dependencies:
+ - python=3.9
+ - cudatoolkit
+ - pip
+ - pip:
+ - torch==2.0.0
+ - torchaudio
+ - torchvision
+ - huggingface-hub==0.18.0
+ - matplotlib==3.7.0
+ - psutil==5.9.4
+ - iopath
+ - pyyaml==6.0
+ - regex==2022.10.31
+ - tokenizers==0.13.2
+ - tqdm==4.64.1
+ - transformers==4.30.0
+ - timm==0.6.13
+ - webdataset==0.2.48
+ - omegaconf==2.3.0
+ - opencv-python==4.7.0.72
+ - decord==0.6.0
+ - peft==0.2.0
+ - sentence-transformers
+ - gradio==3.47.1
+ - accelerate==0.20.3
+ - bitsandbytes==0.37.0
+ - scikit-image
+ - visual-genome
+ - wandb
\ No newline at end of file
diff --git a/miniGPT_Med_.pth b/miniGPT_Med_.pth
new file mode 100644
index 0000000000000000000000000000000000000000..584244d4a9728dd246ae05022a3ef1a172561b27
--- /dev/null
+++ b/miniGPT_Med_.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca2d7fc37dc5330cdae927c8a3ff649c5919c726eccb05cae921fb997028b08e
+size 679780138
diff --git a/minigpt4/.DS_Store b/minigpt4/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..dfe1a5f5a91916a5276353e98f78b01c51ae1898
Binary files /dev/null and b/minigpt4/.DS_Store differ
diff --git a/minigpt4/__init__.py b/minigpt4/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb31f42f9107a0b748b878deb1c5768019d62b32
--- /dev/null
+++ b/minigpt4/__init__.py
@@ -0,0 +1,31 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import sys
+
+from omegaconf import OmegaConf
+
+from minigpt4.common.registry import registry
+
+from minigpt4.datasets.builders import *
+from minigpt4.models import *
+from minigpt4.processors import *
+from minigpt4.tasks import *
+
+
+root_dir = os.path.dirname(os.path.abspath(__file__))
+default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
+
+registry.register_path("library_root", root_dir)
+repo_root = os.path.join(root_dir, "..")
+registry.register_path("repo_root", repo_root)
+cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
+registry.register_path("cache_root", cache_root)
+
+registry.register("MAX_INT", sys.maxsize)
+registry.register("SPLIT_NAMES", ["train", "val", "test"])
diff --git a/minigpt4/__pycache__/__init__.cpython-310.pyc b/minigpt4/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7cba737a63ffb6faa7d9cf789e0867e50b466121
Binary files /dev/null and b/minigpt4/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/__pycache__/__init__.cpython-39.pyc b/minigpt4/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a4e164b815f467157eb60b7ccee95b9fe103ad6
Binary files /dev/null and b/minigpt4/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/common/.DS_Store b/minigpt4/common/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..671c42f1235d58027381a65f37f93a61422e43eb
Binary files /dev/null and b/minigpt4/common/.DS_Store differ
diff --git a/minigpt4/common/__init__.py b/minigpt4/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/minigpt4/common/__pycache__/__init__.cpython-310.pyc b/minigpt4/common/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a9ae6f76d53fafc9c0069c6f1e66b146ee805a62
Binary files /dev/null and b/minigpt4/common/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/common/__pycache__/__init__.cpython-39.pyc b/minigpt4/common/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91d48ed7c1455e2aeebad97442cfa276fa1f86f4
Binary files /dev/null and b/minigpt4/common/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/common/__pycache__/config.cpython-310.pyc b/minigpt4/common/__pycache__/config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f96e46b19bfe3a8db806047d5934cf427d098633
Binary files /dev/null and b/minigpt4/common/__pycache__/config.cpython-310.pyc differ
diff --git a/minigpt4/common/__pycache__/config.cpython-39.pyc b/minigpt4/common/__pycache__/config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..056cf780306ce7642a3c51b5426d892d0890a5b4
Binary files /dev/null and b/minigpt4/common/__pycache__/config.cpython-39.pyc differ
diff --git a/minigpt4/common/__pycache__/dist_utils.cpython-310.pyc b/minigpt4/common/__pycache__/dist_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..544ab4cf0c7a1d5263f72da0e683e485c2f6d451
Binary files /dev/null and b/minigpt4/common/__pycache__/dist_utils.cpython-310.pyc differ
diff --git a/minigpt4/common/__pycache__/dist_utils.cpython-39.pyc b/minigpt4/common/__pycache__/dist_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0038f3e782a46c196ad9ded4717a0c366953b3dc
Binary files /dev/null and b/minigpt4/common/__pycache__/dist_utils.cpython-39.pyc differ
diff --git a/minigpt4/common/__pycache__/eval_utils.cpython-39.pyc b/minigpt4/common/__pycache__/eval_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bdcba1970b745df6cc13d2469a55f8ce8aad5c54
Binary files /dev/null and b/minigpt4/common/__pycache__/eval_utils.cpython-39.pyc differ
diff --git a/minigpt4/common/__pycache__/logger.cpython-310.pyc b/minigpt4/common/__pycache__/logger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7024ff71ea5044000317a6d46b08553690a299a
Binary files /dev/null and b/minigpt4/common/__pycache__/logger.cpython-310.pyc differ
diff --git a/minigpt4/common/__pycache__/logger.cpython-39.pyc b/minigpt4/common/__pycache__/logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f354c6f1565d127850c89de7178a86711469abe
Binary files /dev/null and b/minigpt4/common/__pycache__/logger.cpython-39.pyc differ
diff --git a/minigpt4/common/__pycache__/optims.cpython-39.pyc b/minigpt4/common/__pycache__/optims.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddf59b3be1419d789060d7cfefe470a1b7026147
Binary files /dev/null and b/minigpt4/common/__pycache__/optims.cpython-39.pyc differ
diff --git a/minigpt4/common/__pycache__/registry.cpython-310.pyc b/minigpt4/common/__pycache__/registry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61a5abbf0c8f9f5f48ea510dba9096f185d087ff
Binary files /dev/null and b/minigpt4/common/__pycache__/registry.cpython-310.pyc differ
diff --git a/minigpt4/common/__pycache__/registry.cpython-39.pyc b/minigpt4/common/__pycache__/registry.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d65a6bc47a95d275f9dac3d5902ebd93f2734223
Binary files /dev/null and b/minigpt4/common/__pycache__/registry.cpython-39.pyc differ
diff --git a/minigpt4/common/__pycache__/utils.cpython-310.pyc b/minigpt4/common/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b972d85b7da290579486669c4900e151c6eb8f0
Binary files /dev/null and b/minigpt4/common/__pycache__/utils.cpython-310.pyc differ
diff --git a/minigpt4/common/__pycache__/utils.cpython-39.pyc b/minigpt4/common/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24c33f9e6ce08eafe4605139cd1e12c318f0e465
Binary files /dev/null and b/minigpt4/common/__pycache__/utils.cpython-39.pyc differ
diff --git a/minigpt4/common/config.py b/minigpt4/common/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d3278bfe9caf59bddecd102d42a79ed8b71e55
--- /dev/null
+++ b/minigpt4/common/config.py
@@ -0,0 +1,496 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import json
+from typing import Dict
+
+from omegaconf import OmegaConf
+from minigpt4.common.registry import registry
+
+
+class Config:
+ def __init__(self, args):
+ self.config = {}
+
+ self.args = args
+
+ # Register the config and configuration for setup
+ registry.register("configuration", self)
+
+ user_config = self._build_opt_list(self.args.options)
+
+ config = OmegaConf.load(self.args.cfg_path)
+
+ runner_config = self.build_runner_config(config)
+ model_config = self.build_model_config(config, **user_config)
+ dataset_config = self.build_dataset_config(config)
+ evaluation_dataset_config = self.build_evaluation_dataset_config(config)
+
+ # Validate the user-provided runner configuration
+ # model and dataset configuration are supposed to be validated by the respective classes
+ # [TODO] validate the model/dataset configuration
+ # self._validate_runner_config(runner_config)
+
+ # Override the default configuration with user options.
+ self.config = OmegaConf.merge(
+ runner_config, model_config, dataset_config,evaluation_dataset_config, user_config
+ )
+
+ def _validate_runner_config(self, runner_config):
+ """
+ This method validates the configuration, such that
+ 1) all the user specified options are valid;
+ 2) no type mismatches between the user specified options and the config.
+ """
+ runner_config_validator = create_runner_config_validator()
+ runner_config_validator.validate(runner_config)
+
+ def _build_opt_list(self, opts):
+ opts_dot_list = self._convert_to_dot_list(opts)
+ return OmegaConf.from_dotlist(opts_dot_list)
+
+ @staticmethod
+ def build_model_config(config, **kwargs):
+ model = config.get("model", None)
+ assert model is not None, "Missing model configuration file."
+
+ model_cls = registry.get_model_class(model.arch)
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
+
+ model_type = kwargs.get("model.model_type", None)
+ if not model_type:
+ model_type = model.get("model_type", None)
+ # else use the model type selected by user.
+
+ assert model_type is not None, "Missing model_type."
+
+ model_config_path = model_cls.default_config_path(model_type=model_type)
+
+ model_config = OmegaConf.create()
+ # hierarchy override, customized config > default config
+ model_config = OmegaConf.merge(
+ model_config,
+ OmegaConf.load(model_config_path),
+ {"model": config["model"]},
+ )
+
+ return model_config
+
+ @staticmethod
+ def build_runner_config(config):
+ return {"run": config.run}
+
+ @staticmethod
+ def build_dataset_config(config):
+ datasets = config.get("datasets", None)
+ if datasets is None:
+ raise KeyError(
+ "Expecting 'datasets' as the root key for dataset configuration."
+ )
+
+ dataset_config = OmegaConf.create()
+
+ for dataset_name in datasets:
+ builder_cls = registry.get_builder_class(dataset_name)
+
+ dataset_config_type = datasets[dataset_name].get("type", "default")
+ dataset_config_path = builder_cls.default_config_path(
+ type=dataset_config_type
+ )
+
+ # hierarchy override, customized config > default config
+ dataset_config = OmegaConf.merge(
+ dataset_config,
+ OmegaConf.load(dataset_config_path),
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
+ )
+
+ return dataset_config
+
+
+ @staticmethod
+ def build_evaluation_dataset_config(config):
+ datasets = config.get("evaluation_datasets", None)
+ # if datasets is None:
+ # raise KeyError(
+ # "Expecting 'datasets' as the root key for dataset configuration."
+ # )
+
+ dataset_config = OmegaConf.create()
+
+ if datasets is not None:
+ for dataset_name in datasets:
+ builder_cls = registry.get_builder_class(dataset_name)
+
+ # hierarchy override, customized config > default config
+ dataset_config = OmegaConf.merge(
+ dataset_config,
+ {"evaluation_datasets": {dataset_name: config["evaluation_datasets"][dataset_name]}},
+ )
+
+ return dataset_config
+
+ def _convert_to_dot_list(self, opts):
+ if opts is None:
+ opts = []
+
+ if len(opts) == 0:
+ return opts
+
+ has_equal = opts[0].find("=") != -1
+
+ if has_equal:
+ return opts
+
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
+
+ def get_config(self):
+ return self.config
+
+ @property
+ def run_cfg(self):
+ return self.config.run
+
+ @property
+ def datasets_cfg(self):
+ return self.config.datasets
+
+ @property
+ def evaluation_datasets_cfg(self):
+ return self.config.evaluation_datasets
+
+ @property
+ def model_cfg(self):
+ return self.config.model
+
+ def pretty_print(self):
+ logging.info("\n===== Running Parameters =====")
+ logging.info(self._convert_node_to_json(self.config.run))
+
+ logging.info("\n====== Dataset Attributes ======")
+ datasets = self.config.datasets
+
+ for dataset in datasets:
+ if dataset in self.config.datasets:
+ logging.info(f"\n======== {dataset} =======")
+ dataset_config = self.config.datasets[dataset]
+ logging.info(self._convert_node_to_json(dataset_config))
+ else:
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
+
+ logging.info(f"\n====== Model Attributes ======")
+ logging.info(self._convert_node_to_json(self.config.model))
+
+ def _convert_node_to_json(self, node):
+ container = OmegaConf.to_container(node, resolve=True)
+ return json.dumps(container, indent=4, sort_keys=True)
+
+ def to_dict(self):
+ return OmegaConf.to_container(self.config)
+
+
+def node_to_dict(node):
+ return OmegaConf.to_container(node)
+
+
+class ConfigValidator:
+ """
+ This is a preliminary implementation to centralize and validate the configuration.
+ May be altered in the future.
+
+ A helper class to validate configurations from yaml file.
+
+ This serves the following purposes:
+ 1. Ensure all the options in the yaml are defined, raise error if not.
+ 2. when type mismatches are found, the validator will raise an error.
+ 3. a central place to store and display helpful messages for supported configurations.
+
+ """
+
+ class _Argument:
+ def __init__(self, name, choices=None, type=None, help=None):
+ self.name = name
+ self.val = None
+ self.choices = choices
+ self.type = type
+ self.help = help
+
+ def __str__(self):
+ s = f"{self.name}={self.val}"
+ if self.type is not None:
+ s += f", ({self.type})"
+ if self.choices is not None:
+ s += f", choices: {self.choices}"
+ if self.help is not None:
+ s += f", ({self.help})"
+ return s
+
+ def __init__(self, description):
+ self.description = description
+
+ self.arguments = dict()
+
+ self.parsed_args = None
+
+ def __getitem__(self, key):
+ assert self.parsed_args is not None, "No arguments parsed yet."
+
+ return self.parsed_args[key]
+
+ def __str__(self) -> str:
+ return self.format_help()
+
+ def add_argument(self, *args, **kwargs):
+ """
+ Assume the first argument is the name of the argument.
+ """
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
+
+ def validate(self, config=None):
+ """
+ Convert yaml config (dict-like) to list, required by argparse.
+ """
+ for k, v in config.items():
+ assert (
+ k in self.arguments
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
+
+ if self.arguments[k].type is not None:
+ try:
+ self.arguments[k].val = self.arguments[k].type(v)
+ except ValueError:
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
+
+ if self.arguments[k].choices is not None:
+ assert (
+ v in self.arguments[k].choices
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
+
+ return config
+
+ def format_arguments(self):
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
+
+ def format_help(self):
+ # description + key-value pair string for each argument
+ help_msg = str(self.description)
+ return help_msg + ", available arguments: " + self.format_arguments()
+
+ def print_help(self):
+ # display help message
+ print(self.format_help())
+
+
+def create_runner_config_validator():
+ validator = ConfigValidator(description="Runner configurations")
+
+ validator.add_argument(
+ "runner",
+ type=str,
+ choices=["runner_base", "runner_iter"],
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
+ runner runs based on iters. Default: runner_base""",
+ )
+ # add argumetns for training dataset ratios
+ validator.add_argument(
+ "train_dataset_ratios",
+ type=Dict[str, float],
+ help="""Ratios of training dataset. This is used in iteration-based runner.
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
+ Default: None""",
+ )
+ validator.add_argument(
+ "max_iters",
+ type=float,
+ help="Maximum number of iterations to run.",
+ )
+ validator.add_argument(
+ "max_epoch",
+ type=int,
+ help="Maximum number of epochs to run.",
+ )
+ # add arguments for iters_per_inner_epoch
+ validator.add_argument(
+ "iters_per_inner_epoch",
+ type=float,
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
+ )
+ lr_scheds_choices = registry.list_lr_schedulers()
+ validator.add_argument(
+ "lr_sched",
+ type=str,
+ choices=lr_scheds_choices,
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
+ )
+ task_choices = registry.list_tasks()
+ validator.add_argument(
+ "task",
+ type=str,
+ choices=task_choices,
+ help="Task to use, from {}".format(task_choices),
+ )
+ # add arguments for init_lr
+ validator.add_argument(
+ "init_lr",
+ type=float,
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
+ )
+ # add arguments for min_lr
+ validator.add_argument(
+ "min_lr",
+ type=float,
+ help="Minimum learning rate (after decay).",
+ )
+ # add arguments for warmup_lr
+ validator.add_argument(
+ "warmup_lr",
+ type=float,
+ help="Starting learning rate for warmup.",
+ )
+ # add arguments for learning rate decay rate
+ validator.add_argument(
+ "lr_decay_rate",
+ type=float,
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
+ )
+ # add arguments for weight decay
+ validator.add_argument(
+ "weight_decay",
+ type=float,
+ help="Weight decay rate.",
+ )
+ # add arguments for training batch size
+ validator.add_argument(
+ "batch_size_train",
+ type=int,
+ help="Training batch size.",
+ )
+ # add arguments for evaluation batch size
+ validator.add_argument(
+ "batch_size_eval",
+ type=int,
+ help="Evaluation batch size, including validation and testing.",
+ )
+ # add arguments for number of workers for data loading
+ validator.add_argument(
+ "num_workers",
+ help="Number of workers for data loading.",
+ )
+ # add arguments for warm up steps
+ validator.add_argument(
+ "warmup_steps",
+ type=int,
+ help="Number of warmup steps. Required if a warmup schedule is used.",
+ )
+ # add arguments for random seed
+ validator.add_argument(
+ "seed",
+ type=int,
+ help="Random seed.",
+ )
+ # add arguments for output directory
+ validator.add_argument(
+ "output_dir",
+ type=str,
+ help="Output directory to save checkpoints and logs.",
+ )
+ # add arguments for whether only use evaluation
+ validator.add_argument(
+ "evaluate",
+ help="Whether to only evaluate the model. If true, training will not be performed.",
+ )
+ # add arguments for splits used for training, e.g. ["train", "val"]
+ validator.add_argument(
+ "train_splits",
+ type=list,
+ help="Splits to use for training.",
+ )
+ # add arguments for splits used for validation, e.g. ["val"]
+ validator.add_argument(
+ "valid_splits",
+ type=list,
+ help="Splits to use for validation. If not provided, will skip the validation.",
+ )
+ # add arguments for splits used for testing, e.g. ["test"]
+ validator.add_argument(
+ "test_splits",
+ type=list,
+ help="Splits to use for testing. If not provided, will skip the testing.",
+ )
+ # add arguments for accumulating gradient for iterations
+ validator.add_argument(
+ "accum_grad_iters",
+ type=int,
+ help="Number of iterations to accumulate gradient for.",
+ )
+
+ # ====== distributed training ======
+ validator.add_argument(
+ "device",
+ type=str,
+ choices=["cpu", "cuda"],
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
+ )
+ validator.add_argument(
+ "world_size",
+ type=int,
+ help="Number of processes participating in the job.",
+ )
+ validator.add_argument("dist_url", type=str)
+ validator.add_argument("distributed", type=bool)
+ # add arguments to opt using distributed sampler during evaluation or not
+ validator.add_argument(
+ "use_dist_eval_sampler",
+ type=bool,
+ help="Whether to use distributed sampler during evaluation or not.",
+ )
+
+ # ====== task specific ======
+ # generation task specific arguments
+ # add arguments for maximal length of text output
+ validator.add_argument(
+ "max_len",
+ type=int,
+ help="Maximal length of text output.",
+ )
+ # add arguments for minimal length of text output
+ validator.add_argument(
+ "min_len",
+ type=int,
+ help="Minimal length of text output.",
+ )
+ # add arguments number of beams
+ validator.add_argument(
+ "num_beams",
+ type=int,
+ help="Number of beams used for beam search.",
+ )
+
+ # vqa task specific arguments
+ # add arguments for number of answer candidates
+ validator.add_argument(
+ "num_ans_candidates",
+ type=int,
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
+ )
+ # add arguments for inference method
+ validator.add_argument(
+ "inference_method",
+ type=str,
+ choices=["genearte", "rank"],
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
+ )
+
+ # ====== model specific ======
+ validator.add_argument(
+ "k_test",
+ type=int,
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
+ )
+
+ return validator
diff --git a/minigpt4/common/dist_utils.py b/minigpt4/common/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6fc1b904dccccbffbd96326b1506f8ff3ca19c1
--- /dev/null
+++ b/minigpt4/common/dist_utils.py
@@ -0,0 +1,140 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import timm.models.hub as timm_hub
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def init_distributed_mode(args):
+ if args.distributed is False:
+ print("Not using distributed mode")
+ return
+ elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ elif "SLURM_PROCID" in os.environ:
+ args.rank = int(os.environ["SLURM_PROCID"])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print("Not using distributed mode")
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}, world {}): {}".format(
+ args.rank, args.world_size, args.dist_url
+ ),
+ flush=True,
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ timeout=datetime.timedelta(
+ days=365
+ ), # allow auto-downloading and de-compressing
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def get_dist_info():
+ if torch.__version__ < "1.0":
+ initialized = dist._initialized
+ else:
+ initialized = dist.is_initialized()
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else: # non-distributed training
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def main_process(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+ """
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
+ """
+
+ def get_cached_file_path():
+ # a hack to sync the file path across processes
+ parts = torch.hub.urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
+
+ return cached_file
+
+ if is_main_process():
+ timm_hub.download_cached_file(url, check_hash, progress)
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ return get_cached_file_path()
diff --git a/minigpt4/common/eval_utils.py b/minigpt4/common/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3087d2a820a4e2a0d16b9bbfeeaacb9c474653af
--- /dev/null
+++ b/minigpt4/common/eval_utils.py
@@ -0,0 +1,76 @@
+import argparse
+import numpy as np
+from nltk.translate.bleu_score import sentence_bleu
+
+from minigpt4.common.registry import registry
+from minigpt4.common.config import Config
+
+# imports modules for registration
+from minigpt4.datasets.builders import *
+from minigpt4.models import *
+from minigpt4.processors import *
+from minigpt4.runners import *
+from minigpt4.tasks import *
+
+
+
+def eval_parser():
+ parser = argparse.ArgumentParser(description="Demo")
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
+ parser.add_argument("--name", type=str, default='A2', help="evaluation name")
+ parser.add_argument("--ckpt", type=str, help="path to configuration file.")
+ parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.")
+ parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens")
+ parser.add_argument("--batch_size", type=int, default=32)
+ parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model")
+ parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha")
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ return parser
+
+
+def prepare_texts(texts, conv_temp):
+ convs = [conv_temp.copy() for _ in range(len(texts))]
+ [conv.append_message(
+ conv.roles[0], '
{}'.format(text)) for conv, text in zip(convs, texts)]
+ [conv.append_message(conv.roles[1], None) for conv in convs]
+ texts = [conv.get_prompt() for conv in convs]
+ return texts
+
+
+def init_model(args):
+ print('Initialization Model')
+ cfg = Config(args)
+ # cfg.model_cfg.ckpt = args.ckpt
+ # cfg.model_cfg.lora_r = args.lora_r
+ # cfg.model_cfg.lora_alpha = args.lora_alpha
+
+ model_config = cfg.model_cfg
+ model_cls = registry.get_model_class(model_config.arch)
+ model = model_cls.from_config(model_config).to('cuda:0')
+
+# import pudb; pudb.set_trace()
+ key = list(cfg.datasets_cfg.keys())[0]
+ vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+ print('Initialization Finished')
+ return model, vis_processor
+
+def computeIoU(bbox1, bbox2):
+ x1, y1, x2, y2 = bbox1
+ x3, y3, x4, y4 = bbox2
+ intersection_x1 = max(x1, x3)
+ intersection_y1 = max(y1, y3)
+ intersection_x2 = min(x2, x4)
+ intersection_y2 = min(y2, y4)
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
+ union_area = bbox1_area + bbox2_area - intersection_area
+ iou = intersection_area / union_area
+ return iou
diff --git a/minigpt4/common/gradcam.py b/minigpt4/common/gradcam.py
new file mode 100644
index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0
--- /dev/null
+++ b/minigpt4/common/gradcam.py
@@ -0,0 +1,24 @@
+import numpy as np
+from matplotlib import pyplot as plt
+from scipy.ndimage import filters
+from skimage import transform as skimage_transform
+
+
+def getAttMap(img, attMap, blur=True, overlap=True):
+ attMap -= attMap.min()
+ if attMap.max() > 0:
+ attMap /= attMap.max()
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
+ if blur:
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
+ attMap -= attMap.min()
+ attMap /= attMap.max()
+ cmap = plt.get_cmap("jet")
+ attMapV = cmap(attMap)
+ attMapV = np.delete(attMapV, 3, 2)
+ if overlap:
+ attMap = (
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
+ )
+ return attMap
diff --git a/minigpt4/common/logger.py b/minigpt4/common/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5a727213c6478606a154172830cdc43aae6f5a
--- /dev/null
+++ b/minigpt4/common/logger.py
@@ -0,0 +1,195 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import logging
+import time
+from collections import defaultdict, deque
+
+import torch
+import torch.distributed as dist
+
+from minigpt4.common import dist_utils
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not dist_utils.is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ log_msg = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_msg.append("max mem: {memory:.0f}")
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len(iterable)
+ )
+ )
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def setup_logger():
+ logging.basicConfig(
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
+ format="%(asctime)s [%(levelname)s] %(message)s",
+ handlers=[logging.StreamHandler()],
+ )
diff --git a/minigpt4/common/optims.py b/minigpt4/common/optims.py
new file mode 100644
index 0000000000000000000000000000000000000000..58327f723d445633ce7d1b5c3cc799b041319a97
--- /dev/null
+++ b/minigpt4/common/optims.py
@@ -0,0 +1,119 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import math
+
+from minigpt4.common.registry import registry
+
+
+@registry.register_lr_scheduler("linear_warmup_step_lr")
+class LinearWarmupStepLRScheduler:
+ def __init__(
+ self,
+ optimizer,
+ max_epoch,
+ min_lr,
+ init_lr,
+ decay_rate=1,
+ warmup_start_lr=-1,
+ warmup_steps=0,
+ **kwargs
+ ):
+ self.optimizer = optimizer
+
+ self.max_epoch = max_epoch
+ self.min_lr = min_lr
+
+ self.decay_rate = decay_rate
+
+ self.init_lr = init_lr
+ self.warmup_steps = warmup_steps
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+ def step(self, cur_epoch, cur_step):
+ if cur_epoch == 0:
+ warmup_lr_schedule(
+ step=cur_step,
+ optimizer=self.optimizer,
+ max_step=self.warmup_steps,
+ init_lr=self.warmup_start_lr,
+ max_lr=self.init_lr,
+ )
+ else:
+ step_lr_schedule(
+ epoch=cur_epoch,
+ optimizer=self.optimizer,
+ init_lr=self.init_lr,
+ min_lr=self.min_lr,
+ decay_rate=self.decay_rate,
+ )
+
+
+@registry.register_lr_scheduler("linear_warmup_cosine_lr")
+class LinearWarmupCosineLRScheduler:
+ def __init__(
+ self,
+ optimizer,
+ max_epoch,
+ iters_per_epoch,
+ min_lr,
+ init_lr,
+ warmup_steps=0,
+ warmup_start_lr=-1,
+ **kwargs
+ ):
+ self.optimizer = optimizer
+
+ self.max_epoch = max_epoch
+ self.iters_per_epoch = iters_per_epoch
+ self.min_lr = min_lr
+
+ self.init_lr = init_lr
+ self.warmup_steps = warmup_steps
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+ def step(self, cur_epoch, cur_step):
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
+ if total_cur_step < self.warmup_steps:
+ warmup_lr_schedule(
+ step=cur_step,
+ optimizer=self.optimizer,
+ max_step=self.warmup_steps,
+ init_lr=self.warmup_start_lr,
+ max_lr=self.init_lr,
+ )
+ else:
+ cosine_lr_schedule(
+ epoch=total_cur_step,
+ optimizer=self.optimizer,
+ max_epoch=self.max_epoch * self.iters_per_epoch,
+ init_lr=self.init_lr,
+ min_lr=self.min_lr,
+ )
+
+
+def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
+ """Decay the learning rate"""
+ lr = (init_lr - min_lr) * 0.5 * (
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
+ ) + min_lr
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
+ """Warmup the learning rate"""
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+
+
+def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
+ """Decay the learning rate"""
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
diff --git a/minigpt4/common/registry.py b/minigpt4/common/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..679467a7411eda19ed956b810c21234322f06779
--- /dev/null
+++ b/minigpt4/common/registry.py
@@ -0,0 +1,329 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+
+class Registry:
+ mapping = {
+ "builder_name_mapping": {},
+ "task_name_mapping": {},
+ "processor_name_mapping": {},
+ "model_name_mapping": {},
+ "lr_scheduler_name_mapping": {},
+ "runner_name_mapping": {},
+ "state": {},
+ "paths": {},
+ }
+
+ @classmethod
+ def register_builder(cls, name):
+ r"""Register a dataset builder to registry with key 'name'
+
+ Args:
+ name: Key with which the builder will be registered.
+
+ Usage:
+
+ from minigpt4.common.registry import registry
+ from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
+ """
+
+ def wrap(builder_cls):
+ from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+
+ assert issubclass(
+ builder_cls, BaseDatasetBuilder
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
+ builder_cls
+ )
+ if name in cls.mapping["builder_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["builder_name_mapping"][name]
+ )
+ )
+ cls.mapping["builder_name_mapping"][name] = builder_cls
+ return builder_cls
+
+ return wrap
+
+ @classmethod
+ def register_task(cls, name):
+ r"""Register a task to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from minigpt4.common.registry import registry
+ """
+
+ def wrap(task_cls):
+ from minigpt4.tasks.base_task import BaseTask
+
+ assert issubclass(
+ task_cls, BaseTask
+ ), "All tasks must inherit BaseTask class"
+ if name in cls.mapping["task_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["task_name_mapping"][name]
+ )
+ )
+ cls.mapping["task_name_mapping"][name] = task_cls
+ return task_cls
+
+ return wrap
+
+ @classmethod
+ def register_model(cls, name):
+ r"""Register a task to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from minigpt4.common.registry import registry
+ """
+
+ def wrap(model_cls):
+ from minigpt4.models import BaseModel
+
+ assert issubclass(
+ model_cls, BaseModel
+ ), "All models must inherit BaseModel class"
+ if name in cls.mapping["model_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["model_name_mapping"][name]
+ )
+ )
+ cls.mapping["model_name_mapping"][name] = model_cls
+ return model_cls
+
+ return wrap
+
+ @classmethod
+ def register_processor(cls, name):
+ r"""Register a processor to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from minigpt4.common.registry import registry
+ """
+
+ def wrap(processor_cls):
+ from minigpt4.processors import BaseProcessor
+
+ assert issubclass(
+ processor_cls, BaseProcessor
+ ), "All processors must inherit BaseProcessor class"
+ if name in cls.mapping["processor_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["processor_name_mapping"][name]
+ )
+ )
+ cls.mapping["processor_name_mapping"][name] = processor_cls
+ return processor_cls
+
+ return wrap
+
+ @classmethod
+ def register_lr_scheduler(cls, name):
+ r"""Register a model to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from minigpt4.common.registry import registry
+ """
+
+ def wrap(lr_sched_cls):
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
+ )
+ )
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
+ return lr_sched_cls
+
+ return wrap
+
+ @classmethod
+ def register_runner(cls, name):
+ r"""Register a model to registry with key 'name'
+
+ Args:
+ name: Key with which the task will be registered.
+
+ Usage:
+
+ from minigpt4.common.registry import registry
+ """
+
+ def wrap(runner_cls):
+ if name in cls.mapping["runner_name_mapping"]:
+ raise KeyError(
+ "Name '{}' already registered for {}.".format(
+ name, cls.mapping["runner_name_mapping"][name]
+ )
+ )
+ cls.mapping["runner_name_mapping"][name] = runner_cls
+ return runner_cls
+
+ return wrap
+
+ @classmethod
+ def register_path(cls, name, path):
+ r"""Register a path to registry with key 'name'
+
+ Args:
+ name: Key with which the path will be registered.
+
+ Usage:
+
+ from minigpt4.common.registry import registry
+ """
+ assert isinstance(path, str), "All path must be str."
+ if name in cls.mapping["paths"]:
+ raise KeyError("Name '{}' already registered.".format(name))
+ cls.mapping["paths"][name] = path
+
+ @classmethod
+ def register(cls, name, obj):
+ r"""Register an item to registry with key 'name'
+
+ Args:
+ name: Key with which the item will be registered.
+
+ Usage::
+
+ from minigpt4.common.registry import registry
+
+ registry.register("config", {})
+ """
+ path = name.split(".")
+ current = cls.mapping["state"]
+
+ for part in path[:-1]:
+ if part not in current:
+ current[part] = {}
+ current = current[part]
+
+ current[path[-1]] = obj
+
+ # @classmethod
+ # def get_trainer_class(cls, name):
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_builder_class(cls, name):
+ return cls.mapping["builder_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_model_class(cls, name):
+ return cls.mapping["model_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_task_class(cls, name):
+ return cls.mapping["task_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_processor_class(cls, name):
+ return cls.mapping["processor_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_lr_scheduler_class(cls, name):
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
+
+ @classmethod
+ def get_runner_class(cls, name):
+ return cls.mapping["runner_name_mapping"].get(name, None)
+
+ @classmethod
+ def list_runners(cls):
+ return sorted(cls.mapping["runner_name_mapping"].keys())
+
+ @classmethod
+ def list_models(cls):
+ return sorted(cls.mapping["model_name_mapping"].keys())
+
+ @classmethod
+ def list_tasks(cls):
+ return sorted(cls.mapping["task_name_mapping"].keys())
+
+ @classmethod
+ def list_processors(cls):
+ return sorted(cls.mapping["processor_name_mapping"].keys())
+
+ @classmethod
+ def list_lr_schedulers(cls):
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
+
+ @classmethod
+ def list_datasets(cls):
+ return sorted(cls.mapping["builder_name_mapping"].keys())
+
+ @classmethod
+ def get_path(cls, name):
+ return cls.mapping["paths"].get(name, None)
+
+ @classmethod
+ def get(cls, name, default=None, no_warning=False):
+ r"""Get an item from registry with key 'name'
+
+ Args:
+ name (string): Key whose value needs to be retrieved.
+ default: If passed and key is not in registry, default value will
+ be returned with a warning. Default: None
+ no_warning (bool): If passed as True, warning when key doesn't exist
+ will not be generated. Useful for MMF's
+ internal operations. Default: False
+ """
+ original_name = name
+ name = name.split(".")
+ value = cls.mapping["state"]
+ for subname in name:
+ value = value.get(subname, default)
+ if value is default:
+ break
+
+ if (
+ "writer" in cls.mapping["state"]
+ and value == default
+ and no_warning is False
+ ):
+ cls.mapping["state"]["writer"].warning(
+ "Key {} is not present in registry, returning default value "
+ "of {}".format(original_name, default)
+ )
+ return value
+
+ @classmethod
+ def unregister(cls, name):
+ r"""Remove an item from registry with key 'name'
+
+ Args:
+ name: Key which needs to be removed.
+ Usage::
+
+ from mmf.common.registry import registry
+
+ config = registry.unregister("config")
+ """
+ return cls.mapping["state"].pop(name, None)
+
+
+registry = Registry()
diff --git a/minigpt4/common/utils.py b/minigpt4/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3069cd10ce986a1ec249490fa813cae9254bd0d
--- /dev/null
+++ b/minigpt4/common/utils.py
@@ -0,0 +1,424 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import io
+import json
+import logging
+import os
+import pickle
+import re
+import shutil
+import urllib
+import urllib.error
+import urllib.request
+from typing import Optional
+from urllib.parse import urlparse
+
+import numpy as np
+import pandas as pd
+import yaml
+from iopath.common.download import download
+from iopath.common.file_io import file_lock, g_pathmgr
+from minigpt4.common.registry import registry
+from torch.utils.model_zoo import tqdm
+from torchvision.datasets.utils import (
+ check_integrity,
+ download_file_from_google_drive,
+ extract_archive,
+)
+
+
+def now():
+ from datetime import datetime
+
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
+
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+
+def get_cache_path(rel_path):
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
+
+
+def get_abs_path(rel_path):
+ return os.path.join(registry.get_path("library_root"), rel_path)
+
+
+def load_json(filename):
+ with open(filename, "r") as f:
+ return json.load(f)
+
+
+# The following are adapted from torchvision and vissl
+# torchvision: https://github.com/pytorch/vision
+# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
+
+
+def makedir(dir_path):
+ """
+ Create the directory if it does not exist.
+ """
+ is_success = False
+ try:
+ if not g_pathmgr.exists(dir_path):
+ g_pathmgr.mkdirs(dir_path)
+ is_success = True
+ except BaseException:
+ print(f"Error creating directory: {dir_path}")
+ return is_success
+
+
+def get_redirected_url(url: str):
+ """
+ Given a URL, returns the URL it redirects to or the
+ original URL in case of no indirection
+ """
+ import requests
+
+ with requests.Session() as session:
+ with session.get(url, stream=True, allow_redirects=True) as response:
+ if response.history:
+ return response.url
+ else:
+ return url
+
+
+def to_google_drive_download_url(view_url: str) -> str:
+ """
+ Utility function to transform a view URL of google drive
+ to a download URL for google drive
+ Example input:
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
+ Example output:
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
+ """
+ splits = view_url.split("/")
+ assert splits[-1] == "view"
+ file_id = splits[-2]
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
+
+
+def download_google_drive_url(url: str, output_path: str, output_file_name: str):
+ """
+ Download a file from google drive
+ Downloading an URL from google drive requires confirmation when
+ the file of the size is too big (google drive notifies that
+ anti-viral checks cannot be performed on such files)
+ """
+ import requests
+
+ with requests.Session() as session:
+
+ # First get the confirmation token and append it to the URL
+ with session.get(url, stream=True, allow_redirects=True) as response:
+ for k, v in response.cookies.items():
+ if k.startswith("download_warning"):
+ url = url + "&confirm=" + v
+
+ # Then download the content of the file
+ with session.get(url, stream=True, verify=True) as response:
+ makedir(output_path)
+ path = os.path.join(output_path, output_file_name)
+ total_size = int(response.headers.get("Content-length", 0))
+ with open(path, "wb") as file:
+ from tqdm import tqdm
+
+ with tqdm(total=total_size) as progress_bar:
+ for block in response.iter_content(
+ chunk_size=io.DEFAULT_BUFFER_SIZE
+ ):
+ file.write(block)
+ progress_bar.update(len(block))
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+ parts = urlparse(url)
+
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+ return None
+
+ match = re.match(r"/file/d/(?P[^/]*)", parts.path)
+ if match is None:
+ return None
+
+ return match.group("id")
+
+
+def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
+ with open(filename, "wb") as fh:
+ with urllib.request.urlopen(
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
+ ) as response:
+ with tqdm(total=response.length) as pbar:
+ for chunk in iter(lambda: response.read(chunk_size), ""):
+ if not chunk:
+ break
+ pbar.update(chunk_size)
+ fh.write(chunk)
+
+
+def download_url(
+ url: str,
+ root: str,
+ filename: Optional[str] = None,
+ md5: Optional[str] = None,
+) -> None:
+ """Download a file from a url and place it in root.
+ Args:
+ url (str): URL to download file from
+ root (str): Directory to place downloaded file in
+ filename (str, optional): Name to save the file under.
+ If None, use the basename of the URL.
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
+ """
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = os.path.basename(url)
+ fpath = os.path.join(root, filename)
+
+ makedir(root)
+
+ # check if file is already present locally
+ if check_integrity(fpath, md5):
+ print("Using downloaded and verified file: " + fpath)
+ return
+
+ # expand redirect chain if needed
+ url = get_redirected_url(url)
+
+ # check if file is located on Google Drive
+ file_id = _get_google_drive_file_id(url)
+ if file_id is not None:
+ return download_file_from_google_drive(file_id, root, filename, md5)
+
+ # download the file
+ try:
+ print("Downloading " + url + " to " + fpath)
+ _urlretrieve(url, fpath)
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
+ if url[:5] == "https":
+ url = url.replace("https:", "http:")
+ print(
+ "Failed download. Trying https -> http instead."
+ " Downloading " + url + " to " + fpath
+ )
+ _urlretrieve(url, fpath)
+ else:
+ raise e
+
+ # check integrity of downloaded file
+ if not check_integrity(fpath, md5):
+ raise RuntimeError("File not found or corrupted.")
+
+
+def download_and_extract_archive(
+ url: str,
+ download_root: str,
+ extract_root: Optional[str] = None,
+ filename: Optional[str] = None,
+ md5: Optional[str] = None,
+ remove_finished: bool = False,
+) -> None:
+ download_root = os.path.expanduser(download_root)
+ if extract_root is None:
+ extract_root = download_root
+ if not filename:
+ filename = os.path.basename(url)
+
+ download_url(url, download_root, filename, md5)
+
+ archive = os.path.join(download_root, filename)
+ print("Extracting {} to {}".format(archive, extract_root))
+ extract_archive(archive, extract_root, remove_finished)
+
+
+def cache_url(url: str, cache_dir: str) -> str:
+ """
+ This implementation downloads the remote resource and caches it locally.
+ The resource will only be downloaded if not previously requested.
+ """
+ parsed_url = urlparse(url)
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
+ makedir(dirname)
+ filename = url.split("/")[-1]
+ cached = os.path.join(dirname, filename)
+ with file_lock(cached):
+ if not os.path.isfile(cached):
+ logging.info(f"Downloading {url} to {cached} ...")
+ cached = download(url, dirname, filename=filename)
+ logging.info(f"URL {url} cached in {cached}")
+ return cached
+
+
+# TODO (prigoyal): convert this into RAII-style API
+def create_file_symlink(file1, file2):
+ """
+ Simply create the symlinks for a given file1 to file2.
+ Useful during model checkpointing to symlinks to the
+ latest successful checkpoint.
+ """
+ try:
+ if g_pathmgr.exists(file2):
+ g_pathmgr.rm(file2)
+ g_pathmgr.symlink(file1, file2)
+ except Exception as e:
+ logging.info(f"Could NOT create symlink. Error: {e}")
+
+
+def save_file(data, filename, append_to_json=True, verbose=True):
+ """
+ Common i/o utility to handle saving data to various file formats.
+ Supported:
+ .pkl, .pickle, .npy, .json
+ Specifically for .json, users have the option to either append (default)
+ or rewrite by passing in Boolean value to append_to_json.
+ """
+ if verbose:
+ logging.info(f"Saving data to file: {filename}")
+ file_ext = os.path.splitext(filename)[1]
+ if file_ext in [".pkl", ".pickle"]:
+ with g_pathmgr.open(filename, "wb") as fopen:
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
+ elif file_ext == ".npy":
+ with g_pathmgr.open(filename, "wb") as fopen:
+ np.save(fopen, data)
+ elif file_ext == ".json":
+ if append_to_json:
+ with g_pathmgr.open(filename, "a") as fopen:
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
+ fopen.flush()
+ else:
+ with g_pathmgr.open(filename, "w") as fopen:
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
+ fopen.flush()
+ elif file_ext == ".yaml":
+ with g_pathmgr.open(filename, "w") as fopen:
+ dump = yaml.dump(data)
+ fopen.write(dump)
+ fopen.flush()
+ else:
+ raise Exception(f"Saving {file_ext} is not supported yet")
+
+ if verbose:
+ logging.info(f"Saved data to file: {filename}")
+
+
+def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
+ """
+ Common i/o utility to handle loading data from various file formats.
+ Supported:
+ .pkl, .pickle, .npy, .json
+ For the npy files, we support reading the files in mmap_mode.
+ If the mmap_mode of reading is not successful, we load data without the
+ mmap_mode.
+ """
+ if verbose:
+ logging.info(f"Loading data from file: {filename}")
+
+ file_ext = os.path.splitext(filename)[1]
+ if file_ext == ".txt":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = fopen.readlines()
+ elif file_ext in [".pkl", ".pickle"]:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = pickle.load(fopen, encoding="latin1")
+ elif file_ext == ".npy":
+ if mmap_mode:
+ try:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(
+ fopen,
+ allow_pickle=allow_pickle,
+ encoding="latin1",
+ mmap_mode=mmap_mode,
+ )
+ except ValueError as e:
+ logging.info(
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
+ )
+ data = np.load(
+ filename,
+ allow_pickle=allow_pickle,
+ encoding="latin1",
+ mmap_mode=mmap_mode,
+ )
+ logging.info("Successfully loaded without g_pathmgr")
+ except Exception:
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+ else:
+ with g_pathmgr.open(filename, "rb") as fopen:
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+ elif file_ext == ".json":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = json.load(fopen)
+ elif file_ext == ".yaml":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
+ elif file_ext == ".csv":
+ with g_pathmgr.open(filename, "r") as fopen:
+ data = pd.read_csv(fopen)
+ else:
+ raise Exception(f"Reading from {file_ext} is not supported yet")
+ return data
+
+
+def abspath(resource_path: str):
+ """
+ Make a path absolute, but take into account prefixes like
+ "http://" or "manifold://"
+ """
+ regex = re.compile(r"^\w+://")
+ if regex.match(resource_path) is None:
+ return os.path.abspath(resource_path)
+ else:
+ return resource_path
+
+
+def makedir(dir_path):
+ """
+ Create the directory if it does not exist.
+ """
+ is_success = False
+ try:
+ if not g_pathmgr.exists(dir_path):
+ g_pathmgr.mkdirs(dir_path)
+ is_success = True
+ except BaseException:
+ logging.info(f"Error creating directory: {dir_path}")
+ return is_success
+
+
+def is_url(input_url):
+ """
+ Check if an input string is a url. look for http(s):// and ignoring the case
+ """
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
+ return is_url
+
+
+def cleanup_dir(dir):
+ """
+ Utility for deleting a directory. Useful for cleaning the storage space
+ that contains various training artifacts like checkpoints, data etc.
+ """
+ if os.path.exists(dir):
+ logging.info(f"Deleting directory: {dir}")
+ shutil.rmtree(dir)
+ logging.info(f"Deleted contents of directory: {dir}")
+
+
+def get_file_size(filename):
+ """
+ Given a file, get the size of file in MB
+ """
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
+ return size_in_mb
diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ca21d805684d71593c8d738798822411bdecc6
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
@@ -0,0 +1,89 @@
+# coding: utf-8
+
+import sys
+dataDir = '../../VQA'
+sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
+from vqa import VQA
+from vqaEvaluation.vqaEval import VQAEval
+import matplotlib.pyplot as plt
+import skimage.io as io
+import json
+import random
+import os
+
+# set up file names and paths
+versionType ='v2_' # this should be '' when using VQA v2.0 dataset
+taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
+dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
+dataSubType ='train2014'
+annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
+quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
+imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
+resultType ='fake'
+fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
+
+# An example result json file has been provided in './Results' folder.
+
+[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
+resultType, fileType) for fileType in fileTypes]
+
+# create vqa object and vqaRes object
+vqa = VQA(annFile, quesFile)
+vqaRes = vqa.loadRes(resFile, quesFile)
+
+# create vqaEval object by taking vqa and vqaRes
+vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
+
+# evaluate results
+"""
+If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
+By default it uses all the question ids in annotation file
+"""
+vqaEval.evaluate()
+
+# print accuracies
+print "\n"
+print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
+print "Per Question Type Accuracy is the following:"
+for quesType in vqaEval.accuracy['perQuestionType']:
+ print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
+print "\n"
+print "Per Answer Type Accuracy is the following:"
+for ansType in vqaEval.accuracy['perAnswerType']:
+ print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
+print "\n"
+# demo how to use evalQA to retrieve low score result
+evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
+if len(evals) > 0:
+ print 'ground truth answers'
+ randomEval = random.choice(evals)
+ randomAnn = vqa.loadQA(randomEval)
+ vqa.showQA(randomAnn)
+
+ print '\n'
+ print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
+ ann = vqaRes.loadQA(randomEval)[0]
+ print "Answer: %s\n" %(ann['answer'])
+
+ imgId = randomAnn[0]['image_id']
+ imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+ if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
+# plot accuracy for various question types
+plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
+plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
+plt.title('Per Question Type Accuracy', fontsize=10)
+plt.xlabel('Question Types', fontsize=10)
+plt.ylabel('Accuracy', fontsize=10)
+plt.show()
+
+# save evaluation results to ./Results folder
+json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
+json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
+json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
+json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
+
diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..148424d7391f6c8e8070f6dd20f02e2ddb1899cc
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
@@ -0,0 +1 @@
+author='aagrawal'
diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a656044433b08c3b3a7610e0d4f701c9f3f752a
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
@@ -0,0 +1,192 @@
+# coding=utf-8
+
+__author__='aagrawal'
+
+import re
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
+import sys
+
+
+class VQAEval:
+ def __init__(self, vqa, vqaRes, n=2):
+ self.n = n
+ self.accuracy = {}
+ self.evalQA = {}
+ self.evalQuesType = {}
+ self.evalAnsType = {}
+ self.vqa = vqa
+ self.vqaRes = vqaRes
+ self.params = {'question_id': vqa.getQuesIds()}
+ self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
+ "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
+ "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
+ "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
+ "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
+ "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
+ "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
+ "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
+ "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
+ "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
+ "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
+ "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
+ "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
+ "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
+ "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
+ "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
+ "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
+ "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
+ "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
+ "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
+ "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
+ "youll": "you'll", "youre": "you're", "youve": "you've"}
+ self.manualMap = { 'none': '0',
+ 'zero': '0',
+ 'one': '1',
+ 'two': '2',
+ 'three': '3',
+ 'four': '4',
+ 'five': '5',
+ 'six': '6',
+ 'seven': '7',
+ 'eight': '8',
+ 'nine': '9',
+ 'ten': '10'
+ }
+ self.articles = ['a',
+ 'an',
+ 'the'
+ ]
+
+
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
+ self.punct = [';', r"/", '[', ']', '"', '{', '}',
+ '(', ')', '=', '+', '\\', '_', '-',
+ '>', '<', '@', '`', ',', '?', '!']
+
+
+ def evaluate(self, quesIds=None):
+ if quesIds == None:
+ quesIds = [quesId for quesId in self.params['question_id']]
+ gts = {}
+ res = {}
+ for quesId in quesIds:
+ gts[quesId] = self.vqa.qa[quesId]
+ res[quesId] = self.vqaRes.qa[quesId]
+
+ # =================================================
+ # Compute accuracy
+ # =================================================
+ accQA = []
+ accQuesType = {}
+ accAnsType = {}
+ # print "computing accuracy"
+ step = 0
+ for quesId in quesIds:
+ for ansDic in gts[quesId]['answers']:
+ ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
+ ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
+ ansDic['answer'] = ansDic['answer'].strip()
+ resAns = res[quesId]['answer']
+ resAns = resAns.replace('\n', ' ')
+ resAns = resAns.replace('\t', ' ')
+ resAns = resAns.strip()
+ gtAcc = []
+ gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
+
+ if len(set(gtAnswers)) > 1:
+ for ansDic in gts[quesId]['answers']:
+ ansDic['answer'] = self.processPunctuation(ansDic['answer'])
+ ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
+ resAns = self.processPunctuation(resAns)
+ resAns = self.processDigitArticle(resAns)
+
+ for gtAnsDatum in gts[quesId]['answers']:
+ otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
+ matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
+ acc = min(1, float(len(matchingAns))/3)
+ gtAcc.append(acc)
+ quesType = gts[quesId]['question_type']
+ ansType = gts[quesId]['answer_type']
+ avgGTAcc = float(sum(gtAcc))/len(gtAcc)
+ accQA.append(avgGTAcc)
+ if quesType not in accQuesType:
+ accQuesType[quesType] = []
+ accQuesType[quesType].append(avgGTAcc)
+ if ansType not in accAnsType:
+ accAnsType[ansType] = []
+ accAnsType[ansType].append(avgGTAcc)
+ self.setEvalQA(quesId, avgGTAcc)
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
+ if step%100 == 0:
+ self.updateProgress(step/float(len(quesIds)))
+ step = step + 1
+
+ self.setAccuracy(accQA, accQuesType, accAnsType)
+ # print "Done computing accuracy"
+
+ def processPunctuation(self, inText):
+ outText = inText
+ for p in self.punct:
+ if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
+ outText = outText.replace(p, '')
+ else:
+ outText = outText.replace(p, ' ')
+ outText = self.periodStrip.sub("",
+ outText,
+ re.UNICODE)
+ return outText
+
+ def processDigitArticle(self, inText):
+ outText = []
+ tempText = inText.lower().split()
+ for word in tempText:
+ word = self.manualMap.setdefault(word, word)
+ if word not in self.articles:
+ outText.append(word)
+ else:
+ pass
+ for wordId, word in enumerate(outText):
+ if word in self.contractions:
+ outText[wordId] = self.contractions[word]
+ outText = ' '.join(outText)
+ return outText
+
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
+ self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
+ self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
+ self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
+
+ def setEvalQA(self, quesId, acc):
+ self.evalQA[quesId] = round(100*acc, self.n)
+
+ def setEvalQuesType(self, quesId, quesType, acc):
+ if quesType not in self.evalQuesType:
+ self.evalQuesType[quesType] = {}
+ self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
+
+ def setEvalAnsType(self, quesId, ansType, acc):
+ if ansType not in self.evalAnsType:
+ self.evalAnsType[ansType] = {}
+ self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
+
+ def updateProgress(self, progress):
+ barLength = 20
+ status = ""
+ if isinstance(progress, int):
+ progress = float(progress)
+ if not isinstance(progress, float):
+ progress = 0
+ status = "error: progress var must be float\r\n"
+ if progress < 0:
+ progress = 0
+ status = "Halt...\r\n"
+ if progress >= 1:
+ progress = 1
+ status = "Done...\r\n"
+ block = int(round(barLength*progress))
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
+ sys.stdout.write(text)
+ sys.stdout.flush()
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
new file mode 100644
index 0000000000000000000000000000000000000000..406b59642a7c2c208b87b0222a299e48a5831eb1
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
@@ -0,0 +1,73 @@
+# coding: utf-8
+
+from vqaTools.vqa import VQA
+import random
+import skimage.io as io
+import matplotlib.pyplot as plt
+import os
+
+dataDir ='../../VQA'
+versionType ='v2_' # this should be '' when using VQA v2.0 dataset
+taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
+dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
+dataSubType ='train2014'
+annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
+quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
+imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
+
+# initialize VQA api for QA annotations
+vqa=VQA(annFile, quesFile)
+
+# load and display QA annotations for given question types
+"""
+All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
+"""
+annIds = vqa.getQuesIds(quesTypes='how many');
+anns = vqa.loadQA(annIds)
+randomAnn = random.choice(anns)
+vqa.showQA([randomAnn])
+imgId = randomAnn['image_id']
+imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
+# load and display QA annotations for given answer types
+"""
+ansTypes can be one of the following
+yes/no
+number
+other
+"""
+annIds = vqa.getQuesIds(ansTypes='yes/no');
+anns = vqa.loadQA(annIds)
+randomAnn = random.choice(anns)
+vqa.showQA([randomAnn])
+imgId = randomAnn['image_id']
+imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
+# load and display QA annotations for given images
+"""
+Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
+Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
+"""
+ids = vqa.getImgIds()
+annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
+anns = vqa.loadQA(annIds)
+randomAnn = random.choice(anns)
+vqa.showQA([randomAnn])
+imgId = randomAnn['image_id']
+imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..072d8d90cd261c19c62fa4624ca22471fe72abfd
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
@@ -0,0 +1 @@
+__author__ = 'aagrawal'
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f769619fc64ce150d1a462d91ea29282f08104a
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
@@ -0,0 +1,179 @@
+__author__ = 'aagrawal'
+__version__ = '0.9'
+
+# Interface for accessing the VQA dataset.
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
+
+# The following functions are defined:
+# VQA - VQA class that loads VQA annotation file and prepares data structures.
+# getQuesIds - Get question ids that satisfy given filter conditions.
+# getImgIds - Get image ids that satisfy given filter conditions.
+# loadQA - Load questions and answers with the specified question ids.
+# showQA - Display the specified questions and answers.
+# loadRes - Load result file and create result object.
+
+# Help on each function can be accessed by: "help(COCO.function)"
+
+import json
+import datetime
+import copy
+
+
+class VQA:
+ def __init__(self, annotation_file=None, question_file=None):
+ """
+ Constructor of VQA helper class for reading and visualizing questions and answers.
+ :param annotation_file (str): location of VQA annotation file
+ :return:
+ """
+ # load dataset
+ self.dataset = {}
+ self.questions = {}
+ self.qa = {}
+ self.qqa = {}
+ self.imgToQA = {}
+ if not annotation_file == None and not question_file == None:
+ # print 'loading VQA annotations and questions into memory...'
+ time_t = datetime.datetime.utcnow()
+ dataset = json.load(open(annotation_file, 'r'))
+ questions = json.load(open(question_file, 'r'))
+ # print datetime.datetime.utcnow() - time_t
+ self.dataset = dataset
+ self.questions = questions
+ self.createIndex()
+
+ def createIndex(self):
+ imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
+ qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
+ qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
+ for ann in self.dataset['annotations']:
+ imgToQA[ann['image_id']] += [ann]
+ qa[ann['question_id']] = ann
+ for ques in self.questions['questions']:
+ qqa[ques['question_id']] = ques
+ # print 'index created!'
+
+ # create class members
+ self.qa = qa
+ self.qqa = qqa
+ self.imgToQA = imgToQA
+
+ def info(self):
+ """
+ Print information about the VQA annotation file.
+ :return:
+ """
+
+ # for key, value in self.datset['info'].items():
+ # print '%s: %s'%(key, value)
+
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get question ids that satisfy given filter conditions. default skips that filter
+ :param imgIds (int array) : get question ids for given imgs
+ quesTypes (str array) : get question ids for given question types
+ ansTypes (str array) : get question ids for given answer types
+ :return: ids (int array) : integer array of question ids
+ """
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(imgIds) == 0:
+ anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
+ else:
+ anns = self.dataset['annotations']
+ anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
+ anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
+ ids = [ann['question_id'] for ann in anns]
+ return ids
+
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get image ids that satisfy given filter conditions. default skips that filter
+ :param quesIds (int array) : get image ids for given question ids
+ quesTypes (str array) : get image ids for given question types
+ ansTypes (str array) : get image ids for given answer types
+ :return: ids (int array) : integer array of image ids
+ """
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(quesIds) == 0:
+ anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
+ else:
+ anns = self.dataset['annotations']
+ anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
+ anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
+ ids = [ann['image_id'] for ann in anns]
+ return ids
+
+ def loadQA(self, ids=[]):
+ """
+ Load questions and answers with the specified question ids.
+ :param ids (int array) : integer ids specifying question ids
+ :return: qa (object array) : loaded qa objects
+ """
+ if type(ids) == list:
+ return [self.qa[id] for id in ids]
+ elif type(ids) == int:
+ return [self.qa[ids]]
+
+ def showQA(self, anns):
+ """
+ Display the specified annotations.
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ for ann in anns:
+ quesId = ann['question_id']
+ print("Question: %s" % (self.qqa[quesId]['question']))
+ for ans in ann['answers']:
+ print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
+
+ def loadRes(self, resFile, quesFile):
+ """
+ Load result file and return a result object.
+ :param resFile (str) : file name of result file
+ :return: res (obj) : result api object
+ """
+ res = VQA()
+ res.questions = json.load(open(quesFile))
+ res.dataset['info'] = copy.deepcopy(self.questions['info'])
+ res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
+ res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
+ res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
+ res.dataset['license'] = copy.deepcopy(self.questions['license'])
+
+ # print 'Loading and preparing results... '
+ time_t = datetime.datetime.utcnow()
+ anns = json.load(open(resFile))
+ assert type(anns) == list, 'results is not an array of objects'
+ annsQuesIds = [ann['question_id'] for ann in anns]
+ assert set(annsQuesIds) == set(self.getQuesIds()), \
+ 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
+ for ann in anns:
+ quesId = ann['question_id']
+ if res.dataset['task_type'] == 'Multiple Choice':
+ assert ann['answer'] in self.qqa[quesId][
+ 'multiple_choices'], 'predicted answer is not one of the multiple choices'
+ qaAnn = self.qa[quesId]
+ ann['image_id'] = qaAnn['image_id']
+ ann['question_type'] = qaAnn['question_type']
+ ann['answer_type'] = qaAnn['answer_type']
+ # print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
+
+ res.dataset['annotations'] = anns
+ res.createIndex()
+ return res
diff --git a/minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt b/minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt
new file mode 100644
index 0000000000000000000000000000000000000000..44304fc865d1fee83ca73a36d3fbe2580cc4b5f9
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt
@@ -0,0 +1,81 @@
+how many
+what color is the
+is the
+where is the
+what
+what is
+are the
+what is the
+is there a
+does the
+is the woman
+is the man
+what is on the
+is it
+is the girl
+is the boy
+is the dog
+are they
+who is
+what kind of
+what color are the
+what is in the
+what is the man
+is there
+what is the woman
+what are the
+what is the boy
+are there
+what is the girl
+is this
+how
+which
+how many people are
+is the cat
+why is the
+are
+will the
+what type of
+what is the dog
+do
+is she
+does
+do the
+is
+is the baby
+are there any
+is the lady
+can
+what animal is
+where are the
+is the sun
+what are they
+did the
+what is the cat
+what is the lady
+how many clouds are
+is that
+is the little girl
+is he
+are these
+how many trees are
+how many pillows
+are the people
+why
+is the young
+how many windows are
+is this a
+what is the little
+is the tv
+how many animals are
+who
+how many pictures
+how many plants are
+how many birds are
+what color is
+what is the baby
+is anyone
+what color
+how many bushes
+is the old man
+none of the above
diff --git a/minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt b/minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt
new file mode 100644
index 0000000000000000000000000000000000000000..95590506bf8af7ba1eaeb91746b43da0eb9b4baa
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt
@@ -0,0 +1,65 @@
+how many
+is the
+what
+what color is the
+what is the
+is this
+is this a
+what is
+are the
+what kind of
+is there a
+what type of
+is it
+what are the
+where is the
+is there
+does the
+what color are the
+are these
+are there
+which
+is
+what is the man
+is the man
+are
+how
+does this
+what is on the
+what does the
+how many people are
+what is in the
+what is this
+do
+what are
+are they
+what time
+what sport is
+are there any
+is he
+what color is
+why
+where are the
+what color
+who is
+what animal is
+is the woman
+is this an
+do you
+how many people are in
+what room is
+has
+is this person
+what is the woman
+can you
+why is the
+is the person
+what is the color of the
+what is the person
+could
+was
+is that a
+what number is
+what is the name
+what brand
+none of the above
diff --git a/minigpt4/common/vqa_tools/VQA/README.md b/minigpt4/common/vqa_tools/VQA/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..439d59d4d7c761423ab7016ab8768105b2df6c35
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/README.md
@@ -0,0 +1,80 @@
+Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
+===================
+## VQA v2.0 release ##
+This release consists of
+- Real
+ - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
+ - 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
+ - 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
+
+There is only one type of task
+- Open-ended task
+
+## VQA v1.0 release ##
+This release consists of
+- Real
+ - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
+ - 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
+ - 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
+- Abstract
+ - 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
+ - 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
+ - 600,000 answers for training and 300,000 answers for validation (10 per question)
+
+There are two types of tasks
+- Open-ended task
+- Multiple-choice task (18 choices per question)
+
+## Requirements ##
+- python 2.7
+- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
+- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
+
+## Files ##
+./Questions
+- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
+- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
+- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
+ - [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
+ - [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
+- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
+
+./Annotations
+- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
+- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
+- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
+ - [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
+ - [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
+- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
+
+./Images
+- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
+- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
+
+./PythonHelperTools
+- This directory contains the Python API to read and visualize the VQA dataset
+- vqaDemo.py (demo script)
+- vqaTools (API to read and visualize data)
+
+./PythonEvaluationTools
+- This directory contains the Python evaluation code
+- vqaEvalDemo.py (evaluation demo script)
+- vqaEvaluation (evaluation code)
+
+./Results
+- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
+- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
+
+./QuestionTypes
+- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
+- mscoco_question_types.txt
+- abstract_v002_question_types.txt
+
+## References ##
+- [VQA: Visual Question Answering](http://visualqa.org/)
+- [Microsoft COCO](http://mscoco.org/)
+
+## Developers ##
+- Aishwarya Agrawal (Virginia Tech)
+- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
+- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
diff --git a/minigpt4/common/vqa_tools/VQA/license.txt b/minigpt4/common/vqa_tools/VQA/license.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f87c06bb4f439b09dec29988b9b23c5995d0e7d4
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/license.txt
@@ -0,0 +1,30 @@
+Copyright (c) 2014, Aishwarya Agrawal
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
+FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+The views and conclusions contained in the software and documentation are
+those
+of the authors and should not be interpreted as representing official
+policies,
+either expressed or implied, of the FreeBSD Project.
diff --git a/minigpt4/common/vqa_tools/__init__.py b/minigpt4/common/vqa_tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b98da85428159ad0dcfab7685c080848ecf8c7b
--- /dev/null
+++ b/minigpt4/common/vqa_tools/__init__.py
@@ -0,0 +1,8 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+__author__ = "aagrawal"
diff --git a/minigpt4/common/vqa_tools/vqa.py b/minigpt4/common/vqa_tools/vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..a386b9094b0528b33e7511aff4027f30459a7ff7
--- /dev/null
+++ b/minigpt4/common/vqa_tools/vqa.py
@@ -0,0 +1,211 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+__author__ = "aagrawal"
+__version__ = "0.9"
+
+# Interface for accessing the VQA dataset.
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
+
+# The following functions are defined:
+# VQA - VQA class that loads VQA annotation file and prepares data structures.
+# getQuesIds - Get question ids that satisfy given filter conditions.
+# getImgIds - Get image ids that satisfy given filter conditions.
+# loadQA - Load questions and answers with the specified question ids.
+# showQA - Display the specified questions and answers.
+# loadRes - Load result file and create result object.
+
+# Help on each function can be accessed by: "help(COCO.function)"
+
+import json
+import datetime
+import copy
+
+
+class VQA:
+ def __init__(self, annotation_file=None, question_file=None):
+ """
+ Constructor of VQA helper class for reading and visualizing questions and answers.
+ :param annotation_file (str): location of VQA annotation file
+ :return:
+ """
+ # load dataset
+ self.dataset = {}
+ self.questions = {}
+ self.qa = {}
+ self.qqa = {}
+ self.imgToQA = {}
+ if not annotation_file == None and not question_file == None:
+ print("loading VQA annotations and questions into memory...")
+ time_t = datetime.datetime.utcnow()
+ dataset = json.load(open(annotation_file, "r"))
+ questions = json.load(open(question_file, "r"))
+ self.dataset = dataset
+ self.questions = questions
+ self.createIndex()
+
+ def createIndex(self):
+ # create index
+ print("creating index...")
+ imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
+ qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
+ qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
+ for ann in self.dataset["annotations"]:
+ imgToQA[ann["image_id"]] += [ann]
+ qa[ann["question_id"]] = ann
+ for ques in self.questions["questions"]:
+ qqa[ques["question_id"]] = ques
+ print("index created!")
+
+ # create class members
+ self.qa = qa
+ self.qqa = qqa
+ self.imgToQA = imgToQA
+
+ def info(self):
+ """
+ Print information about the VQA annotation file.
+ :return:
+ """
+ for key, value in self.datset["info"].items():
+ print("%s: %s" % (key, value))
+
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get question ids that satisfy given filter conditions. default skips that filter
+ :param imgIds (int array) : get question ids for given imgs
+ quesTypes (str array) : get question ids for given question types
+ ansTypes (str array) : get question ids for given answer types
+ :return: ids (int array) : integer array of question ids
+ """
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset["annotations"]
+ else:
+ if not len(imgIds) == 0:
+ anns = sum(
+ [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
+ [],
+ )
+ else:
+ anns = self.dataset["annotations"]
+ anns = (
+ anns
+ if len(quesTypes) == 0
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
+ )
+ anns = (
+ anns
+ if len(ansTypes) == 0
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
+ )
+ ids = [ann["question_id"] for ann in anns]
+ return ids
+
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get image ids that satisfy given filter conditions. default skips that filter
+ :param quesIds (int array) : get image ids for given question ids
+ quesTypes (str array) : get image ids for given question types
+ ansTypes (str array) : get image ids for given answer types
+ :return: ids (int array) : integer array of image ids
+ """
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset["annotations"]
+ else:
+ if not len(quesIds) == 0:
+ anns = sum(
+ [self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
+ )
+ else:
+ anns = self.dataset["annotations"]
+ anns = (
+ anns
+ if len(quesTypes) == 0
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
+ )
+ anns = (
+ anns
+ if len(ansTypes) == 0
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
+ )
+ ids = [ann["image_id"] for ann in anns]
+ return ids
+
+ def loadQA(self, ids=[]):
+ """
+ Load questions and answers with the specified question ids.
+ :param ids (int array) : integer ids specifying question ids
+ :return: qa (object array) : loaded qa objects
+ """
+ if type(ids) == list:
+ return [self.qa[id] for id in ids]
+ elif type(ids) == int:
+ return [self.qa[ids]]
+
+ def showQA(self, anns):
+ """
+ Display the specified annotations.
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ for ann in anns:
+ quesId = ann["question_id"]
+ print("Question: %s" % (self.qqa[quesId]["question"]))
+ for ans in ann["answers"]:
+ print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
+
+ def loadRes(self, resFile, quesFile):
+ """
+ Load result file and return a result object.
+ :param resFile (str) : file name of result file
+ :return: res (obj) : result api object
+ """
+ res = VQA()
+ res.questions = json.load(open(quesFile))
+ res.dataset["info"] = copy.deepcopy(self.questions["info"])
+ res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
+ res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
+ res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
+ res.dataset["license"] = copy.deepcopy(self.questions["license"])
+
+ print("Loading and preparing results... ")
+ time_t = datetime.datetime.utcnow()
+ anns = json.load(open(resFile))
+ assert type(anns) == list, "results is not an array of objects"
+ annsQuesIds = [ann["question_id"] for ann in anns]
+ assert set(annsQuesIds) == set(
+ self.getQuesIds()
+ ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
+ for ann in anns:
+ quesId = ann["question_id"]
+ if res.dataset["task_type"] == "Multiple Choice":
+ assert (
+ ann["answer"] in self.qqa[quesId]["multiple_choices"]
+ ), "predicted answer is not one of the multiple choices"
+ qaAnn = self.qa[quesId]
+ ann["image_id"] = qaAnn["image_id"]
+ ann["question_type"] = qaAnn["question_type"]
+ ann["answer_type"] = qaAnn["answer_type"]
+ print(
+ "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
+ )
+
+ res.dataset["annotations"] = anns
+ res.createIndex()
+ return res
diff --git a/minigpt4/common/vqa_tools/vqa_eval.py b/minigpt4/common/vqa_tools/vqa_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee808b349bb6166c744338b02af2bc84a68650ff
--- /dev/null
+++ b/minigpt4/common/vqa_tools/vqa_eval.py
@@ -0,0 +1,324 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+# coding=utf-8
+
+__author__ = "aagrawal"
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
+import sys
+import re
+
+
+class VQAEval:
+ def __init__(self, vqa=None, vqaRes=None, n=2):
+ self.n = n
+ self.accuracy = {}
+ self.evalQA = {}
+ self.evalQuesType = {}
+ self.evalAnsType = {}
+ self.vqa = vqa
+ self.vqaRes = vqaRes
+ if vqa is not None:
+ self.params = {"question_id": vqa.getQuesIds()}
+ self.contractions = {
+ "aint": "ain't",
+ "arent": "aren't",
+ "cant": "can't",
+ "couldve": "could've",
+ "couldnt": "couldn't",
+ "couldn'tve": "couldn't've",
+ "couldnt've": "couldn't've",
+ "didnt": "didn't",
+ "doesnt": "doesn't",
+ "dont": "don't",
+ "hadnt": "hadn't",
+ "hadnt've": "hadn't've",
+ "hadn'tve": "hadn't've",
+ "hasnt": "hasn't",
+ "havent": "haven't",
+ "hed": "he'd",
+ "hed've": "he'd've",
+ "he'dve": "he'd've",
+ "hes": "he's",
+ "howd": "how'd",
+ "howll": "how'll",
+ "hows": "how's",
+ "Id've": "I'd've",
+ "I'dve": "I'd've",
+ "Im": "I'm",
+ "Ive": "I've",
+ "isnt": "isn't",
+ "itd": "it'd",
+ "itd've": "it'd've",
+ "it'dve": "it'd've",
+ "itll": "it'll",
+ "let's": "let's",
+ "maam": "ma'am",
+ "mightnt": "mightn't",
+ "mightnt've": "mightn't've",
+ "mightn'tve": "mightn't've",
+ "mightve": "might've",
+ "mustnt": "mustn't",
+ "mustve": "must've",
+ "neednt": "needn't",
+ "notve": "not've",
+ "oclock": "o'clock",
+ "oughtnt": "oughtn't",
+ "ow's'at": "'ow's'at",
+ "'ows'at": "'ow's'at",
+ "'ow'sat": "'ow's'at",
+ "shant": "shan't",
+ "shed've": "she'd've",
+ "she'dve": "she'd've",
+ "she's": "she's",
+ "shouldve": "should've",
+ "shouldnt": "shouldn't",
+ "shouldnt've": "shouldn't've",
+ "shouldn'tve": "shouldn't've",
+ "somebody'd": "somebodyd",
+ "somebodyd've": "somebody'd've",
+ "somebody'dve": "somebody'd've",
+ "somebodyll": "somebody'll",
+ "somebodys": "somebody's",
+ "someoned": "someone'd",
+ "someoned've": "someone'd've",
+ "someone'dve": "someone'd've",
+ "someonell": "someone'll",
+ "someones": "someone's",
+ "somethingd": "something'd",
+ "somethingd've": "something'd've",
+ "something'dve": "something'd've",
+ "somethingll": "something'll",
+ "thats": "that's",
+ "thered": "there'd",
+ "thered've": "there'd've",
+ "there'dve": "there'd've",
+ "therere": "there're",
+ "theres": "there's",
+ "theyd": "they'd",
+ "theyd've": "they'd've",
+ "they'dve": "they'd've",
+ "theyll": "they'll",
+ "theyre": "they're",
+ "theyve": "they've",
+ "twas": "'twas",
+ "wasnt": "wasn't",
+ "wed've": "we'd've",
+ "we'dve": "we'd've",
+ "weve": "we've",
+ "werent": "weren't",
+ "whatll": "what'll",
+ "whatre": "what're",
+ "whats": "what's",
+ "whatve": "what've",
+ "whens": "when's",
+ "whered": "where'd",
+ "wheres": "where's",
+ "whereve": "where've",
+ "whod": "who'd",
+ "whod've": "who'd've",
+ "who'dve": "who'd've",
+ "wholl": "who'll",
+ "whos": "who's",
+ "whove": "who've",
+ "whyll": "why'll",
+ "whyre": "why're",
+ "whys": "why's",
+ "wont": "won't",
+ "wouldve": "would've",
+ "wouldnt": "wouldn't",
+ "wouldnt've": "wouldn't've",
+ "wouldn'tve": "wouldn't've",
+ "yall": "y'all",
+ "yall'll": "y'all'll",
+ "y'allll": "y'all'll",
+ "yall'd've": "y'all'd've",
+ "y'alld've": "y'all'd've",
+ "y'all'dve": "y'all'd've",
+ "youd": "you'd",
+ "youd've": "you'd've",
+ "you'dve": "you'd've",
+ "youll": "you'll",
+ "youre": "you're",
+ "youve": "you've",
+ }
+ self.manualMap = {
+ "none": "0",
+ "zero": "0",
+ "one": "1",
+ "two": "2",
+ "three": "3",
+ "four": "4",
+ "five": "5",
+ "six": "6",
+ "seven": "7",
+ "eight": "8",
+ "nine": "9",
+ "ten": "10",
+ }
+ self.articles = ["a", "an", "the"]
+
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
+ self.commaStrip = re.compile("(\d)(,)(\d)")
+ self.punct = [
+ ";",
+ r"/",
+ "[",
+ "]",
+ '"',
+ "{",
+ "}",
+ "(",
+ ")",
+ "=",
+ "+",
+ "\\",
+ "_",
+ "-",
+ ">",
+ "<",
+ "@",
+ "`",
+ ",",
+ "?",
+ "!",
+ ]
+
+ def evaluate(self, quesIds=None):
+ if quesIds == None:
+ quesIds = [quesId for quesId in self.params["question_id"]]
+ gts = {}
+ res = {}
+ for quesId in quesIds:
+ gts[quesId] = self.vqa.qa[quesId]
+ res[quesId] = self.vqaRes.qa[quesId]
+
+ # =================================================
+ # Compute accuracy
+ # =================================================
+ accQA = []
+ accQuesType = {}
+ accAnsType = {}
+ print("computing accuracy")
+ step = 0
+ for quesId in quesIds:
+ resAns = res[quesId]["answer"]
+ resAns = resAns.replace("\n", " ")
+ resAns = resAns.replace("\t", " ")
+ resAns = resAns.strip()
+ resAns = self.processPunctuation(resAns)
+ resAns = self.processDigitArticle(resAns)
+ gtAcc = []
+ gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
+ if len(set(gtAnswers)) > 1:
+ for ansDic in gts[quesId]["answers"]:
+ ansDic["answer"] = self.processPunctuation(ansDic["answer"])
+ for gtAnsDatum in gts[quesId]["answers"]:
+ otherGTAns = [
+ item for item in gts[quesId]["answers"] if item != gtAnsDatum
+ ]
+ matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
+ acc = min(1, float(len(matchingAns)) / 3)
+ gtAcc.append(acc)
+ quesType = gts[quesId]["question_type"]
+ ansType = gts[quesId]["answer_type"]
+ avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
+ accQA.append(avgGTAcc)
+ if quesType not in accQuesType:
+ accQuesType[quesType] = []
+ accQuesType[quesType].append(avgGTAcc)
+ if ansType not in accAnsType:
+ accAnsType[ansType] = []
+ accAnsType[ansType].append(avgGTAcc)
+ self.setEvalQA(quesId, avgGTAcc)
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
+ if step % 100 == 0:
+ self.updateProgress(step / float(len(quesIds)))
+ step = step + 1
+
+ self.setAccuracy(accQA, accQuesType, accAnsType)
+ print("Done computing accuracy")
+
+ def processPunctuation(self, inText):
+ outText = inText
+ for p in self.punct:
+ if (p + " " in inText or " " + p in inText) or (
+ re.search(self.commaStrip, inText) != None
+ ):
+ outText = outText.replace(p, "")
+ else:
+ outText = outText.replace(p, " ")
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
+ return outText
+
+ def processDigitArticle(self, inText):
+ outText = []
+ tempText = inText.lower().split()
+ for word in tempText:
+ word = self.manualMap.setdefault(word, word)
+ if word not in self.articles:
+ outText.append(word)
+ else:
+ pass
+ for wordId, word in enumerate(outText):
+ if word in self.contractions:
+ outText[wordId] = self.contractions[word]
+ outText = " ".join(outText)
+ return outText
+
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
+ self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
+ self.accuracy["perQuestionType"] = {
+ quesType: round(
+ 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
+ self.n,
+ )
+ for quesType in accQuesType
+ }
+ self.accuracy["perAnswerType"] = {
+ ansType: round(
+ 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
+ )
+ for ansType in accAnsType
+ }
+
+ def setEvalQA(self, quesId, acc):
+ self.evalQA[quesId] = round(100 * acc, self.n)
+
+ def setEvalQuesType(self, quesId, quesType, acc):
+ if quesType not in self.evalQuesType:
+ self.evalQuesType[quesType] = {}
+ self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
+
+ def setEvalAnsType(self, quesId, ansType, acc):
+ if ansType not in self.evalAnsType:
+ self.evalAnsType[ansType] = {}
+ self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
+
+ def updateProgress(self, progress):
+ barLength = 20
+ status = ""
+ if isinstance(progress, int):
+ progress = float(progress)
+ if not isinstance(progress, float):
+ progress = 0
+ status = "error: progress var must be float\r\n"
+ if progress < 0:
+ progress = 0
+ status = "Halt...\r\n"
+ if progress >= 1:
+ progress = 1
+ status = "Done...\r\n"
+ block = int(round(barLength * progress))
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format(
+ "#" * block + "-" * (barLength - block), int(progress * 100), status
+ )
+ sys.stdout.write(text)
+ sys.stdout.flush()
diff --git a/minigpt4/configs/.DS_Store b/minigpt4/configs/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..36141bb62b682f7fe7f49f91165de294958da267
Binary files /dev/null and b/minigpt4/configs/.DS_Store differ
diff --git a/minigpt4/configs/datasets/cc_sbu/align.yaml b/minigpt4/configs/datasets/cc_sbu/align.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5710834200fe45449d60185d467d6bcb90a98cca
--- /dev/null
+++ b/minigpt4/configs/datasets/cc_sbu/align.yaml
@@ -0,0 +1,5 @@
+datasets:
+ cc_sbu_align:
+ data_type: images
+ build_info:
+ storage: /path/to/cc_sbu_align/
diff --git a/minigpt4/configs/datasets/cc_sbu/defaults.yaml b/minigpt4/configs/datasets/cc_sbu/defaults.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60390eece551fe06a0f7c3ebb395351794b9f5f1
--- /dev/null
+++ b/minigpt4/configs/datasets/cc_sbu/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+ cc_sbu:
+ data_type: images
+ build_info:
+ storage: /path/to/cc_sbu_dataset/{00000..01255}.tar
diff --git a/minigpt4/configs/datasets/detect_mimic/detect_mimic.yaml b/minigpt4/configs/datasets/detect_mimic/detect_mimic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..38d2af4c744188eb385b40f0729e10d7af184343
--- /dev/null
+++ b/minigpt4/configs/datasets/detect_mimic/detect_mimic.yaml
@@ -0,0 +1,6 @@
+datasets:
+ detect_mimic:
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/mimic-cxr-dataset/detection_MIMIC
+ ann_path: /miniGPT-Med/json_files/MIMIC-bbox/train_detect_mimic.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/grounding_SLAKE/grounding_SLAKE.yaml b/minigpt4/configs/datasets/grounding_SLAKE/grounding_SLAKE.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..82e04fb7fb92c77b54a58fae77a01886d708e58c
--- /dev/null
+++ b/minigpt4/configs/datasets/grounding_SLAKE/grounding_SLAKE.yaml
@@ -0,0 +1,6 @@
+datasets:
+ grounding_SLAKE:
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/SLAKE_images/imgs
+ ann_path: /miniGPT-Med/json_files/SLAKE/grounding_train_SLAKE.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/grounding_nlst/grounding_nlst.yaml b/minigpt4/configs/datasets/grounding_nlst/grounding_nlst.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..56d4483155ac5fdf96132d03f59b4840d79e9a95
--- /dev/null
+++ b/minigpt4/configs/datasets/grounding_nlst/grounding_nlst.yaml
@@ -0,0 +1,6 @@
+datasets:
+ grounding_nlst:
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/NLST/NLST_images
+ ann_path: /miniGPT-Med/json_files/NLST/NLST_train.json
diff --git a/minigpt4/configs/datasets/grounding_rsna/grounding_rsna.yaml b/minigpt4/configs/datasets/grounding_rsna/grounding_rsna.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b7e30acb8fd2e1bb6cf4784c7b1128bd4ee0265d
--- /dev/null
+++ b/minigpt4/configs/datasets/grounding_rsna/grounding_rsna.yaml
@@ -0,0 +1,6 @@
+datasets:
+ grounding_rsna: ## check this
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/RSNA/RSNA-bbox-1024
+ ann_path: /miniGPT-Med/json_files/RSNA/RSNA_train.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/identify_nlst/identify_nlst.yaml b/minigpt4/configs/datasets/identify_nlst/identify_nlst.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..63ccac039d3103f0da4ac4d0fa922fed40ecf801
--- /dev/null
+++ b/minigpt4/configs/datasets/identify_nlst/identify_nlst.yaml
@@ -0,0 +1,6 @@
+datasets:
+ identify_nlst:
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/NLST/NLST_images
+ ann_path: /miniGPT-Med/json_files/NLST/NLST_train.json
diff --git a/minigpt4/configs/datasets/identify_rsna/identify_rsna.yaml b/minigpt4/configs/datasets/identify_rsna/identify_rsna.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3eaba4da276d32129ca0ea47018653eade4a472d
--- /dev/null
+++ b/minigpt4/configs/datasets/identify_rsna/identify_rsna.yaml
@@ -0,0 +1,6 @@
+datasets:
+ identify_rsna: ## check this
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/RSNA/RSNA-bbox-1024
+ ann_path: /miniGPT-Med/json_files/RSNA/RSNA_train.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/mimic_cxr/mimic_cxr.yaml b/minigpt4/configs/datasets/mimic_cxr/mimic_cxr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..61fcd052e259bf294d0befab50fb22ee656dbbc5
--- /dev/null
+++ b/minigpt4/configs/datasets/mimic_cxr/mimic_cxr.yaml
@@ -0,0 +1,6 @@
+datasets:
+ mimic_cxr:
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/mimic-cxr-dataset/image
+ ann_path: /miniGPT-Med/json_files/mimic/MIMIC_train.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/nlst/nlst.yaml b/minigpt4/configs/datasets/nlst/nlst.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..07c334f3ff962d5ba597c81a1a76c258d0cda203
--- /dev/null
+++ b/minigpt4/configs/datasets/nlst/nlst.yaml
@@ -0,0 +1,6 @@
+datasets:
+ nlst: ## check this
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/NLST/NLST_images
+ ann_path: /miniGPT-Med/json_files/NLST/NLST_train.json
diff --git a/minigpt4/configs/datasets/radvqa/radvqa.yaml b/minigpt4/configs/datasets/radvqa/radvqa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1f90cadb4f57c76fca784bcdc3c1cf0f2c4a1c87
--- /dev/null
+++ b/minigpt4/configs/datasets/radvqa/radvqa.yaml
@@ -0,0 +1,6 @@
+datasets:
+ radvqa: ## check this
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/radVQA/VQA_RAD_Images
+ ann_path: /miniGPT-Med/json_files/vqa/vqa_train.json
diff --git a/minigpt4/configs/datasets/refer_nlst/refer_nlst.yaml b/minigpt4/configs/datasets/refer_nlst/refer_nlst.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dc9cbaaad1dbb5040f31fcce912b48860c6f761c
--- /dev/null
+++ b/minigpt4/configs/datasets/refer_nlst/refer_nlst.yaml
@@ -0,0 +1,6 @@
+datasets:
+ refer_nlst:
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/NLST/NLST_images
+ ann_path: /miniGPT-Med/json_files/NLST/NLST_train.json
diff --git a/minigpt4/configs/datasets/refer_rsna/refer_rsna.yaml b/minigpt4/configs/datasets/refer_rsna/refer_rsna.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f1deda304ccab91b71c487ccd6652ed1e708268f
--- /dev/null
+++ b/minigpt4/configs/datasets/refer_rsna/refer_rsna.yaml
@@ -0,0 +1,6 @@
+datasets:
+ refer_rsna: ## check this
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/RSNA/RSNA-bbox-1024
+ ann_path: /miniGPT-Med/json_files/RSNA/RSNA_train.json
\ No newline at end of file
diff --git a/minigpt4/configs/datasets/rsna/rsna.yaml b/minigpt4/configs/datasets/rsna/rsna.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..df746befa439d38a17276b9890f66b0a207db4d9
--- /dev/null
+++ b/minigpt4/configs/datasets/rsna/rsna.yaml
@@ -0,0 +1,6 @@
+datasets:
+ rsna: ## check this
+ data_type: images
+ build_info:
+ image_path: /miniGPT-Med/RSNA/RSNA-bbox-1024
+ ann_path: /miniGPT-Med/json_files/RSNA/RSNA_train.json
diff --git a/minigpt4/configs/default.yaml b/minigpt4/configs/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ff5a6a23fa2e3914938631b96c71fdf723dbbc10
--- /dev/null
+++ b/minigpt4/configs/default.yaml
@@ -0,0 +1,5 @@
+env:
+ # For default users
+ # cache_root: "cache"
+ # For internal use with persistent storage
+ cache_root: "/export/home/.cache/minigpt4"
diff --git a/minigpt4/configs/models/minigpt_v2.yaml b/minigpt4/configs/models/minigpt_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d0dd799057b0d9cccc04366f0900c1622a1000b5
--- /dev/null
+++ b/minigpt4/configs/models/minigpt_v2.yaml
@@ -0,0 +1,31 @@
+model:
+ arch: minigpt_v2
+
+ # vit encoder
+ image_size: 448
+ drop_path_rate: 0
+ use_grad_checkpoint: False
+ vit_precision: "fp16"
+ freeze_vit: True
+
+ # generation configs
+ prompt: ""
+
+ llama_model: "/ibex/project/c2106/RadGPT/MiniGPT4-v2/llama-2-7b-chat-hf"
+ lora_r: 64
+ lora_alpha: 16
+
+
+preprocess:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ eval:
+ name: "blip2_image_eval"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+ eval:
+ name: "blip_caption"
\ No newline at end of file
diff --git a/minigpt4/conversation/.DS_Store b/minigpt4/conversation/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..38734ca2de71d90578b12a191d5ff30a57f26d5c
Binary files /dev/null and b/minigpt4/conversation/.DS_Store differ
diff --git a/minigpt4/conversation/__init__.py b/minigpt4/conversation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/minigpt4/conversation/__pycache__/__init__.cpython-310.pyc b/minigpt4/conversation/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3f66bc171f09f2368108732061b070d8e7b1f71
Binary files /dev/null and b/minigpt4/conversation/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/conversation/__pycache__/__init__.cpython-39.pyc b/minigpt4/conversation/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4cb976430df8ca1c08dfa50b997258c1ea78dbc
Binary files /dev/null and b/minigpt4/conversation/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/conversation/__pycache__/conversation.cpython-310.pyc b/minigpt4/conversation/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca51d725200e7354672909b391d0e8e2dcc53000
Binary files /dev/null and b/minigpt4/conversation/__pycache__/conversation.cpython-310.pyc differ
diff --git a/minigpt4/conversation/__pycache__/conversation.cpython-39.pyc b/minigpt4/conversation/__pycache__/conversation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..307b91db0fe00124615604edc6617a84d121f827
Binary files /dev/null and b/minigpt4/conversation/__pycache__/conversation.cpython-39.pyc differ
diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..829e034f758d66087afda14b79bbe7a27f2b3852
--- /dev/null
+++ b/minigpt4/conversation/conversation.py
@@ -0,0 +1,233 @@
+import argparse
+import time
+from threading import Thread
+from PIL import Image
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
+from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
+
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple, Any
+
+from minigpt4.common.registry import registry
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ # system_img: List[Image.Image] = []
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+
+ skip_next: bool = False
+ conv_id: Any = None
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + message + seps[i % 2]
+ else:
+ ret += role
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ # system_img=self.system_img,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ conv_id=self.conv_id)
+
+ def dict(self):
+ return {
+ "system": self.system,
+ # "system_img": self.system_img,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ "conv_id": self.conv_id,
+ }
+
+
+class StoppingCriteriaSub(StoppingCriteria):
+
+ def __init__(self, stops=[], encounters=1):
+ super().__init__()
+ self.stops = stops
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ for stop in self.stops:
+ if torch.all(input_ids[:, -len(stop):] == stop).item():
+ return True
+
+ return False
+
+
+CONV_VISION_Vicuna0 = Conversation(
+ system="Give the following image:
ImageContent. "
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
+ roles=("Human: ", "Assistant: "),
+ messages=[],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+CONV_VISION_LLama2 = Conversation(
+ system="Give the following image:
ImageContent. "
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
+ roles=("[INST] ", " [/INST] "),
+ messages=[],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="",
+)
+
+CONV_VISION_minigptv2 = Conversation(
+ system="",
+ roles=("[INST] ", " [/INST]"),
+ messages=[],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="",
+)
+
+class Chat:
+ def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
+ self.device = device
+ self.model = model
+ self.vis_processor = vis_processor
+
+ if stopping_criteria is not None:
+ self.stopping_criteria = stopping_criteria
+ else:
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
+
+ def ask(self, text, conv):
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
+ and conv.messages[-1][1][-6:] == '': # last message is image.
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
+ else:
+ conv.append_message(conv.roles[0], text)
+
+ def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
+ repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+ embs = self.model.get_context_emb(prompt, img_list)
+
+ current_max_len = embs.shape[1] + max_new_tokens
+ if current_max_len - max_length > 0:
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
+ 'The model will not see the contexts outside the range.')
+ begin_idx = max(0, current_max_len - max_length)
+ embs = embs[:, begin_idx:]
+
+ generation_kwargs = dict(
+ inputs_embeds=embs,
+ max_new_tokens=max_new_tokens,
+ stopping_criteria=self.stopping_criteria,
+ num_beams=num_beams,
+ do_sample=True,
+ min_length=min_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty,
+ temperature=float(temperature),
+ )
+ return generation_kwargs
+
+ def answer(self, conv, img_list, **kargs):
+ generation_dict = self.answer_prepare(conv, img_list, **kargs)
+ output_token = self.model_generate(**generation_dict)[0]
+ output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
+
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
+ output_text = output_text.split('Assistant:')[-1].strip()
+
+ conv.messages[-1][1] = output_text
+ return output_text, output_token.cpu().numpy()
+
+ def stream_answer(self, conv, img_list, **kargs):
+ generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
+ streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
+ generation_kwargs['streamer'] = streamer
+ thread = Thread(target=self.model_generate, kwargs=generation_kwargs)
+ thread.start()
+ return streamer
+
+ def model_generate(self, *args, **kwargs):
+ # for 8 bit and 16 bit compatibility
+ with self.model.maybe_autocast():
+ output = self.model.llama_model.generate(*args, **kwargs)
+ return output
+
+ def encode_img(self, img_list):
+ image = img_list[0]
+ img_list.pop(0)
+ if isinstance(image, str): # is a image path
+ raw_image = Image.open(image).convert('RGB')
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
+ elif isinstance(image, Image.Image):
+ raw_image = image
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
+ elif isinstance(image, torch.Tensor):
+ if len(image.shape) == 3:
+ image = image.unsqueeze(0)
+ image = image.to(self.device)
+
+ image_emb, _ = self.model.encode_img(image)
+ img_list.append(image_emb)
+
+ def upload_img(self, image, conv, img_list):
+ conv.append_message(conv.roles[0], "
")
+ img_list.append(image)
+ msg = "Received."
+
+ return msg
+
diff --git a/minigpt4/datasets/.DS_Store b/minigpt4/datasets/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..79fa52a65174964ea080fb0679ca9b672b086660
Binary files /dev/null and b/minigpt4/datasets/.DS_Store differ
diff --git a/minigpt4/datasets/__init__.py b/minigpt4/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/minigpt4/datasets/__pycache__/__init__.cpython-310.pyc b/minigpt4/datasets/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..113ed8f3c085dd3d3ad80f79203b8098a7a8b73a
Binary files /dev/null and b/minigpt4/datasets/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/datasets/__pycache__/__init__.cpython-38.pyc b/minigpt4/datasets/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c978bde7a6bb80704611195c561306741448cfc
Binary files /dev/null and b/minigpt4/datasets/__pycache__/__init__.cpython-38.pyc differ
diff --git a/minigpt4/datasets/__pycache__/__init__.cpython-39.pyc b/minigpt4/datasets/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4ac468e8b50d52c84c813917b6e98731f71b013
Binary files /dev/null and b/minigpt4/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/datasets/__pycache__/data_utils.cpython-310.pyc b/minigpt4/datasets/__pycache__/data_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bfce5dcf8e7ef43a1653ee79c66ce7c68abfcd8
Binary files /dev/null and b/minigpt4/datasets/__pycache__/data_utils.cpython-310.pyc differ
diff --git a/minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc b/minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bd9de37ff2f5813f363caade1784b70b1647d9e
Binary files /dev/null and b/minigpt4/datasets/__pycache__/data_utils.cpython-39.pyc differ
diff --git a/minigpt4/datasets/builders/.DS_Store b/minigpt4/datasets/builders/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..38734ca2de71d90578b12a191d5ff30a57f26d5c
Binary files /dev/null and b/minigpt4/datasets/builders/.DS_Store differ
diff --git a/minigpt4/datasets/builders/__init__.py b/minigpt4/datasets/builders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51f171c6c868903a721e30a68f4e3b2e960cbbad
--- /dev/null
+++ b/minigpt4/datasets/builders/__init__.py
@@ -0,0 +1,88 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
+from minigpt4.datasets.builders.image_text_pair_builder import (
+ MimicCxrBuilder,
+ RadVQABuilder,
+ RSNABuilder,
+ ReferRSNABuilder,
+ IdentifyRSNABuilder,
+ NlstBuilder,
+ ReferNLSTBuilder,
+ IdentifyNLSTBuilder,
+ GroundingSLAKEBuilder,
+ # DetectMIMICBuilder,
+
+)
+
+from minigpt4.common.registry import registry
+
+__all__ = [
+ 'MimicCxrBuilder',
+ "RadVQABuilder",
+ "RSNABuilder",
+ "ReferRSNABuilder",
+ "IdentifyRSNABuilder",
+ "NlstBuilder",
+ "ReferNLSTBuilder",
+ "IdentifyNLSTBuilder",
+ "GroundingSLAKEBuilder",
+ # "DetectMIMICBuilder",
+]
+
+
+def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
+ """
+ Example
+
+ >>> dataset = load_dataset("coco_caption", cfg=None)
+ >>> splits = dataset.keys()
+ >>> print([len(dataset[split]) for split in splits])
+
+ """
+ if cfg_path is None:
+ cfg = None
+ else:
+ cfg = load_dataset_config(cfg_path)
+
+ try:
+ builder = registry.get_builder_class(name)(cfg)
+ except TypeError:
+ print(
+ f"Dataset {name} not found. Available datasets:\n"
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
+ )
+ exit(1)
+
+ if vis_path is not None:
+ if data_type is None:
+ # use default data type in the config
+ data_type = builder.config.data_type
+
+ assert (
+ data_type in builder.config.build_info
+ ), f"Invalid data_type {data_type} for {name}."
+
+ builder.config.build_info.get(data_type).storage = vis_path
+
+ dataset = builder.build_datasets()
+ return dataset
+
+
+class DatasetZoo:
+ def __init__(self) -> None:
+ self.dataset_zoo = {
+ k: list(v.DATASET_CONFIG_DICT.keys())
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
+ }
+
+ def get_names(self):
+ return list(self.dataset_zoo.keys())
+
+
+dataset_zoo = DatasetZoo()
diff --git a/minigpt4/datasets/builders/__pycache__/__init__.cpython-310.pyc b/minigpt4/datasets/builders/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bd33d4f7bc0b5dfc4f9674f1021196654d1f41b
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/__init__.cpython-38.pyc b/minigpt4/datasets/builders/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9192e145a2f0fa2e645269e18d43b917ed9b493c
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/__init__.cpython-38.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc b/minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..809abb87e61713ecdbc955be70a5695663dfb6fc
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc b/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7c071095464347bbf43856f84d7f460f6d292ab
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-38.pyc b/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d55c1299aca16bd1e9d671df396ea2513db53a51
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-38.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc b/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6379d2047c95676c5f7a977fdd59aa7e8826330
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-310.pyc b/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..503805df95c1d89e2a22020f2fe8b1680dceb605
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-310.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc b/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..234139af7163838b8a0ba4758716adf20f670afa
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc differ
diff --git a/minigpt4/datasets/builders/__pycache__/vqa_builder.cpython-39.pyc b/minigpt4/datasets/builders/__pycache__/vqa_builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3485cefc8179ce3c2b7e77bc0cc72e3273b4bc72
Binary files /dev/null and b/minigpt4/datasets/builders/__pycache__/vqa_builder.cpython-39.pyc differ
diff --git a/minigpt4/datasets/builders/base_dataset_builder.py b/minigpt4/datasets/builders/base_dataset_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b607e3c0a8abaa6b1ccbc711e27ff3755f5ec11
--- /dev/null
+++ b/minigpt4/datasets/builders/base_dataset_builder.py
@@ -0,0 +1,236 @@
+"""
+ This file is from
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+import shutil
+import warnings
+
+from omegaconf import OmegaConf
+import torch.distributed as dist
+from torchvision.datasets.utils import download_url
+
+import minigpt4.common.utils as utils
+from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
+from minigpt4.common.registry import registry
+from minigpt4.processors.base_processor import BaseProcessor
+
+
+
+class BaseDatasetBuilder:
+ train_dataset_cls, eval_dataset_cls = None, None
+
+ def __init__(self, cfg=None):
+ super().__init__()
+
+ if cfg is None:
+ # help to create datasets from default config.
+ self.config = load_dataset_config(self.default_config_path())
+ elif isinstance(cfg, str):
+ self.config = load_dataset_config(cfg)
+ else:
+ # when called from task.build_dataset()
+ self.config = cfg
+
+ self.data_type = self.config.data_type
+
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+
+ def build_datasets(self):
+ # download, split, etc...
+ # only called on 1 GPU/TPU in distributed
+
+ if is_main_process():
+ self._download_data()
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ datasets = self.build() # dataset['train'/'val'/'test']
+
+ return datasets
+
+ def build_processors(self):
+ vis_proc_cfg = self.config.get("vis_processor")
+ txt_proc_cfg = self.config.get("text_processor")
+
+ if vis_proc_cfg is not None:
+ vis_train_cfg = vis_proc_cfg.get("train")
+ vis_eval_cfg = vis_proc_cfg.get("eval")
+
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
+
+ if txt_proc_cfg is not None:
+ txt_train_cfg = txt_proc_cfg.get("train")
+ txt_eval_cfg = txt_proc_cfg.get("eval")
+
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
+
+ @staticmethod
+ def _build_proc_from_cfg(cfg):
+ return (
+ registry.get_processor_class(cfg.name).from_config(cfg)
+ if cfg is not None
+ else None
+ )
+
+ @classmethod
+ def default_config_path(cls, type="default"):
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
+
+ def _download_data(self):
+ self._download_ann()
+ self._download_vis()
+
+ def _download_ann(self):
+ """
+ Download annotation files if necessary.
+ All the vision-language datasets should have annotations of unified format.
+
+ storage_path can be:
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
+
+ Local annotation paths should be relative.
+ """
+ anns = self.config.build_info.annotations
+
+ splits = anns.keys()
+
+ cache_root = registry.get_path("cache_root")
+
+ for split in splits:
+ info = anns[split]
+
+ urls, storage_paths = info.get("url", None), info.storage
+
+ if isinstance(urls, str):
+ urls = [urls]
+ if isinstance(storage_paths, str):
+ storage_paths = [storage_paths]
+
+ assert len(urls) == len(storage_paths)
+
+ for url_or_filename, storage_path in zip(urls, storage_paths):
+ # if storage_path is relative, make it full by prefixing with cache_root.
+ if not os.path.isabs(storage_path):
+ storage_path = os.path.join(cache_root, storage_path)
+
+ dirname = os.path.dirname(storage_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ if os.path.isfile(url_or_filename):
+ src, dst = url_or_filename, storage_path
+ if not os.path.exists(dst):
+ shutil.copyfile(src=src, dst=dst)
+ else:
+ logging.info("Using existing file {}.".format(dst))
+ else:
+ if os.path.isdir(storage_path):
+ # if only dirname is provided, suffix with basename of URL.
+ raise ValueError(
+ "Expecting storage_path to be a file path, got directory {}".format(
+ storage_path
+ )
+ )
+ else:
+ filename = os.path.basename(storage_path)
+
+ download_url(url=url_or_filename, root=dirname, filename=filename)
+
+ def _download_vis(self):
+
+ storage_path = self.config.build_info.get(self.data_type).storage
+ storage_path = utils.get_cache_path(storage_path)
+
+ if not os.path.exists(storage_path):
+ warnings.warn(
+ f"""
+ The specified path {storage_path} for visual inputs does not exist.
+ Please provide a correct path to the visual inputs or
+ refer to datasets/download_scripts/README.md for downloading instructions.
+ """
+ )
+
+ def build(self):
+ """
+ Create by split datasets inheriting torch.utils.data.Datasets.
+
+ # build() can be dataset-specific. Overwrite to customize.
+ """
+ self.build_processors()
+
+ build_info = self.config.build_info
+
+ ann_info = build_info.annotations
+ vis_info = build_info.get(self.data_type)
+
+ datasets = dict()
+ for split in ann_info.keys():
+ if split not in ["train", "val", "test"]:
+ continue
+
+ is_train = split == "train"
+
+ # processors
+ vis_processor = (
+ self.vis_processors["train"]
+ if is_train
+ else self.vis_processors["eval"]
+ )
+ text_processor = (
+ self.text_processors["train"]
+ if is_train
+ else self.text_processors["eval"]
+ )
+
+ # annotation path
+ ann_paths = ann_info.get(split).storage
+ if isinstance(ann_paths, str):
+ ann_paths = [ann_paths]
+
+ abs_ann_paths = []
+ for ann_path in ann_paths:
+ if not os.path.isabs(ann_path):
+ ann_path = utils.get_cache_path(ann_path)
+ abs_ann_paths.append(ann_path)
+ ann_paths = abs_ann_paths
+
+ # visual data storage path
+ vis_path = os.path.join(vis_info.storage, split)
+
+ if not os.path.isabs(vis_path):
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
+ vis_path = utils.get_cache_path(vis_path)
+
+ if not os.path.exists(vis_path):
+ warnings.warn("storage path {} does not exist.".format(vis_path))
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
+ datasets[split] = dataset_cls(
+ vis_processor=vis_processor,
+ text_processor=text_processor,
+ ann_paths=ann_paths,
+ vis_root=vis_path,
+ )
+
+ return datasets
+
+
+def load_dataset_config(cfg_path):
+ cfg = OmegaConf.load(cfg_path).datasets
+ cfg = cfg[list(cfg.keys())[0]]
+
+ return cfg
diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..725b5ca69bb5021a88e42a58ed52e2cf1dab3ca4
--- /dev/null
+++ b/minigpt4/datasets/builders/image_text_pair_builder.py
@@ -0,0 +1,281 @@
+import os
+import logging
+import warnings
+
+from minigpt4.common.registry import registry
+from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
+from minigpt4.datasets.datasets.mimic_cxr_dataset import MimicCxrDataset
+from minigpt4.datasets.datasets.radvqa_dataset import RadVQADataset
+from minigpt4.datasets.datasets.rsna_dataset import RSNADataset,ReferRSNADataset,IdentifyRSNADataset
+from minigpt4.datasets.datasets.nlst_dataset import NlstDataset,ReferNLSTDataset,IdentifyNLSTDataset
+from minigpt4.datasets.datasets.SLAKE_dataset import GroundingSLAKEDatase
+
+@registry.register_builder("cc_sbu_align")
+class CCSBUAlignBuilder(BaseDatasetBuilder):
+ train_dataset_cls = CCSBUAlignDataset
+
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/cc_sbu/align.yaml",
+ }
+
+ def build_datasets(self):
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+ logging.info("Building datasets...")
+ self.build_processors()
+
+ build_info = self.config.build_info
+ storage_path = build_info.storage
+
+ datasets = dict()
+
+ if not os.path.exists(storage_path):
+ warnings.warn("storage path {} does not exist.".format(storage_path))
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
+ vis_root=os.path.join(storage_path, 'image'),
+ )
+
+ return datasets
+
+@registry.register_builder("mimic_cxr")
+class MimicCxrBuilder(BaseDatasetBuilder):
+ train_dataset_cls = MimicCxrDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/mimic_cxr/mimic_cxr.yaml",
+ }
+
+ def build_datasets(self):
+ logging.info("Building MIMIC dataset...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+
+ return datasets
+
+@registry.register_builder("radvqa")
+class RadVQABuilder(BaseDatasetBuilder):
+ train_dataset_cls = RadVQADataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/radvqa/radvqa.yaml",
+ }
+ def build_datasets(self):
+ logging.info("Building RADVQA datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+
+ return datasets
+
+@registry.register_builder("rsna")
+class RSNABuilder(BaseDatasetBuilder):
+ train_dataset_cls = RSNADataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/rsna/rsna.yaml",
+ }
+ def build_datasets(self):
+ logging.info("Building RSNA dataset...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+ return datasets
+
+@registry.register_builder("refer_rsna")
+class ReferRSNABuilder(BaseDatasetBuilder):
+ train_dataset_cls = ReferRSNADataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/refer_rsna/refer_rsna.yaml",
+ }
+
+ def build_datasets(self):
+ logging.info("Building [refer] RSNA datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+ return datasets
+
+@registry.register_builder("identify_rsna")
+class IdentifyRSNABuilder(BaseDatasetBuilder):
+ train_dataset_cls = IdentifyRSNADataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/identify_rsna/identify_rsna.yaml",
+ }
+ def build_datasets(self):
+ logging.info("Building [identify] RSNA dataset...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+ return datasets
+
+
+@registry.register_builder("nlst")
+class NlstBuilder(BaseDatasetBuilder):
+ train_dataset_cls = NlstDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/nlst/nlst.yaml",
+ }
+ def build_datasets(self):
+ logging.info("Building NLST dataset...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+@registry.register_builder("refer_nlst")
+class ReferNLSTBuilder(BaseDatasetBuilder):
+ train_dataset_cls = NlstDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/refer_nlst/refer_nlst.yaml",
+ }
+ def build_datasets(self):
+ logging.info("Building [refer] NLST dataset...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+@registry.register_builder("identify_nlst")
+class IdentifyNLSTBuilder(BaseDatasetBuilder):
+ train_dataset_cls = NlstDataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/identify_nlst/identify_nlst.yaml",
+ }
+ def build_datasets(self):
+ logging.info("Building [identify] NLST dataset...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+@registry.register_builder("grounding_SLAKE")
+class GroundingSLAKEBuilder(BaseDatasetBuilder):
+ train_dataset_cls = GroundingSLAKEDatase
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/grounding_SLAKE/grounding_SLAKE.yaml",
+ }
+
+ def build_datasets(self):
+ logging.info("Building [grounding] NLST dataset...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ dataset_cls = self.train_dataset_cls
+
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors['train'],
+ text_processor=self.text_processors['train'],
+ ann_path=build_info.ann_path,
+ vis_root=build_info.image_path,
+ )
+
+ return datasets
+
+
+
+# @registry.register_builder("detect_mimic")
+# class DetectMIMICBuilder(BaseDatasetBuilder):
+# train_dataset_cls = Detect_MIMIC
+# DATASET_CONFIG_DICT = {
+# "default": "configs/datasets/detect_mimic/detect_mimic.yaml",
+# }
+# def build_datasets(self):
+# logging.info("Building NLST dataset...")
+# self.build_processors()
+# build_info = self.config.build_info
+# datasets = dict()
+
+# dataset_cls = self.train_dataset_cls
+
+# datasets['train'] = dataset_cls(
+# vis_processor=self.vis_processors['train'],
+# text_processor=self.text_processors['train'],
+# ann_path=build_info.ann_path,
+# vis_root=build_info.image_path,
+# )
+
+# return datasets
+
+
diff --git a/minigpt4/datasets/data_utils.py b/minigpt4/datasets/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..773b10facf26e89f71db6f7841a0377f93f1a2a9
--- /dev/null
+++ b/minigpt4/datasets/data_utils.py
@@ -0,0 +1,199 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import gzip
+import logging
+import os
+import random as rnd
+import tarfile
+import zipfile
+import random
+from typing import List
+from tqdm import tqdm
+
+import decord
+from decord import VideoReader
+import webdataset as wds
+import numpy as np
+import torch
+from torch.utils.data.dataset import IterableDataset
+
+from minigpt4.common.registry import registry
+from minigpt4.datasets.datasets.base_dataset import ConcatDataset
+
+
+decord.bridge.set_bridge("torch")
+MAX_INT = registry.get("MAX_INT")
+
+
+class ChainDataset(wds.DataPipeline):
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
+
+ This class is useful to assemble different existing dataset streams. The
+ chaining operation is done on-the-fly, so concatenating large-scale
+ datasets with this class will be efficient.
+
+ Args:
+ datasets (iterable of IterableDataset): datasets to be chained together
+ """
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
+ super().__init__()
+ self.datasets = datasets
+ self.prob = []
+ self.names = []
+ for dataset in self.datasets:
+ if hasattr(dataset, 'name'):
+ self.names.append(dataset.name)
+ else:
+ self.names.append('Unknown')
+ if hasattr(dataset, 'sample_ratio'):
+ self.prob.append(dataset.sample_ratio)
+ else:
+ self.prob.append(1)
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
+
+ def __iter__(self):
+ datastreams = [iter(dataset) for dataset in self.datasets]
+ while True:
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
+ yield next(select_datastream)
+
+
+def apply_to_sample(f, sample):
+ if len(sample) == 0:
+ return {}
+
+ def _apply(x):
+ if torch.is_tensor(x):
+ return f(x)
+ elif isinstance(x, dict):
+ return {key: _apply(value) for key, value in x.items()}
+ elif isinstance(x, list):
+ return [_apply(x) for x in x]
+ else:
+ return x
+
+ return _apply(sample)
+
+
+def move_to_cuda(sample):
+ def _move_to_cuda(tensor):
+ return tensor.cuda()
+
+ return apply_to_sample(_move_to_cuda, sample)
+
+
+def prepare_sample(samples, cuda_enabled=True):
+ if cuda_enabled:
+ samples = move_to_cuda(samples)
+
+ # TODO fp16 support
+
+ return samples
+
+
+def reorg_datasets_by_split(datasets, batch_sizes):
+ """
+ Organizes datasets by split.
+
+ Args:
+ datasets: dict of torch.utils.data.Dataset objects by name.
+
+ Returns:
+ Dict of datasets by split {split_name: List[Datasets]}.
+ """
+ # if len(datasets) == 1:
+ # return datasets[list(datasets.keys())[0]]
+ # else:
+ reorg_datasets = dict()
+ reorg_batch_sizes = dict()
+
+ # reorganize by split
+ for dataset_name, dataset in datasets.items():
+ for split_name, dataset_split in dataset.items():
+ if split_name not in reorg_datasets:
+ reorg_datasets[split_name] = [dataset_split]
+ reorg_batch_sizes[split_name] = [batch_sizes[dataset_name]]
+ else:
+ reorg_datasets[split_name].append(dataset_split)
+ reorg_batch_sizes[split_name].append(batch_sizes[dataset_name])
+
+ return reorg_datasets, reorg_batch_sizes
+
+
+def concat_datasets(datasets):
+ """
+ Concatenates multiple datasets into a single dataset.
+
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
+ generic IterableDataset because it requires creating separate samplers.
+
+ Now only supports conctenating training datasets and assuming validation and testing
+ have only a single dataset. This is because metrics should not be computed on the concatenated
+ datasets.
+
+ Args:
+ datasets: dict of torch.utils.data.Dataset objects by split.
+
+ Returns:
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
+ "val" and "test" remain the same.
+
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
+ a tuple, where the first element is a concatenated map-style dataset and the second
+ element is a chained DataPipeline dataset.
+
+ """
+ # concatenate datasets in the same split
+ for split_name in datasets:
+ if split_name != "train":
+ assert (
+ len(datasets[split_name]) == 1
+ ), "Do not support multiple {} datasets.".format(split_name)
+ datasets[split_name] = datasets[split_name][0]
+ else:
+ iterable_datasets, map_datasets = [], []
+ for dataset in datasets[split_name]:
+ if isinstance(dataset, wds.DataPipeline):
+ logging.info(
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
+ dataset
+ )
+ )
+ iterable_datasets.append(dataset)
+ elif isinstance(dataset, IterableDataset):
+ raise NotImplementedError(
+ "Do not support concatenation of generic IterableDataset."
+ )
+ else:
+ map_datasets.append(dataset)
+
+ # if len(iterable_datasets) > 0:
+ # concatenate map-style datasets and iterable-style datasets separately
+ if len(iterable_datasets) > 1:
+ chained_datasets = (
+ ChainDataset(iterable_datasets)
+ )
+ elif len(iterable_datasets) == 1:
+ chained_datasets = iterable_datasets[0]
+ else:
+ chained_datasets = None
+
+ concat_datasets = (
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
+ )
+
+ train_datasets = concat_datasets, chained_datasets
+ train_datasets = tuple([x for x in train_datasets if x is not None])
+ train_datasets = (
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
+ )
+
+ datasets[split_name] = train_datasets
+
+ return datasets
+
diff --git a/minigpt4/datasets/datasets/MS-CXR.py b/minigpt4/datasets/datasets/MS-CXR.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e0021f59e5fcf07eaff29faeab89cce454c63d
--- /dev/null
+++ b/minigpt4/datasets/datasets/MS-CXR.py
@@ -0,0 +1,172 @@
+import json
+import os
+import random
+from PIL import Image
+from torch.utils.data import Dataset
+
+class MS_CXRDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ self.vis_root = vis_root
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ self.original_size = 1024
+ self.image_size = 100
+ self.instruction_pool = ['[detection] pneumonia']
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ return self.bbox_phrase_preprocess(index)
+
+ def prepare_image_and_annotations(self, info):
+ image = self.process_image(info["key"])
+ bboxs, ref_phrases = self.generate_bboxs_and_phrases(info)
+ return image, bboxs, ref_phrases
+
+ def process_image(self, image_file):
+ image_path = os.path.join(self.vis_root, image_file)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ return self.vis_processor(image)
+
+ def generate_bboxs_and_phrases(self, info):
+ bboxs, ref_phrases = [], []
+ for bbox in info["bbox"]:
+ scaled_bbox = self.scale_bbox(*bbox)
+ self.assert_bbox_in_range(*scaled_bbox)
+ ref_phrases.append("pneumonia")
+ bboxs.append(f"{{<{scaled_bbox[0]}><{scaled_bbox[1]}><{scaled_bbox[2]}><{scaled_bbox[3]}>}}")
+ return bboxs, ref_phrases
+
+ def scale_bbox(self, x1, y1, x2, y2):
+ scale = lambda x: int((x / self.original_size) * self.image_size)
+ return scale(x1), scale(y1), scale(x2), scale(y2)
+
+ def assert_bbox_in_range(self, x1, y1, x2, y2):
+ for coord in [x1, y1, x2, y2]:
+ assert 0 <= coord <= self.image_size, f"{coord} out of range"
+
+ def generate_caption(self, phrases, bounding_boxes):
+ phrase_bbox={}
+ for phrase, bbox in zip(phrases, bounding_boxes):
+ if phrase not in phrase_bbox.keys():
+ generated_phrase = "{}
".format(phrase)
+ generated_phrase_bbox = generated_phrase+str(bbox)
+ else:
+ generated_phrase = phrase_bbox[phrase]
+ generated_phrase_bbox = generated_phrase+""+str(bbox)
+ phrase_bbox[phrase] = generated_phrase_bbox
+ generated_caption= ' '.join(phrase_bbox.values())
+ return generated_caption
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image, bboxs, ref_phrases = self.prepare_image_and_annotations(info)
+
+ generated_caption = self.generate_caption(ref_phrases, bboxs)
+ instruction = f'[INST]
{self.instruction_pool[0]} [/INST]'
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": generated_caption,
+ "image_id": info['key'],
+ }
+
+class ReferMS_CXRDataset(MS_CXRDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path, dataset='refcoco', splitBy='unc'):
+ super().__init__(vis_processor, text_processor, vis_root, ann_path)
+ self.instruction_pool = [
+ "[refer] pneumonia"
+ "[refer] give me the location of pneumonia",
+ "[refer] where is pneumonia ?",
+ "[refer] from this image, tell me the location of pneumonia",
+ "[refer] the location of pneumonia is ",
+ "[refer] could you tell me the location for pneumonia ?",
+ "[refer] where can I locate the pneumonia ?",
+ ]
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image, bboxs, ref_phrases = self.prepare_image_and_annotations(info)
+
+ generated_caption = self.generate_caption(ref_phrases, bboxs)
+ instruction = '[INST]
{} [/INST]'.format(random.choice(self.instruction_pool))
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": generated_caption,
+ "image_id": info['key'],
+ }
+
+class IdentifyMS_CXRDataset(MS_CXRDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ super().__init__(vis_processor, text_processor, vis_root, ann_path)
+ self.instruction_pool = [
+ "[identify] {}",
+ "[identify] what object is in this location {}",
+ "[identify] identify the object present at this location {}",
+ "[identify] what is it in {}",
+ "[identify] describe this object in {}",
+ "[identify] this {} is",
+ "[identify] the object in {} is",
+ ]
+
+ def generate_boxes(self, phrases, bounding_boxes):
+ phrase_bbox = {}
+ for phrase, bbox in zip(phrases, bounding_boxes):
+ if phrase not in phrase_bbox:
+ grounded_bbox = str(bbox)
+ else:
+ grounded_bbox = phrase_bbox[phrase] + "" + str(bbox)
+ phrase_bbox[phrase] = grounded_bbox
+
+ ground_boxes = ' '.join(phrase_bbox.values())
+ return ground_boxes
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image = self.process_image(info['key'])
+ ref_exps = info["bbox"]
+ caption = info["rephrased_caption"]
+ bboxs, ref_phrases = self.generate_bboxs_and_phrases(info)
+ identify_boxes = self.generate_boxes(ref_phrases,bboxs)
+ instruction = random.choice(self.instruction_pool).format(identify_boxes)
+ instruction = f'[INST]
{instruction} [/INST]'
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": caption,
+ "image_id": info['key'],
+ }
+
+
+class evalMS_CXRDataset(Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = data['key']
+ sent = data['objects']
+ image_path = os.path.join(self.root_path, img_id)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+ question = "[detection] pneumonia"
+
+ return image, question, img_id
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/SLAKE_dataset.py b/minigpt4/datasets/datasets/SLAKE_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..206021ad4b5a0261d2c640fc2929da524d971656
--- /dev/null
+++ b/minigpt4/datasets/datasets/SLAKE_dataset.py
@@ -0,0 +1,71 @@
+import json
+import os
+import random
+from PIL import Image
+from torch.utils.data import Dataset
+
+class GroundingSLAKEDatase(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+
+ self.vis_root = vis_root
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.instruction_pool = [
+ '[grounding] please describe this image in details',
+ '[grounding] describe this image as detailed as possible',
+ '[grounding] summarize this image in details',
+ '[grounding] give a thorough description of what you see in this image',
+ ]
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+
+ image_file = info['folder_name']
+ image_path = os.path.join(self.vis_root, image_file)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+
+ answer = info['grounded_caption']
+
+ instruction = random.choice(self.instruction_pool)
+
+ instruction = "
{} ".format(instruction)
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": info['folder_name'],
+ }
+
+
+class evalSLAKEDataset(Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = data['folder_name']
+ # sent = data['objects']
+ image_path = os.path.join(self.root_path, img_id)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+ question = "[grounding] please describe this image in details"
+
+ return image, question, img_id
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/__init__.py b/minigpt4/datasets/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/minigpt4/datasets/datasets/__pycache__/SLAKE_dataset.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/SLAKE_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a8d0abce934930e0b9e914930f9e45e6e77c1270
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/SLAKE_dataset.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/SLAKE_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/SLAKE_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb540b4098a6f51af7e6ac5cbff5f7398cd1ed9f
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/SLAKE_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/__init__.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c2e8632e6b1bfb70f9c868864281496c377df9e
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9601b15ec10942e8e64063d6e9942ac21d69c46f
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/aok_vqa_datasets.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/aok_vqa_datasets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..998765b1ff71c56278febf3d4f7278e1d54b9939
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/aok_vqa_datasets.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/aok_vqa_reasoning_datasets.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/aok_vqa_reasoning_datasets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea9d9a12e37ecd45f464e4156395f1c4c169d505
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/aok_vqa_reasoning_datasets.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6ea57960154ed9440c5a577dab342e90535aa271
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1aa80367c3a888eae34bb685c1c0fd0cca90359a
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/base_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c4ba92f3e7744427cb468dcaee1d14ef9f52ccd
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0185a767f002f3839bd7aedfb9e67550d0f044d0
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/caption_datasets.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/caption_reasoning.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/caption_reasoning.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea619405f121a6ccc6c4f57352569d3f6f8cbf54
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/caption_reasoning.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e114be2e623aa31e880dbbeeaed07f1bf47e2fa3
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96bcc6b1051e677c511dcfcb3cc22b21d93f08df
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/cc_sbu_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/coco_caption.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/coco_caption.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eff7d6ab588df83e2b82e9583b97b407408bae76
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/coco_caption.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/coco_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/coco_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d1b91e0e7114076b61f0a9e50e5daed011c3026
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/coco_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/coco_vqa_datasets.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/coco_vqa_datasets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d81693294c35a0fb71d8da37d93f2010adf185b5
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/coco_vqa_datasets.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/cot.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/cot.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..349d1ec00374dc9c49b6e333cd341713601df6fc
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/cot.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/coyo_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/coyo_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74689dc963eec6b2d21a60c224cdcfee26179c28
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/coyo_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31afc9c5b1eca607e688579344c25c966c55d128
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80f2d4e78f8209a558c1beca2ba41762c1c357a6
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/dataloader_utils.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/doc_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/doc_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8ca3ddf9e2c1f3ba860da1240de2756644ce89c
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/doc_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/gqa_datasets.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/gqa_datasets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..259b9e8e8f79521cda852d27845a2b7086297301
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/gqa_datasets.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/grounded_detailed_image_caption_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/grounded_detailed_image_caption_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9beb5c7a858f28b5489bfd62f6507fab10d54426
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/grounded_detailed_image_caption_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0dc8ac14c0ad216ea3563f767eb542e2458a7f2
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/laion_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/llava_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/llava_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..272b24066280f9e5057c6349c58cf6b2d0ef1ff8
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/llava_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/locna_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/locna_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..44129ad15c85c780c9ec207eac6a1cfa8ef9f268
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/locna_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/luna_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/luna_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..906749468c1fb83908fdd74bf57f404a3d32031a
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/luna_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/lvis_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/lvis_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aba2e1bcda2f87bcf6c5989f7a372e9a4429d8d5
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/lvis_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/mammogram_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/mammogram_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..21995c33ac1fed687e8912642b48ac99d6cf1b25
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/mammogram_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/mimic_cxr_dataset.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/mimic_cxr_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98a98faa66507bdb6f23403289512492936e3da2
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/mimic_cxr_dataset.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/mimic_cxr_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/mimic_cxr_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9729c7a4c4b5c3bf0a0dbb9983594c4899fce6b
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/mimic_cxr_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/nlst_dataset.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/nlst_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c1f9291d9a711417edfb223fa4ae21003b738ab
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/nlst_dataset.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/nlst_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/nlst_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8da006f9525d60f409048ab34cd55977ded75c1
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/nlst_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/open_images.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/open_images.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..683ce3b9ddcd907c0eadb7ede46fa67cb80542ab
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/open_images.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/paint_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/paint_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a57fd5c2286586ed2e34a8198719b929fa4763f
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/paint_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/radvqa_dataset.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/radvqa_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f401705fa6e6b8581d1583260e694e99518bfd2b
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/radvqa_dataset.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/radvqa_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/radvqa_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31df15e4e42749fdc69868f249e0c3f96fc9f89c
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/radvqa_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/reasoning_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/reasoning_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..102456c1fbfe847818811ead9ea30a018752b25e
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/reasoning_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/rsna_dataset.cpython-310.pyc b/minigpt4/datasets/datasets/__pycache__/rsna_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0010d4406ac7bada92faebe42c56ee018ee34342
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/rsna_dataset.cpython-310.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/rsna_dataset.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/rsna_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5580c3362bf03c338fb225e8c0ac5e63035c03b
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/rsna_dataset.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/text_caps.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/text_caps.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aab79703ff5079ac29a423235eb1aa76ec94d4ef
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/text_caps.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/unnatural_instruction.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/unnatural_instruction.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07d4aa4d87424e2fbfde1d678665629c8e52ae79
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/unnatural_instruction.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/video_datasets.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/video_datasets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..010dc2ea793fa81576cd80a95c774f26426b2007
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/video_datasets.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/__pycache__/vqa_datasets.cpython-39.pyc b/minigpt4/datasets/datasets/__pycache__/vqa_datasets.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d35a774fe27f362490dfc7f92053750316e35aeb
Binary files /dev/null and b/minigpt4/datasets/datasets/__pycache__/vqa_datasets.cpython-39.pyc differ
diff --git a/minigpt4/datasets/datasets/base_dataset.py b/minigpt4/datasets/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aebd87d6b5229d8560bc7a01bde0aa9b1c6fb63
--- /dev/null
+++ b/minigpt4/datasets/datasets/base_dataset.py
@@ -0,0 +1,74 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import json
+from typing import Iterable
+
+from torch.utils.data import Dataset, ConcatDataset
+from torch.utils.data.dataloader import default_collate
+
+
+class BaseDataset(Dataset):
+ def __init__(
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
+ ):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ self.vis_root = vis_root
+
+ self.annotation = []
+ # print("ann paths", ann_paths)
+ for ann_path in ann_paths:
+ # print("ann_path", ann_path)
+ ann = json.load(open(ann_path, "r"))
+ if isinstance(ann, dict):
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
+ # self.annotation.extend(json.load(open(ann_path, "r")))
+ else:
+ self.annotation.extend(json.load(open(ann_path, "r")))
+
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self._add_instance_ids()
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def collater(self, samples):
+ return default_collate(samples)
+
+ def set_processors(self, vis_processor, text_processor):
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ def _add_instance_ids(self, key="instance_id"):
+ for idx, ann in enumerate(self.annotation):
+ ann[key] = str(idx)
+
+
+class ConcatDataset(ConcatDataset):
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
+ super().__init__(datasets)
+
+ def collater(self, samples):
+ # TODO For now only supports datasets with same underlying collater implementations
+ all_keys = set()
+ for s in samples:
+ all_keys.update(s)
+
+ shared_keys = all_keys
+ for s in samples:
+ shared_keys = shared_keys & set(s.keys())
+
+ samples_shared_keys = []
+ for s in samples:
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
+
+ return self.datasets[0].collater(samples_shared_keys)
diff --git a/minigpt4/datasets/datasets/caption_datasets.py b/minigpt4/datasets/datasets/caption_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..be40164cf38df218f7c1e96e8c8ef31c18cce841
--- /dev/null
+++ b/minigpt4/datasets/datasets/caption_datasets.py
@@ -0,0 +1,150 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+from collections import OrderedDict
+
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from PIL import Image
+import random
+
+class __DisplMixin:
+ def displ_item(self, index):
+ sample, ann = self.__getitem__(index), self.annotation[index]
+
+ return OrderedDict(
+ {
+ "file": ann["image"],
+ "caption": ann["caption"],
+ "image": sample["image"],
+ }
+ )
+
+
+class CaptionDataset(BaseDataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ self.img_ids = {}
+ n = 0
+ for ann in self.annotation:
+ img_id = ann["image_id"]
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ def __getitem__(self, index):
+
+ # TODO this assumes image input, not general enough
+ ann = self.annotation[index]
+
+ img_file = '{:0>12}.jpg'.format(ann["image_id"])
+ image_path = os.path.join(self.vis_root, img_file)
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ caption = self.text_processor(ann["caption"])
+
+ return {
+ "image": image,
+ "answer": caption,
+ "image_id": self.img_ids[ann["image_id"]],
+ }
+
+
+class COCOCaptionDataset(BaseDataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ self.img_ids = {}
+ n = 0
+
+ self.filter_anntation = []
+
+ for ann in self.annotation:
+ if "train" in ann["image"]:
+ self.filter_anntation.append(ann)
+ self.annotation = self.filter_anntation
+
+ for ann in self.annotation:
+ img_id = ann["image_id"]
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ self.instruction_pool = [
+ 'Briefly describe this image.',
+ 'Provide a concise depiction of this image.',
+ 'Present a short description of this image.',
+ 'Summarize this image in a few words.',
+ 'A short image caption:',
+ 'A short image description:',
+ 'A photo of ',
+ 'An image that shows ',
+ 'Write a short description for the image. ',
+ 'Write a description for the photo.',
+ 'Provide a description of what is presented in the photo.',
+ 'Briefly describe the content of the image.',
+ 'Can you briefly explain what you see in the image?',
+ 'Could you use a few words to describe what you perceive in the photo?',
+ 'Please provide a short depiction of the picture.',
+ 'Using language, provide a short account of the image.',
+ 'Use a few words to illustrate what is happening in the picture.',
+ ]
+ def __getitem__(self, index):
+
+ # TODO this assumes image input, not general enough
+ ann = self.annotation[index]
+
+ # img_file = '{:0>12}.jpg'.format(ann["image_id"])
+ img_file = ann["image"].split("/")[-1]
+ image_path = os.path.join(self.vis_root, img_file)
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+ caption = self.text_processor(ann["caption"])
+
+ instruction = random.choice(self.instruction_pool)
+ instruction = "
[caption] {} ".format(instruction)
+
+ return {
+ "image": image,
+ "answer": caption,
+ "instruction_input": instruction,
+ }
+
+class CaptionEvalDataset(BaseDataset, __DisplMixin):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+ """
+ vis_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ """
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.vis_root, ann["image"])
+ image = Image.open(image_path).convert("RGB")
+
+ image = self.vis_processor(image)
+
+ return {
+ "image": image,
+ "image_id": ann["image_id"],
+ "instance_id": ann["instance_id"],
+ }
diff --git a/minigpt4/datasets/datasets/cc_sbu_dataset.py b/minigpt4/datasets/datasets/cc_sbu_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ab3dbbeb96cfed2c82e2c365d5a4a468f28f69b
--- /dev/null
+++ b/minigpt4/datasets/datasets/cc_sbu_dataset.py
@@ -0,0 +1,190 @@
+import os
+from PIL import Image
+import webdataset as wds
+from minigpt4.datasets.datasets.base_dataset import BaseDataset
+from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
+import json
+import random
+from webdataset import select
+
+
+def process_bbox(phrases, boxes):
+ new_boxes = []
+ for box in boxes:
+ small_box = []
+ for ele in box:
+ small_box.append(int(round(ele,2)*224))
+ new_boxes.append(small_box)
+
+ output = dict()
+
+ for index,phrase in enumerate(phrases):
+ box = new_boxes[index]
+ if phrase not in output.keys():
+ output[phrase]=[str(box)]
+ else:
+ output[phrase].append(str(box))
+
+ full_sentence = ""
+ for phrase in output.keys():
+ if len(output[phrase])==1:
+ bboxs = output[phrase][0]
+ sentence = "{}: {} ".format(phrase,bboxs)
+ else:
+ if len(output[phrase]) >2:
+ output[phrase] = random.sample(output[phrase],1)
+ bboxs = ",".join(output[phrase])
+ sentence = "{}: {} ".format(phrase,bboxs)
+ full_sentence += sentence
+
+ return full_sentence
+
+
+def sample_phrase_box(phrases, boxes):
+ new_boxes = []
+ for box in boxes:
+ small_box = []
+ for ele in box:
+ small_box.append(int(round(ele,2)*224))
+ new_boxes.append(small_box)
+
+ index = random.sample(range(0,len(phrases)),1)[0]
+ return phrases[index], str(new_boxes[index])
+
+def sample_phrase(phrases, region):
+ # new_boxes = []
+ # for box in boxes:
+ # small_box = []
+ # for ele in box:
+ # small_box.append(int(round(ele,2)*224))
+ # new_boxes.append(small_box)
+
+ index = random.sample(range(0,len(phrases)),1)[0]
+
+ return phrases[index], region[index]
+
+
+
+
+class CCSBUDataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, location):
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+ self.instruction_pool = [
+ 'Briefly describe this image.',
+ 'Provide a concise depiction of this image.',
+ 'Present a short description of this image.',
+ 'Summarize this image in a few words.',
+ 'A short image caption:',
+ 'A short image description:',
+ 'A photo of ',
+ 'An image that shows ',
+ 'Write a short description for the image. ',
+ 'Write a description for the photo.',
+ 'Provide a description of what is presented in the photo.',
+ 'Briefly describe the content of the image.',
+ 'Can you briefly explain what you see in the image?',
+ 'Could you use a few words to describe what you perceive in the photo?',
+ 'Please provide a short depiction of the picture.',
+ 'Using language, provide a short account of the image.',
+ 'Use a few words to illustrate what is happening in the picture.',
+ ]
+
+ self.inner_dataset = wds.DataPipeline(
+ wds.ResampledShards(location),
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
+ wds.shuffle(1000, handler=wds.warn_and_continue),
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
+ )
+
+ def to_dict(self, sample):
+ instruction = random.choice(self.instruction_pool)
+
+ # instruction = "###Human:
{}###Assistant: ".format(instruction)
+ instruction = "
[caption] {} ".format(instruction)
+
+ return {
+ "image": sample[0],
+ "instruction_input": instruction,
+ "answer": self.text_processor(sample[1]["caption"]),
+ }
+
+
+class CCSBUBBOXDataset(BaseDataset):
+ def __init__(self, vis_processor, text_processor, location):
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+ self.bbox_json = json.load(open("/ibex/project/c2133/aa_shenx/GroundingDINO/cc_box_filter_new.json"))
+
+ self.inner_dataset = wds.DataPipeline(
+ wds.ResampledShards(location),
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
+ wds.shuffle(1000, handler=wds.warn_and_continue),
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
+ wds.select(self.filter_sample),
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
+ )
+
+ def filter_sample(self,sample):
+ # print(sample[1]["key"] in self.bbox_json)
+ return sample[1]["key"] in self.bbox_json
+
+ def to_dict(self, sample):
+
+ image_key = sample[1]["key"]
+
+ phrases = self.bbox_json[image_key]["phrases"]
+ boxes = self.bbox_json[image_key]["boxes"]
+ phrase_region = self.bbox_json[image_key]["box_regions"]
+
+ phrase, region = sample_phrase(phrases,phrase_region)
+
+ # phrase = " the bounding box of "+phrase+" is "
+ # box = phrase+box
+
+ phrase_input = "Given an image, identify the objects and their bounding boxes in the format of {object, x1,y1,x2,y2}. "
+ box_input = phrase_input + region
+
+ return {
+ "image": sample[0],
+ "answer": self.text_processor(sample[1]["caption"]),
+ "phrase_input": self.text_processor(phrase_input),
+ "box_input": self.text_processor(box_input),
+ "data_type": "bbox",
+ "question_split": True
+ }
+
+
+
+
+
+class CCSBUAlignDataset(CaptionDataset):
+
+ def __getitem__(self, index):
+
+ # TODO this assumes image input, not general enough
+ ann = self.annotation[index]
+
+ img_file = '{}.jpg'.format(ann["image_id"])
+ image_path = os.path.join(self.vis_root, img_file)
+ image = Image.open(image_path).convert("RGB")
+
+ # if ann["image_id"] in self.bbox_json:
+ # print(ann["image_id"])
+ # else:
+ # print("false")
+ # assert False
+
+ image = self.vis_processor(image)
+ caption = ann["caption"]
+
+ return {
+ "image": image,
+ "answer": caption,
+ "image_id": self.img_ids[ann["image_id"]],
+ "data_type": "caption",
+ "question_split": True
+ }
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/dataloader_utils.py b/minigpt4/datasets/datasets/dataloader_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eaa3a58b0ad42ca7937fb51b46e53511cc3cd0c
--- /dev/null
+++ b/minigpt4/datasets/datasets/dataloader_utils.py
@@ -0,0 +1,162 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import time
+import random
+import torch
+from minigpt4.datasets.data_utils import move_to_cuda
+from torch.utils.data import DataLoader
+
+
+class MultiIterLoader:
+ """
+ A simple wrapper for iterating over multiple iterators.
+
+ Args:
+ loaders (List[Loader]): List of Iterator loaders.
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
+ """
+
+ def __init__(self, loaders, ratios=None):
+ # assert all loaders has __next__ method
+ for loader in loaders:
+ assert hasattr(
+ loader, "__next__"
+ ), "Loader {} has no __next__ method.".format(loader)
+
+ if ratios is None:
+ ratios = [1.0] * len(loaders)
+ else:
+ assert len(ratios) == len(loaders)
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
+
+ self.loaders = loaders
+ self.ratios = ratios
+
+ def __next__(self):
+ # random sample from each loader by ratio
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
+ return next(self.loaders[loader_idx])
+
+
+class PrefetchLoader(object):
+ """
+ Modified from https://github.com/ChenRocks/UNITER.
+
+ overlap compute and cuda data transfer
+ (copied and then modified from nvidia apex)
+ """
+
+ def __init__(self, loader):
+ self.loader = loader
+ self.stream = torch.cuda.Stream()
+
+ def __iter__(self):
+ loader_it = iter(self.loader)
+ self.preload(loader_it)
+ batch = self.next(loader_it)
+ while batch is not None:
+ is_tuple = isinstance(batch, tuple)
+ if is_tuple:
+ task, batch = batch
+
+ if is_tuple:
+ yield task, batch
+ else:
+ yield batch
+ batch = self.next(loader_it)
+
+ def __len__(self):
+ return len(self.loader)
+
+ def preload(self, it):
+ try:
+ self.batch = next(it)
+ except StopIteration:
+ self.batch = None
+ return
+ # if record_stream() doesn't work, another option is to make sure
+ # device inputs are created on the main stream.
+ # self.next_input_gpu = torch.empty_like(self.next_input,
+ # device='cuda')
+ # self.next_target_gpu = torch.empty_like(self.next_target,
+ # device='cuda')
+ # Need to make sure the memory allocated for next_* is not still in use
+ # by the main stream at the time we start copying to next_*:
+ # self.stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.stream):
+ self.batch = move_to_cuda(self.batch)
+ # more code for the alternative if record_stream() doesn't work:
+ # copy_ will record the use of the pinned source tensor in this
+ # side stream.
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
+ # self.next_input = self.next_input_gpu
+ # self.next_target = self.next_target_gpu
+
+ def next(self, it):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is not None:
+ record_cuda_stream(batch)
+ self.preload(it)
+ return batch
+
+ def __getattr__(self, name):
+ method = self.loader.__getattribute__(name)
+ return method
+
+
+def record_cuda_stream(batch):
+ if isinstance(batch, torch.Tensor):
+ batch.record_stream(torch.cuda.current_stream())
+ elif isinstance(batch, list) or isinstance(batch, tuple):
+ for t in batch:
+ record_cuda_stream(t)
+ elif isinstance(batch, dict):
+ for t in batch.values():
+ record_cuda_stream(t)
+ else:
+ pass
+
+
+class IterLoader:
+ """
+ A wrapper to convert DataLoader as an infinite iterator.
+
+ Modified from:
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
+ """
+
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
+ self._dataloader = dataloader
+ self.iter_loader = iter(self._dataloader)
+ self._use_distributed = use_distributed
+ self._epoch = 0
+
+ @property
+ def epoch(self) -> int:
+ return self._epoch
+
+ def __next__(self):
+ try:
+ data = next(self.iter_loader)
+ except StopIteration:
+ self._epoch += 1
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
+ self._dataloader.sampler.set_epoch(self._epoch)
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.iter_loader = iter(self._dataloader)
+ data = next(self.iter_loader)
+
+ return data
+
+ def __iter__(self):
+ return self
+
+ def __len__(self):
+ return len(self._dataloader)
diff --git a/minigpt4/datasets/datasets/mimic_cxr_dataset.py b/minigpt4/datasets/datasets/mimic_cxr_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cfe11b5a36e912de4b1215a56740a78c0d3aa93
--- /dev/null
+++ b/minigpt4/datasets/datasets/mimic_cxr_dataset.py
@@ -0,0 +1,101 @@
+import os
+import json
+import random
+from PIL import Image
+from torch.utils.data import Dataset
+
+class MimicCxrDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ self.vis_root = vis_root
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ self.instruction_pool = [
+ 'Describe this image in detail',
+ 'Take a look at this image and describe what you notice',
+ 'Please provide a detailed description of the picture',
+ 'Could you describe the contents of this image for me?'
+ ]
+
+ def load_image(self, image_id):
+ image_file = f'{image_id}.jpg'
+ image_path = os.path.join(self.vis_root, image_file)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+ return image
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+ image = self.load_image(info['image_id'])
+ instruction = random.choice(self.instruction_pool)
+ instruction = f'
{self.text_processor(instruction)}'
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": info['caption'],
+ "image_id": info['image_id'],
+ }
+
+#####Eval Classes#####
+
+class evalMIMICDataset(Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ self.instruction_pool = [
+ 'Describe this image in detail',
+ 'Take a look at this image and describe what you notice',
+ 'Please provide a detailed description of the picture',
+ 'Could you describe the contents of this image for me?'
+ ]
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ info = self.loaded_data[idx]
+ img_id = '{}.jpg'.format(info['image_id'])
+ image_path = os.path.join(self.root_path, img_id)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+
+ answer = info['caption']
+ question = random.choice(self.instruction_pool)
+
+ return image, question, img_id
+
+
+class evalDetectMimicDataset(Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = data['key']
+ sent = data['objects']
+ image_path = os.path.join(self.root_path, img_id)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+ question = f"[detection] {sent}"
+
+ return image, question, img_id
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/nlst_dataset.py b/minigpt4/datasets/datasets/nlst_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca6b8fd466515547dbe6a51166b21a6d352fd0b6
--- /dev/null
+++ b/minigpt4/datasets/datasets/nlst_dataset.py
@@ -0,0 +1,173 @@
+import json
+import os
+import random
+from PIL import Image
+from torch.utils.data import Dataset
+
+class NlstDataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ self.vis_root = vis_root
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ self.original_size = 512
+ self.image_size = 100
+ self.instruction_pool = ['[detection] tumor']
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ return self.bbox_phrase_preprocess(index)
+
+ def prepare_image_and_annotations(self, info):
+ image = self.process_image(info["key"])
+ bboxs, ref_phrases = self.generate_bboxs_and_phrases(info)
+ return image, bboxs, ref_phrases
+
+ def process_image(self, image_file):
+ image_file = '{}.png'.format(image_file)
+ image_path = os.path.join(self.vis_root, image_file)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ return self.vis_processor(image)
+
+ def generate_bboxs_and_phrases(self, info):
+ bboxs, ref_phrases = [], []
+ for bbox in info["bbox"]:
+ scaled_bbox = self.scale_bbox(*bbox)
+ self.assert_bbox_in_range(*scaled_bbox)
+ ref_phrases.append("tumor")
+ bboxs.append(f"{{<{scaled_bbox[0]}><{scaled_bbox[1]}><{scaled_bbox[2]}><{scaled_bbox[3]}>}}")
+ return bboxs, ref_phrases
+
+ def scale_bbox(self, x1, y1, x2, y2):
+ scale = lambda x: int((x / self.original_size) * self.image_size)
+ return scale(x1), scale(y1), scale(x2), scale(y2)
+
+ def assert_bbox_in_range(self, x1, y1, x2, y2):
+ for coord in [x1, y1, x2, y2]:
+ assert 0 <= coord <= self.image_size, f"{coord} out of range"
+
+ def generate_caption(self, phrases, bounding_boxes):
+ phrase_bbox={}
+ for phrase, bbox in zip(phrases, bounding_boxes):
+ if phrase not in phrase_bbox.keys():
+ generated_phrase = "{}
".format(phrase)
+ generated_phrase_bbox = generated_phrase+str(bbox)
+ else:
+ generated_phrase = phrase_bbox[phrase]
+ generated_phrase_bbox = generated_phrase+""+str(bbox)
+ phrase_bbox[phrase] = generated_phrase_bbox
+ grounded_caption= ' '.join(phrase_bbox.values())
+ return grounded_caption
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image, bboxs, ref_phrases = self.prepare_image_and_annotations(info)
+
+ generated_caption = self.generate_caption(ref_phrases, bboxs)
+ instruction = f'[INST]
{self.instruction_pool[0]} [/INST]'
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": generated_caption,
+ "image_id": info['key'],
+ }
+
+class ReferNLSTDataset(NlstDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ super().__init__(vis_processor, text_processor, vis_root, ann_path)
+ self.instruction_pool = [
+ "[refer] tumor",
+ "[refer] give me the location of tumor ",
+ "[refer] where is tumor ?",
+ "[refer] from this image, tell me the location of tumor",
+ "[refer] the location of tumor is ",
+ "[refer] could you tell me the location for tumor ?",
+ "[refer] where can I locate the tumor ?",
+ ]
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image, bboxs, ref_phrases = self.prepare_image_and_annotations(info)
+
+ grounded_caption = self.generate_caption(ref_phrases, bboxs)
+ instruction = '[INST]
{} [/INST]'.format(random.choice(self.instruction_pool))
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": grounded_caption,
+ "image_id": info['key'],
+ }
+
+class IdentifyNLSTDataset(NlstDataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ super().__init__(vis_processor, text_processor, vis_root, ann_path)
+ self.instruction_pool = [
+ "[identify] {}",
+ "[identify] what object is in this location {}",
+ "[identify] identify the object present at this location {}",
+ "[identify] what is it in {}",
+ "[identify] describe this object in {}",
+ "[identify] this {} is",
+ "[identify] the object in {} is",
+ ]
+
+ def generate_boxes(self, phrases, bounding_boxes):
+ phrase_bbox = {}
+ for phrase, bbox in zip(phrases, bounding_boxes):
+ if phrase not in phrase_bbox:
+ grounded_bbox = str(bbox)
+ else:
+ grounded_bbox = phrase_bbox[phrase] + "" + str(bbox)
+ phrase_bbox[phrase] = grounded_bbox
+
+ ground_boxes = ' '.join(phrase_bbox.values())
+ return ground_boxes
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image = self.process_image(info['key'])
+ ref_exps = info["bbox"]
+ caption = info["rephrased_caption"]
+ bboxs, ref_phrases = self.generate_bboxs_and_phrases(info)
+ identify_boxes = self.generate_boxes(ref_phrases,bboxs)
+ identify_boxes = ' '.join([bbox for bbox in bboxs])
+ instruction = random.choice(self.instruction_pool).format(identify_boxes)
+ instruction = f'[INST]
{instruction} [/INST]'
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": caption,
+ "image_id": info['key'],
+ }
+
+class eval_NLST_Dataset(Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = '{}.png'.format(data['key'])
+ sent = data['objects']
+ image_path = os.path.join(self.root_path, img_id)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+ question = "[detection] tumor"
+
+ return image, question, img_id
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/radvqa_dataset.py b/minigpt4/datasets/datasets/radvqa_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4192f19563407ab845723dee080f79210c7cb910
--- /dev/null
+++ b/minigpt4/datasets/datasets/radvqa_dataset.py
@@ -0,0 +1,59 @@
+import os
+import json
+from PIL import Image
+from torch.utils.data import Dataset
+
+class RadVQADataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ self.vis_root = vis_root
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+ self.instruction_pool = ["[vqa] {}"]
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ def process_image(self, image_name):
+ image_path = os.path.join(self.vis_root, image_name)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ return self.vis_processor(image)
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ info = self.ann[index]
+ image = self.process_image(info['image_name'])
+ instruction = self.text_processor(self.instruction_pool[0].format(info['question']))
+ instruction = '[INST]
{} [/INST]'.format(instruction)
+
+ answer = str(info['answer'])
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": info['image_name'],
+ }
+
+class evalRadVQADataset(Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ info = self.loaded_data[idx]
+ image_file = '{}'.format(info['image_name'])
+ image_path = os.path.join(self.root_path, image_file)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+ question = "[vqa] {}".format(info['question'])
+ return image, question, image_file
diff --git a/minigpt4/datasets/datasets/rsna_dataset.py b/minigpt4/datasets/datasets/rsna_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45f435b7fecf1fb454d2f8fcadeee1440d279e5
--- /dev/null
+++ b/minigpt4/datasets/datasets/rsna_dataset.py
@@ -0,0 +1,172 @@
+import json
+import os
+import random
+from PIL import Image
+from torch.utils.data import Dataset
+
+class RSNADataset(Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ self.vis_root = vis_root
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ with open(ann_path, 'r') as f:
+ self.ann = json.load(f)
+
+ self.original_size = 1024
+ self.image_size = 100
+ self.instruction_pool = ['[detection] pneumonia']
+
+ def __len__(self):
+ return len(self.ann)
+
+ def __getitem__(self, index):
+ return self.bbox_phrase_preprocess(index)
+
+ def prepare_image_and_annotations(self, info):
+ image = self.process_image(info["key"])
+ bboxs, ref_phrases = self.generate_bboxs_and_phrases(info)
+ return image, bboxs, ref_phrases
+
+ def process_image(self, image_file):
+ image_path = os.path.join(self.vis_root, image_file)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ return self.vis_processor(image)
+
+ def generate_bboxs_and_phrases(self, info):
+ bboxs, ref_phrases = [], []
+ for bbox in info["bbox"]:
+ scaled_bbox = self.scale_bbox(*bbox)
+ self.assert_bbox_in_range(*scaled_bbox)
+ ref_phrases.append("pneumonia")
+ bboxs.append(f"{{<{scaled_bbox[0]}><{scaled_bbox[1]}><{scaled_bbox[2]}><{scaled_bbox[3]}>}}")
+ return bboxs, ref_phrases
+
+ def scale_bbox(self, x1, y1, x2, y2):
+ scale = lambda x: int((x / self.original_size) * self.image_size)
+ return scale(x1), scale(y1), scale(x2), scale(y2)
+
+ def assert_bbox_in_range(self, x1, y1, x2, y2):
+ for coord in [x1, y1, x2, y2]:
+ assert 0 <= coord <= self.image_size, f"{coord} out of range"
+
+ def generate_caption(self, phrases, bounding_boxes):
+ phrase_bbox={}
+ for phrase, bbox in zip(phrases, bounding_boxes):
+ if phrase not in phrase_bbox.keys():
+ generated_phrase = "{}
".format(phrase)
+ generated_phrase_bbox = generated_phrase+str(bbox)
+ else:
+ generated_phrase = phrase_bbox[phrase]
+ generated_phrase_bbox = generated_phrase+""+str(bbox)
+ phrase_bbox[phrase] = generated_phrase_bbox
+ generated_caption= ' '.join(phrase_bbox.values())
+ return generated_caption
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image, bboxs, ref_phrases = self.prepare_image_and_annotations(info)
+
+ generated_caption = self.generate_caption(ref_phrases, bboxs)
+ instruction = f'[INST]
{self.instruction_pool[0]} [/INST]'
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": generated_caption,
+ "image_id": info['key'],
+ }
+
+class ReferRSNADataset(RSNADataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path, dataset='refcoco', splitBy='unc'):
+ super().__init__(vis_processor, text_processor, vis_root, ann_path)
+ self.instruction_pool = [
+ "[refer] pneumonia"
+ "[refer] give me the location of pneumonia",
+ "[refer] where is pneumonia ?",
+ "[refer] from this image, tell me the location of pneumonia",
+ "[refer] the location of pneumonia is ",
+ "[refer] could you tell me the location for pneumonia ?",
+ "[refer] where can I locate the pneumonia ?",
+ ]
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image, bboxs, ref_phrases = self.prepare_image_and_annotations(info)
+
+ generated_caption = self.generate_caption(ref_phrases, bboxs)
+ instruction = '[INST]
{} [/INST]'.format(random.choice(self.instruction_pool))
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": generated_caption,
+ "image_id": info['key'],
+ }
+
+class IdentifyRSNADataset(RSNADataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ super().__init__(vis_processor, text_processor, vis_root, ann_path)
+ self.instruction_pool = [
+ "[identify] {}",
+ "[identify] what object is in this location {}",
+ "[identify] identify the object present at this location {}",
+ "[identify] what is it in {}",
+ "[identify] describe this object in {}",
+ "[identify] this {} is",
+ "[identify] the object in {} is",
+ ]
+
+ def generate_boxes(self, phrases, bounding_boxes):
+ phrase_bbox = {}
+ for phrase, bbox in zip(phrases, bounding_boxes):
+ if phrase not in phrase_bbox:
+ grounded_bbox = str(bbox)
+ else:
+ grounded_bbox = phrase_bbox[phrase] + "" + str(bbox)
+ phrase_bbox[phrase] = grounded_bbox
+
+ ground_boxes = ' '.join(phrase_bbox.values())
+ return ground_boxes
+
+ def bbox_phrase_preprocess(self, index):
+ info = self.ann[index]
+ image = self.process_image(info['key'])
+ ref_exps = info["bbox"]
+ caption = info["rephrased_caption"]
+ bboxs, ref_phrases = self.generate_bboxs_and_phrases(info)
+ identify_boxes = self.generate_boxes(ref_phrases,bboxs)
+ instruction = random.choice(self.instruction_pool).format(identify_boxes)
+ instruction = f'[INST]
{instruction} [/INST]'
+
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": caption,
+ "image_id": info['key'],
+ }
+
+
+class evalRSNADataset(Dataset):
+ def __init__(self, loaded_data, vis_processor, root_path):
+ self.loaded_data = loaded_data
+ self.root_path = root_path
+ self.vis_processor = vis_processor
+
+ def __len__(self):
+ return len(self.loaded_data)
+
+ def __getitem__(self, idx):
+ data = self.loaded_data[idx]
+ img_id = data['key']
+ sent = data['objects']
+ image_path = os.path.join(self.root_path, img_id)
+ grayscale_image = Image.open(image_path).convert("L")
+ image = Image.new("RGB", grayscale_image.size)
+ image.paste(grayscale_image)
+ image = self.vis_processor(image)
+ question = "[detection] pneumonia"
+
+ return image, question, img_id
\ No newline at end of file
diff --git a/minigpt4/models/Qformer.py b/minigpt4/models/Qformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71b12375e10511858a9c505dc795181e6ce5603
--- /dev/null
+++ b/minigpt4/models/Qformer.py
@@ -0,0 +1,1216 @@
+"""
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
+ )
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size
+ )
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
+ )
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+
+ self.config = config
+
+ def forward(
+ self,
+ input_ids=None,
+ position_ids=None,
+ query_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ seq_length = input_ids.size()[1]
+ else:
+ seq_length = 0
+
+ if position_ids is None:
+ position_ids = self.position_ids[
+ :, past_key_values_length : seq_length + past_key_values_length
+ ].clone()
+
+ if input_ids is not None:
+ embeddings = self.word_embeddings(input_ids)
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if query_embeds is not None:
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
+ else:
+ embeddings = query_embeds
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, "embedding_size"
+ ):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
+ )
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ mixed_query_layer = self.query(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if (
+ self.position_embedding_type == "relative_key"
+ or self.position_embedding_type == "relative_key_query"
+ ):
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(-1, 1)
+ position_ids_r = torch.arange(
+ seq_length, dtype=torch.long, device=hidden_states.device
+ ).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(
+ distance + self.max_position_embeddings - 1
+ )
+ positional_embedding = positional_embedding.to(
+ dtype=query_layer.dtype
+ ) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum(
+ "bhld,lrd->bhlr", query_layer, positional_embedding
+ )
+ relative_position_scores_key = torch.einsum(
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
+ )
+ attention_scores = (
+ attention_scores
+ + relative_position_scores_query
+ + relative_position_scores_key
+ )
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
+ )
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = (
+ self.self.attention_head_size * self.self.num_attention_heads
+ )
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if (
+ self.config.add_cross_attention
+ and layer_num % self.config.cross_attention_freq == 0
+ ):
+ self.crossattention = BertAttention(
+ config, is_cross_attention=self.config.add_cross_attention
+ )
+ self.has_cross_attention = True
+ else:
+ self.has_cross_attention = False
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ self.intermediate_query = BertIntermediate(config)
+ self.output_query = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ query_length=0,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = (
+ past_key_value[:2] if past_key_value is not None else None
+ )
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:-1]
+
+ present_key_value = self_attention_outputs[-1]
+
+ if query_length > 0:
+ query_attention_output = attention_output[:, :query_length, :]
+
+ if self.has_cross_attention:
+ assert (
+ encoder_hidden_states is not None
+ ), "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ query_attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ query_attention_output = cross_attention_outputs[0]
+ outputs = (
+ outputs + cross_attention_outputs[1:-1]
+ ) # add cross attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk_query,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ query_attention_output,
+ )
+ if attention_output.shape[1] > query_length:
+ layer_output_text = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output[:, query_length:, :],
+ )
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+ else:
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def feed_forward_chunk_query(self, attention_output):
+ intermediate_output = self.intermediate_query(attention_output)
+ layer_output = self.output_query(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ query_length=0,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = (
+ () if output_attentions and self.config.add_cross_attention else None
+ )
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(
+ *inputs, past_key_value, output_attentions, query_length
+ )
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ query_length,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: Tensor,
+ input_shape: Tuple[int],
+ device: device,
+ is_decoder: bool,
+ has_query: bool = False,
+ ) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = (
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+ <= seq_ids[None, :, None]
+ )
+
+ # add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ if has_query: # UniLM style attention mask
+ causal_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, prefix_seq_len, seq_length),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=1,
+ )
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+ extended_attention_mask = (
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=self.dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is None:
+ assert (
+ query_embeds is not None
+ ), "You have to specify query_embeds when input_ids is None"
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] - self.config.query_length
+ if past_key_values is not None
+ else 0
+ )
+
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ query_embeds=query_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length + past_key_values_length)), device=device
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if is_decoder:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ device,
+ is_decoder,
+ has_query=(query_embeds is not None),
+ )
+ else:
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, device, is_decoder
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
+ 0
+ ].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
+ ]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ query_length=query_length,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = (
+ self.pooler(sequence_output) if self.pooler is not None else None
+ )
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=True,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction="mean",
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ if labels is not None:
+ use_cache = False
+ if past_key_values is not None:
+ query_embeds = None
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ sequence_output = outputs[0]
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+ if reduction == "none":
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
+ ):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "query_embeds": query_embeds,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx) for past_state in layer_past
+ ),
+ )
+ return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ query_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ query_embeds=query_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ )
+
+ if query_embeds is not None:
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return (
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ )
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc01b56181aa81554efbe9df10ab3678a1c7bb86
--- /dev/null
+++ b/minigpt4/models/__init__.py
@@ -0,0 +1,202 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import torch
+from omegaconf import OmegaConf
+
+from minigpt4.common.registry import registry
+from minigpt4.models.base_model import BaseModel
+from minigpt4.models.minigpt_base import MiniGPTBase
+from minigpt4.models.minigpt4 import MiniGPT4
+from minigpt4.models.minigpt_v2 import MiniGPTv2
+from minigpt4.processors.base_processor import BaseProcessor
+
+
+__all__ = [
+ "load_model",
+ "BaseModel",
+ "MiniGPTBase",
+ "MiniGPT4",
+ "MiniGPTv2"
+]
+
+
+def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
+ """
+ Load supported models.
+
+ To list all available models and types in registry:
+ >>> from minigpt4.models import model_zoo
+ >>> print(model_zoo)
+
+ Args:
+ name (str): name of the model.
+ model_type (str): type of the model.
+ is_eval (bool): whether the model is in eval mode. Default: False.
+ device (str): device to use. Default: "cpu".
+ checkpoint (str): path or to checkpoint. Default: None.
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
+
+ Returns:
+ model (torch.nn.Module): model.
+ """
+
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
+
+ if checkpoint is not None:
+ model.load_checkpoint(checkpoint)
+
+ if is_eval:
+ model.eval()
+
+ if device == "cpu":
+ model = model.float()
+
+ return model.to(device)
+
+
+def load_preprocess(config):
+ """
+ Load preprocessor configs and construct preprocessors.
+
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
+
+ Args:
+ config (dict): preprocessor configs.
+
+ Returns:
+ vis_processors (dict): preprocessors for visual inputs.
+ txt_processors (dict): preprocessors for text inputs.
+
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
+ """
+
+ def _build_proc_from_cfg(cfg):
+ return (
+ registry.get_processor_class(cfg.name).from_config(cfg)
+ if cfg is not None
+ else BaseProcessor()
+ )
+
+ vis_processors = dict()
+ txt_processors = dict()
+
+ vis_proc_cfg = config.get("vis_processor")
+ txt_proc_cfg = config.get("text_processor")
+
+ if vis_proc_cfg is not None:
+ vis_train_cfg = vis_proc_cfg.get("train")
+ vis_eval_cfg = vis_proc_cfg.get("eval")
+ else:
+ vis_train_cfg = None
+ vis_eval_cfg = None
+
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
+
+ if txt_proc_cfg is not None:
+ txt_train_cfg = txt_proc_cfg.get("train")
+ txt_eval_cfg = txt_proc_cfg.get("eval")
+ else:
+ txt_train_cfg = None
+ txt_eval_cfg = None
+
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
+
+ return vis_processors, txt_processors
+
+
+def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
+ """
+ Load model and its related preprocessors.
+
+ List all available models and types in registry:
+ >>> from minigpt4.models import model_zoo
+ >>> print(model_zoo)
+
+ Args:
+ name (str): name of the model.
+ model_type (str): type of the model.
+ is_eval (bool): whether the model is in eval mode. Default: False.
+ device (str): device to use. Default: "cpu".
+
+ Returns:
+ model (torch.nn.Module): model.
+ vis_processors (dict): preprocessors for visual inputs.
+ txt_processors (dict): preprocessors for text inputs.
+ """
+ model_cls = registry.get_model_class(name)
+
+ # load model
+ model = model_cls.from_pretrained(model_type=model_type)
+
+ if is_eval:
+ model.eval()
+
+ # load preprocess
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
+ if cfg is not None:
+ preprocess_cfg = cfg.preprocess
+
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
+ else:
+ vis_processors, txt_processors = None, None
+ logging.info(
+ f"""No default preprocess for model {name} ({model_type}).
+ This can happen if the model is not finetuned on downstream datasets,
+ or it is not intended for direct use without finetuning.
+ """
+ )
+
+ if device == "cpu" or device == torch.device("cpu"):
+ model = model.float()
+
+ return model.to(device), vis_processors, txt_processors
+
+
+class ModelZoo:
+ """
+ A utility class to create string representation of available model architectures and types.
+
+ >>> from minigpt4.models import model_zoo
+ >>> # list all available models
+ >>> print(model_zoo)
+ >>> # show total number of models
+ >>> print(len(model_zoo))
+ """
+
+ def __init__(self) -> None:
+ self.model_zoo = {
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
+ for k, v in registry.mapping["model_name_mapping"].items()
+ }
+
+ def __str__(self) -> str:
+ return (
+ "=" * 50
+ + "\n"
+ + f"{'Architectures':<30} {'Types'}\n"
+ + "=" * 50
+ + "\n"
+ + "\n".join(
+ [
+ f"{name:<30} {', '.join(types)}"
+ for name, types in self.model_zoo.items()
+ ]
+ )
+ )
+
+ def __iter__(self):
+ return iter(self.model_zoo.items())
+
+ def __len__(self):
+ return sum([len(v) for v in self.model_zoo.values()])
+
+
+model_zoo = ModelZoo()
diff --git a/minigpt4/models/__pycache__/Qformer.cpython-310.pyc b/minigpt4/models/__pycache__/Qformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ef500e9e9cace6bd9fabca004066e537ca1059d
Binary files /dev/null and b/minigpt4/models/__pycache__/Qformer.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/Qformer.cpython-39.pyc b/minigpt4/models/__pycache__/Qformer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1a4e2f12d1d09fc56484d3e1405489cc3d5d601
Binary files /dev/null and b/minigpt4/models/__pycache__/Qformer.cpython-39.pyc differ
diff --git a/minigpt4/models/__pycache__/__init__.cpython-310.pyc b/minigpt4/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57978ef1419849a11904399a4825d934383d5660
Binary files /dev/null and b/minigpt4/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/__init__.cpython-39.pyc b/minigpt4/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..043c98a2f33d05f3276cc0ebf5c2f6cad490f434
Binary files /dev/null and b/minigpt4/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/models/__pycache__/base_model.cpython-310.pyc b/minigpt4/models/__pycache__/base_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dfcf84683df9a176a6a07bbf6673812b3b030a4
Binary files /dev/null and b/minigpt4/models/__pycache__/base_model.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/base_model.cpython-39.pyc b/minigpt4/models/__pycache__/base_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8f8a0522d83819f2b58d33f61051a17c80c47e0
Binary files /dev/null and b/minigpt4/models/__pycache__/base_model.cpython-39.pyc differ
diff --git a/minigpt4/models/__pycache__/eva_vit.cpython-310.pyc b/minigpt4/models/__pycache__/eva_vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..375cc04af76266791d8dd0a543feb92fa0e991b4
Binary files /dev/null and b/minigpt4/models/__pycache__/eva_vit.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/eva_vit.cpython-39.pyc b/minigpt4/models/__pycache__/eva_vit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8209ab9e5e32caafdf7cc55a8837fc6c4e884cf0
Binary files /dev/null and b/minigpt4/models/__pycache__/eva_vit.cpython-39.pyc differ
diff --git a/minigpt4/models/__pycache__/minigpt4.cpython-310.pyc b/minigpt4/models/__pycache__/minigpt4.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20277cc46c940c34a76bd8a49dcdf8886cf03384
Binary files /dev/null and b/minigpt4/models/__pycache__/minigpt4.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/minigpt4.cpython-39.pyc b/minigpt4/models/__pycache__/minigpt4.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4852bd69aff3b9eb525122266d1f5cb404979b07
Binary files /dev/null and b/minigpt4/models/__pycache__/minigpt4.cpython-39.pyc differ
diff --git a/minigpt4/models/__pycache__/minigpt_base.cpython-310.pyc b/minigpt4/models/__pycache__/minigpt_base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce2798ecb670de3156bbeffd2636b15c58a79e14
Binary files /dev/null and b/minigpt4/models/__pycache__/minigpt_base.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/minigpt_base.cpython-39.pyc b/minigpt4/models/__pycache__/minigpt_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db91df10b900eb15662582717b8a280faf786d24
Binary files /dev/null and b/minigpt4/models/__pycache__/minigpt_base.cpython-39.pyc differ
diff --git a/minigpt4/models/__pycache__/minigpt_v2.cpython-310.pyc b/minigpt4/models/__pycache__/minigpt_v2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c3900259694e9e04f78faf137278a3e3f703767f
Binary files /dev/null and b/minigpt4/models/__pycache__/minigpt_v2.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/minigpt_v2.cpython-39.pyc b/minigpt4/models/__pycache__/minigpt_v2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..178665a6478913f6881ba1d75ca70abdb50dd82c
Binary files /dev/null and b/minigpt4/models/__pycache__/minigpt_v2.cpython-39.pyc differ
diff --git a/minigpt4/models/__pycache__/modeling_llama.cpython-310.pyc b/minigpt4/models/__pycache__/modeling_llama.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cc1abae03022d9c53eebd9e0541eac24808aab30
Binary files /dev/null and b/minigpt4/models/__pycache__/modeling_llama.cpython-310.pyc differ
diff --git a/minigpt4/models/__pycache__/modeling_llama.cpython-39.pyc b/minigpt4/models/__pycache__/modeling_llama.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec6805593cd9e6a3f74e582c377c9a7e3ba48564
Binary files /dev/null and b/minigpt4/models/__pycache__/modeling_llama.cpython-39.pyc differ
diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d70ca18cc21fdff952c8922beef7644164a872aa
--- /dev/null
+++ b/minigpt4/models/base_model.py
@@ -0,0 +1,248 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import logging
+import contextlib
+
+from omegaconf import OmegaConf
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers import LlamaTokenizer
+from peft import (
+ LoraConfig,
+ get_peft_model,
+ prepare_model_for_int8_training,
+)
+
+from minigpt4.common.dist_utils import download_cached_file
+from minigpt4.common.utils import get_abs_path, is_url
+from minigpt4.models.eva_vit import create_eva_vit_g
+from minigpt4.models.modeling_llama import LlamaForCausalLM
+
+
+
+class BaseModel(nn.Module):
+ """Base class for models."""
+
+ def __init__(self):
+ super().__init__()
+
+ @property
+ def device(self):
+ return list(self.parameters())[-1].device
+
+ def load_checkpoint(self, url_or_filename):
+ """
+ Load from a finetuned checkpoint.
+
+ This should expect no mismatch in the model keys and the checkpoint keys.
+ """
+
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ if "model" in checkpoint.keys():
+ state_dict = checkpoint["model"]
+ else:
+ state_dict = checkpoint
+
+ msg = self.load_state_dict(state_dict, strict=False)
+
+ logging.info("Missing keys {}".format(msg.missing_keys))
+ logging.info("load checkpoint from %s" % url_or_filename)
+
+ return msg
+
+ @classmethod
+ def from_pretrained(cls, model_type):
+ """
+ Build a pretrained model from default configuration file, specified by model_type.
+
+ Args:
+ - model_type (str): model type, specifying architecture and checkpoints.
+
+ Returns:
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
+ """
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
+ model = cls.from_config(model_cfg)
+
+ return model
+
+ @classmethod
+ def default_config_path(cls, model_type):
+ assert (
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
+ ), "Unknown model type {}".format(model_type)
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
+
+ def load_checkpoint_from_config(self, cfg, **kwargs):
+ """
+ Load checkpoint as specified in the config file.
+
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
+ When loading the pretrained model, each task-specific architecture may define their
+ own load_from_pretrained() method.
+ """
+ load_finetuned = cfg.get("load_finetuned", True)
+ if load_finetuned:
+ finetune_path = cfg.get("finetuned", None)
+ assert (
+ finetune_path is not None
+ ), "Found load_finetuned is True, but finetune_path is None."
+ self.load_checkpoint(url_or_filename=finetune_path)
+ else:
+ # load pre-trained weights
+ pretrain_path = cfg.get("pretrained", None)
+ assert "Found load_finetuned is False, but pretrain_path is None."
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
+
+ def before_evaluation(self, **kwargs):
+ pass
+
+ def show_n_params(self, return_str=True):
+ tot = 0
+ for p in self.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return "{:.1f}M".format(tot / 1e6)
+ else:
+ return "{:.1f}K".format(tot / 1e3)
+ else:
+ return tot
+
+ def maybe_autocast(self, dtype=torch.float16):
+ # if on cpu, don't use autocast
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+ enable_autocast = self.device != torch.device("cpu")
+
+ if enable_autocast:
+ return torch.cuda.amp.autocast(dtype=dtype)
+ else:
+ return contextlib.nullcontext()
+
+ @classmethod
+ def init_vision_encoder(
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
+ ):
+ logging.info('Loading VIT')
+
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
+ if not freeze:
+ precision = "fp32" # fp16 is not for training
+
+ visual_encoder = create_eva_vit_g(
+ img_size, drop_path_rate, use_grad_checkpoint, precision
+ )
+
+ ln_vision = LayerNorm(visual_encoder.num_features)
+
+ if freeze:
+ for name, param in visual_encoder.named_parameters():
+ param.requires_grad = False
+ visual_encoder = visual_encoder.eval()
+ visual_encoder.train = disabled_train
+ for name, param in ln_vision.named_parameters():
+ param.requires_grad = False
+ ln_vision = ln_vision.eval()
+ ln_vision.train = disabled_train
+ logging.info("freeze vision encoder")
+
+ logging.info('Loading VIT Done')
+ return visual_encoder, ln_vision
+
+ def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
+ lora_target_modules=["q_proj","v_proj"], **lora_kargs):
+ logging.info('Loading LLAMA')
+ llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
+ llama_tokenizer.pad_token = "$$"
+
+ if low_resource:
+ llama_model = LlamaForCausalLM.from_pretrained(
+ llama_model_path,
+ torch_dtype=torch.float16,
+ load_in_8bit=True,
+ device_map={'': low_res_device}
+ )
+ else:
+ llama_model = LlamaForCausalLM.from_pretrained(
+ llama_model_path,
+ torch_dtype=torch.float16,
+ )
+
+ if lora_r > 0:
+ llama_model = prepare_model_for_int8_training(llama_model)
+ loraconfig = LoraConfig(
+ r=lora_r,
+ bias="none",
+ task_type="CAUSAL_LM",
+ target_modules=lora_target_modules,
+ **lora_kargs
+ )
+ llama_model = get_peft_model(llama_model, loraconfig)
+
+ llama_model.print_trainable_parameters()
+
+ else:
+ for name, param in llama_model.named_parameters():
+ param.requires_grad = False
+ logging.info('Loading LLAMA Done')
+ return llama_model, llama_tokenizer
+
+
+ def load_from_pretrained(self, url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location="cpu")
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+
+ msg = self.load_state_dict(state_dict, strict=False)
+
+ # logging.info("Missing keys {}".format(msg.missing_keys))
+ logging.info("load checkpoint from %s" % url_or_filename)
+
+ return msg
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+
+
+
diff --git a/minigpt4/models/eva_vit.py b/minigpt4/models/eva_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fcc63a74049f1faf65c99943ef94f72383ca3f5
--- /dev/null
+++ b/minigpt4/models/eva_vit.py
@@ -0,0 +1,442 @@
+# Based on EVA, BEIT, timm and DeiT code bases
+# https://github.com/baaivision/EVA
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/microsoft/unilm/tree/master/beit
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+
+from minigpt4.common.dist_utils import download_cached_file
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ **kwargs
+ }
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if init_values is not None and init_values > 0:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ def forward(self):
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
+ super().__init__()
+ self.image_size = img_size
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+ self.use_checkpoint = use_checkpoint
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+ for i in range(depth)])
+# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ # trunc_normal_(self.mask_token, std=.02)
+# if isinstance(self.head, nn.Linear):
+# trunc_normal_(self.head.weight, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+# if isinstance(self.head, nn.Linear):
+# self.head.weight.data.mul_(init_scale)
+# self.head.bias.data.mul_(init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
+ else:
+ x = blk(x, rel_pos_bias)
+ return x
+# x = self.norm(x)
+
+# if self.fc_norm is not None:
+# t = x[:, 1:, :]
+# return self.fc_norm(t.mean(1))
+# else:
+# return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+# x = self.head(x)
+ return x
+
+ def get_intermediate_layers(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ features = []
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ x = blk(x, rel_pos_bias)
+ features.append(x)
+
+ return features
+
+
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+
+def convert_weights_to_fp16(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+# if isinstance(l, (nn.MultiheadAttention, Attention)):
+# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+# tensor = getattr(l, attr)
+# if tensor is not None:
+# tensor.data = tensor.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
+ model = VisionTransformer(
+ img_size=img_size,
+ patch_size=14,
+ use_mean_pooling=False,
+ embed_dim=1408,
+ depth=39,
+ num_heads=1408//88,
+ mlp_ratio=4.3637,
+ qkv_bias=True,
+ drop_path_rate=drop_path_rate,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ use_checkpoint=use_checkpoint,
+ )
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
+ cached_file = download_cached_file(
+ url, check_hash=False, progress=True
+ )
+ state_dict = torch.load(cached_file, map_location="cpu")
+ interpolate_pos_embed(model,state_dict)
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
+# print(incompatible_keys)
+
+ if precision == "fp16":
+# model.to("cuda")
+ convert_weights_to_fp16(model)
+ return model
\ No newline at end of file
diff --git a/minigpt4/models/minigpt4.py b/minigpt4/models/minigpt4.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e4798bb9713467b0ddac2dcec3cb1681c6418d
--- /dev/null
+++ b/minigpt4/models/minigpt4.py
@@ -0,0 +1,195 @@
+import logging
+import random
+
+import torch
+from torch.cuda.amp import autocast as autocast
+import torch.nn as nn
+
+from minigpt4.common.registry import registry
+from minigpt4.models.base_model import disabled_train
+from minigpt4.models.minigpt_base import MiniGPTBase
+from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
+
+
+@registry.register_model("minigpt4")
+class MiniGPT4(MiniGPTBase):
+ """
+ MiniGPT-4 model
+ """
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
+ "pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
+ }
+
+ def __init__(
+ self,
+ vit_model="eva_clip_g",
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
+ img_size=224,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision="fp16",
+ freeze_vit=True,
+ has_qformer=True,
+ freeze_qformer=True,
+ num_query_token=32,
+ llama_model="",
+ prompt_path="",
+ prompt_template="",
+ max_txt_len=32,
+ end_sym='\n',
+ low_resource=False, # use 8 bit and put vit in cpu
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
+ ):
+ super().__init__(
+ vit_model=vit_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ llama_model=llama_model,
+ max_txt_len=max_txt_len,
+ end_sym=end_sym,
+ low_resource=low_resource,
+ device_8bit=device_8bit,
+ )
+
+ self.has_qformer = has_qformer
+ if self.has_qformer:
+ print('Loading Q-Former')
+ self.Qformer, self.query_tokens = self.init_Qformer(
+ num_query_token, self.visual_encoder.num_features, freeze_qformer
+ )
+ self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here
+
+ img_f_dim = self.Qformer.config.hidden_size
+ print('Loading Q-Former Done')
+ else:
+ img_f_dim = self.visual_encoder.num_features * 4
+ print('Do not use Q-Former here.')
+
+ self.llama_proj = nn.Linear(
+ img_f_dim, self.llama_model.config.hidden_size
+ )
+
+ if prompt_path:
+ with open(prompt_path, 'r') as f:
+ raw_prompts = f.read().splitlines()
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt]
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
+ print('Load {} training prompts'.format(len(self.prompt_list)))
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
+ else:
+ self.prompt_list = []
+
+ @classmethod
+ def init_Qformer(cls, num_query_token, vision_width, freeze):
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
+ encoder_config.encoder_width = vision_width
+ # insert cross-attention layer every other block
+ encoder_config.add_cross_attention = True
+ encoder_config.cross_attention_freq = 2
+ encoder_config.query_length = num_query_token
+ Qformer = BertLMHeadModel(config=encoder_config)
+ query_tokens = nn.Parameter(
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
+ )
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+
+ Qformer.cls = None
+ Qformer.bert.embeddings.word_embeddings = None
+ Qformer.bert.embeddings.position_embeddings = None
+ for layer in Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+
+ if freeze:
+ for name, param in Qformer.named_parameters():
+ param.requires_grad = False
+ Qformer = Qformer.eval()
+ Qformer.train = disabled_train
+ query_tokens.requires_grad = False
+ logging.info("freeze Qformer")
+
+ return Qformer, query_tokens
+
+ def encode_img(self, image):
+ device = image.device
+
+ if len(image.shape) > 4:
+ image = image.reshape(-1, *image.shape[-3:])
+
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
+ if self.has_qformer:
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
+
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
+
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
+ else:
+ image_embeds = image_embeds[:, 1:, :]
+ bs, pn, hs = image_embeds.shape
+ image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
+
+ inputs_llama = self.llama_proj(image_embeds)
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
+ return inputs_llama, atts_llama
+
+ @classmethod
+ def from_config(cls, cfg):
+ vit_model = cfg.get("vit_model", "eva_clip_g")
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
+ img_size = cfg.get("image_size")
+ num_query_token = cfg.get("num_query_token")
+ llama_model = cfg.get("llama_model")
+
+ drop_path_rate = cfg.get("drop_path_rate", 0)
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+ vit_precision = cfg.get("vit_precision", "fp16")
+ freeze_vit = cfg.get("freeze_vit", True)
+ has_qformer = cfg.get("has_qformer", True)
+ freeze_qformer = cfg.get("freeze_qformer", True)
+ low_resource = cfg.get("low_resource", False)
+ device_8bit = cfg.get("device_8bit", 0)
+
+ prompt_path = cfg.get("prompt_path", "")
+ prompt_template = cfg.get("prompt_template", "")
+ max_txt_len = cfg.get("max_txt_len", 32)
+ end_sym = cfg.get("end_sym", '\n')
+
+ model = cls(
+ vit_model=vit_model,
+ q_former_model=q_former_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ has_qformer=has_qformer,
+ freeze_qformer=freeze_qformer,
+ num_query_token=num_query_token,
+ llama_model=llama_model,
+ prompt_path=prompt_path,
+ prompt_template=prompt_template,
+ max_txt_len=max_txt_len,
+ end_sym=end_sym,
+ low_resource=low_resource,
+ device_8bit=device_8bit,
+ )
+
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
+ if ckpt_path:
+ print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ msg = model.load_state_dict(ckpt['model'], strict=False)
+
+ return model
diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..66c8c80ee68a0625caebd0b37746776036f8102c
--- /dev/null
+++ b/minigpt4/models/minigpt_base.py
@@ -0,0 +1,410 @@
+import logging
+import random
+
+import torch
+from torch.cuda.amp import autocast as autocast
+import torch.nn as nn
+
+from minigpt4.common.registry import registry
+from minigpt4.models.base_model import BaseModel
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+from minigpt4.conversation.conversation import StoppingCriteriaSub
+
+class MiniGPTBase(BaseModel):
+ """
+ Base class for MiniGPT-4 and MiniGPT-v2
+ """
+
+ def __init__(
+ self,
+ vit_model="eva_clip_g",
+ img_size=224,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision="fp16",
+ freeze_vit=True,
+ llama_model="",
+ max_txt_len=32,
+ max_context_len=3800,
+ prompt_template="",
+ end_sym='\n',
+ low_resource=False, # use 8 bit and put vit in cpu
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
+ lora_r=0, # lora_r means lora is not used
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.05,
+ ):
+ super().__init__()
+
+ self.llama_model, self.llama_tokenizer = self.init_llm(
+ llama_model_path=llama_model,
+ low_resource=low_resource,
+ low_res_device=device_8bit,
+ lora_r=lora_r,
+ lora_target_modules=lora_target_modules,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
+ )
+
+ self.max_txt_len = max_txt_len
+ self.max_context_len = max_context_len
+ self.end_sym = end_sym
+
+ self.prompt_template = prompt_template
+ self.prompt_list = []
+
+ def vit_to_cpu(self):
+ self.ln_vision.to("cpu")
+ self.ln_vision.float()
+ self.visual_encoder.to("cpu")
+ self.visual_encoder.float()
+
+ def get_context_emb(self, prompt, img_list):
+ device = img_list[0].device
+ prompt_segs = prompt.split('')
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
+ seg_tokens = [
+ self.llama_tokenizer(
+ seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg
+ for i, seg in enumerate(prompt_segs)
+ ]
+ seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
+
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
+ mixed_embs = torch.cat(mixed_embs, dim=1)
+ return mixed_embs
+
+ def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
+ if prompts is None or len(prompts) == 0:
+ # prompts is not provided, just return the original image embedding
+ return img_embeds, atts_img
+ elif img_embeds is None:
+ # prompt is provided but there is no image embedding. return the prompt embedding in right padding
+ self.llama_tokenizer.padding_side = "right"
+ prompt_tokens = self.llama_tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding="longest",
+ add_special_tokens=False
+ ).to(self.device)
+ prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
+ atts_prompt = prompt_tokens.attention_mask
+ return prompt_embeds, atts_prompt
+ else:
+ # return the multi-modal embedding in right padding
+ emb_lists = []
+ if isinstance(prompts, str):
+ prompts = [prompts] * len(img_embeds)
+
+ for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
+ pn = each_img_embed.shape[-2]
+ if lengths is not None:
+ each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
+ each_img_embed = each_img_embed[:lengths[idx] * pn]
+ p_segs = each_prompt.split('')
+ interleave_emb = []
+ for idx, seg in enumerate(p_segs[:-1]):
+ p_tokens = self.llama_tokenizer(
+ seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
+ p_embed = self.embed_tokens(p_tokens.input_ids)
+ interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1))
+ wrapped_emb = torch.cat(interleave_emb, dim=1)
+ p_tokens = self.llama_tokenizer(
+ p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
+ p_embed = self.embed_tokens(p_tokens.input_ids)
+ wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1)
+ emb_lists.append(wrapped_emb)
+
+ emb_lens = [emb.shape[1] for emb in emb_lists]
+ pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
+
+ max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
+ wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
+ wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
+
+ for i, emb in enumerate(emb_lists):
+ length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
+ wrapped_embs[i, :length] = emb[:, :length]
+ wrapped_atts[i, :length] = 1
+ return wrapped_embs, wrapped_atts
+
+ def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
+ """
+ Concatenate the batched input embedding and batched output embedding together.
+ Both the input and the output embedding should be right padded.
+ """
+ input_lens = []
+ cat_embs = []
+ cat_atts = []
+ for i in range(input_embs.size(0)):
+ input_len = input_atts[i].sum()
+ input_lens.append(input_len)
+ cat_embs.append(
+ torch.cat([
+ input_embs[i][:input_len],
+ output_embs[i],
+ input_embs[i][input_len:]
+ ])
+ )
+ cat_atts.append(
+ torch.cat([
+ input_atts[i][:input_len],
+ output_atts[i],
+ input_atts[i][input_len:]
+ ])
+ )
+ cat_embs = torch.stack(cat_embs)
+ cat_atts = torch.stack(cat_atts)
+ return cat_embs, cat_atts, input_lens
+
+ def tokenize_conversation(self, conv_q, conv_a):
+ """concatenate conversation and make sure the model is only trained to regress the answer"""
+
+ to_regress_token_ids_list = []
+ targets_list = []
+
+ batch_size = len(conv_q)
+ for batch_idx in range(batch_size):
+ questions, answers = conv_q[batch_idx], conv_a[batch_idx]
+ questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q,
+ return_tensors="pt",
+ add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
+ answers = [self.llama_tokenizer(a + self.end_sym,
+ return_tensors="pt",
+ add_special_tokens=False).to(self.device) for a in answers]
+ cur_id = []
+ cur_target = []
+ for i in range(len(questions)):
+ cur_id.append(answers[i].input_ids)
+ cur_target.append(answers[i].input_ids)
+ cur_id.append(questions[i].input_ids)
+ cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
+
+ cur_id.append(answers[-1].input_ids)
+ cur_target.append(answers[-1].input_ids)
+
+ cur_id = torch.cat(cur_id, dim=1)
+ cur_target = torch.cat(cur_target, dim=1)
+ to_regress_token_ids_list.append(cur_id)
+ targets_list.append(cur_target)
+
+ max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
+ to_regress_token_ids = torch.ones([batch_size, max_len],
+ dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id
+ targets = torch.ones([batch_size, max_len],
+ dtype=cur_id.dtype, device=self.device) * -100
+ for batch_idx in range(batch_size):
+ cur_len = to_regress_token_ids_list[batch_idx].shape[1]
+ to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len]
+ targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
+
+ to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int)
+
+ return to_regress_token_ids, to_regress_token_attn, targets
+
+ def preparing_embedding(self, samples):
+ ### prepare input tokens
+ if 'image' in samples:
+ img_embeds, img_atts = self.encode_img(samples["image"])
+ else:
+ img_embeds = img_atts = None
+
+ if 'conv_q' in samples:
+ # handeling conversation datasets
+ conv_q, conv_a = samples['conv_q'], samples['conv_a']
+
+ connect_sym = samples['connect_sym'][0]
+ conv_q = [q.split(connect_sym)for q in conv_q]
+ conv_a = [a.split(connect_sym) for a in conv_a]
+
+ conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
+
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
+ regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
+
+ else:
+ if "instruction_input" in samples:
+ instruction = samples["instruction_input"]
+ elif self.prompt_list:
+ instruction = random.choice(self.prompt_list)
+ else:
+ instruction = None
+
+ if hasattr(self, 'chat_template') and self.chat_template:
+ instruction = [self.prompt_template.format(instruct) for instruct in instruction]
+
+ if 'length' in samples:
+ # the input is a image train (like videos)
+ bsz, pn, hs = img_embeds.shape
+ img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
+ else:
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
+
+ ### prepare target tokens
+ self.llama_tokenizer.padding_side = "right"
+ text = [t + self.end_sym for t in samples["answer"]]
+
+ regress_tokens = self.llama_tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=self.max_txt_len,
+ add_special_tokens=False
+ ).to(self.device)
+
+ regress_token_ids = regress_tokens.input_ids
+ regress_atts = regress_tokens.attention_mask
+ part_targets = regress_token_ids.masked_fill(
+ regress_token_ids == self.llama_tokenizer.pad_token_id, -100
+ )
+
+ regress_embeds = self.embed_tokens(regress_token_ids)
+
+ return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
+
+ def forward(self, samples, reduction='mean'):
+ # prepare the embedding to condition and the embedding to regress
+ cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
+ self.preparing_embedding(samples)
+
+ # concat the embedding to condition and the embedding to regress
+ inputs_embeds, attention_mask, input_lens = \
+ self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
+
+ # get bos token embedding
+ bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
+ bos_embeds = self.embed_tokens(bos)
+ bos_atts = cond_atts[:, :1]
+
+ # add bos token at the begining
+ inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
+ attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
+
+ # ensemble the final targets
+ targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
+ dtype=torch.long).to(self.device).fill_(-100)
+
+ for i, target in enumerate(part_targets):
+ targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
+
+ with self.maybe_autocast():
+ outputs = self.llama_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ return_dict=True,
+ labels=targets,
+ reduction=reduction
+ )
+ loss = outputs.loss
+
+ return {"loss": loss}
+
+ def embed_tokens(self, token_ids):
+ if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
+ embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
+ else:
+ embeds = self.llama_model.base_model.embed_tokens(token_ids)
+ return embeds
+
+ @torch.no_grad()
+ def generate(
+ self,
+ images,
+ texts,
+ num_beams=1,
+ max_new_tokens=20,
+ min_length=1,
+ top_p=0.9,
+ repetition_penalty=1,
+ length_penalty=1,
+ temperature=1,
+ do_sample=False,
+ stop_words_ids=[2],
+ ):
+ '''
+ function for generate test use
+ '''
+
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
+ stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
+
+ img_embeds, atts_img = self.encode_img(images.to(self.device))
+ image_lists = [[image_emb[None]] for image_emb in img_embeds]
+
+ batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
+
+ batch_size = len(batch_embs)
+ max_len = max([emb.shape[1] for emb in batch_embs])
+ emb_dim = batch_embs[0].shape[2]
+ dtype = batch_embs[0].dtype
+ device = batch_embs[0].device
+
+ embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
+ attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
+ for i, emb in enumerate(batch_embs):
+ emb_len = emb.shape[1]
+ embs[i, -emb_len:] = emb[0]
+ attn_mask[i, -emb_len:] = 1
+
+ with self.maybe_autocast():
+ outputs = self.llama_model.generate(
+ inputs_embeds=embs,
+ attention_mask=attn_mask,
+ max_new_tokens=max_new_tokens,
+ num_beams=num_beams,
+ length_penalty=length_penalty,
+ temperature=temperature,
+ do_sample=do_sample,
+ min_length=min_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ # stopping_criteria=stopping_criteria,
+ )
+
+ # with self.maybe_autocast():
+ # outputs = self.llama_model.generate(
+ # inputs_embeds=embs,
+ # attention_mask=attn_mask,
+ # max_new_tokens=max_new_tokens,
+ # num_beams=num_beams,
+ # do_sample=do_sample,
+ # # stopping_criteria=stopping_criteria,
+ # )
+ answers = []
+ for output_token in outputs:
+ if output_token[0] == 0:
+ output_token = output_token[1:]
+ output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
+ output_texts = output_texts.split('')[0] # remove the stop sign
+ output_texts = output_texts.replace("", "")
+ output_texts = output_texts.split(r'[/INST]')[-1].strip()
+ answers.append(output_texts)
+
+ return answers
+
+ @torch.no_grad()
+ def multi_select(self, images, texts, answers, num_cand=None):
+ all_losses = []
+ for answer in answers:
+ choice_samples = {
+ 'image': images,
+ 'instruction_input': texts,
+ 'answer': answer
+ }
+ loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
+ all_losses.append(loss)
+ torch.cuda.empty_cache()
+ all_losses = torch.cat(all_losses, dim=-1)
+ if num_cand is not None:
+ for i in range(all_losses.shape[0]):
+ all_losses[i, num_cand[i]:] = 9999
+ output_class_ranks = torch.argsort(all_losses, dim=-1)
+ return output_class_ranks.tolist()
diff --git a/minigpt4/models/minigpt_v2.py b/minigpt4/models/minigpt_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a046b0baff41db50477e35904af9bcad5baa619c
--- /dev/null
+++ b/minigpt4/models/minigpt_v2.py
@@ -0,0 +1,139 @@
+import logging
+import random
+
+import torch
+from torch.cuda.amp import autocast as autocast
+import torch.nn as nn
+
+from minigpt4.common.registry import registry
+from minigpt4.models.base_model import disabled_train
+from minigpt4.models.minigpt_base import MiniGPTBase
+from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
+
+
+@registry.register_model("minigpt_v2")
+class MiniGPTv2(MiniGPTBase):
+ """
+ MiniGPT-v2 model
+ """
+
+ PRETRAINED_MODEL_CONFIG_DICT = {
+ "pretrain": "configs/models/minigpt_v2.yaml",
+ }
+
+ def __init__(
+ self,
+ vit_model="eva_clip_g",
+ img_size=448,
+ drop_path_rate=0,
+ use_grad_checkpoint=False,
+ vit_precision="fp16",
+ freeze_vit=True,
+ llama_model="",
+ prompt_template='[INST] {} [/INST]',
+ max_txt_len=300,
+ end_sym='\n',
+ lora_r=64,
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.05,
+ chat_template=False,
+ use_grad_checkpoint_llm=False,
+ max_context_len=3800,
+ low_resource=False, # use 8 bit and put vit in cpu
+ device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
+ ):
+ super().__init__(
+ vit_model=vit_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ llama_model=llama_model,
+ max_txt_len=max_txt_len,
+ max_context_len=max_context_len,
+ end_sym=end_sym,
+ prompt_template=prompt_template,
+ low_resource=low_resource,
+ device_8bit=device_8bit,
+ lora_r=lora_r,
+ lora_target_modules=lora_target_modules,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ img_f_dim = self.visual_encoder.num_features * 4
+ self.llama_proj = nn.Linear(
+ img_f_dim, self.llama_model.config.hidden_size
+ )
+ self.chat_template = chat_template
+
+ if use_grad_checkpoint_llm:
+ self.llama_model.gradient_checkpointing_enable()
+
+ def encode_img(self, image):
+ device = image.device
+
+ if len(image.shape) > 4:
+ image = image.reshape(-1, *image.shape[-3:])
+
+ with self.maybe_autocast():
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
+ image_embeds = image_embeds[:, 1:, :]
+ bs, pn, hs = image_embeds.shape
+ image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
+
+ inputs_llama = self.llama_proj(image_embeds)
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
+ return inputs_llama, atts_llama
+
+ @classmethod
+ def from_config(cls, cfg):
+ vit_model = cfg.get("vit_model", "eva_clip_g")
+ img_size = cfg.get("image_size")
+ llama_model = cfg.get("llama_model")
+
+ drop_path_rate = cfg.get("drop_path_rate", 0)
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+ vit_precision = cfg.get("vit_precision", "fp16")
+ freeze_vit = cfg.get("freeze_vit", True)
+ low_resource = cfg.get("low_resource", False)
+
+ prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
+ max_txt_len = cfg.get("max_txt_len", 300)
+ end_sym = cfg.get("end_sym", '\n')
+
+ lora_r = cfg.get("lora_r", 64)
+ lora_alpha = cfg.get("lora_alpha", 16)
+ chat_template = cfg.get("chat_template", False)
+
+ use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
+ max_context_len = cfg.get("max_context_len", 3800)
+
+ model = cls(
+ vit_model=vit_model,
+ img_size=img_size,
+ drop_path_rate=drop_path_rate,
+ use_grad_checkpoint=use_grad_checkpoint,
+ vit_precision=vit_precision,
+ freeze_vit=freeze_vit,
+ llama_model=llama_model,
+ prompt_template=prompt_template,
+ max_txt_len=max_txt_len,
+ low_resource=low_resource,
+ end_sym=end_sym,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ chat_template=chat_template,
+ use_grad_checkpoint_llm=use_grad_checkpoint_llm,
+ max_context_len=max_context_len,
+ )
+
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
+ if ckpt_path:
+ print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ msg = model.load_state_dict(ckpt['model'], strict=False)
+
+ return model
diff --git a/minigpt4/models/modeling_llama.py b/minigpt4/models/modeling_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d59a53faf45ef55cf127714489201d84a9364d9
--- /dev/null
+++ b/minigpt4/models/modeling_llama.py
@@ -0,0 +1,111 @@
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+
+from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
+from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
+
+
+class LlamaForCausalLM(LlamaForCausalLMOrig):
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ reduction: Optional[str] = "mean",
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(reduction=reduction)
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+ if reduction == "none":
+ loss = loss.view(logits.size(0), -1).mean(1)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/minigpt4/processors/__init__.py b/minigpt4/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e560eaa15f3266dbc1ffbca70bdc791901737a60
--- /dev/null
+++ b/minigpt4/processors/__init__.py
@@ -0,0 +1,33 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from minigpt4.processors.base_processor import BaseProcessor
+from minigpt4.processors.blip_processors import (
+ Blip2ImageTrainProcessor,
+ Blip2ImageEvalProcessor,
+ BlipCaptionProcessor,
+)
+
+from minigpt4.common.registry import registry
+
+__all__ = [
+ "BaseProcessor",
+ "Blip2ImageTrainProcessor",
+ "Blip2ImageEvalProcessor",
+ "BlipCaptionProcessor",
+]
+
+
+def load_processor(name, cfg=None):
+ """
+ Example
+
+ >>> processor = load_processor("alpro_video_train", cfg=None)
+ """
+ processor = registry.get_processor_class(name).from_config(cfg)
+
+ return processor
diff --git a/minigpt4/processors/__pycache__/__init__.cpython-310.pyc b/minigpt4/processors/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da8d575ccae031260b489440c5a364f6ae96c28b
Binary files /dev/null and b/minigpt4/processors/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/processors/__pycache__/__init__.cpython-39.pyc b/minigpt4/processors/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79211354aacfbebc5eb62cacf1433d7113ffae4a
Binary files /dev/null and b/minigpt4/processors/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/processors/__pycache__/base_processor.cpython-310.pyc b/minigpt4/processors/__pycache__/base_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c60d0bb2fa90ee5d6d0486f6e7658ba0fffcfd2a
Binary files /dev/null and b/minigpt4/processors/__pycache__/base_processor.cpython-310.pyc differ
diff --git a/minigpt4/processors/__pycache__/base_processor.cpython-39.pyc b/minigpt4/processors/__pycache__/base_processor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8d375505262900593a1a830401a2a0de500c3b7
Binary files /dev/null and b/minigpt4/processors/__pycache__/base_processor.cpython-39.pyc differ
diff --git a/minigpt4/processors/__pycache__/blip_processors.cpython-310.pyc b/minigpt4/processors/__pycache__/blip_processors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a56a529a3f51f680bc72ac0599873fe6898a10f
Binary files /dev/null and b/minigpt4/processors/__pycache__/blip_processors.cpython-310.pyc differ
diff --git a/minigpt4/processors/__pycache__/blip_processors.cpython-39.pyc b/minigpt4/processors/__pycache__/blip_processors.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b9f8faa80dff7d4ff42ec743d8e979965676e74
Binary files /dev/null and b/minigpt4/processors/__pycache__/blip_processors.cpython-39.pyc differ
diff --git a/minigpt4/processors/__pycache__/randaugment.cpython-310.pyc b/minigpt4/processors/__pycache__/randaugment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8fe1be96f2c5c296c06d42fc7e3261ac8faa4f7
Binary files /dev/null and b/minigpt4/processors/__pycache__/randaugment.cpython-310.pyc differ
diff --git a/minigpt4/processors/__pycache__/randaugment.cpython-39.pyc b/minigpt4/processors/__pycache__/randaugment.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..184023b4bf3f1a2bf1cbe2b294cd1a5a590542a7
Binary files /dev/null and b/minigpt4/processors/__pycache__/randaugment.cpython-39.pyc differ
diff --git a/minigpt4/processors/base_processor.py b/minigpt4/processors/base_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..39b33cdf8fcd97cfd3e4a5fbece6593357af9d41
--- /dev/null
+++ b/minigpt4/processors/base_processor.py
@@ -0,0 +1,26 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from omegaconf import OmegaConf
+
+
+class BaseProcessor:
+ def __init__(self):
+ self.transform = lambda x: x
+ return
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ return cls()
+
+ def build(self, **kwargs):
+ cfg = OmegaConf.create(kwargs)
+
+ return self.from_config(cfg)
diff --git a/minigpt4/processors/blip_processors.py b/minigpt4/processors/blip_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd26160ec96a8458cdac083d19c19695937a7a62
--- /dev/null
+++ b/minigpt4/processors/blip_processors.py
@@ -0,0 +1,141 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import re
+
+from minigpt4.common.registry import registry
+from minigpt4.processors.base_processor import BaseProcessor
+from minigpt4.processors.randaugment import RandomAugment
+from omegaconf import OmegaConf
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+
+class BlipImageBaseProcessor(BaseProcessor):
+ def __init__(self, mean=None, std=None):
+ if mean is None:
+ mean = (0.48145466, 0.4578275, 0.40821073)
+ if std is None:
+ std = (0.26862954, 0.26130258, 0.27577711)
+
+ self.normalize = transforms.Normalize(mean, std)
+
+
+@registry.register_processor("blip_caption")
+class BlipCaptionProcessor(BaseProcessor):
+ def __init__(self, prompt="", max_words=50):
+ self.prompt = prompt
+ self.max_words = max_words
+
+ def __call__(self, caption):
+ caption = self.prompt + self.pre_caption(caption)
+
+ return caption
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ prompt = cfg.get("prompt", "")
+ max_words = cfg.get("max_words", 50)
+
+ return cls(prompt=prompt, max_words=max_words)
+
+ def pre_caption(self, caption):
+ caption = re.sub(
+ r"([.!\"()*#:;~])",
+ " ",
+ caption.lower(),
+ )
+ caption = re.sub(
+ r"\s{2,}",
+ " ",
+ caption,
+ )
+ caption = caption.rstrip("\n")
+ caption = caption.strip(" ")
+
+ # truncate caption
+ caption_words = caption.split(" ")
+ if len(caption_words) > self.max_words:
+ caption = " ".join(caption_words[: self.max_words])
+
+ return caption
+
+
+@registry.register_processor("blip2_image_train")
+class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(
+ image_size,
+ scale=(min_scale, max_scale),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 224)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ min_scale = cfg.get("min_scale", 0.5)
+ max_scale = cfg.get("max_scale", 1.0)
+
+ return cls(
+ image_size=image_size,
+ mean=mean,
+ std=std,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ )
+
+
+@registry.register_processor("blip2_image_eval")
+class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
+ def __init__(self, image_size=224, mean=None, std=None):
+ super().__init__(mean=mean, std=std)
+
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize(
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
+ ),
+ transforms.ToTensor(),
+ self.normalize,
+ ]
+ )
+
+ def __call__(self, item):
+ return self.transform(item)
+
+ @classmethod
+ def from_config(cls, cfg=None):
+ if cfg is None:
+ cfg = OmegaConf.create()
+
+ image_size = cfg.get("image_size", 224)
+
+ mean = cfg.get("mean", None)
+ std = cfg.get("std", None)
+
+ return cls(image_size=image_size, mean=mean, std=std)
\ No newline at end of file
diff --git a/minigpt4/processors/randaugment.py b/minigpt4/processors/randaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..7034a49ad5fc63b97910790017432617ff4c6d7b
--- /dev/null
+++ b/minigpt4/processors/randaugment.py
@@ -0,0 +1,398 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import cv2
+import numpy as np
+
+import torch
+
+
+## aug functions
+def identity_func(img):
+ return img
+
+
+def autocontrast_func(img, cutoff=0):
+ """
+ same output as PIL.ImageOps.autocontrast
+ """
+ n_bins = 256
+
+ def tune_channel(ch):
+ n = ch.size
+ cut = cutoff * n // 100
+ if cut == 0:
+ high, low = ch.max(), ch.min()
+ else:
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ low = np.argwhere(np.cumsum(hist) > cut)
+ low = 0 if low.shape[0] == 0 else low[0]
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
+ if high <= low:
+ table = np.arange(n_bins)
+ else:
+ scale = (n_bins - 1) / (high - low)
+ offset = -low * scale
+ table = np.arange(n_bins) * scale + offset
+ table[table < 0] = 0
+ table[table > n_bins - 1] = n_bins - 1
+ table = table.clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def equalize_func(img):
+ """
+ same output as PIL.ImageOps.equalize
+ PIL's implementation is different from cv2.equalize
+ """
+ n_bins = 256
+
+ def tune_channel(ch):
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ non_zero_hist = hist[hist != 0].reshape(-1)
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
+ if step == 0:
+ return ch
+ n = np.empty_like(hist)
+ n[0] = step // 2
+ n[1:] = hist[:-1]
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def rotate_func(img, degree, fill=(0, 0, 0)):
+ """
+ like PIL, rotate by degree, not radians
+ """
+ H, W = img.shape[0], img.shape[1]
+ center = W / 2, H / 2
+ M = cv2.getRotationMatrix2D(center, degree, 1)
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
+ return out
+
+
+def solarize_func(img, thresh=128):
+ """
+ same output as PIL.ImageOps.posterize
+ """
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
+ table = table.clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def color_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Color
+ """
+ ## implementation according to PIL definition, quite slow
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
+ # out = blend(degenerate, img, factor)
+ # M = (
+ # np.eye(3) * factor
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
+ # )[np.newaxis, np.newaxis, :]
+ M = np.float32(
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
+ return out
+
+
+def contrast_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
+ table = (
+ np.array([(el - mean) * factor + mean for el in range(256)])
+ .clip(0, 255)
+ .astype(np.uint8)
+ )
+ out = table[img]
+ return out
+
+
+def brightness_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def sharpness_func(img, factor):
+ """
+ The differences the this result and PIL are all on the 4 boundaries, the center
+ areas are same
+ """
+ kernel = np.ones((3, 3), dtype=np.float32)
+ kernel[1][1] = 5
+ kernel /= 13
+ degenerate = cv2.filter2D(img, -1, kernel)
+ if factor == 0.0:
+ out = degenerate
+ elif factor == 1.0:
+ out = img
+ else:
+ out = img.astype(np.float32)
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
+ out = out.astype(np.uint8)
+ return out
+
+
+def shear_x_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def translate_x_func(img, offset, fill=(0, 0, 0)):
+ """
+ same output as PIL.Image.transform
+ """
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def translate_y_func(img, offset, fill=(0, 0, 0)):
+ """
+ same output as PIL.Image.transform
+ """
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def posterize_func(img, bits):
+ """
+ same output as PIL.ImageOps.posterize
+ """
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
+ return out
+
+
+def shear_y_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
+ out = cv2.warpAffine(
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+ ).astype(np.uint8)
+ return out
+
+
+def cutout_func(img, pad_size, replace=(0, 0, 0)):
+ replace = np.array(replace, dtype=np.uint8)
+ H, W = img.shape[0], img.shape[1]
+ rh, rw = np.random.random(2)
+ pad_size = pad_size // 2
+ ch, cw = int(rh * H), int(rw * W)
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
+ out = img.copy()
+ out[x1:x2, y1:y2, :] = replace
+ return out
+
+
+### level to args
+def enhance_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
+
+ return level_to_args
+
+
+def shear_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 0.3
+ if np.random.random() > 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * float(translate_const)
+ if np.random.random() > 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * cutout_const)
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def solarize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 256)
+ return (level,)
+
+ return level_to_args
+
+
+def none_level_to_args(level):
+ return ()
+
+
+def posterize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 4)
+ return (level,)
+
+ return level_to_args
+
+
+def rotate_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 30
+ if np.random.random() < 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+func_dict = {
+ "Identity": identity_func,
+ "AutoContrast": autocontrast_func,
+ "Equalize": equalize_func,
+ "Rotate": rotate_func,
+ "Solarize": solarize_func,
+ "Color": color_func,
+ "Contrast": contrast_func,
+ "Brightness": brightness_func,
+ "Sharpness": sharpness_func,
+ "ShearX": shear_x_func,
+ "TranslateX": translate_x_func,
+ "TranslateY": translate_y_func,
+ "Posterize": posterize_func,
+ "ShearY": shear_y_func,
+}
+
+translate_const = 10
+MAX_LEVEL = 10
+replace_value = (128, 128, 128)
+arg_dict = {
+ "Identity": none_level_to_args,
+ "AutoContrast": none_level_to_args,
+ "Equalize": none_level_to_args,
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
+ "Color": enhance_level_to_args(MAX_LEVEL),
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
+}
+
+
+class RandomAugment(object):
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
+ self.N = N
+ self.M = M
+ self.isPIL = isPIL
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N)
+ return [(op, 0.5, self.M) for op in sampled_ops]
+
+ def __call__(self, img):
+ if self.isPIL:
+ img = np.array(img)
+ ops = self.get_random_ops()
+ for name, prob, level in ops:
+ if np.random.random() > prob:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return img
+
+
+class VideoRandomAugment(object):
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
+ self.N = N
+ self.M = M
+ self.p = p
+ self.tensor_in_tensor_out = tensor_in_tensor_out
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
+ return [(op, self.M) for op in sampled_ops]
+
+ def __call__(self, frames):
+ assert (
+ frames.shape[-1] == 3
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
+
+ if self.tensor_in_tensor_out:
+ frames = frames.numpy().astype(np.uint8)
+
+ num_frames = frames.shape[0]
+
+ ops = num_frames * [self.get_random_ops()]
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
+
+ frames = torch.stack(
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
+ ).float()
+
+ return frames
+
+ def _aug(self, img, ops, apply_or_not):
+ for i, (name, level) in enumerate(ops):
+ if not apply_or_not[i]:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return torch.from_numpy(img)
+
+
+if __name__ == "__main__":
+ a = RandomAugment()
+ img = np.random.randn(32, 32, 3)
+ a(img)
diff --git a/minigpt4/runners/__init__.py b/minigpt4/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..64e7a4d643a8b5a1714687f42d43347a94b72373
--- /dev/null
+++ b/minigpt4/runners/__init__.py
@@ -0,0 +1,10 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from minigpt4.runners.runner_base import RunnerBase
+
+__all__ = ["RunnerBase"]
diff --git a/minigpt4/runners/__pycache__/__init__.cpython-310.pyc b/minigpt4/runners/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..110c95562f061ac263f942108fdab7eb89e3cff2
Binary files /dev/null and b/minigpt4/runners/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/runners/__pycache__/__init__.cpython-39.pyc b/minigpt4/runners/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..996da796e228af4ed3c2ecb6fda3406ce10611ac
Binary files /dev/null and b/minigpt4/runners/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/runners/__pycache__/runner_base.cpython-310.pyc b/minigpt4/runners/__pycache__/runner_base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49b9a11c7e5a32831c198cd8f15f1d133eda1dd3
Binary files /dev/null and b/minigpt4/runners/__pycache__/runner_base.cpython-310.pyc differ
diff --git a/minigpt4/runners/__pycache__/runner_base.cpython-39.pyc b/minigpt4/runners/__pycache__/runner_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19de23a4b039041dd56233bdef1a7223f16d0ecb
Binary files /dev/null and b/minigpt4/runners/__pycache__/runner_base.cpython-39.pyc differ
diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc8dc1de2c5546bdc6e1ab2f04d92813c9d023bb
--- /dev/null
+++ b/minigpt4/runners/runner_base.py
@@ -0,0 +1,659 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import json
+import logging
+import os
+import time
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from minigpt4.common.dist_utils import (
+ download_cached_file,
+ get_rank,
+ get_world_size,
+ is_main_process,
+ main_process,
+)
+from minigpt4.common.registry import registry
+from minigpt4.common.utils import is_url
+from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
+from minigpt4.datasets.datasets.dataloader_utils import (
+ IterLoader,
+ MultiIterLoader,
+ PrefetchLoader,
+)
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import DataLoader, DistributedSampler
+
+
+@registry.register_runner("runner_base")
+class RunnerBase:
+ """
+ A runner class to train and evaluate a model given a task and datasets.
+
+ The runner uses pytorch distributed data parallel by default. Future release
+ will support other distributed frameworks.
+ """
+
+ def __init__(self, cfg, task, model, datasets, job_id):
+ self.config = cfg
+ self.job_id = job_id
+
+ self.task = task
+ self.datasets = datasets
+
+ self._model = model
+
+ self._wrapped_model = None
+ self._device = None
+ self._optimizer = None
+ self._scaler = None
+ self._dataloaders = None
+ self._lr_sched = None
+
+ self.start_epoch = 0
+
+ # self.setup_seeds()
+ self.setup_output_dir()
+
+ @property
+ def device(self):
+ if self._device is None:
+ self._device = torch.device(self.config.run_cfg.device)
+
+ return self._device
+
+ @property
+ def use_distributed(self):
+ return self.config.run_cfg.distributed
+
+ @property
+ def model(self):
+ """
+ A property to get the DDP-wrapped model on the device.
+ """
+ # move model to device
+ if self._model.device != self.device:
+ self._model = self._model.to(self.device)
+
+ # distributed training wrapper
+ if self.use_distributed:
+ if self._wrapped_model is None:
+ self._wrapped_model = DDP(
+ self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=True
+ )
+ else:
+ self._wrapped_model = self._model
+
+ return self._wrapped_model
+
+ @property
+ def optimizer(self):
+ # TODO make optimizer class and configurations
+ if self._optimizer is None:
+ num_parameters = 0
+ p_wd, p_non_wd = [], []
+ for n, p in self.model.named_parameters():
+ if not p.requires_grad:
+ continue # frozen weights
+ print(n)
+ if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
+ p_non_wd.append(p)
+ else:
+ p_wd.append(p)
+ num_parameters += p.data.nelement()
+ logging.info("number of trainable parameters: %d" % num_parameters)
+ optim_params = [
+ {
+ "params": p_wd,
+ "weight_decay": float(self.config.run_cfg.weight_decay),
+ },
+ {"params": p_non_wd, "weight_decay": 0},
+ ]
+ beta2 = self.config.run_cfg.get("beta2", 0.999)
+ self._optimizer = torch.optim.AdamW(
+ optim_params,
+ lr=float(self.config.run_cfg.init_lr),
+ weight_decay=float(self.config.run_cfg.weight_decay),
+ betas=(0.9, beta2),
+ )
+
+ return self._optimizer
+
+ @property
+ def scaler(self):
+ amp = self.config.run_cfg.get("amp", False)
+
+ if amp:
+ if self._scaler is None:
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ return self._scaler
+
+ @property
+ def lr_scheduler(self):
+ """
+ A property to get and create learning rate scheduler by split just in need.
+ """
+ if self._lr_sched is None:
+ lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
+
+ # max_epoch = self.config.run_cfg.max_epoch
+ max_epoch = self.max_epoch
+ # min_lr = self.config.run_cfg.min_lr
+ min_lr = self.min_lr
+ # init_lr = self.config.run_cfg.init_lr
+ init_lr = self.init_lr
+
+ # optional parameters
+ decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
+ warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
+ warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
+ iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None)
+
+ if iters_per_epoch is None:
+ try:
+ iters_per_epoch = len(self.dataloaders['train'])
+ except (AttributeError, TypeError):
+ iters_per_epoch = 10000
+
+ self._lr_sched = lr_sched_cls(
+ optimizer=self.optimizer,
+ max_epoch=max_epoch,
+ iters_per_epoch=iters_per_epoch,
+ min_lr=min_lr,
+ init_lr=init_lr,
+ decay_rate=decay_rate,
+ warmup_start_lr=warmup_start_lr,
+ warmup_steps=warmup_steps,
+ )
+
+ return self._lr_sched
+
+ @property
+ def dataloaders(self) -> dict:
+ """
+ A property to get and create dataloaders by split just in need.
+
+ If no train_dataset_ratio is provided, concatenate map-style datasets and
+ chain wds.DataPipe datasets separately. Training set becomes a tuple
+ (ConcatDataset, ChainDataset), both are optional but at least one of them is
+ required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
+
+ If train_dataset_ratio is provided, create a MultiIterLoader to sample
+ each dataset by ratios during training.
+
+ Currently do not support multiple datasets for validation and test.
+
+ Returns:
+ dict: {split_name: (tuples of) dataloader}
+ """
+ if self._dataloaders is None:
+
+ # concatenate map-style datasets and chain wds.DataPipe datasets separately
+ # training set becomes a tuple (ConcatDataset, ChainDataset), both are
+ # optional but at least one of them is required. The resultant ConcatDataset
+ # and ChainDataset will be sampled evenly.
+ logging.info(
+ "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
+ )
+
+ batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size
+ for dataset_name in self.datasets.keys()}
+ datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes)
+ self.datasets = datasets
+ # self.datasets = concat_datasets(datasets)
+
+ # print dataset statistics after concatenation/chaining
+ for split_name in self.datasets:
+ if isinstance(self.datasets[split_name], tuple) or isinstance(
+ self.datasets[split_name], list
+ ):
+ # mixed wds.DataPipeline and torch.utils.data.Dataset
+ num_records = sum(
+ [
+ len(d)
+ if not type(d) in [wds.DataPipeline, ChainDataset]
+ else 0
+ for d in self.datasets[split_name]
+ ]
+ )
+
+ else:
+ if hasattr(self.datasets[split_name], "__len__"):
+ # a single map-style dataset
+ num_records = len(self.datasets[split_name])
+ else:
+ # a single wds.DataPipeline
+ num_records = -1
+ logging.info(
+ "Only a single wds.DataPipeline dataset, no __len__ attribute."
+ )
+
+ if num_records >= 0:
+ logging.info(
+ "Loaded {} records for {} split from the dataset.".format(
+ num_records, split_name
+ )
+ )
+
+ # create dataloaders
+ split_names = sorted(self.datasets.keys())
+
+ datasets = [self.datasets[split] for split in split_names]
+ batch_sizes = [batch_sizes[split] for split in split_names]
+ is_trains = [split in self.train_splits for split in split_names]
+
+ print("batch sizes", batch_sizes)
+
+ collate_fns = []
+ for dataset in datasets:
+ if isinstance(dataset, tuple) or isinstance(dataset, list):
+ collate_fns.append([getattr(d, "collater", None) for d in dataset])
+ else:
+ collate_fns.append(getattr(dataset, "collater", None))
+
+ dataloaders = self.create_loaders(
+ datasets=datasets,
+ num_workers=self.config.run_cfg.num_workers,
+ batch_sizes=batch_sizes,
+ is_trains=is_trains,
+ collate_fns=collate_fns,
+ )
+
+ self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
+
+ return self._dataloaders
+
+ @property
+ def cuda_enabled(self):
+ return self.device.type == "cuda"
+
+ @property
+ def max_epoch(self):
+ return int(self.config.run_cfg.max_epoch)
+
+ @property
+ def log_freq(self):
+ log_freq = self.config.run_cfg.get("log_freq", 50)
+ return int(log_freq)
+
+ @property
+ def init_lr(self):
+ return float(self.config.run_cfg.init_lr)
+
+ @property
+ def min_lr(self):
+ return float(self.config.run_cfg.min_lr)
+
+ @property
+ def accum_grad_iters(self):
+ return int(self.config.run_cfg.get("accum_grad_iters", 1))
+
+ @property
+ def valid_splits(self):
+ valid_splits = self.config.run_cfg.get("valid_splits", [])
+
+ if len(valid_splits) == 0:
+ logging.info("No validation splits found.")
+
+ return valid_splits
+
+ @property
+ def test_splits(self):
+ test_splits = self.config.run_cfg.get("test_splits", [])
+
+ return test_splits
+
+ @property
+ def train_splits(self):
+ train_splits = self.config.run_cfg.get("train_splits", [])
+
+ if len(train_splits) == 0:
+ logging.info("Empty train splits.")
+
+ return train_splits
+
+ @property
+ def evaluate_only(self):
+ """
+ Set to True to skip training.
+ """
+ return self.config.run_cfg.evaluate
+
+ @property
+ def use_dist_eval_sampler(self):
+ return self.config.run_cfg.get("use_dist_eval_sampler", True)
+
+ @property
+ def resume_ckpt_path(self):
+ return self.config.run_cfg.get("resume_ckpt_path", None)
+
+ @property
+ def train_loader(self):
+ train_dataloader = self.dataloaders["train"]
+
+ return train_dataloader
+
+ def setup_output_dir(self):
+ lib_root = Path(registry.get_path("library_root"))
+
+ output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
+ # output_dir = lib_root / self.config.run_cfg.output_dir
+ result_dir = output_dir / "result"
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ registry.register_path("result_dir", str(result_dir))
+ registry.register_path("output_dir", str(output_dir))
+
+ self.result_dir = result_dir
+ self.output_dir = output_dir
+
+ def train(self):
+ start_time = time.time()
+ best_agg_metric = 0
+ best_epoch = 0
+
+ self.log_config()
+
+ # resume from checkpoint if specified
+ if not self.evaluate_only and self.resume_ckpt_path is not None:
+ self._load_checkpoint(self.resume_ckpt_path)
+
+ for cur_epoch in range(self.start_epoch, self.max_epoch):
+ # training phase
+ if not self.evaluate_only:
+ logging.info("Start training")
+ train_stats = self.train_epoch(cur_epoch)
+ self.log_stats(split_name="train", stats=train_stats)
+
+ # evaluation phase
+ if len(self.valid_splits) > 0:
+ for split_name in self.valid_splits:
+ logging.info("Evaluating on {}.".format(split_name))
+
+ val_log = self.eval_epoch(
+ split_name=split_name, cur_epoch=cur_epoch
+ )
+ if val_log is not None:
+ if is_main_process():
+ assert (
+ "agg_metrics" in val_log
+ ), "No agg_metrics found in validation log."
+
+ agg_metrics = val_log["agg_metrics"]
+ if agg_metrics > best_agg_metric and split_name == "val":
+ best_epoch, best_agg_metric = cur_epoch, agg_metrics
+
+ self._save_checkpoint(cur_epoch, is_best=True)
+
+ val_log.update({"best_epoch": best_epoch})
+ self.log_stats(val_log, split_name)
+
+ else:
+ # if no validation split is provided, we just save the checkpoint at the end of each epoch.
+ if not self.evaluate_only:
+ self._save_checkpoint(cur_epoch, is_best=False)
+
+ if self.evaluate_only:
+ break
+
+ if self.config.run_cfg.distributed:
+ dist.barrier()
+
+ # testing phase
+ test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
+ self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ logging.info("Training time {}".format(total_time_str))
+
+ def evaluate(self, cur_epoch="best", skip_reload=False):
+ test_logs = dict()
+
+ if len(self.test_splits) > 0:
+ for split_name in self.test_splits:
+ test_logs[split_name] = self.eval_epoch(
+ split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
+ )
+
+ return test_logs
+
+ def train_epoch(self, epoch):
+ # train
+ self.model.train()
+
+ return self.task.train_epoch(
+ epoch=epoch,
+ model=self.model,
+ data_loader=self.train_loader,
+ optimizer=self.optimizer,
+ scaler=self.scaler,
+ lr_scheduler=self.lr_scheduler,
+ cuda_enabled=self.cuda_enabled,
+ log_freq=self.log_freq,
+ accum_grad_iters=self.accum_grad_iters,
+ )
+
+ @torch.no_grad()
+ def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
+ """
+ Evaluate the model on a given split.
+
+ Args:
+ split_name (str): name of the split to evaluate on.
+ cur_epoch (int): current epoch.
+ skip_reload_best (bool): whether to skip reloading the best checkpoint.
+ During training, we will reload the best checkpoint for validation.
+ During testing, we will use provided weights and skip reloading the best checkpoint .
+ """
+ data_loader = self.dataloaders.get(split_name, None)
+ assert data_loader, "data_loader for split {} is None.".format(split_name)
+
+ # TODO In validation, you need to compute loss as well as metrics
+ # TODO consider moving to model.before_evaluation()
+ model = self.unwrap_dist_model(self.model)
+ if not skip_reload and cur_epoch == "best":
+ model = self._reload_best_model(model)
+ model.eval()
+
+ self.task.before_evaluation(
+ model=model,
+ dataset=self.datasets[split_name],
+ )
+ results = self.task.evaluation(model, data_loader)
+
+ if results is not None:
+ return self.task.after_evaluation(
+ val_result=results,
+ split_name=split_name,
+ epoch=cur_epoch,
+ )
+
+ def unwrap_dist_model(self, model):
+ if self.use_distributed:
+ return model.module
+ else:
+ return model
+
+ def create_loaders(
+ self,
+ datasets,
+ num_workers,
+ batch_sizes,
+ is_trains,
+ collate_fns,
+ dataset_ratios=None,
+ ):
+ """
+ Create dataloaders for training and validation.
+ """
+
+ def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
+ # create a single dataloader for each split
+ if isinstance(dataset, ChainDataset) or isinstance(
+ dataset, wds.DataPipeline
+ ):
+ # wds.WebdDataset instance are chained together
+ # webdataset.DataPipeline has its own sampler and collate_fn
+ loader = iter(
+ DataLoader(
+ dataset,
+ batch_size=bsz,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
+ )
+ else:
+ # map-style dataset are concatenated together
+ # setup distributed sampler
+
+ if self.use_distributed:
+ sampler = DistributedSampler(
+ dataset,
+ shuffle=is_train,
+ num_replicas=get_world_size(),
+ rank=get_rank(),
+ )
+ if not self.use_dist_eval_sampler:
+ # e.g. retrieval evaluation
+ sampler = sampler if is_train else None
+ else:
+ sampler = None
+
+ loader = DataLoader(
+ dataset,
+ batch_size=bsz,
+ num_workers=num_workers,
+ pin_memory=True,
+ sampler=sampler,
+ shuffle=sampler is None and is_train,
+ collate_fn=collate_fn,
+ drop_last=True if is_train else False,
+ )
+ loader = PrefetchLoader(loader)
+
+ if is_train:
+ loader = IterLoader(loader, use_distributed=self.use_distributed)
+
+ return loader
+
+ loaders = []
+
+ for dataset, bsz, is_train, collate_fn in zip(
+ datasets, batch_sizes, is_trains, collate_fns
+ ):
+ if isinstance(dataset, list) or isinstance(dataset, tuple):
+ if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None:
+ dataset_ratios = [d.sample_ratio for d in dataset]
+ loader = MultiIterLoader(
+ loaders=[
+ _create_loader(d, num_workers, bsz[i], is_train, collate_fn[i])
+ for i, d in enumerate(dataset)
+ ],
+ ratios=dataset_ratios,
+ )
+ else:
+ loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
+
+ loaders.append(loader)
+
+ return loaders
+
+ @main_process
+ def _save_checkpoint(self, cur_epoch, is_best=False):
+ """
+ Save the checkpoint at the current epoch.
+ """
+ model_no_ddp = self.unwrap_dist_model(self.model)
+ param_grad_dic = {
+ k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
+ }
+ state_dict = model_no_ddp.state_dict()
+ for k in list(state_dict.keys()):
+ if k in param_grad_dic.keys() and not param_grad_dic[k]:
+ # delete parameters that do not require gradient
+ del state_dict[k]
+ save_obj = {
+ "model": state_dict,
+ "optimizer": self.optimizer.state_dict(),
+ "config": self.config.to_dict(),
+ "scaler": self.scaler.state_dict() if self.scaler else None,
+ "epoch": cur_epoch,
+ }
+ save_to = os.path.join(
+ self.output_dir,
+ "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
+ )
+ logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
+ torch.save(save_obj, save_to)
+
+ def _reload_best_model(self, model):
+ """
+ Load the best checkpoint for evaluation.
+ """
+ checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
+
+ logging.info("Loading checkpoint from {}.".format(checkpoint_path))
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ try:
+ model.load_state_dict(checkpoint["model"])
+ except RuntimeError as e:
+ logging.warning(
+ """
+ Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
+ Trying to load the model with strict=False.
+ """
+ )
+ model.load_state_dict(checkpoint["model"], strict=False)
+ return model
+
+ def _load_checkpoint(self, url_or_filename):
+ """
+ Resume from a checkpoint.
+ """
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(
+ url_or_filename, check_hash=False, progress=True
+ )
+ checkpoint = torch.load(cached_file, map_location=self.device)
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location=self.device)
+ else:
+ raise RuntimeError("checkpoint url or path is invalid")
+
+ state_dict = checkpoint["model"]
+ message = self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False)
+
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ if self.scaler and "scaler" in checkpoint:
+ self.scaler.load_state_dict(checkpoint["scaler"])
+
+ self.start_epoch = checkpoint["epoch"] + 1
+ print("resume the checkpoint")
+ logging.info("Resume checkpoint from {}".format(url_or_filename))
+
+ @main_process
+ def log_stats(self, stats, split_name):
+ if isinstance(stats, dict):
+ log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ elif isinstance(stats, list):
+ pass
+
+ @main_process
+ def log_config(self):
+ with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+ f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
diff --git a/minigpt4/tasks/__init__.py b/minigpt4/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab1fb1c8289535cf9397bb9805c0cba3666ad26f
--- /dev/null
+++ b/minigpt4/tasks/__init__.py
@@ -0,0 +1,26 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from minigpt4.common.registry import registry
+from minigpt4.tasks.base_task import BaseTask
+from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask
+
+
+def setup_task(cfg):
+ assert "task" in cfg.run_cfg, "Task name must be provided."
+
+ task_name = cfg.run_cfg.task
+ task = registry.get_task_class(task_name).setup_task(cfg=cfg)
+ assert task is not None, "Task {} not properly registered.".format(task_name)
+
+ return task
+
+
+__all__ = [
+ "BaseTask",
+ "ImageTextPretrainTask",
+]
diff --git a/minigpt4/tasks/__pycache__/__init__.cpython-310.pyc b/minigpt4/tasks/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6fbab072bd8f2487017580493447ec8c65282481
Binary files /dev/null and b/minigpt4/tasks/__pycache__/__init__.cpython-310.pyc differ
diff --git a/minigpt4/tasks/__pycache__/__init__.cpython-39.pyc b/minigpt4/tasks/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..931b3ed6cda0fb7a7c7de17416a38f20f125e754
Binary files /dev/null and b/minigpt4/tasks/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/tasks/__pycache__/base_task.cpython-310.pyc b/minigpt4/tasks/__pycache__/base_task.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1001f51512f31b335233f8862b29d36c33009d1
Binary files /dev/null and b/minigpt4/tasks/__pycache__/base_task.cpython-310.pyc differ
diff --git a/minigpt4/tasks/__pycache__/base_task.cpython-39.pyc b/minigpt4/tasks/__pycache__/base_task.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e81f2a103f797f94eca4e734a49b9931f762056
Binary files /dev/null and b/minigpt4/tasks/__pycache__/base_task.cpython-39.pyc differ
diff --git a/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-310.pyc b/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfe294319d64b658c065aa62b881285d5a39cc05
Binary files /dev/null and b/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-310.pyc differ
diff --git a/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc b/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9949e7684d43c0e7a3b50b8026cb4b8175891c7
Binary files /dev/null and b/minigpt4/tasks/__pycache__/image_text_pretrain.cpython-39.pyc differ
diff --git a/minigpt4/tasks/base_task.py b/minigpt4/tasks/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cfa46ce6ae8b0319e7094d23bf9d1ff0393f9b9
--- /dev/null
+++ b/minigpt4/tasks/base_task.py
@@ -0,0 +1,290 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
+from minigpt4.common.logger import MetricLogger, SmoothedValue
+from minigpt4.common.registry import registry
+from minigpt4.datasets.data_utils import prepare_sample
+import wandb
+
+class BaseTask:
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ self.inst_id_key = "instance_id"
+ self.cfg = ""
+
+ @classmethod
+ def setup_task(cls, **kwargs):
+ return cls()
+
+ def build_model(self, cfg):
+ self.cfg = cfg
+ model_config = cfg.model_cfg
+
+ model_cls = registry.get_model_class(model_config.arch)
+ return model_cls.from_config(model_config)
+
+ def build_datasets(self, cfg):
+ """
+ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
+ Download dataset and annotations automatically if not exist.
+
+ Args:
+ cfg (common.config.Config): _description_
+
+ Returns:
+ dict: Dictionary of torch.utils.data.Dataset objects by split.
+ """
+
+ datasets = dict()
+
+ datasets_config = cfg.datasets_cfg
+
+ assert len(datasets_config) > 0, "At least one dataset has to be specified."
+
+ for name in datasets_config:
+ dataset_config = datasets_config[name]
+
+ builder = registry.get_builder_class(name)(dataset_config)
+ dataset = builder.build_datasets()
+
+ dataset['train'].name = name
+ if 'sample_ratio' in dataset_config:
+ dataset['train'].sample_ratio = dataset_config.sample_ratio
+
+ datasets[name] = dataset
+
+ return datasets
+
+ def train_step(self, model, samples):
+ loss = model(samples)["loss"]
+ return loss
+
+ def valid_step(self, model, samples):
+ raise NotImplementedError
+
+ def before_evaluation(self, model, dataset, **kwargs):
+ model.before_evaluation(dataset=dataset, task_type=type(self))
+
+ def after_evaluation(self, **kwargs):
+ pass
+
+ def inference_step(self):
+ raise NotImplementedError
+
+ def evaluation(self, model, data_loader, cuda_enabled=True):
+ metric_logger = MetricLogger(delimiter=" ")
+ header = "Evaluation"
+ # TODO make it configurable
+ print_freq = 10
+
+ results = []
+
+ for samples in metric_logger.log_every(data_loader, print_freq, header):
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+
+ eval_output = self.valid_step(model=model, samples=samples)
+ results.extend(eval_output)
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ return results
+
+ def train_epoch(
+ self,
+ epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ cuda_enabled=False,
+ log_freq=50,
+ accum_grad_iters=1,
+ ):
+ return self._train_inner_loop(
+ epoch=epoch,
+ iters_per_epoch=lr_scheduler.iters_per_epoch,
+ model=model,
+ data_loader=data_loader,
+ optimizer=optimizer,
+ scaler=scaler,
+ lr_scheduler=lr_scheduler,
+ log_freq=log_freq,
+ cuda_enabled=cuda_enabled,
+ accum_grad_iters=accum_grad_iters,
+ )
+
+ def train_iters(
+ self,
+ epoch,
+ start_iters,
+ iters_per_inner_epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ cuda_enabled=False,
+ log_freq=50,
+ accum_grad_iters=1,
+ ):
+ return self._train_inner_loop(
+ epoch=epoch,
+ start_iters=start_iters,
+ iters_per_epoch=iters_per_inner_epoch,
+ model=model,
+ data_loader=data_loader,
+ optimizer=optimizer,
+ scaler=scaler,
+ lr_scheduler=lr_scheduler,
+ log_freq=log_freq,
+ cuda_enabled=cuda_enabled,
+ accum_grad_iters=accum_grad_iters,
+ )
+
+ def _train_inner_loop(
+ self,
+ epoch,
+ iters_per_epoch,
+ model,
+ data_loader,
+ optimizer,
+ lr_scheduler,
+ scaler=None,
+ start_iters=None,
+ log_freq=50,
+ cuda_enabled=False,
+ accum_grad_iters=1,
+ ):
+ """
+ An inner training loop compatible with both epoch-based and iter-based training.
+
+ When using epoch-based, training stops after one epoch; when using iter-based,
+ training stops after #iters_per_epoch iterations.
+ """
+ use_amp = scaler is not None
+
+ if not hasattr(data_loader, "__next__"):
+ # convert to iterator if not already
+ data_loader = iter(data_loader)
+
+ metric_logger = MetricLogger(delimiter=" ")
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
+ metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
+
+ # if iter-based runner, schedule lr based on inner epoch.
+ logging.info(
+ "Start training epoch {}, {} iters per inner epoch.".format(
+ epoch, iters_per_epoch
+ )
+ )
+ header = "Train: data epoch: [{}]".format(epoch)
+ if start_iters is None:
+ # epoch-based runner
+ inner_epoch = epoch
+ else:
+ # In iter-based runner, we schedule the learning rate based on iterations.
+ inner_epoch = start_iters // iters_per_epoch
+ header = header + "; inner epoch [{}]".format(inner_epoch)
+
+ for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
+ # if using iter-based runner, we stop after iters_per_epoch iterations.
+ if i >= iters_per_epoch:
+ break
+
+ samples = next(data_loader)
+
+ samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+ samples.update(
+ {
+ "epoch": inner_epoch,
+ "num_iters_per_epoch": iters_per_epoch,
+ "iters": i,
+ }
+ )
+
+ lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ loss = self.train_step(model=model, samples=samples)
+
+ # after_train_step()
+ if use_amp:
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ # update gradients every accum_grad_iters iterations
+ if (i + 1) % accum_grad_iters == 0:
+ if use_amp:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ optimizer.step()
+ optimizer.zero_grad()
+ # if self.cfg.wandb_log:
+ if self.cfg.run_cfg.wandb_log:
+ wandb.log({"epoch": inner_epoch, "loss": loss})
+ metric_logger.update(loss=loss.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # after train_epoch()
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ logging.info("Averaged stats: " + str(metric_logger.global_avg()))
+ return {
+ k: "{:.3f}".format(meter.global_avg)
+ for k, meter in metric_logger.meters.items()
+ }
+
+ @staticmethod
+ def save_result(result, result_dir, filename, remove_duplicate=""):
+ import json
+
+ result_file = os.path.join(
+ result_dir, "%s_rank%d.json" % (filename, get_rank())
+ )
+ final_result_file = os.path.join(result_dir, "%s.json" % filename)
+
+ json.dump(result, open(result_file, "w"))
+
+ if is_dist_avail_and_initialized():
+ dist.barrier()
+
+ if is_main_process():
+ logging.warning("rank %d starts merging results." % get_rank())
+ # combine results from all processes
+ result = []
+
+ for rank in range(get_world_size()):
+ result_file = os.path.join(
+ result_dir, "%s_rank%d.json" % (filename, rank)
+ )
+ res = json.load(open(result_file, "r"))
+ result += res
+
+ if remove_duplicate:
+ result_new = []
+ id_list = []
+ for res in result:
+ if res[remove_duplicate] not in id_list:
+ id_list.append(res[remove_duplicate])
+ result_new.append(res)
+ result = result_new
+
+ json.dump(result, open(final_result_file, "w"))
+ print("result file saved to %s" % final_result_file)
+
+ return final_result_file
diff --git a/minigpt4/tasks/image_text_pretrain.py b/minigpt4/tasks/image_text_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbe8ec83a5dc95ee26a36e457feb394d18b7cd17
--- /dev/null
+++ b/minigpt4/tasks/image_text_pretrain.py
@@ -0,0 +1,18 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from minigpt4.common.registry import registry
+from minigpt4.tasks.base_task import BaseTask
+
+
+@registry.register_task("image_text_pretrain")
+class ImageTextPretrainTask(BaseTask):
+ def __init__(self):
+ super().__init__()
+
+ def evaluation(self, model, data_loader, cuda_enabled=True):
+ pass
diff --git a/prompts/alignment.txt b/prompts/alignment.txt
new file mode 100644
index 0000000000000000000000000000000000000000..38ae75a9cee293861f06544cbff6fdc4aa941d85
--- /dev/null
+++ b/prompts/alignment.txt
@@ -0,0 +1,4 @@
+
Describe this image in detail.
+
Take a look at this image and describe what you notice.
+
Please provide a detailed description of the picture.
+
Could you describe the contents of this image for me?
\ No newline at end of file
diff --git a/requirmentsss.txt b/requirmentsss.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fa238851425a7f8c33f11aeb153b41d905e8cba7
--- /dev/null
+++ b/requirmentsss.txt
@@ -0,0 +1,6 @@
+protobuf<4
+sentencepiece==0.1.98
+timm
+torch==2.0.1
+open-clip-torch==2.16.0
+open-flamingo==2.0.1
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dead8e599625cb6cb33c41beaa906f35ace8194
--- /dev/null
+++ b/train.py
@@ -0,0 +1,104 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import argparse
+import os
+import random
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+import wandb
+
+import minigpt4.tasks as tasks
+from minigpt4.common.config import Config
+from minigpt4.common.dist_utils import get_rank, init_distributed_mode
+from minigpt4.common.logger import setup_logger
+from minigpt4.common.optims import (
+ LinearWarmupCosineLRScheduler,
+ LinearWarmupStepLRScheduler,
+)
+from minigpt4.common.registry import registry
+from minigpt4.common.utils import now
+
+# imports modules for registration
+from minigpt4.datasets.builders import *
+from minigpt4.models import *
+from minigpt4.processors import *
+from minigpt4.runners import *
+from minigpt4.tasks import *
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Training")
+
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ args = parser.parse_args()
+
+ return args
+
+
+def setup_seeds(config):
+ seed = config.run_cfg.seed + get_rank()
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ cudnn.benchmark = False
+ cudnn.deterministic = True
+
+
+def get_runner_class(cfg):
+ """
+ Get runner class from config. Default to epoch-based runner.
+ """
+ runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
+
+ return runner_cls
+
+
+def main():
+ # allow auto-dl completes on main process without timeout when using NCCL backend.
+ # os.environ["NCCL_BLOCKING_WAIT"] = "1"
+
+ # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
+ job_id = now()
+ args = parse_args()
+ cfg = Config(args)
+
+ init_distributed_mode(cfg.run_cfg)
+ setup_seeds(cfg)
+
+ # set after init_distributed_mode() to only log on master.
+ setup_logger()
+ cfg.pretty_print()
+
+ task = tasks.setup_task(cfg)
+ datasets = task.build_datasets(cfg)
+ model = task.build_model(cfg)
+
+ if cfg.run_cfg.wandb_log:
+ wandb.login()
+ wandb.init(project="minigptv", name=cfg.run_cfg.job_name)
+ wandb.watch(model)
+
+ runner = get_runner_class(cfg)(
+ cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
+ )
+ runner.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/train_configs/minigptv2_finetune.yaml b/train_configs/minigptv2_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..438cf0d43fee942cea4920f73cd116e23c0d13fe
--- /dev/null
+++ b/train_configs/minigptv2_finetune.yaml
@@ -0,0 +1,135 @@
+# torchrun --nproc-per-node 1 train.py --cfg-path train_configs/minigptv2_finetune.yaml
+model:
+ arch: minigpt_v2
+ model_type: pretrain
+ max_txt_len: 1024
+ image_size: 448
+ end_sym: ""
+ llama_model: "llama-2-7b-chat-hf"
+ ckpt: "/MiniGPT4-v2/checkpoints/checkpoint_stage3.pth"
+ use_grad_checkpoint: True
+ chat_template: True
+ lora_r: 64
+ lora_alpha: 16
+
+datasets:
+ grounding_SLAKE:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+ mimic_cxr:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+ radvqa:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+ # rsna:
+ # batch_size: 6
+ # vis_processor:
+ # train:
+ # name: "blip2_image_train"
+ # image_size: 448
+ # text_processor:
+ # train:
+ # name: "blip_caption"
+
+ # refer_rsna:
+ # batch_size: 6
+ # vis_processor:
+ # train:
+ # name: "blip2_image_train"
+ # image_size: 448
+ # text_processor:
+ # train:
+ # name: "blip_caption"
+
+ # identify_rsna:
+ # batch_size: 6
+ # vis_processor:
+ # train:
+ # name: "blip2_image_train"
+ # image_size: 448
+ # text_processor:
+ # train:
+ # name: "blip_caption"
+
+ nlst:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+ refer_nlst:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+ identify_nlst:
+ batch_size: 6
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 448
+ text_processor:
+ train:
+ name: "blip_caption"
+
+
+run:
+ task: image_text_pretrain
+ lr_sched: "linear_warmup_cosine_lr"
+ init_lr: 1e-5
+ min_lr: 8e-5
+ warmup_lr: 1e-6
+
+ weight_decay: 0.05
+ max_epoch: 100
+ num_workers: 6
+ warmup_steps: 1000
+ iters_per_epoch: 1000
+
+ seed: 42
+ output_dir: "expermints_folder"
+
+ amp: True
+ resume_ckpt_path: null
+
+ evaluate: False
+ train_splits: ["train"]
+
+ device: "cuda"
+ world_size: 1
+ dist_url: "env://"
+ distributed: True
+
+ wandb_log: True
+ job_name: minigptv2_finetune_final
\ No newline at end of file