File size: 5,030 Bytes
064e744 7c46f5f 064e744 fedff3e 064e744 fd01149 a359739 543bb42 fd01149 b83932c e5aaeb2 3420145 b83932c fd01149 6846de5 fd01149 507d5f6 6846de5 fd01149 507d5f6 6846de5 fd01149 6846de5 507d5f6 fd01149 20abc8b b08ccca fd01149 507d5f6 fd01149 543bb42 7c46f5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
---
license: apache-2.0
datasets:
- liuhaotian/LLaVA-Pretrain
- liuhaotian/LLaVA-Instruct-150K
language:
- en
- zh
library_name: transformers
---
# WORK IN PROGRESS
We present TinyLLaVA, a small vision-language chatbot (1.4B) that reaches comparable performances with contemporary vision language models on common benchmarks, using less parameters.
TinyLLaVA was trained by finetuning [TinyLlama](https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3) on the [LLaVA-1.5](https://github.com/haotian-liu/LLaVA) dataset, following the training recipe of [LLaVA-1.5](https://github.com/haotian-liu/LLaVA). For more details, please refer to the [LLaVA-1.5 paper](https://arxiv.org/abs/2310.03744).
## Model Performance
We have evaluated TinyLLaVA on [GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html), [VizWiz](https://www.vizwiz.com/), [VQAv2](https://visualqa.org/), [TextVQA](https://textvqa.org/) and [SQA](https://github.com/lupantech/ScienceQA).
| Model | VQAv2 | GQA | SQA | TextVQA | VizWiz |
| -------------------- | :------------: | :------------: | :------------: | :------------: | :------------: |
| TinyLLaVA-v1-tinyllama | 73.41 | 57.54 | 59.40 | 46.37 | |
| TinyLLaVA-v1-stablelm | 74.9 | 58.86 | 62.82 | 49.52 | 35.6 |
| TinyLLaVA-v1.1-tinyllama| 75.24 | 59.43 | 58.80 | 48.05 | 34.74 |
| TinyLLaVA-v1.1-stablelm| 76.34 | 60.26 | 63.06 | 51.6 | 36.34 |
| BLIP-2 | 41.00 | 41.00 | 61.00 | 42.50 | 19.60 |
| LLaVA-v1.5-7B | 78.50 | 62.00 | 66.80 | 61.3 | 50 |
| LLaVA-v1.5-13B | 80.00 | 63.30 | 71.60 | 61.3 | 53.6 |
| Qwen-VL-7B | 78.80 | 59.30 | 67.10 | 63.8 | 35.2 |
| Qwen-VL-13B | 78.20 | 57.50 | 68.20 | 61.5 | 38.9 |
More evaluations are ongoing.
## Model Preparations
#### - Transformers Version
Make sure to have `transformers >= 4.35.3`.
#### - Prompt Template
The model supports multi-image and multi-prompt generation. When using the model, make sure to follow the correct prompt template (`USER: <image>xxx\nASSISTANT:`), where `<image>` token is a place-holding special token for image embeddings.
## Model Inference from `pipeline` and `transformers`
#### - Using `pipeline`:
Below we used [`"bczhou/tiny-llava-v1-hf"`](https://huggingface.co/bczhou/tiny-llava-v1-hf) checkpoint.
```python
from transformers import pipeline
from PIL import Image
import requests
model_id = "bczhou/tiny-llava-v1-hf"
pipe = pipeline("image-to-text", model=model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image>\nWhat does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud\nASSISTANT:"
outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
print(outputs[0])
>>> {"generated_text': 'USER: \nWhat does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud\nASSISTANT: The label 15 represents lava, which is a type of volcanic rock."}
```
#### - Using pure `transformers`:
Below is an example script to run generation in `float16` precision on a GPU device:
```python
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "bczhou/tiny-llava-v1-hf"
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))
```
## Contact
This model was trained by [Baichuan Zhou](https://baichuanzhou.github.io/), from Beihang Univerisity, under the supervision of [Prof. Lei Huang](https://huangleibuaa.github.io/).
## ✏ Citation
If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:.
```BibTeX
@misc{zhou2024tinyllava,
title={TinyLLaVA: A Framework of Small-scale Large Multimodal Models},
author={Baichuan Zhou and Ying Hu and Xi Weng and Junlong Jia and Jie Luo and Xien Liu and Ji Wu and Lei Huang},
year={2024},
eprint={2402.14289},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
``` |