BigData-KSU's picture
Update README.md
d6a879d verified
---
license: mit
---
## RS-LLaVA: Large Vision Language Model for Joint Captioning and Question Answering in Remote Sensing Imagery
- **Repository:** https://github.com/BigData-KSU/RS-LLaVA
- **Paper:** https://www.mdpi.com/2072-4292/16/9/1477
- **Demo:** Soon.
## How to Get Started with the Model
### Install
1. Clone this repository and navigate to RS-LLaVA folder
```
git clone https://github.com/BigData-KSU/RS-LLaVA.git
cd RS-LLaVA
```
2. Install Package
```
conda create -n rs-llava python=3.10 -y
conda activate rs-llava
pip install --upgrade pip # enable PEP 660 support
```
3. Install additional packages
```
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers==4.35
pip install einops
pip inastall SentencePiece
pip install accelerate
pip install peft
```
---
### Inference
Use the code below to get started with the model.
```python
import torch
import os
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import math
######## model here.................
model_path = 'BigData-KSU/RS-llava-v1.5-7b-LoRA'
model_base = 'Intel/neural-chat-7b-v3-3'
#### Further instrcutions here..........
conv_mode = 'llava_v1'
disable_torch_init()
model_name = get_model_name_from_path(model_path)
print('model name', model_name)
print('model base', model_base)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)
def chat_with_RS_LLaVA(cur_prompt,image_name):
# Prepare the input text, adding image-related tokens if needed
image_mem = Image.open(image_name)
image_tensor = image_processor.preprocess(image_mem, return_tensors='pt')['pixel_values'][0]
if model.config.mm_use_im_start_end:
cur_prompt = f"{DEFAULT_IM_START_TOKEN} {DEFAULT_IMAGE_TOKEN} {DEFAULT_IM_END_TOKEN}\n{cur_prompt}"
else:
cur_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{cur_prompt}"
# Create a copy of the conversation template
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], cur_prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Process image inputs if provided
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) .cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
do_sample=True,
temperature=0.2,
top_p=None,
num_beams=1,
no_repeat_ngram_size=3,
max_new_tokens=2048,
use_cache=True)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
return outputs
if __name__ == "__main__":
print('Model input...............')
cur_prompt='Generate three questions and answers about the content of this image. Then, compile a summary.'
image_name='assets/example_images/parking_lot_010.jpg'
outputs=chat_with_RS_LLaVA(cur_prompt,image_name)
print('Model Response.....')
print(outputs)
```
## Training Details
Training RS-LLaVa is carried out in three stages:
#### Stage 1: Pretraining (Feature alignment) stage:
Using LAION/CC/SBU BLIP-Caption Concept-balanced 558K dataset, and two RS datasets, [NWPU](https://github.com/HaiyanHuang98/NWPU-Captions) and [RSICD](https://huggingface.co/datasets/arampacha/rsicd).
| Dataset | Size | Link |
| --- | --- |--- |
|CC-3M Concept-balanced 595K|211 MB|[Link](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)|
|NWPU-RSICD-Pretrain|16.6 MB|[Link](https://huggingface.co/datasets/BigData-KSU/RS-instructions-dataset/blob/main/NWPU-RSICD-pretrain.json)|
#### Stage 2: Visual Instruction Tuning:
To teach the model to follow instructions, we used the proposed RS-Instructions Dataset plus LLaVA-Instruct-150K dataset.
| Dataset | Size | Link |
| --- | --- |--- |
|RS-Instructions|91.3 MB|[Link](https://huggingface.co/datasets/BigData-KSU/RS-instructions-dataset/blob/main/NWPU-RSICD-UAV-UCM-LR-DOTA-intrcutions.json)|
|llava_v1_5_mix665k|1.03 GB|[Link](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json)|
#### Stage 3: Downstram Task Tuning:
In this stage, the model is fine-tuned on one of the downstream tasks (e.g., RS image captioning or VQA)
## Citation
**BibTeX:**
```bibtex
@Article{rs16091477,
AUTHOR = {Bazi, Yakoub and Bashmal, Laila and Al Rahhal, Mohamad Mahmoud and Ricci, Riccardo and Melgani, Farid},
TITLE = {RS-LLaVA: A Large Vision-Language Model for Joint Captioning and Question Answering in Remote Sensing Imagery},
JOURNAL = {Remote Sensing},
VOLUME = {16},
YEAR = {2024},
NUMBER = {9},
ARTICLE-NUMBER = {1477},
URL = {https://www.mdpi.com/2072-4292/16/9/1477},
ISSN = {2072-4292},
DOI = {10.3390/rs16091477}
}
```