ncoria commited on
Commit
c784516
·
verified ·
1 Parent(s): 92a7fcc

remove hf_token parameter

Browse files

model is now public, no need for hf_token validation

Files changed (1) hide show
  1. get_llava_response.py +184 -185
get_llava_response.py CHANGED
@@ -1,186 +1,185 @@
1
- import argparse
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
4
- import numpy as np
5
-
6
- from huggingface_hub import whoami
7
-
8
- import llava
9
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
10
- from llava.conversation import conv_templates, SeparatorStyle
11
- from llava.model.builder import load_pretrained_model
12
- from llava.utils import disable_torch_init
13
- from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
14
-
15
- from PIL import Image
16
-
17
- import requests
18
- from PIL import Image
19
- from io import BytesIO
20
- from transformers import TextStreamer
21
- from tqdm import tqdm
22
-
23
- import warnings
24
- warnings.filterwarnings('ignore')
25
-
26
- REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
27
-
28
- def load_image(image_file):
29
- if image_file.startswith('http://') or image_file.startswith('https://'):
30
- response = requests.get(image_file)
31
- image = Image.open(BytesIO(response.content)).convert('RGB')
32
- else:
33
- image = Image.open(image_file).convert('RGB')
34
- return image
35
-
36
- def load_llava_checkpoint(model_path: str):
37
- model_name = get_model_name_from_path(model_path)
38
- return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda")
39
-
40
- def load_llava_checkpoint_hf(model_path, hf_token):
41
- user = whoami(token=hf_token)
42
- kwargs = {"device_map": "auto"}
43
- kwargs['load_in_4bit'] = True
44
- kwargs['quantization_config'] = BitsAndBytesConfig(
45
- load_in_4bit=True,
46
- bnb_4bit_compute_dtype=torch.float16,
47
- bnb_4bit_use_double_quant=True,
48
- bnb_4bit_quant_type='nf4'
49
- )
50
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=hf_token)
51
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, token=hf_token, **kwargs)
52
- mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
53
- mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
54
- if mm_use_im_patch_token:
55
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
56
- if mm_use_im_start_end:
57
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
58
- model.resize_token_embeddings(len(tokenizer))
59
-
60
- vision_tower = model.get_vision_tower()
61
- if not vision_tower.is_loaded:
62
- vision_tower.load_model(device_map="auto")
63
- image_processor = vision_tower.image_processor
64
- return tokenizer, model, image_processor
65
-
66
- def get_llava_response(user_prompts: list[str],
67
- images: list,
68
- sys_prompt: str,
69
- tokenizer,
70
- model,
71
- image_processor,
72
- model_path = REPO_NAME,
73
- stream_output = True):
74
- """
75
- This function returns the response from the given model. It creates a one turn conversation in which
76
- the only content is a system prompt and the given user message applied to each image.
77
-
78
- Parameters:
79
- ----------
80
- user_prompt : str
81
- The prompt sent by the user.
82
- images : str
83
- List of images from file.
84
- sys_prompt : str
85
- The prompt that sets the tone for the conversation.
86
- model_path : str
87
- The path to the merged checkpoint or base model.
88
-
89
- Returns:
90
- --------
91
- """
92
- # set up and load model
93
- model_name = get_model_name_from_path(model_path)
94
- temperature = 0.2 # default
95
- max_new_tokens = 512 # default
96
-
97
- # determine conversation type
98
- if "llama-2" in model_name.lower():
99
- conv_mode = "llava_llama_2"
100
- elif "mistral" in model_name.lower():
101
- conv_mode = "mistral_instruct"
102
- elif "v1.6-34b" in model_name.lower():
103
- conv_mode = "chatml_direct"
104
- elif "v1" in model_name.lower():
105
- conv_mode = "llava_v1"
106
- elif "mpt" in model_name.lower():
107
- conv_mode = "mpt"
108
- else:
109
- conv_mode = "llava_v0"
110
-
111
- # run clean conversation for each image
112
- llm_outputs = []
113
- for i, img in tqdm(enumerate(images)):
114
- # set up clean conversation
115
- conv = conv_templates[conv_mode].copy()
116
- if "mpt" in model_name.lower():
117
- roles = ('user', 'assistant')
118
- else:
119
- roles = conv.roles
120
-
121
- conv.system = sys_prompt
122
-
123
- # load image
124
- # image = load_image("../images/mouse.png") # previous method
125
- if isinstance(img, np.ndarray) and len(img.shape) == 2:
126
- img = Image.fromarray(img, 'L')
127
- elif isinstance(img, np.ndarray):
128
- img = Image.fromarray(img)
129
-
130
- image = img.convert('RGB')
131
- image_size = image.size
132
-
133
- # NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files!
134
-
135
- # Similar operation in model_worker.py
136
- image_tensor = process_images([image], image_processor, model.config)
137
- if type(image_tensor) is list:
138
- image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
139
- else:
140
- image_tensor = image_tensor.to(model.device, dtype=torch.float16)
141
-
142
- # execute conversation
143
- inp = user_prompts[i]
144
- if image is not None:
145
- # first message
146
- if model.config.mm_use_im_start_end:
147
- inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
148
- else:
149
- inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
150
- image = None
151
- conv.append_message(conv.roles[0], inp)
152
- conv.append_message(conv.roles[1], None)
153
- prompt = conv.get_prompt()
154
- input_ids = tokenizer_image_token(prompt,
155
- tokenizer,
156
- IMAGE_TOKEN_INDEX,
157
- return_tensors='pt').unsqueeze(0).to(model.device)
158
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
159
- keywords = [stop_str]
160
- if stream_output:
161
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
162
- else:
163
- streamer = None
164
-
165
- with torch.inference_mode():
166
- output_ids = model.generate(
167
- input_ids,
168
- images=image_tensor,
169
- image_sizes=[image_size],
170
- do_sample=True if temperature > 0 else False,
171
- temperature=temperature,
172
- max_new_tokens=max_new_tokens,
173
- streamer=streamer,
174
- use_cache=True)
175
-
176
- outputs = tokenizer.decode(output_ids[0]).strip()
177
- llm_outputs.append(outputs)
178
- return llm_outputs
179
-
180
-
181
-
182
-
183
-
184
-
185
-
186
 
 
1
+ import argparse
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
4
+ import numpy as np
5
+
6
+ from huggingface_hub import whoami
7
+
8
+ import llava
9
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN
10
+ from llava.conversation import conv_templates, SeparatorStyle
11
+ from llava.model.builder import load_pretrained_model
12
+ from llava.utils import disable_torch_init
13
+ from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
14
+
15
+ from PIL import Image
16
+
17
+ import requests
18
+ from PIL import Image
19
+ from io import BytesIO
20
+ from transformers import TextStreamer
21
+ from tqdm import tqdm
22
+
23
+ import warnings
24
+ warnings.filterwarnings('ignore')
25
+
26
+ REPO_NAME = 'ncoria/llava-lora-vicuna-clip-5-epochs-merge'
27
+
28
+ def load_image(image_file):
29
+ if image_file.startswith('http://') or image_file.startswith('https://'):
30
+ response = requests.get(image_file)
31
+ image = Image.open(BytesIO(response.content)).convert('RGB')
32
+ else:
33
+ image = Image.open(image_file).convert('RGB')
34
+ return image
35
+
36
+ def load_llava_checkpoint(model_path: str):
37
+ model_name = get_model_name_from_path(model_path)
38
+ return load_pretrained_model(model_path, None, model_name, load_4bit=True, device="cuda")
39
+
40
+ def load_llava_checkpoint_hf(model_path):
41
+ kwargs = {"device_map": "auto"}
42
+ kwargs['load_in_4bit'] = True
43
+ kwargs['quantization_config'] = BitsAndBytesConfig(
44
+ load_in_4bit=True,
45
+ bnb_4bit_compute_dtype=torch.float16,
46
+ bnb_4bit_use_double_quant=True,
47
+ bnb_4bit_quant_type='nf4'
48
+ )
49
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
50
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
51
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
52
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
53
+ if mm_use_im_patch_token:
54
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
55
+ if mm_use_im_start_end:
56
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
57
+ model.resize_token_embeddings(len(tokenizer))
58
+
59
+ vision_tower = model.get_vision_tower()
60
+ if not vision_tower.is_loaded:
61
+ vision_tower.load_model(device_map="auto")
62
+ image_processor = vision_tower.image_processor
63
+ return tokenizer, model, image_processor
64
+
65
+ def get_llava_response(user_prompts: list[str],
66
+ images: list,
67
+ sys_prompt: str,
68
+ tokenizer,
69
+ model,
70
+ image_processor,
71
+ model_path = REPO_NAME,
72
+ stream_output = True):
73
+ """
74
+ This function returns the response from the given model. It creates a one turn conversation in which
75
+ the only content is a system prompt and the given user message applied to each image.
76
+
77
+ Parameters:
78
+ ----------
79
+ user_prompt : str
80
+ The prompt sent by the user.
81
+ images : str
82
+ List of images from file.
83
+ sys_prompt : str
84
+ The prompt that sets the tone for the conversation.
85
+ model_path : str
86
+ The path to the merged checkpoint or base model.
87
+
88
+ Returns:
89
+ --------
90
+ """
91
+ # set up and load model
92
+ model_name = get_model_name_from_path(model_path)
93
+ temperature = 0.2 # default
94
+ max_new_tokens = 512 # default
95
+
96
+ # determine conversation type
97
+ if "llama-2" in model_name.lower():
98
+ conv_mode = "llava_llama_2"
99
+ elif "mistral" in model_name.lower():
100
+ conv_mode = "mistral_instruct"
101
+ elif "v1.6-34b" in model_name.lower():
102
+ conv_mode = "chatml_direct"
103
+ elif "v1" in model_name.lower():
104
+ conv_mode = "llava_v1"
105
+ elif "mpt" in model_name.lower():
106
+ conv_mode = "mpt"
107
+ else:
108
+ conv_mode = "llava_v0"
109
+
110
+ # run clean conversation for each image
111
+ llm_outputs = []
112
+ for i, img in tqdm(enumerate(images)):
113
+ # set up clean conversation
114
+ conv = conv_templates[conv_mode].copy()
115
+ if "mpt" in model_name.lower():
116
+ roles = ('user', 'assistant')
117
+ else:
118
+ roles = conv.roles
119
+
120
+ conv.system = sys_prompt
121
+
122
+ # load image
123
+ # image = load_image("../images/mouse.png") # previous method
124
+ if isinstance(img, np.ndarray) and len(img.shape) == 2:
125
+ img = Image.fromarray(img, 'L')
126
+ elif isinstance(img, np.ndarray):
127
+ img = Image.fromarray(img)
128
+
129
+ image = img.convert('RGB')
130
+ image_size = image.size
131
+
132
+ # NOTE: image is simply PIL Image (.convert('RGB')), no need for temp files!
133
+
134
+ # Similar operation in model_worker.py
135
+ image_tensor = process_images([image], image_processor, model.config)
136
+ if type(image_tensor) is list:
137
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
138
+ else:
139
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
140
+
141
+ # execute conversation
142
+ inp = user_prompts[i]
143
+ if image is not None:
144
+ # first message
145
+ if model.config.mm_use_im_start_end:
146
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
147
+ else:
148
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
149
+ image = None
150
+ conv.append_message(conv.roles[0], inp)
151
+ conv.append_message(conv.roles[1], None)
152
+ prompt = conv.get_prompt()
153
+ input_ids = tokenizer_image_token(prompt,
154
+ tokenizer,
155
+ IMAGE_TOKEN_INDEX,
156
+ return_tensors='pt').unsqueeze(0).to(model.device)
157
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
158
+ keywords = [stop_str]
159
+ if stream_output:
160
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
161
+ else:
162
+ streamer = None
163
+
164
+ with torch.inference_mode():
165
+ output_ids = model.generate(
166
+ input_ids,
167
+ images=image_tensor,
168
+ image_sizes=[image_size],
169
+ do_sample=True if temperature > 0 else False,
170
+ temperature=temperature,
171
+ max_new_tokens=max_new_tokens,
172
+ streamer=streamer,
173
+ use_cache=True)
174
+
175
+ outputs = tokenizer.decode(output_ids[0]).strip()
176
+ llm_outputs.append(outputs)
177
+ return llm_outputs
178
+
179
+
180
+
181
+
182
+
183
+
184
+
 
185