Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
@@ -1,3 +1,117 @@
|
|
1 |
---
|
2 |
-
license:
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: apache-2.0
|
3 |
+
pipeline_tag: Image-text
|
4 |
---
|
5 |
+
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<img src="logo_en.png" width="400"/>
|
9 |
+
<p>
|
10 |
+
|
11 |
+
<p align="center">
|
12 |
+
<b><font size="6">InternLM-XComposer2</font></b>
|
13 |
+
<p>
|
14 |
+
|
15 |
+
<div align="center">
|
16 |
+
|
17 |
+
[💻Github Repo](https://github.com/adwardlee/t2i_safety)
|
18 |
+
|
19 |
+
[Paper](https://arxiv.org/abs/)
|
20 |
+
|
21 |
+
</div>
|
22 |
+
|
23 |
+
**ImageGuard** is a vision-language model (VLM) based on [InternLM-XComposer2](https://github.com/InternLM/InternLM-XComposer) for advanced image safety evaluation.
|
24 |
+
|
25 |
+
### Import from Transformers
|
26 |
+
ImageGuard works with transformers>=4.42.
|
27 |
+
|
28 |
+
## Quickstart
|
29 |
+
We provide a simple example to show how to use InternLM-XComposer with 🤗 Transformers.
|
30 |
+
```python
|
31 |
+
import os
|
32 |
+
import json
|
33 |
+
import torch
|
34 |
+
import time
|
35 |
+
import numpy as np
|
36 |
+
import argparse
|
37 |
+
import yaml
|
38 |
+
|
39 |
+
from PIL import Image
|
40 |
+
from utils.img_utils import ImageProcessor
|
41 |
+
from utils.arguments import ModelArguments, DataArguments, EvalArguments, LoraArguments
|
42 |
+
from utils.model_utils import init_model
|
43 |
+
from utils.conv_utils import fair_query, safe_query
|
44 |
+
|
45 |
+
def load_yaml(cfg_path):
|
46 |
+
with open(cfg_path, 'r', encoding='utf-8') as f:
|
47 |
+
result = yaml.load(f.read(), Loader=yaml.FullLoader)
|
48 |
+
return result
|
49 |
+
|
50 |
+
def textprocess(safe=True):
|
51 |
+
if safe:
|
52 |
+
conversation = safe_query('Internlm')
|
53 |
+
else:
|
54 |
+
conversation = fair_query('Internlm')
|
55 |
+
return conversation
|
56 |
+
|
57 |
+
def model_init(
|
58 |
+
model_args: ModelArguments,
|
59 |
+
data_args: DataArguments,
|
60 |
+
training_args: EvalArguments,
|
61 |
+
lora_args: LoraArguments,
|
62 |
+
model_cfg):
|
63 |
+
model, tokenizer = init_model(model_args.model_name_or_path, training_args, data_args, lora_args, model_cfg)
|
64 |
+
model.eval()
|
65 |
+
model.cuda().eval().half()
|
66 |
+
model.tokenizer = tokenizer
|
67 |
+
return model
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == '__main__':
|
72 |
+
parser = argparse.ArgumentParser()
|
73 |
+
parser.add_argument('--load_dir', required=False, type=str, default='lora/')
|
74 |
+
parser.add_argument('--base_model', type=str, default='internlm/internlm-xcomposer2-vl-7b')
|
75 |
+
args = parser.parse_args()
|
76 |
+
load_dir = args.load_dir
|
77 |
+
config = load_yaml(os.path.join(load_dir, 'config.yaml'))
|
78 |
+
model_cfg = config['model_cfg']
|
79 |
+
data_cfg = config['data_cfg']['data_cfg']
|
80 |
+
model_cfg['model_name'] = 'Internlm'
|
81 |
+
data_cfg['train']['model_name'] = 'Internlm'
|
82 |
+
lora_cfg = config['lora_cfg']
|
83 |
+
training_cfg = config['training_cfg']
|
84 |
+
|
85 |
+
model_args = ModelArguments()
|
86 |
+
model_args.model_name_or_path = args.base_model
|
87 |
+
Lora_args = LoraArguments()
|
88 |
+
Lora_args.lora_alpha = lora_cfg['lora_alpha']
|
89 |
+
Lora_args.lora_bias = lora_cfg['lora_bias']
|
90 |
+
Lora_args.lora_dropout = lora_cfg['lora_dropout']
|
91 |
+
Lora_args.lora_r = lora_cfg['lora_r']
|
92 |
+
Lora_args.lora_target_modules = lora_cfg['lora_target_modules']
|
93 |
+
Lora_args.lora_weight_path = load_dir ### comment for base model testing ### llj ## change ##
|
94 |
+
train_args = EvalArguments()
|
95 |
+
train_args.max_length = training_cfg['max_length']
|
96 |
+
train_args.fix_vit = training_cfg['fix_vit']
|
97 |
+
train_args.fix_sampler = training_cfg['fix_sampler']
|
98 |
+
train_args.use_lora = training_cfg['use_lora']
|
99 |
+
train_args.gradient_checkpointing = training_cfg['gradient_checkpointing']
|
100 |
+
data_args = DataArguments()
|
101 |
+
|
102 |
+
model = model_init(model_args, data_args, train_args, Lora_args, model_cfg)
|
103 |
+
print(' model device: ', model.device, flush=True)
|
104 |
+
|
105 |
+
img = Image.open('punch.png')
|
106 |
+
safe = True ## True for toxicity and privacy, False for fairness
|
107 |
+
prompt = textprocess(safe=safe)
|
108 |
+
vis_processor = ImageProcessor(image_size=490)
|
109 |
+
image = vis_processor(img)[None, :, :, :]
|
110 |
+
with torch.cuda.amp.autocast():
|
111 |
+
response, _ = model.chat(model.tokenizer, prompt, image, history=[], do_sample=False, meta_instruction=None)
|
112 |
+
print(response)
|
113 |
+
# unsafe\n violence
|
114 |
+
```
|
115 |
+
|
116 |
+
### Open Source License
|
117 |
+
The code is licensed under Apache-2.0, while model weights are fully open for academic research and also allow free commercial usage.
|