Tyrannosaurus
commited on
Commit
•
38fd7ad
1
Parent(s):
d96ba24
Upload 10 files
Browse files- .gitattributes +2 -34
- ArtGPT-4.pdf +0 -0
- LICENSE +29 -0
- LICENSE.md +14 -0
- LICENSE_Lavis.md +14 -0
- PrepareVicuna.md +35 -0
- README.md +118 -3
- demo.py +157 -0
- environment.yml +63 -0
- train.py +104 -0
.gitattributes
CHANGED
@@ -1,34 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ArtGPT-4.pdf
ADDED
Binary file (678 kB). View file
|
|
LICENSE
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2023, DLYuanGod
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
3. Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
LICENSE.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright 2023 Deyao Zhu
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
11 |
+
|
12 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
LICENSE_Lavis.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Salesforce, Inc.
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
11 |
+
|
12 |
+
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
PrepareVicuna.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## How to Prepare Vicuna Weight
|
2 |
+
Vicuna is an open-source LLAMA-based LLM that has a performance close to ChatGPT.
|
3 |
+
We currently use the v0 version of Vicuna-13B.
|
4 |
+
|
5 |
+
To prepare Vicuna’s weight, first download Vicuna’s **delta** weight from [https://huggingface.co/lmsys/vicuna-13b-delta-v0](https://huggingface.co/lmsys/vicuna-13b-delta-v0).
|
6 |
+
In case you have git-lfs installed (https://git-lfs.com), this can be done by
|
7 |
+
|
8 |
+
```
|
9 |
+
git lfs install
|
10 |
+
git clone https://huggingface.co/lmsys/vicuna-13b-delta-v0 # more powerful, need at least 24G gpu memory
|
11 |
+
# or
|
12 |
+
git clone https://huggingface.co/lmsys/vicuna-7b-delta-v0 # smaller, need 12G gpu memory
|
13 |
+
```
|
14 |
+
|
15 |
+
Note that this is not directly the working weight, but the difference between the working weight and the original weight of LLAMA-13B. (Due to LLAMA’s rules, we cannot distribute the weight of LLAMA.)
|
16 |
+
|
17 |
+
Then, you need to obtain the original LLAMA-7B or LLAMA-13B weights in the HuggingFace format
|
18 |
+
either following the instruction provided by HuggingFace
|
19 |
+
[here](https://huggingface.co/docs/transformers/main/model_doc/llama) or from the Internet.
|
20 |
+
|
21 |
+
When these two weights are ready, we can use tools from Vicuna’s team to create the real working weight.
|
22 |
+
First, Install their library that is compatible with v0 Vicuna by
|
23 |
+
|
24 |
+
```
|
25 |
+
pip install git+https://github.com/lm-sys/[email protected]
|
26 |
+
```
|
27 |
+
|
28 |
+
Then, run the following command to create the final working weight
|
29 |
+
|
30 |
+
```
|
31 |
+
python -m fastchat.model.apply_delta --base /path/to/llama-13bOR7b-hf/ --target /path/to/save/working/vicuna/weight/ --delta /path/to/vicuna-13bOR7b-delta-v0/
|
32 |
+
```
|
33 |
+
|
34 |
+
Now you are good to go!
|
35 |
+
|
README.md
CHANGED
@@ -1,3 +1,118 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ArtGPT-4: Artistic Vision-Language Understanding with Adapter-enhanced MiniGPT-4
|
2 |
+
[Zhengqing Yuan](https://orcid.org/0000-0002-4870-8492)*, [Huiwen Xue]()*, [Xinyi Wang]()*, [Yongming Liu](https://www.semanticscholar.org/author/Yongming-Liu/2130184867)*, [Zhuanzhe Zhao](https://www.semanticscholar.org/author/Zhuanzhe-Zhao/2727550)*, and [Kun Wang](https://www.ahpu.edu.cn/jsjyxxgc/2023/0220/c5472a187109/page.htm)*. *Equal Contribution
|
3 |
+
|
4 |
+
**Anhui Polytechnic University, Soochow University**
|
5 |
+
|
6 |
+
<a href='https://artgpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='ArtGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
|
7 |
+
<!-- <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) -->
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
## Online Demo
|
12 |
+
|
13 |
+
<!-- Click the image to chat with MiniGPT-4 around your images
|
14 |
+
[![demo](figs/online_demo.png)](https://artgpt-4.github.io) -->
|
15 |
+
Waiting for updates...
|
16 |
+
|
17 |
+
|
18 |
+
## Examples
|
19 |
+
| | |
|
20 |
+
:-------------------------:|:-------------------------:
|
21 |
+
![Dec Image](figs/examples/Art1.png)
|
22 |
+
![Art](figs/examples/Art2.png)
|
23 |
+
![Web](figs/examples/Art3-G.png)
|
24 |
+
|
25 |
+
More examples can be found in the [project page](https://artgpt-4.github.io).
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
## Introduction
|
30 |
+
- ArtGPT-4 is a novel model that builds upon the architecture of MiniGPT-4 by incorporating tailored linear layers and activation functions into Vicuna, specifically designed to optimize the model's performance in vision-language tasks.
|
31 |
+
- The modifications made to Vicuna in ArtGPT-4 enable the model to better capture intricate details and understand the meaning of artistic images, resulting in improved image understanding compared to the original MiniGPT-4 model.
|
32 |
+
- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset.
|
33 |
+
- ArtGPT-4 was trained using about 200 GB of image-text pairs on a Tesla A100 device in just 2 hours, demonstrating impressive efficiency and effectiveness in training.
|
34 |
+
- In addition to improved image understanding, ArtGPT-4 is capable of generating visual code, including aesthetically pleasing HTML/CSS web pages, with a more artistic flair.
|
35 |
+
|
36 |
+
|
37 |
+
![overview](figs/examples/TBLOC.png)
|
38 |
+
|
39 |
+
|
40 |
+
## Getting Started
|
41 |
+
### Installation
|
42 |
+
|
43 |
+
**1. Prepare the code and the environment**
|
44 |
+
|
45 |
+
Git clone our repository, creating a python environment and ativate it via the following command
|
46 |
+
|
47 |
+
```bash
|
48 |
+
git clone https://github.com/DLYuanGod/ArtGPT-4.git
|
49 |
+
cd ArtGPT-4
|
50 |
+
conda env create -f environment.yml
|
51 |
+
conda activate artgpt4
|
52 |
+
```
|
53 |
+
|
54 |
+
|
55 |
+
**2. Prepare the pretrained Vicuna weights**
|
56 |
+
|
57 |
+
The current version of MiniGPT-4 is built on the v0 versoin of Vicuna-13B.
|
58 |
+
Please refer to our instruction [here](PrepareVicuna.md)
|
59 |
+
to prepare the Vicuna weights.
|
60 |
+
The final weights would be in a single folder in a structure similar to the following:
|
61 |
+
|
62 |
+
```
|
63 |
+
vicuna_weights
|
64 |
+
├── config.json
|
65 |
+
├── generation_config.json
|
66 |
+
├── pytorch_model.bin.index.json
|
67 |
+
├── pytorch_model-00001-of-00003.bin
|
68 |
+
...
|
69 |
+
```
|
70 |
+
|
71 |
+
Then, set the path to the vicuna weight in the model config file
|
72 |
+
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
|
73 |
+
|
74 |
+
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
75 |
+
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link)
|
76 |
+
|
77 |
+
|
78 |
+
Then, set the path to the pretrained checkpoint in the evaluation config file
|
79 |
+
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
### Launching Demo Locally
|
84 |
+
|
85 |
+
Try out our demo [demo.py](demo.py) on your local machine by running
|
86 |
+
|
87 |
+
```
|
88 |
+
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
89 |
+
```
|
90 |
+
|
91 |
+
|
92 |
+
### Training
|
93 |
+
The training of ArtGPT-4 contains two alignment stages. The training process for the step is consistent with that of [MiniGPT-4](https://minigpt-4.github.io/).
|
94 |
+
|
95 |
+
**Datasets**
|
96 |
+
We use [Laion-aesthetic](https://github.com/LAION-AI/laion-datasets/blob/main/laion-aesthetic.md) from the LAION-5B dataset, which amounts to approximately 200GB for the first 302 tar files.
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
## Acknowledgement
|
101 |
+
|
102 |
+
+ [MiniGPT-4](https://minigpt-4.github.io/) Our work is based on improvements to the model.
|
103 |
+
|
104 |
+
|
105 |
+
If you're using ArtGPT-4 in your research or applications, please cite using this BibTeX:
|
106 |
+
```bibtex
|
107 |
+
@article{yuan2023artgpt4,
|
108 |
+
title={ArtGPT-4: Artistic Vision-Language Understanding with Adapter-enhanced MiniGPT-4},
|
109 |
+
author={Yuan, Zhengqng and Xue, Huiwen and Wang, Xinyi and Liu, Yongming and Zhao, zhuanzhe and Wang, Kun},
|
110 |
+
year={2023}
|
111 |
+
}
|
112 |
+
```
|
113 |
+
|
114 |
+
|
115 |
+
## License
|
116 |
+
This repository is under [BSD 3-Clause License](LICENSE.md).
|
117 |
+
Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
|
118 |
+
BSD 3-Clause License [here](LICENSE_Lavis.md).
|
demo.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from minigpt4.common.config import Config
|
11 |
+
from minigpt4.common.dist_utils import get_rank
|
12 |
+
from minigpt4.common.registry import registry
|
13 |
+
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
14 |
+
|
15 |
+
# imports modules for registration
|
16 |
+
from minigpt4.datasets.builders import *
|
17 |
+
from minigpt4.models import *
|
18 |
+
from minigpt4.processors import *
|
19 |
+
from minigpt4.runners import *
|
20 |
+
from minigpt4.tasks import *
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args():
|
24 |
+
parser = argparse.ArgumentParser(description="Demo")
|
25 |
+
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
26 |
+
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
27 |
+
parser.add_argument(
|
28 |
+
"--options",
|
29 |
+
nargs="+",
|
30 |
+
help="override some settings in the used config, the key-value pair "
|
31 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
32 |
+
"change to --cfg-options instead.",
|
33 |
+
)
|
34 |
+
args = parser.parse_args()
|
35 |
+
return args
|
36 |
+
|
37 |
+
|
38 |
+
def setup_seeds(config):
|
39 |
+
seed = config.run_cfg.seed + get_rank()
|
40 |
+
|
41 |
+
random.seed(seed)
|
42 |
+
np.random.seed(seed)
|
43 |
+
torch.manual_seed(seed)
|
44 |
+
|
45 |
+
cudnn.benchmark = False
|
46 |
+
cudnn.deterministic = True
|
47 |
+
|
48 |
+
|
49 |
+
# ========================================
|
50 |
+
# Model Initialization
|
51 |
+
# ========================================
|
52 |
+
|
53 |
+
print('Initializing Chat')
|
54 |
+
args = parse_args()
|
55 |
+
cfg = Config(args)
|
56 |
+
|
57 |
+
model_config = cfg.model_cfg
|
58 |
+
model_config.device_8bit = args.gpu_id
|
59 |
+
model_cls = registry.get_model_class(model_config.arch)
|
60 |
+
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
61 |
+
|
62 |
+
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
63 |
+
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
64 |
+
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
65 |
+
print('Initialization Finished')
|
66 |
+
|
67 |
+
# ========================================
|
68 |
+
# Gradio Setting
|
69 |
+
# ========================================
|
70 |
+
|
71 |
+
def gradio_reset(chat_state, img_list):
|
72 |
+
if chat_state is not None:
|
73 |
+
chat_state.messages = []
|
74 |
+
if img_list is not None:
|
75 |
+
img_list = []
|
76 |
+
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
77 |
+
|
78 |
+
def upload_img(gr_img, text_input, chat_state):
|
79 |
+
if gr_img is None:
|
80 |
+
return None, None, gr.update(interactive=True), chat_state, None
|
81 |
+
chat_state = CONV_VISION.copy()
|
82 |
+
img_list = []
|
83 |
+
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
84 |
+
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
85 |
+
|
86 |
+
def gradio_ask(user_message, chatbot, chat_state):
|
87 |
+
if len(user_message) == 0:
|
88 |
+
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
89 |
+
chat.ask(user_message, chat_state)
|
90 |
+
chatbot = chatbot + [[user_message, None]]
|
91 |
+
return '', chatbot, chat_state
|
92 |
+
|
93 |
+
|
94 |
+
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
95 |
+
llm_message = chat.answer(conv=chat_state,
|
96 |
+
img_list=img_list,
|
97 |
+
num_beams=num_beams,
|
98 |
+
temperature=temperature,
|
99 |
+
max_new_tokens=300,
|
100 |
+
max_length=2000)[0]
|
101 |
+
chatbot[-1][1] = llm_message
|
102 |
+
return chatbot, chat_state, img_list
|
103 |
+
|
104 |
+
title = """<h1 align="center">ArtGPT-4:</h1>"""
|
105 |
+
description = """<h2 align="center"><font color="skyblue">Artistic Vision-Language Understanding with Adapter-enhanced MiniGPT-4</h2>"""
|
106 |
+
Authors = """<h3 align="center"><font color="skyblue">Zhengqing Yuan, Yongming Liu, Xinyi Wang, Zhuanzhe Zhao</h3>"""
|
107 |
+
Address = """<h3 align="center">School of Artificial Intelligence, Anhui Polytechnic University</h3>"""
|
108 |
+
#article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>"
|
109 |
+
""
|
110 |
+
|
111 |
+
#TODO show examples below
|
112 |
+
|
113 |
+
with gr.Blocks() as demo:
|
114 |
+
gr.Markdown(title)
|
115 |
+
gr.Markdown(description)
|
116 |
+
gr.Markdown(Authors)
|
117 |
+
gr.Markdown(Address)
|
118 |
+
#gr.Markdown(article)
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
with gr.Column(scale=0.5):
|
122 |
+
image = gr.Image(type="pil")
|
123 |
+
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
124 |
+
clear = gr.Button("Restart")
|
125 |
+
|
126 |
+
num_beams = gr.Slider(
|
127 |
+
minimum=1,
|
128 |
+
maximum=10,
|
129 |
+
value=1,
|
130 |
+
step=1,
|
131 |
+
interactive=True,
|
132 |
+
label="beam search numbers)",
|
133 |
+
)
|
134 |
+
|
135 |
+
temperature = gr.Slider(
|
136 |
+
minimum=0.1,
|
137 |
+
maximum=2.0,
|
138 |
+
value=1.0,
|
139 |
+
step=0.1,
|
140 |
+
interactive=True,
|
141 |
+
label="Temperature",
|
142 |
+
)
|
143 |
+
|
144 |
+
with gr.Column():
|
145 |
+
chat_state = gr.State()
|
146 |
+
img_list = gr.State()
|
147 |
+
chatbot = gr.Chatbot(label='ArtGPT-4')
|
148 |
+
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
149 |
+
|
150 |
+
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
|
151 |
+
|
152 |
+
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
153 |
+
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
154 |
+
)
|
155 |
+
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
|
156 |
+
|
157 |
+
demo.launch(share=True, enable_queue=True)
|
environment.yml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: artgpt4
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
- anaconda
|
6 |
+
dependencies:
|
7 |
+
- python=3.9
|
8 |
+
- cudatoolkit
|
9 |
+
- pip
|
10 |
+
- pytorch=1.12.1
|
11 |
+
- pytorch-mutex=1.0=cuda
|
12 |
+
- torchaudio=0.12.1
|
13 |
+
- torchvision=0.13.1
|
14 |
+
- pip:
|
15 |
+
- accelerate==0.16.0
|
16 |
+
- aiohttp==3.8.4
|
17 |
+
- aiosignal==1.3.1
|
18 |
+
- async-timeout==4.0.2
|
19 |
+
- attrs==22.2.0
|
20 |
+
- bitsandbytes==0.37.0
|
21 |
+
- cchardet==2.1.7
|
22 |
+
- chardet==5.1.0
|
23 |
+
- contourpy==1.0.7
|
24 |
+
- cycler==0.11.0
|
25 |
+
- filelock==3.9.0
|
26 |
+
- fonttools==4.38.0
|
27 |
+
- frozenlist==1.3.3
|
28 |
+
- huggingface-hub==0.13.4
|
29 |
+
- importlib-resources==5.12.0
|
30 |
+
- kiwisolver==1.4.4
|
31 |
+
- matplotlib==3.7.0
|
32 |
+
- multidict==6.0.4
|
33 |
+
- openai==0.27.0
|
34 |
+
- packaging==23.0
|
35 |
+
- psutil==5.9.4
|
36 |
+
- pycocotools==2.0.6
|
37 |
+
- pyparsing==3.0.9
|
38 |
+
- python-dateutil==2.8.2
|
39 |
+
- pyyaml==6.0
|
40 |
+
- regex==2022.10.31
|
41 |
+
- tokenizers==0.13.2
|
42 |
+
- tqdm==4.64.1
|
43 |
+
- transformers==4.28.0
|
44 |
+
- timm==0.6.13
|
45 |
+
- spacy==3.5.1
|
46 |
+
- webdataset==0.2.48
|
47 |
+
- scikit-learn==1.2.2
|
48 |
+
- scipy==1.10.1
|
49 |
+
- yarl==1.8.2
|
50 |
+
- zipp==3.14.0
|
51 |
+
- omegaconf==2.3.0
|
52 |
+
- opencv-python==4.7.0.72
|
53 |
+
- iopath==0.1.10
|
54 |
+
- decord==0.6.0
|
55 |
+
- tenacity==8.2.2
|
56 |
+
- peft
|
57 |
+
- pycocoevalcap
|
58 |
+
- sentence-transformers
|
59 |
+
- umap-learn
|
60 |
+
- notebook
|
61 |
+
- gradio==3.24.1
|
62 |
+
- gradio-client==0.0.8
|
63 |
+
- wandb
|
train.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
|
16 |
+
import minigpt4.tasks as tasks
|
17 |
+
from minigpt4.common.config import Config
|
18 |
+
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
19 |
+
from minigpt4.common.logger import setup_logger
|
20 |
+
from minigpt4.common.optims import (
|
21 |
+
LinearWarmupCosineLRScheduler,
|
22 |
+
LinearWarmupStepLRScheduler,
|
23 |
+
)
|
24 |
+
from minigpt4.common.registry import registry
|
25 |
+
from minigpt4.common.utils import now
|
26 |
+
|
27 |
+
# imports modules for registration
|
28 |
+
from minigpt4.datasets.builders import *
|
29 |
+
from minigpt4.models import *
|
30 |
+
from minigpt4.processors import *
|
31 |
+
from minigpt4.runners import *
|
32 |
+
from minigpt4.tasks import *
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def parse_args():
|
37 |
+
parser = argparse.ArgumentParser(description="Training")
|
38 |
+
|
39 |
+
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
40 |
+
parser.add_argument(
|
41 |
+
"--options",
|
42 |
+
nargs="+",
|
43 |
+
help="override some settings in the used config, the key-value pair "
|
44 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
45 |
+
"change to --cfg-options instead.",
|
46 |
+
)
|
47 |
+
|
48 |
+
args = parser.parse_args()
|
49 |
+
# if 'LOCAL_RANK' not in os.environ:
|
50 |
+
# os.environ['LOCAL_RANK'] = str(args.local_rank)
|
51 |
+
|
52 |
+
return args
|
53 |
+
|
54 |
+
|
55 |
+
def setup_seeds(config):
|
56 |
+
seed = config.run_cfg.seed + get_rank()
|
57 |
+
|
58 |
+
random.seed(seed)
|
59 |
+
np.random.seed(seed)
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
|
62 |
+
cudnn.benchmark = False
|
63 |
+
cudnn.deterministic = True
|
64 |
+
|
65 |
+
|
66 |
+
def get_runner_class(cfg):
|
67 |
+
"""
|
68 |
+
Get runner class from config. Default to epoch-based runner.
|
69 |
+
"""
|
70 |
+
runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
|
71 |
+
|
72 |
+
return runner_cls
|
73 |
+
|
74 |
+
|
75 |
+
def main():
|
76 |
+
# allow auto-dl completes on main process without timeout when using NCCL backend.
|
77 |
+
# os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
78 |
+
|
79 |
+
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
|
80 |
+
job_id = now()
|
81 |
+
|
82 |
+
cfg = Config(parse_args())
|
83 |
+
|
84 |
+
init_distributed_mode(cfg.run_cfg)
|
85 |
+
|
86 |
+
setup_seeds(cfg)
|
87 |
+
|
88 |
+
# set after init_distributed_mode() to only log on master.
|
89 |
+
setup_logger()
|
90 |
+
|
91 |
+
cfg.pretty_print()
|
92 |
+
|
93 |
+
task = tasks.setup_task(cfg)
|
94 |
+
datasets = task.build_datasets(cfg)
|
95 |
+
model = task.build_model(cfg)
|
96 |
+
|
97 |
+
runner = get_runner_class(cfg)(
|
98 |
+
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
|
99 |
+
)
|
100 |
+
runner.train()
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
main()
|