File size: 5,719 Bytes
3caa1d9 d6a879d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
---
license: mit
---
## RS-LLaVA: Large Vision Language Model for Joint Captioning and Question Answering in Remote Sensing Imagery
- **Repository:** https://github.com/BigData-KSU/RS-LLaVA
- **Paper:** https://www.mdpi.com/2072-4292/16/9/1477
- **Demo:** Soon.
## How to Get Started with the Model
### Install
1. Clone this repository and navigate to RS-LLaVA folder
```
git clone https://github.com/BigData-KSU/RS-LLaVA.git
cd RS-LLaVA
```
2. Install Package
```
conda create -n rs-llava python=3.10 -y
conda activate rs-llava
pip install --upgrade pip # enable PEP 660 support
```
3. Install additional packages
```
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers==4.35
pip install einops
pip inastall SentencePiece
pip install accelerate
pip install peft
```
---
### Inference
Use the code below to get started with the model.
```python
import torch
import os
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import math
######## model here.................
model_path = 'BigData-KSU/RS-llava-v1.5-7b-LoRA'
model_base = 'Intel/neural-chat-7b-v3-3'
#### Further instrcutions here..........
conv_mode = 'llava_v1'
disable_torch_init()
model_name = get_model_name_from_path(model_path)
print('model name', model_name)
print('model base', model_base)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)
def chat_with_RS_LLaVA(cur_prompt,image_name):
# Prepare the input text, adding image-related tokens if needed
image_mem = Image.open(image_name)
image_tensor = image_processor.preprocess(image_mem, return_tensors='pt')['pixel_values'][0]
if model.config.mm_use_im_start_end:
cur_prompt = f"{DEFAULT_IM_START_TOKEN} {DEFAULT_IMAGE_TOKEN} {DEFAULT_IM_END_TOKEN}\n{cur_prompt}"
else:
cur_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{cur_prompt}"
# Create a copy of the conversation template
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], cur_prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Process image inputs if provided
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) .cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
do_sample=True,
temperature=0.2,
top_p=None,
num_beams=1,
no_repeat_ngram_size=3,
max_new_tokens=2048,
use_cache=True)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
return outputs
if __name__ == "__main__":
print('Model input...............')
cur_prompt='Generate three questions and answers about the content of this image. Then, compile a summary.'
image_name='assets/example_images/parking_lot_010.jpg'
outputs=chat_with_RS_LLaVA(cur_prompt,image_name)
print('Model Response.....')
print(outputs)
```
## Training Details
Training RS-LLaVa is carried out in three stages:
#### Stage 1: Pretraining (Feature alignment) stage:
Using LAION/CC/SBU BLIP-Caption Concept-balanced 558K dataset, and two RS datasets, [NWPU](https://github.com/HaiyanHuang98/NWPU-Captions) and [RSICD](https://huggingface.co/datasets/arampacha/rsicd).
| Dataset | Size | Link |
| --- | --- |--- |
|CC-3M Concept-balanced 595K|211 MB|[Link](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)|
|NWPU-RSICD-Pretrain|16.6 MB|[Link](https://huggingface.co/datasets/BigData-KSU/RS-instructions-dataset/blob/main/NWPU-RSICD-pretrain.json)|
#### Stage 2: Visual Instruction Tuning:
To teach the model to follow instructions, we used the proposed RS-Instructions Dataset plus LLaVA-Instruct-150K dataset.
| Dataset | Size | Link |
| --- | --- |--- |
|RS-Instructions|91.3 MB|[Link](https://huggingface.co/datasets/BigData-KSU/RS-instructions-dataset/blob/main/NWPU-RSICD-UAV-UCM-LR-DOTA-intrcutions.json)|
|llava_v1_5_mix665k|1.03 GB|[Link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json)|
#### Stage 3: Downstram Task Tuning:
In this stage, the model is fine-tuned on one of the downstream tasks (e.g., RS image captioning or VQA)
## Citation
**BibTeX:**
```bibtex
@Article{rs16091477,
AUTHOR = {Bazi, Yakoub and Bashmal, Laila and Al Rahhal, Mohamad Mahmoud and Ricci, Riccardo and Melgani, Farid},
TITLE = {RS-LLaVA: A Large Vision-Language Model for Joint Captioning and Question Answering in Remote Sensing Imagery},
JOURNAL = {Remote Sensing},
VOLUME = {16},
YEAR = {2024},
NUMBER = {9},
ARTICLE-NUMBER = {1477},
URL = {https://www.mdpi.com/2072-4292/16/9/1477},
ISSN = {2072-4292},
DOI = {10.3390/rs16091477}
}
```
|