diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..73c8609270fc105dd7746e6515a311d36130745a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/apple.jpeg filter=lfs diff=lfs merge=lfs -text +assets/demo_vl.gif filter=lfs diff=lfs merge=lfs -text +assets/mm_tutorial/Beijing.jpeg filter=lfs diff=lfs merge=lfs -text +assets/mm_tutorial/Chongqing.jpeg filter=lfs diff=lfs merge=lfs -text +assets/mm_tutorial/Rebecca_(1939_poster).jpeg filter=lfs diff=lfs merge=lfs -text +assets/mm_tutorial/Shanghai_Output.jpg filter=lfs diff=lfs merge=lfs -text +assets/touchstone_datasets.jpg filter=lfs diff=lfs merge=lfs -text +assets/touchstone_logo.png filter=lfs diff=lfs merge=lfs -text diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99ce8124f50e14afc55fe8ad119e5c5792e5c004 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,88 @@ +name: 🐞 Bug +description: 提交错误报告 | File a bug/issue +title: "[BUG]
+
+
+
+
+
+ Qwen-VL 🤖 | 🤗  | Qwen-VL-Chat 🤖 | 🤗  |  Demo  |  Report   |   Discord + +
++ 中文  |   English +
+
+
+
+
+
+We release two models of the Qwen-VL series:
+- Qwen-VL: The pre-trained LVLM model uses Qwen-7B as the initialization of the LLM, and [Openclip ViT-bigG](https://github.com/mlfoundations/open_clip) as the initialization of the visual encoder. And connects them with a randomly initialized cross-attention layer.
+- Qwen-VL-Chat: A multimodal LLM-based AI assistant, which is trained with alignment techniques. Qwen-VL-Chat supports more flexible interaction, such as multiple image inputs, multi-round question answering, and creative capabilities.
+
+
+## Evaluation
+
+We evaluated the model's abilities from two perspectives:
+1. **Standard Benchmarks**: We evaluate the model's basic task capabilities on four major categories of multimodal tasks:
+ - Zero-shot Captioning: Evaluate model's zero-shot image captioning ability on unseen datasets;
+ - General VQA: Evaluate the general question-answering ability of pictures, such as the judgment, color, number, category, etc;
+ - Text-based VQA: Evaluate the model's ability to recognize text in pictures, such as document QA, chart QA, etc;
+ - Referring Expression Comprehension: Evaluate the ability to localize a target object in an image described by a referring expression.
+
+2. **TouchStone**: To evaluate the overall text-image dialogue capability and alignment level with humans, we have constructed a benchmark called TouchStone, which is based on scoring with GPT4 to evaluate the LVLM model.
+ - The TouchStone benchmark covers a total of 300+ images, 800+ questions, and 27 categories. Such as attribute-based Q&A, celebrity recognition, writing poetry, summarizing multiple images, product comparison, math problem solving, etc;
+ - In order to break the current limitation of GPT4 in terms of direct image input, TouchStone provides fine-grained image annotations by human labeling. These detailed annotations, along with the questions and the model's output, are then presented to GPT4 for scoring.
+ - The benchmark includes both English and Chinese versions.
+
+The results of the evaluation are as follows:
+
+Qwen-VL outperforms current SOTA generalist models on multiple VL tasks and has a more comprehensive coverage in terms of capability range.
+
+
+
+
+ +### Zero-shot Captioning & General VQA +
Model type | +Model | +Zero-shot Captioning | +General VQA | +|||||
---|---|---|---|---|---|---|---|---|
NoCaps | +Flickr30K | +VQAv2dev | +OK-VQA | +GQA | +SciQA-Img (0-shot) |
+ VizWiz (0-shot) |
+ ||
Generalist Models |
+ Flamingo-9B | +- | +61.5 | +51.8 | +44.7 | +- | +- | +28.8 | +
Flamingo-80B | +- | +67.2 | +56.3 | +50.6 | +- | +- | +31.6 | +|
Unified-IO-XL | +100.0 | +- | +77.9 | +54.0 | +- | +- | +- | +|
Kosmos-1 | +- | +67.1 | +51.0 | +- | +- | +- | +29.2 | +|
Kosmos-2 | +- | +66.7 | +45.6 | +- | +- | +- | +- | +|
BLIP-2 (Vicuna-13B) | +103.9 | +71.6 | +65.0 | +45.9 | +32.3 | +61.0 | +19.6 | +|
InstructBLIP (Vicuna-13B) | +121.9 | +82.8 | +- | +- | +49.5 | +63.1 | +33.4 | +|
Shikra (Vicuna-13B) | +- | +73.9 | +77.36 | +47.16 | +- | +- | +- | +|
Qwen-VL (Qwen-7B) | +121.4 | +85.8 | +78.8 | +58.6 | +59.3 | +67.1 | +35.2 | +|
Qwen-VL-Chat | +120.2 | +81.0 | +78.2 | +56.6 | +57.5 | +68.2 | +38.9 | +|
Previous SOTA (Per Task Fine-tuning) |
+ - | +127.0 (PALI-17B) |
+ 84.5 (InstructBLIP -FlanT5-XL) |
+ 86.1 (PALI-X -55B) |
+ 66.1 (PALI-X -55B) |
+ 72.1 (CFR) |
+ 92.53 (LLaVa+ GPT-4) |
+ 70.9 (PALI-X -55B) |
+
Model type | +Model | +TextVQA | +DocVQA | +ChartQA | +AI2D | +OCR-VQA | +
---|---|---|---|---|---|---|
Generalist Models | +BLIP-2 (Vicuna-13B) | +42.4 | +- | +- | +- | +- | +
InstructBLIP (Vicuna-13B) | +50.7 | +- | +- | +- | +- | +|
mPLUG-DocOwl (LLaMA-7B) | +52.6 | +62.2 | +57.4 | +- | +- | +|
Pic2Struct-Large (1.3B) | +- | +76.6 | +58.6 | +42.1 | +71.3 | +|
Qwen-VL (Qwen-7B) | +63.8 | +65.1 | +65.7 | +62.3 | +75.7 | +|
Specialist SOTAs (Specialist/Finetuned) |
+ PALI-X-55B (Single-task FT) (Without OCR Pipeline) |
+ 71.44 | +80.0 | +70.0 | +81.2 | +75.0 | +
Model type | +Model | +RefCOCO | +RefCOCO+ | +RefCOCOg | +GRIT | +|||||
---|---|---|---|---|---|---|---|---|---|---|
val | +test-A | +test-B | +val | +test-A | +test-B | +val-u | +test-u | +refexp | +||
Generalist Models | +GPV-2 | +- | +- | +- | +- | +- | +- | +- | +- | +51.50 | +
OFA-L* | +79.96 | +83.67 | +76.39 | +68.29 | +76.00 | +61.75 | +67.57 | +67.58 | +61.70 | +|
Unified-IO | +- | +- | +- | +- | +- | +- | +- | +- | +78.61 | +|
VisionLLM-H | ++ | 86.70 | +- | +- | +- | +- | +- | +- | +- | +|
Shikra-7B | +87.01 | +90.61 | +80.24 | +81.60 | +87.36 | +72.12 | +82.27 | +82.19 | +69.34 | +|
Shikra-13B | +87.83 | +91.11 | +81.81 | +82.89 | +87.79 | +74.41 | +82.64 | +83.16 | +69.03 | +|
Qwen-VL-7B | +89.36 | +92.26 | +85.34 | +83.12 | +88.25 | +77.21 | +85.58 | +85.48 | +78.22 | +|
Qwen-VL-7B-Chat | +88.55 | +92.27 | +84.51 | +82.82 | +88.59 | +76.79 | +85.96 | +86.32 | +- | +|
Specialist SOTAs (Specialist/Finetuned) |
+ G-DINO-L | +90.56 | +93.19 | +88.24 | +82.75 | +88.95 | +75.92 | +86.13 | +87.02 | +- | +
UNINEXT-H | +92.64 | +94.33 | +91.46 | +85.24 | +89.63 | +79.79 | +88.73 | +89.37 | +- | +|
ONE-PEACE | +92.58 | +94.18 | +89.26 | +88.77 | +92.21 | +83.23 | +89.22 | +89.27 | +- | +
+
+
+
+
+
+
+Running Qwen-VL
+
+Running Qwen-VL pretrained base model is also simple.
+
+```python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+import torch
+torch.manual_seed(1234)
+
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True)
+
+# use bf16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL", device_map="auto", trust_remote_code=True, bf16=True).eval()
+# use fp16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL", device_map="auto", trust_remote_code=True, fp16=True).eval()
+# use cpu only
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL", device_map="cpu", trust_remote_code=True).eval()
+# use cuda device
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL", device_map="cuda", trust_remote_code=True).eval()
+
+# Specify hyperparameters for generation
+model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True)
+
+query = tokenizer.from_list_format([
+ {'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'}, # Either a local path or an url
+ {'text': 'Generate the caption in English with grounding:'},
+])
+inputs = tokenizer(query, return_tensors='pt')
+inputs = inputs.to(model.device)
+pred = model.generate(**inputs)
+response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
+print(response)
+# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpegGenerate the caption in English with grounding: Woman
+
tags.
+image_path = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'
+response, history = model.chat(tokenizer, query=f'
{image_path}这是什么', history=None)
+print(response)
+# 图中是一名年轻女子在沙滩上和她的狗玩耍,狗的品种是拉布拉多。她们坐在沙滩上,狗的前腿抬起来,与人互动。
+
+# 2st dialogue turn
+response, history = model.chat(tokenizer, '输出击掌的检测框', history=history)
+print(response)
+# "击掌"
+
+
+
+## Demo
+
+### Web UI
+
+We provide code for users to build a web UI demo. Before you start, make sure you install the following packages:
+
+```
+pip install -r requirements_web_demo.txt
+```
+
+Then run the command below and click on the generated link:
+
+```
+python web_demo_mm.py
+```
+
+## FAQ
+
+If you meet problems, please refer to [FAQ](FAQ.md) and the issues first to search a solution before you launch a new issue.
+
+
+## License Agreement
+
+Researchers and developers are free to use the codes and model weights of both Qwen-VL and Qwen-VL-Chat. We also allow their commercial use. Check our license at [LICENSE](LICENSE) for more details.
+
+## Contact Us
+
+If you are interested to leave a message to either our research team or product team, feel free to send an email to qianwen_opensource@alibabacloud.com.
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/README_CN.md b/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..ddbf06ea7e2f15493a63f2d8f914b40bf3a9205b
--- /dev/null
+++ b/README_CN.md
@@ -0,0 +1,666 @@
+
+
+
+
+
+
+
+
+ Qwen-VL 🤖 | 🤗  | Qwen-VL-Chat 🤖 | 🤗  |  Demo  |  Report   |   Discord +
++ 中文  |  English +
+
+
+
+
+
+目前,我们提供了 Qwen-VL 系列的两个模型:
+- Qwen-VL: Qwen-VL 以 Qwen-7B 的预训练模型作为语言模型的初始化,并以 [Openclip ViT-bigG](https://github.com/mlfoundations/open_clip) 作为视觉编码器的初始化,中间加入单层随机初始化的 cross-attention,经过约1.5B的图文数据训练得到。最终图像输入分辨率为448。
+- Qwen-VL-Chat: 在 Qwen-VL 的基础上,我们使用对齐机制打造了基于大语言模型的视觉AI助手Qwen-VL-Chat,它支持更灵活的交互方式,包括多图、多轮问答、创作等能力。
+
+
+## 评测
+
+我们从两个角度评测了两个模型的能力:
+1. 在**英文标准 Benchmark** 上评测模型的基础任务能力。目前评测了四大类多模态任务:
+ - Zero-shot Captioning: 评测模型在未见过数据集上的零样本图片描述能力;
+ - General VQA: 评测模型的通用问答能力,例如判断题、颜色、个数、类目等问答能力;
+ - Text-based VQA:评测模型对于图片中文字相关的识别/问答能力,例如文档问答、图表问答、文字问答等;
+ - Referring Expression Compression:评测模型给定物体描述画检测框的能力;
+
+2. **试金石 (TouchStone)**:为了评测模型整体的图文对话能力和人类对齐水平。我们为此构建了一个基于 GPT4 打分来评测 LVLM 模型的 Benchmark:TouchStone。在 TouchStone-v0.1 中:
+ - 评测基准总计涵盖 300+张图片、800+道题目、27个类别。包括基础属性问答、人物地标问答、影视作品问答、视觉推理、反事实推理、诗歌创作、故事写作,商品比较、图片解题等**尽可能广泛的类别**。
+ - 为了弥补目前 GPT4 无法直接读取图片的缺陷,我们给所有的带评测图片提供了**人工标注的充分详细描述**,并且将图片的详细描述、问题和模型的输出结果一起交给 GPT4 打分。
+ - 评测同时包含英文版本和中文版本。
+
+评测结果如下:
+
+Qwen-VL在多个VL任务上相比目前SOTA的Generalist Models都有明显优势,并且在能力范围也覆盖更加全面。
+
+
+
+
+ +### 零样本图像描述生成(Zero-shot Image Caption) 及 通用视觉问答(General VQA) +
Model type | +Model | +Zero-shot Captioning | +General VQA | +|||||
---|---|---|---|---|---|---|---|---|
NoCaps | +Flickr30K | +VQAv2dev | +OK-VQA | +GQA | +SciQA-Img (0-shot) |
+ VizWiz (0-shot) |
+ ||
Generalist Models |
+ Flamingo-9B | +- | +61.5 | +51.8 | +44.7 | +- | +- | +28.8 | +
Flamingo-80B | +- | +67.2 | +56.3 | +50.6 | +- | +- | +31.6 | +|
Unified-IO-XL | +100.0 | +- | +77.9 | +54.0 | +- | +- | +- | +|
Kosmos-1 | +- | +67.1 | +51.0 | +- | +- | +- | +29.2 | +|
Kosmos-2 | +- | +66.7 | +45.6 | +- | +- | +- | +- | +|
BLIP-2 (Vicuna-13B) | +103.9 | +71.6 | +65.0 | +45.9 | +32.3 | +61.0 | +19.6 | +|
InstructBLIP (Vicuna-13B) | +121.9 | +82.8 | +- | +- | +49.5 | +63.1 | +33.4 | +|
Shikra (Vicuna-13B) | +- | +73.9 | +77.36 | +47.16 | +- | +- | +- | +|
Qwen-VL (Qwen-7B) | +121.4 | +85.8 | +78.8 | +58.6 | +59.3 | +67.1 | +35.2 | +|
Qwen-VL-Chat | +120.2 | +81.0 | +78.2 | +56.6 | +57.5 | +68.2 | +38.9 | +|
Previous SOTA (Per Task Fine-tuning) |
+ - | +127.0 (PALI-17B) |
+ 84.5 (InstructBLIP -FlanT5-XL) |
+ 86.1 (PALI-X -55B) |
+ 66.1 (PALI-X -55B) |
+ 72.1 (CFR) |
+ 92.53 (LLaVa+ GPT-4) |
+ 70.9 (PALI-X -55B) |
+
Model type | +Model | +TextVQA | +DocVQA | +ChartQA | +AI2D | +OCR-VQA | +
---|---|---|---|---|---|---|
Generalist Models | +BLIP-2 (Vicuna-13B) | +42.4 | +- | +- | +- | +- | +
InstructBLIP (Vicuna-13B) | +50.7 | +- | +- | +- | +- | +|
mPLUG-DocOwl (LLaMA-7B) | +52.6 | +62.2 | +57.4 | +- | +- | +|
Pic2Struct-Large (1.3B) | +- | +76.6 | +58.6 | +42.1 | +71.3 | +|
Qwen-VL (Qwen-7B) | +63.8 | +65.1 | +65.7 | +62.3 | +75.7 | +|
Specialist SOTAs (Specialist/Finetuned) |
+ PALI-X-55B (Single-task FT) (Without OCR Pipeline) |
+ 71.44 | +80.0 | +70.0 | +81.2 | +75.0 | +
Model type | +Model | +RefCOCO | +RefCOCO+ | +RefCOCOg | +GRIT | +|||||
---|---|---|---|---|---|---|---|---|---|---|
val | +test-A | +test-B | +val | +test-A | +test-B | +val-u | +test-u | +refexp | +||
Generalist Models | +GPV-2 | +- | +- | +- | +- | +- | +- | +- | +- | +51.50 | +
OFA-L* | +79.96 | +83.67 | +76.39 | +68.29 | +76.00 | +61.75 | +67.57 | +67.58 | +61.70 | +|
Unified-IO | +- | +- | +- | +- | +- | +- | +- | +- | +78.61 | +|
VisionLLM-H | ++ | 86.70 | +- | +- | +- | +- | +- | +- | +- | +|
Shikra-7B | +87.01 | +90.61 | +80.24 | +81.60 | +87.36 | +72.12 | +82.27 | +82.19 | +69.34 | +|
Shikra-13B | +87.83 | +91.11 | +81.81 | +82.89 | +87.79 | +74.41 | +82.64 | +83.16 | +69.03 | +|
Qwen-VL-7B | +89.36 | +92.26 | +85.34 | +83.12 | +88.25 | +77.21 | +85.58 | +85.48 | +78.22 | +|
Qwen-VL-7B-Chat | +88.55 | +92.27 | +84.51 | +82.82 | +88.59 | +76.79 | +85.96 | +86.32 | +- | +|
Specialist SOTAs (Specialist/Finetuned) |
+ G-DINO-L | +90.56 | +93.19 | +88.24 | +82.75 | +88.95 | +75.92 | +86.13 | +87.02 | +- | +
UNINEXT-H | +92.64 | +94.33 | +91.46 | +85.24 | +89.63 | +79.79 | +88.73 | +89.37 | +- | +|
ONE-PEACE | +92.58 | +94.18 | +89.26 | +88.77 | +92.21 | +83.23 | +89.22 | +89.27 | +- | +
+
+
+
+运行Qwen-VL同样非常简单。
+
+ https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpegGenerate the caption in English with grounding: Woman
+
+
+
+
+#### 🤖 ModelScope
+
+魔搭(ModelScope)是开源的模型即服务共享平台,为泛AI开发者提供灵活、易用、低成本的一站式模型服务产品。使用ModelScope同样非常简单,代码如下所示:
+
+```python
+from modelscope import (
+ snapshot_download, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
+)
+import torch
+model_id = 'qwen/Qwen-VL-Chat'
+revision = 'v1.0.0'
+
+model_dir = snapshot_download(model_id, revision=revision)
+torch.manual_seed(1234)
+
+tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
+if not hasattr(tokenizer, 'model_dir'):
+ tokenizer.model_dir = model_dir
+# use bf16
+# model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, bf16=True).eval()
+# use fp16
+model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True, fp16=True).eval()
+# use cpu
+# model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="cpu", trust_remote_code=True).eval()
+# use auto
+# model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True).eval()
+
+# Specify hyperparameters for generation
+model.generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=True)
+
+# 1st dialogue turn
+# Either a local path or an url between
+
+
+ 中文  |  English
+
+
+
+We comprehensively evaluate the model's ability from five dimensions. As shown in the figure above, an example of 27 subtasks is given. From perception to cognition to creativity, as the difficulty increases, the requirements for models are also getting higher and higher. Currently, LVLM capabilities are in their early stages. Our dataset contains 800+ questions and 27 categories.
+
+## Methods
+
+
+We apply a powerful LLM as a judge to enable automated evaluation. To effectively comprehend the contents of an image, we manually substitute the actual image input with fine-grained textual annotations. By inputting these annotations and corresponding questions to a powerful LLM like GPT4, we obtain reference answers.
+
+For the evaluation of the LVLMs, we provide actual images and questions as input and obtain their respective answers. Finally, we employ GPT4 to score the answers generated by the LVLMs based on the fine-grained annotations and questions. The scoring instructions require the model to assess the usefulness, relevance, and accuracy of the answers, considering the annotations as the content of the images. To ensure fairness in the evaluation, each model's answer is compared against a consistent reference answer from GPT4. The average score of the model in all questions is taken as the final score.
+
+To eliminate the influence of answer position, we perform a second scoring round by swapping the positions of the answers and then compute the average of the two scores obtained. This approach aims to mitigate any bias introduced by the placement of the answers.
+
+
+
+
+### Evaluation
+
+#### Evaluation in English-based Multimodal Dialogue
+
+| Model | Score |
+|---------------|-------|
+| PandaGPT | 488.5 |
+| MiniGPT4 | 531.7 |
+| InstructBLIP | 552.4 |
+| LLaMA-AdapterV2 | 590.1 |
+| mPLUG-Owl | 605.4 |
+| LLaVA | 602.7 |
+| Qwen-VL-Chat | 645.2 |
+
+#### Evaluation in Chinese-based Multimodal Dialogue
+
+| Model | Score |
+|---------------|-------|
+| VisualGLM | 247.1 |
+| Qwen-VL-Chat | 401.2 |
+
diff --git a/touchstone/README_CN.md b/touchstone/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..66ce2a09f5e87503ef59832df6b8c2091edd0a86
--- /dev/null
+++ b/touchstone/README_CN.md
@@ -0,0 +1,68 @@
+
+
+
+ 中文  |  English
+
+
+
+我们从五个维度综合评估了模型的能力。 如上图所示,给出了27个子任务的示例。 从感知到认知,再到创造力,随着难度的增加,对模型的要求也越来越高。 目前,LVLM的能力还处于早期阶段。 我们的数据集包含800+道题目、27个类别。
+
+## 测评方式
+
+我们应用SOTA的LLM进行自动化评估。 为了有效地理解图像的内容,我们人工用细粒度的文本注释替换实际的图像输入。 通过将这些注释和相应的问题输入到像GPT4这样强LLM中,我们可以获得参考答案。
+
+对于待测评的LVLM,我们提供实际图像和问题作为输入并获得各自的答案。 最后,我们使用GPT4根据细粒度注释和问题对LVLM生成的答案进行评分。 评分指令要求模型评估答案的有用性、相关性和准确性,并将人工注解视为图像的内容。 为了确保评估的公平性,每个模型的答案都会与 GPT4生成的参考答案进行比较。 模型在所有问题上的平均得分作为最终得分。
+
+为了消除答案位置的影响,我们通过交换答案的位置来进行第二轮评分,然后计算获得的两次分数的平均值。
+
+
+
+
+
+## 测评结果
+
+#### 英文版本测评
+
+| Model | Score |
+|---------------|-------|
+| PandaGPT | 488.5 |
+| MiniGPT4 | 531.7 |
+| InstructBLIP | 552.4 |
+| LLaMA-AdapterV2 | 590.1 |
+| mPLUG-Owl | 605.4 |
+| LLaVA | 602.7 |
+| Qwen-VL-Chat | 645.2 |
+
+#### 中文版本测评
+
+| Model | Score |
+|---------------|-------|
+| VisualGLM | 247.1 |
+| Qwen-VL-Chat | 401.2 |
+
diff --git a/web_demo_mm.py b/web_demo_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad17d954c2ddf22ceaaed3973a96c888962e6536
--- /dev/null
+++ b/web_demo_mm.py
@@ -0,0 +1,234 @@
+# Copyright (c) Alibaba Cloud.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""A simple web interactive chat demo based on gradio."""
+
+from argparse import ArgumentParser
+from pathlib import Path
+
+import copy
+import gradio as gr
+import os
+import re
+import secrets
+import tempfile
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+
+DEFAULT_CKPT_PATH = 'Qwen/Qwen-VL-Chat'
+BOX_TAG_PATTERN = r" """)
+ gr.Markdown(""" tags.
+image_path = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'
+response, history = model.chat(tokenizer, query=f'
{image_path}这是什么', history=None)
+print(response)
+# 图中是一名年轻女子在沙滩上和她的狗玩耍,狗的品种是拉布拉多。她们坐在沙滩上,狗的前腿抬起来,与人互动。
+
+# 2st dialogue turn
+response, history = model.chat(tokenizer, '输出击掌的检测框', history=history)
+print(response)
+# "击掌"
{}Describe the image in English:'
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint, device_map='cuda', trust_remote_code=True).eval()
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint,
+ trust_remote_code=True)
+
+ random.seed(args.seed)
+ dataset = CaptionDataset(
+ train=ds_collections[args.dataset]['train'],
+ test=ds_collections[args.dataset]['test'],
+ tokenizer=tokenizer,
+ prompt=prompt,
+ few_shot=args.few_shot,
+ )
+ coco_karpathy_test_loader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ sampler=InferenceSampler(len(dataset)),
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=False,
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
+ )
+
+ image_ids = []
+ captions = []
+ for _, (ids, input_ids,
+ attention_mask) in tqdm(enumerate(coco_karpathy_test_loader)):
+ pred = model.generate(
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=30,
+ min_new_tokens=8,
+ length_penalty=0,
+ num_return_sequences=1,
+ use_cache=True,
+ pad_token_id=tokenizer.eod_id,
+ eos_token_id=tokenizer.eod_id,
+ )
+ image_ids.extend(ids)
+ captions.extend([
+ tokenizer.decode(_[input_ids.size(1):].cpu(),
+ skip_special_tokens=True).strip() for _ in pred
+ ])
+
+ torch.distributed.barrier()
+
+ world_size = torch.distributed.get_world_size()
+ merged_ids = [None for _ in range(world_size)]
+ merged_captions = [None for _ in range(world_size)]
+ torch.distributed.all_gather_object(merged_ids, image_ids)
+ torch.distributed.all_gather_object(merged_captions, captions)
+
+ merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)]
+ merged_captions = [
+ _ for _ in itertools.chain.from_iterable(merged_captions)
+ ]
+
+ if torch.distributed.get_rank() == 0:
+ results = []
+ for image_id, caption in zip(merged_ids, merged_captions):
+ results.append({
+ 'image_id': int(image_id),
+ 'caption': caption,
+ })
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
+ results_file = f'{args.dataset}_{time_prefix}.json'
+ json.dump(results, open(results_file, 'w'))
+
+ coco = COCO(ds_collections[args.dataset]['test'])
+ coco_result = coco.loadRes(results_file)
+ coco_eval = COCOEvalCap(coco, coco_result)
+ coco_eval.evaluate()
+
+ print(coco_eval.eval.items())
+ torch.distributed.barrier()
diff --git a/eval_mm/evaluate_grounding.py b/eval_mm/evaluate_grounding.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a76223cc4e51a0381ab7cafdf3a7abe570eae99
--- /dev/null
+++ b/eval_mm/evaluate_grounding.py
@@ -0,0 +1,213 @@
+import argparse
+import itertools
+import json
+import os
+import re
+from functools import partial
+
+import torch
+from torchvision.ops.boxes import box_area
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ds_collections = {
+ 'refcoco_val': 'data/refcoco/refcoco_val.jsonl',
+ 'refcoco_testA': 'data/refcoco/refcoco_testA.jsonl',
+ 'refcoco_testB': 'data/refcoco/refcoco_testB.jsonl',
+ 'refcoco+_val': 'data/refcoco+/refcoco+_val.jsonl',
+ 'refcoco+_testA': 'data/refcoco+/refcoco+_testA.jsonl',
+ 'refcoco+_testB': 'data/refcoco+/refcoco+_testB.jsonl',
+ 'refcocog_val': 'data/refcocog/refcocog_val.jsonl',
+ 'refcocog_test': 'data/refcocog/refcocog_test.jsonl',
+}
+
+
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+def collate_fn(batches, tokenizer):
+
+ texts = [_['text'] for _ in batches]
+ bboxes = [_['bbox'] for _ in batches]
+ hws = [_['hw'] for _ in batches]
+
+ input_ids = tokenizer(texts, return_tensors='pt', padding='longest')
+
+ return input_ids.input_ids, input_ids.attention_mask, bboxes, hws
+
+
+class RefCOCODataset(torch.utils.data.Dataset):
+
+ def __init__(self, test, tokenizer, prompt):
+ self.datas = open(test).readlines()
+ self.tokenizer = tokenizer
+ self.prompt = prompt
+
+ def __len__(self):
+ return len(self.datas)
+
+ def __getitem__(self, idx):
+ data = json.loads(self.datas[idx].strip())
+ image = data['image']
+ text = data['sent']
+ bbox = data['bbox']
+
+ w, h = data['width'], data['height']
+
+ return {
+ 'text': self.prompt.format(image, text),
+ 'bbox': bbox,
+ 'hw': (h, w),
+ }
+
+
+class InferenceSampler(torch.utils.data.sampler.Sampler):
+
+ def __init__(self, size):
+ self._size = int(size)
+ assert size > 0
+ self._rank = torch.distributed.get_rank()
+ self._world_size = torch.distributed.get_world_size()
+ self._local_indices = self._get_local_indices(size, self._world_size,
+ self._rank)
+
+ @staticmethod
+ def _get_local_indices(total_size, world_size, rank):
+ shard_size = total_size // world_size
+ left = total_size % world_size
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
+
+ begin = sum(shard_sizes[:rank])
+ end = min(sum(shard_sizes[:rank + 1]), total_size)
+ return range(begin, end)
+
+ def __iter__(self):
+ yield from self._local_indices
+
+ def __len__(self):
+ return len(self._local_indices)
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--checkpoint', type=str, default='')
+ parser.add_argument('--dataset', type=str, default='')
+ parser.add_argument('--batch-size', type=int, default=1)
+ parser.add_argument('--num-workers', type=int, default=1)
+ args = parser.parse_args()
+
+ torch.distributed.init_process_group(
+ backend='nccl',
+ world_size=int(os.getenv('WORLD_SIZE', '1')),
+ rank=int(os.getenv('RANK', '0')),
+ )
+
+ torch.cuda.set_device(torch.distributed.get_rank())
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint, device_map='cuda', trust_remote_code=True).eval()
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint,
+ trust_remote_code=True)
+ tokenizer.padding_side = 'left'
+ tokenizer.pad_token_id = tokenizer.eod_id
+
+ prompt = '
{}{}
{}Context: {}\nQuestion: {}\nOptions: {}\nAnswer:'
+
+ dataset = MultipleChoiceDataste(test=ds_collections[args.dataset]['test'],
+ prompt=prompt,
+ tokenizer=tokenizer)
+ dataloader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ sampler=InferenceSampler(len(dataset)),
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=False,
+ collate_fn=partial(collate_fn, pad_token_id=tokenizer.eod_id),
+ )
+
+ results = []
+ with torch.no_grad():
+ for _, (input_tokens, attention_mask, target_lengths, answer,
+ chunk_sizes) in tqdm(enumerate(dataloader)):
+
+ outputs = model(
+ input_ids=input_tokens[:, :-1].cuda(),
+ attention_mask=attention_mask[:, :-1].cuda(),
+ return_dict=True,
+ )
+ losses = torch.nn.functional.cross_entropy(outputs.logits.permute(
+ 0, 2, 1),
+ input_tokens[:,
+ 1:].cuda(),
+ reduction='none')
+
+ losses = losses.split(chunk_sizes, dim=0)
+
+ for loss, target_length, answer in zip(losses, target_lengths,
+ answer):
+
+ target_loss = loss.mean(-1)
+ for _ in range(len(target_length)):
+ target_loss[_] = loss[_, -target_length[_]:].mean()
+ pred = target_loss.argmin().item()
+ if pred == answer:
+ results.append(1)
+ else:
+ results.append(0)
+
+ torch.distributed.barrier()
+
+ world_size = torch.distributed.get_world_size()
+ merged_results = [None for _ in range(world_size)]
+ torch.distributed.all_gather_object(merged_results, results)
+
+ merged_results = [_ for _ in itertools.chain.from_iterable(merged_results)]
+
+ if torch.distributed.get_rank() == 0:
+ print(f'Acc@1: {sum(merged_results) / len(merged_results)}')
+
+ torch.distributed.barrier()
diff --git a/eval_mm/evaluate_vizwiz_testdev.py b/eval_mm/evaluate_vizwiz_testdev.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f40422b12809493b886fa08844cc17e26005467
--- /dev/null
+++ b/eval_mm/evaluate_vizwiz_testdev.py
@@ -0,0 +1,167 @@
+import argparse
+import itertools
+import json
+import os
+import random
+import time
+from functools import partial
+
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def collate_fn(batches, tokenizer):
+
+ images = [_['image'] for _ in batches]
+ questions = [_['question'] for _ in batches]
+
+ input_ids = tokenizer(questions, return_tensors='pt', padding='longest')
+
+ return images, input_ids.input_ids, input_ids.attention_mask
+
+
+class VQADataset(torch.utils.data.Dataset):
+
+ def __init__(self, train, test, prompt, few_shot):
+ self.test = json.load(open(test))
+ self.prompt = prompt
+
+ self.few_shot = few_shot
+ if few_shot > 0:
+ self.train = open(train).readlines()
+
+ def __len__(self):
+ return len(self.test)
+
+ def __getitem__(self, idx):
+ data = self.test[idx]
+ image, question = data['image'], data['question']
+
+ few_shot_prompt = ''
+ if self.few_shot > 0:
+ few_shot_samples = random.sample(self.train, self.few_shot)
+ for sample in few_shot_samples:
+ sample = json.loads(sample.strip())
+ few_shot_prompt += self.prompt.format(
+ sample['image'],
+ sample['question']) + f" {sample['answer']}"
+
+ return {
+ 'image': data['image'],
+ 'question': few_shot_prompt + self.prompt.format(image, question),
+ }
+
+
+class InferenceSampler(torch.utils.data.sampler.Sampler):
+
+ def __init__(self, size):
+ self._size = int(size)
+ assert size > 0
+ self._rank = torch.distributed.get_rank()
+ self._world_size = torch.distributed.get_world_size()
+ self._local_indices = self._get_local_indices(size, self._world_size,
+ self._rank)
+
+ @staticmethod
+ def _get_local_indices(total_size, world_size, rank):
+ shard_size = total_size // world_size
+ left = total_size % world_size
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
+
+ begin = sum(shard_sizes[:rank])
+ end = min(sum(shard_sizes[:rank + 1]), total_size)
+ return range(begin, end)
+
+ def __iter__(self):
+ yield from self._local_indices
+
+ def __len__(self):
+ return len(self._local_indices)
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--checkpoint', type=str, default='')
+ parser.add_argument('--batch-size', type=int, default=1)
+ parser.add_argument('--num-workers', type=int, default=1)
+ parser.add_argument('--few-shot', type=int, default=0)
+ parser.add_argument('--seed', type=int, default=0)
+ args = parser.parse_args()
+
+ torch.distributed.init_process_group(
+ backend='nccl',
+ world_size=int(os.getenv('WORLD_SIZE', '1')),
+ rank=int(os.getenv('RANK', '0')),
+ )
+
+ torch.cuda.set_device(torch.distributed.get_rank())
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint, device_map='cuda', trust_remote_code=True).eval()
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint,
+ trust_remote_code=True)
+ tokenizer.padding_side = 'left'
+ tokenizer.pad_token_id = tokenizer.eod_id
+
+ prompt = '
data/vizwiz/test/{}{} Answer:'
+
+ random.seed(args.seed)
+ dataset = VQADataset(
+ train='data/vizwiz/vizwiz_train.jsonl',
+ test='data/vizwiz/test.json',
+ prompt=prompt,
+ few_shot=args.few_shot,
+ )
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ sampler=InferenceSampler(len(dataset)),
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=False,
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
+ )
+
+ outputs = []
+ for _, (images, input_ids, attention_mask) in tqdm(enumerate(dataloader)):
+ pred = model.generate(
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=10,
+ min_new_tokens=1,
+ length_penalty=1,
+ num_return_sequences=1,
+ output_hidden_states=True,
+ use_cache=True,
+ pad_token_id=tokenizer.eod_id,
+ eos_token_id=tokenizer.eod_id,
+ )
+ answers = [
+ tokenizer.decode(_[input_ids.size(1):].cpu(),
+ skip_special_tokens=True).strip() for _ in pred
+ ]
+
+ for image, answer in zip(images, answers):
+ outputs.append({'image': image, 'answer': answer})
+
+ torch.distributed.barrier()
+
+ world_size = torch.distributed.get_world_size()
+ merged_outputs = [None for _ in range(world_size)]
+ torch.distributed.all_gather_object(merged_outputs, outputs)
+
+ merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
+
+ if torch.distributed.get_rank() == 0:
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
+ results_file = f'vizwiz_testdev_{time_prefix}_fs{args.few_shot}_s{args.seed}.json'
+ json.dump(merged_outputs, open(results_file, 'w'),
+ ensure_ascii=False) # save to results
+
+ torch.distributed.barrier()
diff --git a/eval_mm/evaluate_vqa.py b/eval_mm/evaluate_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..2accb43745652dace9189cda1852e610eb171987
--- /dev/null
+++ b/eval_mm/evaluate_vqa.py
@@ -0,0 +1,357 @@
+import argparse
+import itertools
+import json
+import os
+import random
+import time
+from functools import partial
+from typing import Optional
+
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from vqa import VQA
+from vqa_eval import VQAEval
+
+ds_collections = {
+ 'vqav2_val': {
+ 'train': 'data/vqav2/vqav2_train.jsonl',
+ 'test': 'data/vqav2/vqav2_val.jsonl',
+ 'question': 'data/vqav2/v2_OpenEnded_mscoco_val2014_questions.json',
+ 'annotation': 'data/vqav2/v2_mscoco_val2014_annotations.json',
+ 'metric': 'vqa_score',
+ 'max_new_tokens': 10,
+ },
+ 'okvqa_val': {
+ 'train': 'data/okvqa/okvqa_train.jsonl',
+ 'test': 'data/okvqa/okvqa_val.jsonl',
+ 'question': 'data/okvqa/OpenEnded_mscoco_val2014_questions.json',
+ 'annotation': 'data/okvqa/mscoco_val2014_annotations.json',
+ 'metric': 'vqa_score',
+ 'max_new_tokens': 10,
+ },
+ 'textvqa_val': {
+ 'train': 'data/textvqa/textvqa_train.jsonl',
+ 'test': 'data/textvqa/textvqa_val.jsonl',
+ 'question': 'data/textvqa/textvqa_val_questions.json',
+ 'annotation': 'data/textvqa/textvqa_val_annotations.json',
+ 'metric': 'vqa_score',
+ 'max_new_tokens': 10,
+ },
+ 'vizwiz_val': {
+ 'train': 'data/vizwiz/vizwiz_train.jsonl',
+ 'test': 'data/vizwiz/vizwiz_val.jsonl',
+ 'question': 'data/vizwiz/vizwiz_val_questions.json',
+ 'annotation': 'data/vizwiz/vizwiz_val_annotations.json',
+ 'metric': 'vqa_score',
+ 'max_new_tokens': 10,
+ },
+ 'docvqa': {
+ 'train': 'data/DocVQA/train.jsonl',
+ 'test': 'data/DocVQA/val.jsonl',
+ # 'question': '',
+ 'annotation': './data/DocVQA/val/val_v1.0.json',
+ 'metric': 'anls',
+ 'max_new_tokens': 100,
+ },
+ 'infographicsvqa': {
+ 'train': 'data/InfographicsVQA/train.jsonl',
+ 'test': 'data/InfographicsVQA/val.jsonl',
+ # 'question': '',
+ 'annotation': './data/InfographicsVQA/infographicVQA_val_v1.0.json',
+ 'metric': 'anls',
+ 'max_new_tokens': 100,
+ },
+ 'chartqa': {
+ 'train': 'data/ChartQA/train.jsonl',
+ 'test': 'data/ChartQA/val_human.jsonl',
+ # 'question': '',
+ # 'annotation': '',
+ 'metric': 'relaxed_accuracy',
+ 'max_new_tokens': 100,
+ },
+ 'gqa': {
+ 'train': 'data/GQA/train.jsonl',
+ 'test': 'data/GQA/testdev_balanced.jsonl',
+ # 'question': '',
+ # 'annotation': '',
+ 'metric': 'accuracy',
+ 'max_new_tokens': 10,
+ },
+ 'ocrvqa': {
+ 'train': 'data/OCR-VQA/train.jsonl',
+ 'test': 'data/OCR-VQA/val.jsonl',
+ # 'question': '',
+ # 'annotation': '',
+ 'metric': 'accuracy',
+ 'max_new_tokens': 10,
+ },
+ 'ai2diagram': {
+ 'train': 'data/AI2Diagram/train.jsonl',
+ 'test': 'data/AI2Diagram/test.jsonl',
+ # 'question': '',
+ # 'annotation': '',
+ 'metric': 'accuracy',
+ 'max_new_tokens': 10,
+ }
+}
+
+# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
+def relaxed_correctness(target: str,
+ prediction: str,
+ max_relative_change: float = 0.05) -> bool:
+ """Calculates relaxed correctness.
+
+ The correctness tolerates certain error ratio defined by max_relative_change.
+ See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
+ “Following Methani et al. (2020), we use a relaxed accuracy measure for the
+ numeric answers to allow a minor inaccuracy that may result from the automatic
+ data extraction process. We consider an answer to be correct if it is within
+ 5% of the gold answer. For non-numeric answers, we still need an exact match
+ to consider an answer to be correct.”
+
+ Args:
+ target: Target string.
+ prediction: Predicted string.
+ max_relative_change: Maximum relative change.
+
+ Returns:
+ Whether the prediction was correct given the specified tolerance.
+ """
+
+ def _to_float(text: str) -> Optional[float]:
+ try:
+ if text.endswith("%"):
+ # Convert percentages to floats.
+ return float(text.rstrip("%")) / 100.0
+ else:
+ return float(text)
+ except ValueError:
+ return None
+
+ prediction_float = _to_float(prediction)
+ target_float = _to_float(target)
+ if prediction_float is not None and target_float:
+ relative_change = abs(
+ prediction_float - target_float) / abs(target_float)
+ return relative_change <= max_relative_change
+ else:
+ return prediction.lower() == target.lower()
+
+def evaluate_relaxed_accuracy(entries):
+ scores = []
+ for elem in entries:
+ score = max([relaxed_correctness(elem['answer'].strip(), ann) for ann in elem['annotation']])
+ scores.append(score)
+ return sum(scores) / len(scores)
+
+def evaluate_exact_match_accuracy(entries):
+ scores = []
+ for elem in entries:
+ score = max([(1.0 if (elem['answer'].strip().lower() == ann.strip().lower()) else 0.0) for ann in elem['annotation']])
+ scores.append(score)
+ return sum(scores) / len(scores)
+
+
+def collate_fn(batches, tokenizer):
+
+ questions = [_['question'] for _ in batches]
+ question_ids = [_['question_id'] for _ in batches]
+ annotations = [_['annotation'] for _ in batches]
+
+ input_ids = tokenizer(questions, return_tensors='pt', padding='longest')
+
+ return question_ids, input_ids.input_ids, input_ids.attention_mask, annotations
+
+
+class VQADataset(torch.utils.data.Dataset):
+
+ def __init__(self, train, test, prompt, few_shot):
+ self.test = open(test).readlines()
+ self.prompt = prompt
+
+ self.few_shot = few_shot
+ if few_shot > 0:
+ self.train = open(train).readlines()
+
+ def __len__(self):
+ return len(self.test)
+
+ def __getitem__(self, idx):
+ data = json.loads(self.test[idx].strip())
+ image, question, question_id, annotation = data['image'], data['question'], data[
+ 'question_id'], data['answer']
+
+ few_shot_prompt = ''
+ if self.few_shot > 0:
+ few_shot_samples = random.sample(self.train, self.few_shot)
+ for sample in few_shot_samples:
+ sample = json.loads(sample.strip())
+ few_shot_prompt += self.prompt.format(
+ sample['image'],
+ sample['question']) + f" {sample['answer']}"
+
+ return {
+ 'question': few_shot_prompt + self.prompt.format(image, question),
+ 'question_id': question_id,
+ 'annotation': annotation
+ }
+
+
+class InferenceSampler(torch.utils.data.sampler.Sampler):
+
+ def __init__(self, size):
+ self._size = int(size)
+ assert size > 0
+ self._rank = torch.distributed.get_rank()
+ self._world_size = torch.distributed.get_world_size()
+ self._local_indices = self._get_local_indices(size, self._world_size,
+ self._rank)
+
+ @staticmethod
+ def _get_local_indices(total_size, world_size, rank):
+ shard_size = total_size // world_size
+ left = total_size % world_size
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
+
+ begin = sum(shard_sizes[:rank])
+ end = min(sum(shard_sizes[:rank + 1]), total_size)
+ return range(begin, end)
+
+ def __iter__(self):
+ yield from self._local_indices
+
+ def __len__(self):
+ return len(self._local_indices)
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--checkpoint', type=str, default='')
+ parser.add_argument('--dataset', type=str, default='')
+ parser.add_argument('--batch-size', type=int, default=1)
+ parser.add_argument('--num-workers', type=int, default=1)
+ parser.add_argument('--few-shot', type=int, default=0)
+ parser.add_argument('--seed', type=int, default=0)
+ args = parser.parse_args()
+
+ torch.distributed.init_process_group(
+ backend='nccl',
+ world_size=int(os.getenv('WORLD_SIZE', '1')),
+ rank=int(os.getenv('RANK', '0')),
+ )
+
+ torch.cuda.set_device(torch.distributed.get_rank())
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.checkpoint, device_map='cuda', trust_remote_code=True).eval()
+
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint,
+ trust_remote_code=True)
+ tokenizer.padding_side = 'left'
+ tokenizer.pad_token_id = tokenizer.eod_id
+
+ prompt = '
{}{} Answer:'
+
+ random.seed(args.seed)
+ dataset = VQADataset(
+ train=ds_collections[args.dataset]['train'],
+ test=ds_collections[args.dataset]['test'],
+ prompt=prompt,
+ few_shot=args.few_shot,
+ )
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ sampler=InferenceSampler(len(dataset)),
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=False,
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
+ )
+
+ outputs = []
+ for _, (question_ids, input_ids,
+ attention_mask, annotations) in tqdm(enumerate(dataloader)):
+ pred = model.generate(
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=ds_collections[args.dataset]['max_new_tokens'],
+ min_new_tokens=1,
+ length_penalty=1,
+ num_return_sequences=1,
+ output_hidden_states=True,
+ use_cache=True,
+ pad_token_id=tokenizer.eod_id,
+ eos_token_id=tokenizer.eod_id,
+ )
+ answers = [
+ tokenizer.decode(_[input_ids.size(1):].cpu(),
+ skip_special_tokens=True).strip() for _ in pred
+ ]
+
+ for question_id, answer, annotation in zip(question_ids, answers, annotations):
+ try:
+ outputs.append({'question_id': int(question_id), 'answer': answer, 'annotation': annotation})
+ except:
+ outputs.append({'question_id': question_id, 'answer': answer, 'annotation': annotation})
+
+ torch.distributed.barrier()
+
+ world_size = torch.distributed.get_world_size()
+ merged_outputs = [None for _ in range(world_size)]
+ torch.distributed.all_gather_object(merged_outputs, outputs)
+
+ merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
+
+ if torch.distributed.get_rank() == 0:
+ time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
+ results_file = f'{args.dataset}_{time_prefix}_fs{args.few_shot}_s{args.seed}.json'
+ json.dump(merged_outputs, open(results_file, 'w'),
+ ensure_ascii=False) # save to results
+
+ if ds_collections[args.dataset]['metric'] == 'vqa_score':
+ vqa = VQA(ds_collections[args.dataset]['annotation'],
+ ds_collections[args.dataset]['question'])
+ results = vqa.loadRes(
+ resFile=results_file,
+ quesFile=ds_collections[args.dataset]['question'])
+ vqa_scorer = VQAEval(vqa, results, n=2)
+ vqa_scorer.evaluate()
+
+ print(vqa_scorer.accuracy)
+
+ elif ds_collections[args.dataset]['metric'] == 'anls':
+ merged_outputs = [{'answer': _['answer'], 'questionId': _['question_id']} for _ in merged_outputs]
+ results_file = f'{args.dataset}_official_{time_prefix}.json'
+ json.dump(merged_outputs, open(results_file, 'w'), ensure_ascii=False)
+ print('python infographicsvqa_eval.py -g ' + ds_collections[args.dataset]['annotation'] + ' -s ' + results_file)
+ os.system('python infographicsvqa_eval.py -g ' + ds_collections[args.dataset]['annotation'] + ' -s ' + results_file)
+ elif ds_collections[args.dataset]['metric'] == 'relaxed_accuracy':
+ print({'relaxed_accuracy': evaluate_relaxed_accuracy(merged_outputs)})
+ elif ds_collections[args.dataset]['metric'] == 'accuracy':
+ if 'gqa' in args.dataset:
+ for entry in merged_outputs:
+ response = entry['answer']
+ response = response.strip().split('.')[0].split(',')[0].split('!')[0].lower()
+ if 'is ' in response:
+ response = response.split('is ')[1]
+ if 'are ' in response:
+ response = response.split('are ')[1]
+ if 'a ' in response:
+ response = response.split('a ')[1]
+ if 'an ' in response:
+ response = response.split('an ')[1]
+ if 'the ' in response:
+ response = response.split('the ')[1]
+ if ' of' in response:
+ response = response.split(' of')[0]
+ response = response.strip()
+ entry['answer'] = response
+ print({'accuracy': evaluate_exact_match_accuracy(merged_outputs)})
+
+ torch.distributed.barrier()
diff --git a/eval_mm/vqa.py b/eval_mm/vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1ee18f0532a4f8ed1f4ee4a33c162f7c4375398
--- /dev/null
+++ b/eval_mm/vqa.py
@@ -0,0 +1,206 @@
+"""Copyright (c) 2022, salesforce.com, inc.
+
+All rights reserved.
+SPDX-License-Identifier: BSD-3-Clause
+For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+__author__ = 'aagrawal'
+__version__ = '0.9'
+
+# Interface for accessing the VQA dataset.
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
+
+# The following functions are defined:
+# VQA - VQA class that loads VQA annotation file and prepares data structures.
+# getQuesIds - Get question ids that satisfy given filter conditions.
+# getImgIds - Get image ids that satisfy given filter conditions.
+# loadQA - Load questions and answers with the specified question ids.
+# showQA - Display the specified questions and answers.
+# loadRes - Load result file and create result object.
+
+# Help on each function can be accessed by: "help(COCO.function)"
+
+import copy
+import datetime
+import json
+
+
+class VQA:
+
+ def __init__(self, annotation_file=None, question_file=None):
+ """Constructor of VQA helper class for reading and visualizing
+ questions and answers.
+
+ :param annotation_file (str): location of VQA annotation file
+ :return:
+ """
+ # load dataset
+ self.dataset = {}
+ self.questions = {}
+ self.qa = {}
+ self.qqa = {}
+ self.imgToQA = {}
+ if not annotation_file == None and not question_file == None:
+ print('loading VQA annotations and questions into memory...')
+ time_t = datetime.datetime.utcnow()
+ dataset = json.load(open(annotation_file, 'r'))
+ questions = json.load(open(question_file, 'r'))
+ self.dataset = dataset
+ self.questions = questions
+ self.createIndex()
+
+ def createIndex(self):
+ # create index
+ print('creating index...')
+ imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
+ qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
+ qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
+ for ann in self.dataset['annotations']:
+ imgToQA[ann['image_id']] += [ann]
+ qa[ann['question_id']] = ann
+ for ques in self.questions['questions']:
+ qqa[ques['question_id']] = ques
+ print('index created!')
+
+ # create class members
+ self.qa = qa
+ self.qqa = qqa
+ self.imgToQA = imgToQA
+
+ def info(self):
+ """Print information about the VQA annotation file.
+
+ :return:
+ """
+ for key, value in self.datset['info'].items():
+ print('%s: %s' % (key, value))
+
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
+ """Get question ids that satisfy given filter conditions. default skips
+ that filter.
+
+ :param imgIds (int array) : get question ids for given imgs
+ quesTypes (str array) : get question ids for given question types
+ ansTypes (str array) : get question ids for given answer types
+ :return: ids (int array) : integer array of question ids
+ """
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(imgIds) == 0:
+ anns = sum(
+ [
+ self.imgToQA[imgId]
+ for imgId in imgIds if imgId in self.imgToQA
+ ],
+ [],
+ )
+ else:
+ anns = self.dataset['annotations']
+ anns = (anns if len(quesTypes) == 0 else
+ [ann for ann in anns if ann['question_type'] in quesTypes])
+ anns = (anns if len(ansTypes) == 0 else
+ [ann for ann in anns if ann['answer_type'] in ansTypes])
+ ids = [ann['question_id'] for ann in anns]
+ return ids
+
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
+ """Get image ids that satisfy given filter conditions. default skips
+ that filter.
+
+ :param quesIds (int array) : get image ids for given question ids
+ quesTypes (str array) : get image ids for given question types
+ ansTypes (str array) : get image ids for given answer types
+ :return: ids (int array) : integer array of image ids
+ """
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(quesIds) == 0:
+ anns = sum([
+ self.qa[quesId] for quesId in quesIds if quesId in self.qa
+ ], [])
+ else:
+ anns = self.dataset['annotations']
+ anns = (anns if len(quesTypes) == 0 else
+ [ann for ann in anns if ann['question_type'] in quesTypes])
+ anns = (anns if len(ansTypes) == 0 else
+ [ann for ann in anns if ann['answer_type'] in ansTypes])
+ ids = [ann['image_id'] for ann in anns]
+ return ids
+
+ def loadQA(self, ids=[]):
+ """Load questions and answers with the specified question ids.
+
+ :param ids (int array) : integer ids specifying question ids
+ :return: qa (object array) : loaded qa objects
+ """
+ if type(ids) == list:
+ return [self.qa[id] for id in ids]
+ elif type(ids) == int:
+ return [self.qa[ids]]
+
+ def showQA(self, anns):
+ """Display the specified annotations.
+
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ for ann in anns:
+ quesId = ann['question_id']
+ print('Question: %s' % (self.qqa[quesId]['question']))
+ for ans in ann['answers']:
+ print('Answer %d: %s' % (ans['answer_id'], ans['answer']))
+
+ def loadRes(self, resFile, quesFile):
+ """Load result file and return a result object.
+
+ :param resFile (str) : file name of result file
+ :return: res (obj) : result api object
+ """
+ res = VQA()
+ res.questions = json.load(open(quesFile))
+ res.dataset['info'] = copy.deepcopy(self.questions['info'])
+ res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
+ res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
+ res.dataset['data_subtype'] = copy.deepcopy(
+ self.questions['data_subtype'])
+ res.dataset['license'] = copy.deepcopy(self.questions['license'])
+
+ print('Loading and preparing results... ')
+ time_t = datetime.datetime.utcnow()
+ anns = json.load(open(resFile))
+ assert type(anns) == list, 'results is not an array of objects'
+ annsQuesIds = [ann['question_id'] for ann in anns]
+ assert set(annsQuesIds) == set(
+ self.getQuesIds()
+ ), 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
+ for ann in anns:
+ quesId = ann['question_id']
+ if res.dataset['task_type'] == 'Multiple Choice':
+ assert (
+ ann['answer'] in self.qqa[quesId]['multiple_choices']
+ ), 'predicted answer is not one of the multiple choices'
+ qaAnn = self.qa[quesId]
+ ann['image_id'] = qaAnn['image_id']
+ ann['question_type'] = qaAnn['question_type']
+ ann['answer_type'] = qaAnn['answer_type']
+ print('DONE (t=%0.2fs)' %
+ ((datetime.datetime.utcnow() - time_t).total_seconds()))
+
+ res.dataset['annotations'] = anns
+ res.createIndex()
+ return res
diff --git a/eval_mm/vqa_eval.py b/eval_mm/vqa_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1329ae13cd7f3857a839c95462118738e61b0d6d
--- /dev/null
+++ b/eval_mm/vqa_eval.py
@@ -0,0 +1,330 @@
+"""Copyright (c) 2022, salesforce.com, inc.
+
+All rights reserved.
+SPDX-License-Identifier: BSD-3-Clause
+For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+# coding=utf-8
+
+__author__ = 'aagrawal'
+
+import re
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
+import sys
+
+
+class VQAEval:
+
+ def __init__(self, vqa=None, vqaRes=None, n=2):
+ self.n = n
+ self.accuracy = {}
+ self.evalQA = {}
+ self.evalQuesType = {}
+ self.evalAnsType = {}
+ self.vqa = vqa
+ self.vqaRes = vqaRes
+ if vqa is not None:
+ self.params = {'question_id': vqa.getQuesIds()}
+ self.contractions = {
+ 'aint': "ain't",
+ 'arent': "aren't",
+ 'cant': "can't",
+ 'couldve': "could've",
+ 'couldnt': "couldn't",
+ "couldn'tve": "couldn't've",
+ "couldnt've": "couldn't've",
+ 'didnt': "didn't",
+ 'doesnt': "doesn't",
+ 'dont': "don't",
+ 'hadnt': "hadn't",
+ "hadnt've": "hadn't've",
+ "hadn'tve": "hadn't've",
+ 'hasnt': "hasn't",
+ 'havent': "haven't",
+ 'hed': "he'd",
+ "hed've": "he'd've",
+ "he'dve": "he'd've",
+ 'hes': "he's",
+ 'howd': "how'd",
+ 'howll': "how'll",
+ 'hows': "how's",
+ "Id've": "I'd've",
+ "I'dve": "I'd've",
+ 'Im': "I'm",
+ 'Ive': "I've",
+ 'isnt': "isn't",
+ 'itd': "it'd",
+ "itd've": "it'd've",
+ "it'dve": "it'd've",
+ 'itll': "it'll",
+ "let's": "let's",
+ 'maam': "ma'am",
+ 'mightnt': "mightn't",
+ "mightnt've": "mightn't've",
+ "mightn'tve": "mightn't've",
+ 'mightve': "might've",
+ 'mustnt': "mustn't",
+ 'mustve': "must've",
+ 'neednt': "needn't",
+ 'notve': "not've",
+ 'oclock': "o'clock",
+ 'oughtnt': "oughtn't",
+ "ow's'at": "'ow's'at",
+ "'ows'at": "'ow's'at",
+ "'ow'sat": "'ow's'at",
+ 'shant': "shan't",
+ "shed've": "she'd've",
+ "she'dve": "she'd've",
+ "she's": "she's",
+ 'shouldve': "should've",
+ 'shouldnt': "shouldn't",
+ "shouldnt've": "shouldn't've",
+ "shouldn'tve": "shouldn't've",
+ "somebody'd": 'somebodyd',
+ "somebodyd've": "somebody'd've",
+ "somebody'dve": "somebody'd've",
+ 'somebodyll': "somebody'll",
+ 'somebodys': "somebody's",
+ 'someoned': "someone'd",
+ "someoned've": "someone'd've",
+ "someone'dve": "someone'd've",
+ 'someonell': "someone'll",
+ 'someones': "someone's",
+ 'somethingd': "something'd",
+ "somethingd've": "something'd've",
+ "something'dve": "something'd've",
+ 'somethingll': "something'll",
+ 'thats': "that's",
+ 'thered': "there'd",
+ "thered've": "there'd've",
+ "there'dve": "there'd've",
+ 'therere': "there're",
+ 'theres': "there's",
+ 'theyd': "they'd",
+ "theyd've": "they'd've",
+ "they'dve": "they'd've",
+ 'theyll': "they'll",
+ 'theyre': "they're",
+ 'theyve': "they've",
+ 'twas': "'twas",
+ 'wasnt': "wasn't",
+ "wed've": "we'd've",
+ "we'dve": "we'd've",
+ 'weve': "we've",
+ 'werent': "weren't",
+ 'whatll': "what'll",
+ 'whatre': "what're",
+ 'whats': "what's",
+ 'whatve': "what've",
+ 'whens': "when's",
+ 'whered': "where'd",
+ 'wheres': "where's",
+ 'whereve': "where've",
+ 'whod': "who'd",
+ "whod've": "who'd've",
+ "who'dve": "who'd've",
+ 'wholl': "who'll",
+ 'whos': "who's",
+ 'whove': "who've",
+ 'whyll': "why'll",
+ 'whyre': "why're",
+ 'whys': "why's",
+ 'wont': "won't",
+ 'wouldve': "would've",
+ 'wouldnt': "wouldn't",
+ "wouldnt've": "wouldn't've",
+ "wouldn'tve": "wouldn't've",
+ 'yall': "y'all",
+ "yall'll": "y'all'll",
+ "y'allll": "y'all'll",
+ "yall'd've": "y'all'd've",
+ "y'alld've": "y'all'd've",
+ "y'all'dve": "y'all'd've",
+ 'youd': "you'd",
+ "youd've": "you'd've",
+ "you'dve": "you'd've",
+ 'youll': "you'll",
+ 'youre': "you're",
+ 'youve': "you've",
+ }
+ self.manualMap = {
+ 'none': '0',
+ 'zero': '0',
+ 'one': '1',
+ 'two': '2',
+ 'three': '3',
+ 'four': '4',
+ 'five': '5',
+ 'six': '6',
+ 'seven': '7',
+ 'eight': '8',
+ 'nine': '9',
+ 'ten': '10',
+ }
+ self.articles = ['a', 'an', 'the']
+
+ self.periodStrip = re.compile('(?!<=\d)(\.)(?!\d)')
+ self.commaStrip = re.compile('(\d)(,)(\d)')
+ self.punct = [
+ ';',
+ r'/',
+ '[',
+ ']',
+ '"',
+ '{',
+ '}',
+ '(',
+ ')',
+ '=',
+ '+',
+ '\\',
+ '_',
+ '-',
+ '>',
+ '<',
+ '@',
+ '`',
+ ',',
+ '?',
+ '!',
+ ]
+
+ def evaluate(self, quesIds=None):
+ if quesIds == None:
+ quesIds = [quesId for quesId in self.params['question_id']]
+ gts = {}
+ res = {}
+ for quesId in quesIds:
+ gts[quesId] = self.vqa.qa[quesId]
+ res[quesId] = self.vqaRes.qa[quesId]
+
+ # =================================================
+ # Compute accuracy
+ # =================================================
+ accQA = []
+ accQuesType = {}
+ accAnsType = {}
+ print('computing accuracy')
+ step = 0
+ for quesId in quesIds:
+ resAns = res[quesId]['answer']
+ resAns = resAns.replace('\n', ' ')
+ resAns = resAns.replace('\t', ' ')
+ resAns = resAns.strip()
+ resAns = self.processPunctuation(resAns)
+ resAns = self.processDigitArticle(resAns)
+ gtAcc = []
+ gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
+ if len(set(gtAnswers)) > 1:
+ for ansDic in gts[quesId]['answers']:
+ ansDic['answer'] = self.processPunctuation(
+ ansDic['answer'])
+ for gtAnsDatum in gts[quesId]['answers']:
+ otherGTAns = [
+ item for item in gts[quesId]['answers']
+ if item != gtAnsDatum
+ ]
+ matchingAns = [
+ item for item in otherGTAns if item['answer'] == resAns
+ ]
+ acc = min(1, float(len(matchingAns)) / 3)
+ gtAcc.append(acc)
+ quesType = gts[quesId]['question_type']
+ ansType = gts[quesId]['answer_type']
+ avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
+ accQA.append(avgGTAcc)
+ if quesType not in accQuesType:
+ accQuesType[quesType] = []
+ accQuesType[quesType].append(avgGTAcc)
+ if ansType not in accAnsType:
+ accAnsType[ansType] = []
+ accAnsType[ansType].append(avgGTAcc)
+ self.setEvalQA(quesId, avgGTAcc)
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
+ if step % 100 == 0:
+ self.updateProgress(step / float(len(quesIds)))
+ step = step + 1
+
+ self.setAccuracy(accQA, accQuesType, accAnsType)
+ print('Done computing accuracy')
+
+ def processPunctuation(self, inText):
+ outText = inText
+ for p in self.punct:
+ if (p + ' ' in inText or ' ' + p
+ in inText) or (re.search(self.commaStrip, inText) != None):
+ outText = outText.replace(p, '')
+ else:
+ outText = outText.replace(p, ' ')
+ outText = self.periodStrip.sub('', outText, re.UNICODE)
+ return outText
+
+ def processDigitArticle(self, inText):
+ outText = []
+ tempText = inText.lower().split()
+ for word in tempText:
+ word = self.manualMap.setdefault(word, word)
+ if word not in self.articles:
+ outText.append(word)
+ else:
+ pass
+ for wordId, word in enumerate(outText):
+ if word in self.contractions:
+ outText[wordId] = self.contractions[word]
+ outText = ' '.join(outText)
+ return outText
+
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
+ self.accuracy['overall'] = round(100 * float(sum(accQA)) / len(accQA),
+ self.n)
+ self.accuracy['perQuestionType'] = {
+ quesType: round(
+ 100 * float(sum(accQuesType[quesType])) /
+ len(accQuesType[quesType]),
+ self.n,
+ )
+ for quesType in accQuesType
+ }
+ self.accuracy['perAnswerType'] = {
+ ansType: round(
+ 100 * float(sum(accAnsType[ansType])) /
+ len(accAnsType[ansType]), self.n)
+ for ansType in accAnsType
+ }
+
+ def setEvalQA(self, quesId, acc):
+ self.evalQA[quesId] = round(100 * acc, self.n)
+
+ def setEvalQuesType(self, quesId, quesType, acc):
+ if quesType not in self.evalQuesType:
+ self.evalQuesType[quesType] = {}
+ self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
+
+ def setEvalAnsType(self, quesId, ansType, acc):
+ if ansType not in self.evalAnsType:
+ self.evalAnsType[ansType] = {}
+ self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
+
+ def updateProgress(self, progress):
+ barLength = 20
+ status = ''
+ if isinstance(progress, int):
+ progress = float(progress)
+ if not isinstance(progress, float):
+ progress = 0
+ status = 'error: progress var must be float\r\n'
+ if progress < 0:
+ progress = 0
+ status = 'Halt...\r\n'
+ if progress >= 1:
+ progress = 1
+ status = 'Done...\r\n'
+ block = int(round(barLength * progress))
+ text = '\rFinshed Percent: [{0}] {1}% {2}'.format(
+ '#' * block + '-' * (barLength - block), int(progress * 100),
+ status)
+ sys.stdout.write(text)
+ sys.stdout.flush()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7f70a086db405864f510df8823e8cf250a9b2919
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+transformers==4.31.0
+accelerate
+tiktoken
+einops
+transformers_stream_generator==0.0.4
+scipy
+torchvision
+pillow
+tensorboard
+matplotlib
diff --git a/requirements_web_demo.txt b/requirements_web_demo.txt
new file mode 100644
index 0000000000000000000000000000000000000000..25aceddaba2623925a4c9f20f2bb00c4282b4db7
--- /dev/null
+++ b/requirements_web_demo.txt
@@ -0,0 +1 @@
+gradio
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..00ca9a06bc9ef7923ea4da7375fd282cf08892bd
--- /dev/null
+++ b/test.py
@@ -0,0 +1,38 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformers.generation import GenerationConfig
+import torch
+torch.manual_seed(1234)
+
+# Note: The default behavior now has injection attack prevention off.
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
+
+# use bf16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
+# use fp16
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
+# use cpu only
+# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cpu", trust_remote_code=True).eval()
+# use cuda device
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True).eval()
+
+# Specify hyperparameters for generation
+model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
+
+# 1st dialogue turn
+query = tokenizer.from_list_format([
+ {'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'}, # Either a local path or an url
+ {'text': '这是什么?'},
+])
+response, history = model.chat(tokenizer, query=query, history=None)
+print(response)
+# 图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。
+
+# 2st dialogue turn
+response, history = model.chat(tokenizer, '框出图中击掌的位置', history=history)
+print(response)
+# 击掌
+
+
+
+
+
+
+**TOUCHSTONE** is a comprehensive assessment of multimodal language models, encompassing not only basic recognition and comprehension but also extending to literary creation. By automating the evaluation process and converting multimodal information into text, our TouchStone allows for efficient and accurate assessment of dialogue quality, leveraging the power of advanced language models without the need for manual intervention.
+
+## DATASET
+
+To evaluate the abilities of LVLMs, we construct a diverse and comprehensive dataset that covers five key dimensions: basic descriptive ability, visual recognition ability, visual comprehension ability, visual storytelling ability, and multi-image analysis ability.
+
+- **Basic Descriptive Ability** Image description involves the ability of a model to describe the information contained in an image, including simple and detailed descriptions. Simple descriptions are typically short phrases that describe the main subject and action of the image, while detailed descriptions provide more in-depth information about the image scene, their attributes, and relationships.
+
+- **Visual Recognition Ability** Image recognition is the task of recognizing objects or scenes within an image and inferring relevant information. This area can be further divided into several sub-tasks, including attribute QA, movie/TV recognition, art recognition, landmark recognition, celebrity recognition, emotion recognition, text recognition, object recognition, and structure content recognition.
+
+- **Visual Comprehension Ability** Image understanding involves the ability of a model to understand the meaning of an image and associated tasks. This area encompasses several sub-tasks, such as style appreciation, abstract image understanding, meme understanding, image analysis, chart analysis, general problem-solving, and reasoning QA.
+
+- **Visual Storytelling Ability** The visual storytelling ability is the process of literary creation based on visual content, including writing emails, poetry, stories, ads/commodity recommendations, and brainstorming.
+
+- **Multi-Image Analysis Ability** Multi-image analysis is the task of analyzing and comparing multiple images. This area includes tasks such as comparing two/multiple images, summarizing multiple image information, comparing commodities, and step-by-step analysis of images.
+
+
+
+
+
+
+
+
+
+
+
+**TOUCHSTONE** 是一种针对多模态语言模型(LVLM)的自动化综合评估方法,评估不仅包括基本的认知和理解,还延伸到文学创作。通过人类注解将多模态信息转换为文本,我们的 TouchStone 可以利用SOTA的语言模型来自动化地完成对LVLMs的多模态对话质量评估。
+
+## 数据集
+
+为了评估 LVLMs 的能力,我们构建了一个多样化且全面的数据集,涵盖五个关键维度:基本描述能力、视觉识别能力、视觉理解能力、视觉叙事能力和多图分析能力。
+
+- **基本描述能力** 图像描述考验模型总结图片信息的能力,包括简单描述和详细描述。 简单描述通常是描述图像的主要内容和关系的简短短语,而详细描述则提供有关图像场景、其属性和关系的更深入的信息。
+
+- **视觉识别能力** 图像识别考察模型提取图像中内容的属性以及关联到知识库的能力。为了考察这方面能力,测试的问题包括属性QA、影视识别、艺术识别、地标识别、名人识别、情感识别、文本识别、物体识别和结构内容识别。
+
+- **视觉理解能力** 图像理解需要模型理解图像内容并完成推理进行相关任务。 这方面包含了例如风格欣赏、抽象图像理解、模因理解、图像分析、图表分析、一般问题解决和推理问答等任务。
+
+- **视觉叙事能力** 视觉叙事能力是基于视觉内容的文学创作能力,包括撰写电子邮件、诗歌、故事、广告/商品推荐、头脑风暴等。
+
+- **多图分析能力** 多图分析是分析和比较多幅图像的任务。该领域包括比较两个/多个图像、总结多个图像信息、比较商品以及逐步分析图像等任务。
+
+
+
+
"
+ else:
+ if i > 0:
+ if count % 2 == 1:
+ line = line.replace("`", r"\`")
+ line = line.replace("<", "<")
+ line = line.replace(">", ">")
+ line = line.replace(" ", " ")
+ line = line.replace("*", "*")
+ line = line.replace("_", "_")
+ line = line.replace("-", "-")
+ line = line.replace(".", ".")
+ line = line.replace("!", "!")
+ line = line.replace("(", "(")
+ line = line.replace(")", ")")
+ line = line.replace("$", "$")
+ lines[i] = "'
+ else:
+ lines[i] = f"
" + line
+ text = "".join(lines)
+ return text
+
+
+def _launch_demo(args, model, tokenizer):
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
+ Path(tempfile.gettempdir()) / "gradio"
+ )
+
+ def predict(_chatbot, task_history):
+ query = task_history[-1][0]
+ print("User: " + _parse_text(query))
+ history_cp = copy.deepcopy(task_history)
+ full_response = ""
+
+ history_filter = []
+ pic_idx = 1
+ pre = ""
+ for i, (q, a) in enumerate(history_cp):
+ if isinstance(q, (tuple, list)):
+ q = f'Picture {pic_idx}: {q[0]}'
+ pre += q + '\n'
+ pic_idx += 1
+ else:
+ pre += q
+ history_filter.append((pre, a))
+ pre = ""
+ history, message = history_filter[:-1], history_filter[-1][0]
+ response, history = model.chat(tokenizer, message, history=history)
+ image = tokenizer.draw_bbox_on_latest_picture(response, history)
+ if image is not None:
+ temp_dir = secrets.token_hex(20)
+ temp_dir = Path(uploaded_file_dir) / temp_dir
+ temp_dir.mkdir(exist_ok=True, parents=True)
+ name = f"tmp{secrets.token_hex(5)}.jpg"
+ filename = temp_dir / name
+ image.save(str(filename))
+ _chatbot[-1] = (_parse_text(query), (str(filename),))
+ chat_response = response.replace("", "")
+ chat_response = chat_response.replace(r"", "")
+ chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response)
+ if chat_response != "":
+ _chatbot.append((None, chat_response))
+ else:
+ _chatbot[-1] = (_parse_text(query), response)
+ full_response = _parse_text(response)
+
+ task_history[-1] = (query, full_response)
+ print("Qwen-VL-Chat: " + _parse_text(full_response))
+ return _chatbot
+
+ def regenerate(_chatbot, task_history):
+ if not task_history:
+ return _chatbot
+ item = task_history[-1]
+ if item[1] is None:
+ return _chatbot
+ task_history[-1] = (item[0], None)
+ chatbot_item = _chatbot.pop(-1)
+ if chatbot_item[0] is None:
+ _chatbot[-1] = (_chatbot[-1][0], None)
+ else:
+ _chatbot.append((chatbot_item[0], None))
+ return predict(_chatbot, task_history)
+
+ def add_text(history, task_history, text):
+ history = history + [(_parse_text(text), None)]
+ task_history = task_history + [(text, None)]
+ return history, task_history, ""
+
+ def add_file(history, task_history, file):
+ history = history + [((file.name,), None)]
+ task_history = task_history + [((file.name,), None)]
+ return history, task_history
+
+ def reset_user_input():
+ return gr.update(value="")
+
+ def reset_state(task_history):
+ task_history.clear()
+ return []
+
+ with gr.Blocks() as demo:
+ gr.Markdown("""\
+