adwardlee commited on
Commit
4710e9b
·
verified ·
1 Parent(s): 7a1a200

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +115 -1
README.md CHANGED
@@ -1,3 +1,117 @@
1
  ---
2
- license: mit
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ pipeline_tag: Image-text
4
  ---
5
+
6
+
7
+ <p align="center">
8
+ <img src="logo_en.png" width="400"/>
9
+ <p>
10
+
11
+ <p align="center">
12
+ <b><font size="6">InternLM-XComposer2</font></b>
13
+ <p>
14
+
15
+ <div align="center">
16
+
17
+ [💻Github Repo](https://github.com/adwardlee/t2i_safety)
18
+
19
+ [Paper](https://arxiv.org/abs/)
20
+
21
+ </div>
22
+
23
+ **ImageGuard** is a vision-language model (VLM) based on [InternLM-XComposer2](https://github.com/InternLM/InternLM-XComposer) for advanced image safety evaluation.
24
+
25
+ ### Import from Transformers
26
+ ImageGuard works with transformers>=4.42.
27
+
28
+ ## Quickstart
29
+ We provide a simple example to show how to use InternLM-XComposer with 🤗 Transformers.
30
+ ```python
31
+ import os
32
+ import json
33
+ import torch
34
+ import time
35
+ import numpy as np
36
+ import argparse
37
+ import yaml
38
+
39
+ from PIL import Image
40
+ from utils.img_utils import ImageProcessor
41
+ from utils.arguments import ModelArguments, DataArguments, EvalArguments, LoraArguments
42
+ from utils.model_utils import init_model
43
+ from utils.conv_utils import fair_query, safe_query
44
+
45
+ def load_yaml(cfg_path):
46
+ with open(cfg_path, 'r', encoding='utf-8') as f:
47
+ result = yaml.load(f.read(), Loader=yaml.FullLoader)
48
+ return result
49
+
50
+ def textprocess(safe=True):
51
+ if safe:
52
+ conversation = safe_query('Internlm')
53
+ else:
54
+ conversation = fair_query('Internlm')
55
+ return conversation
56
+
57
+ def model_init(
58
+ model_args: ModelArguments,
59
+ data_args: DataArguments,
60
+ training_args: EvalArguments,
61
+ lora_args: LoraArguments,
62
+ model_cfg):
63
+ model, tokenizer = init_model(model_args.model_name_or_path, training_args, data_args, lora_args, model_cfg)
64
+ model.eval()
65
+ model.cuda().eval().half()
66
+ model.tokenizer = tokenizer
67
+ return model
68
+
69
+
70
+
71
+ if __name__ == '__main__':
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument('--load_dir', required=False, type=str, default='lora/')
74
+ parser.add_argument('--base_model', type=str, default='internlm/internlm-xcomposer2-vl-7b')
75
+ args = parser.parse_args()
76
+ load_dir = args.load_dir
77
+ config = load_yaml(os.path.join(load_dir, 'config.yaml'))
78
+ model_cfg = config['model_cfg']
79
+ data_cfg = config['data_cfg']['data_cfg']
80
+ model_cfg['model_name'] = 'Internlm'
81
+ data_cfg['train']['model_name'] = 'Internlm'
82
+ lora_cfg = config['lora_cfg']
83
+ training_cfg = config['training_cfg']
84
+
85
+ model_args = ModelArguments()
86
+ model_args.model_name_or_path = args.base_model
87
+ Lora_args = LoraArguments()
88
+ Lora_args.lora_alpha = lora_cfg['lora_alpha']
89
+ Lora_args.lora_bias = lora_cfg['lora_bias']
90
+ Lora_args.lora_dropout = lora_cfg['lora_dropout']
91
+ Lora_args.lora_r = lora_cfg['lora_r']
92
+ Lora_args.lora_target_modules = lora_cfg['lora_target_modules']
93
+ Lora_args.lora_weight_path = load_dir ### comment for base model testing ### llj ## change ##
94
+ train_args = EvalArguments()
95
+ train_args.max_length = training_cfg['max_length']
96
+ train_args.fix_vit = training_cfg['fix_vit']
97
+ train_args.fix_sampler = training_cfg['fix_sampler']
98
+ train_args.use_lora = training_cfg['use_lora']
99
+ train_args.gradient_checkpointing = training_cfg['gradient_checkpointing']
100
+ data_args = DataArguments()
101
+
102
+ model = model_init(model_args, data_args, train_args, Lora_args, model_cfg)
103
+ print(' model device: ', model.device, flush=True)
104
+
105
+ img = Image.open('punch.png')
106
+ safe = True ## True for toxicity and privacy, False for fairness
107
+ prompt = textprocess(safe=safe)
108
+ vis_processor = ImageProcessor(image_size=490)
109
+ image = vis_processor(img)[None, :, :, :]
110
+ with torch.cuda.amp.autocast():
111
+ response, _ = model.chat(model.tokenizer, prompt, image, history=[], do_sample=False, meta_instruction=None)
112
+ print(response)
113
+ # unsafe\n violence
114
+ ```
115
+
116
+ ### Open Source License
117
+ The code is licensed under Apache-2.0, while model weights are fully open for academic research and also allow free commercial usage.