Spaces:
Runtime error
Runtime error
marianna13
commited on
Commit
Β·
ef52c23
1
Parent(s):
230a504
added phi
Browse files- app.py +3 -1
- llava/__init__.py +1 -0
- llava/conversation.py +77 -2
- llava/eval/eval_science_qa.py +30 -13
- llava/eval/model_vqa_science.py +20 -9
- llava/model/__init__.py +3 -1
- llava/model/apply_delta.py +2 -1
- llava/model/builder.py +12 -8
- llava/model/language_model/llava_llama.py +2 -2
- llava/model/multimodal_encoder/builder.py +1 -1
- llava/model/multimodal_encoder/clip_encoder.py +29 -2
- llava/train/train.py +157 -7
- llava/train/train_mem.py +2 -2
- requirements.txt +1 -1
app.py
CHANGED
@@ -215,6 +215,8 @@ def http_bot(
|
|
215 |
template_name = "mpt_text"
|
216 |
elif "llama-2" in model_name:
|
217 |
template_name = "llama_2"
|
|
|
|
|
218 |
else:
|
219 |
template_name = "vicuna_v1"
|
220 |
new_state = conv_templates[template_name].copy()
|
@@ -604,7 +606,7 @@ if __name__ == "__main__":
|
|
604 |
args = get_args()
|
605 |
logger.info(f"args: {args}")
|
606 |
|
607 |
-
model_path = "
|
608 |
bits = int(os.getenv("bits", 8))
|
609 |
|
610 |
controller_proc = start_controller()
|
|
|
215 |
template_name = "mpt_text"
|
216 |
elif "llama-2" in model_name:
|
217 |
template_name = "llama_2"
|
218 |
+
elif "phi" in model_name:
|
219 |
+
template_name = "phi"
|
220 |
else:
|
221 |
template_name = "vicuna_v1"
|
222 |
new_state = conv_templates[template_name].copy()
|
|
|
606 |
args = get_args()
|
607 |
logger.info(f"args: {args}")
|
608 |
|
609 |
+
model_path = "marianna13/llava-phi-2-3b"
|
610 |
bits = int(os.getenv("bits", 8))
|
611 |
|
612 |
controller_proc = start_controller()
|
llava/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
from .model import LlavaLlamaForCausalLM
|
|
|
|
1 |
from .model import LlavaLlamaForCausalLM
|
2 |
+
from .model import LlavaMistralForCausalLM
|
llava/conversation.py
CHANGED
@@ -10,6 +10,7 @@ class SeparatorStyle(Enum):
|
|
10 |
MPT = auto()
|
11 |
PLAIN = auto()
|
12 |
LLAMA_2 = auto()
|
|
|
13 |
|
14 |
|
15 |
@dataclasses.dataclass
|
@@ -72,6 +73,28 @@ class Conversation:
|
|
72 |
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
73 |
ret = ""
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
for i, (role, message) in enumerate(messages):
|
76 |
if i == 0:
|
77 |
assert message, "first message should not be none"
|
@@ -261,6 +284,30 @@ conv_vicuna_v1 = Conversation(
|
|
261 |
sep2="</s>",
|
262 |
)
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
conv_llama_2 = Conversation(
|
265 |
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
266 |
|
@@ -287,6 +334,19 @@ conv_llava_llama_2 = Conversation(
|
|
287 |
sep2="</s>",
|
288 |
)
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
conv_mpt = Conversation(
|
291 |
system="""<|im_start|>system
|
292 |
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
@@ -344,6 +404,18 @@ conv_llava_v1 = Conversation(
|
|
344 |
sep2="</s>",
|
345 |
)
|
346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
conv_llava_v1_mmtag = Conversation(
|
348 |
system="A chat between a curious user and an artificial intelligence assistant. "
|
349 |
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
@@ -364,7 +436,7 @@ conv_templates = {
|
|
364 |
"v1": conv_vicuna_v1,
|
365 |
"vicuna_v1": conv_vicuna_v1,
|
366 |
"llama_2": conv_llama_2,
|
367 |
-
|
368 |
"plain": conv_llava_plain,
|
369 |
"v0_plain": conv_llava_plain,
|
370 |
"llava_v0": conv_llava_v0,
|
@@ -372,7 +444,10 @@ conv_templates = {
|
|
372 |
"llava_v1": conv_llava_v1,
|
373 |
"v1_mmtag": conv_llava_v1_mmtag,
|
374 |
"llava_llama_2": conv_llava_llama_2,
|
375 |
-
|
|
|
|
|
|
|
376 |
"mpt": conv_mpt,
|
377 |
}
|
378 |
|
|
|
10 |
MPT = auto()
|
11 |
PLAIN = auto()
|
12 |
LLAMA_2 = auto()
|
13 |
+
PHI = auto()
|
14 |
|
15 |
|
16 |
@dataclasses.dataclass
|
|
|
73 |
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
74 |
ret = ""
|
75 |
|
76 |
+
for i, (role, message) in enumerate(messages):
|
77 |
+
if i == 0:
|
78 |
+
assert message, "first message should not be none"
|
79 |
+
assert role == self.roles[0], "first message should come from user"
|
80 |
+
if message:
|
81 |
+
if type(message) is tuple:
|
82 |
+
message, _, _ = message
|
83 |
+
if i == 0: message = wrap_sys(self.system) + message
|
84 |
+
if i % 2 == 0:
|
85 |
+
message = wrap_inst(message)
|
86 |
+
ret += self.sep + message
|
87 |
+
else:
|
88 |
+
ret += " " + message + " " + self.sep2
|
89 |
+
else:
|
90 |
+
ret += ""
|
91 |
+
ret = ret.lstrip(self.sep)
|
92 |
+
|
93 |
+
elif self.sep_style == SeparatorStyle.PHI:
|
94 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
95 |
+
wrap_inst = lambda msg: f"Instruct: {msg} \nOutput:"
|
96 |
+
ret = ""
|
97 |
+
|
98 |
for i, (role, message) in enumerate(messages):
|
99 |
if i == 0:
|
100 |
assert message, "first message should not be none"
|
|
|
284 |
sep2="</s>",
|
285 |
)
|
286 |
|
287 |
+
phi = Conversation(
|
288 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
289 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
290 |
+
roles=("USER", "ASSISTANT"),
|
291 |
+
version="v0",
|
292 |
+
messages=(),
|
293 |
+
offset=0,
|
294 |
+
sep_style=SeparatorStyle.PHI,
|
295 |
+
sep="<|endoftext|>",
|
296 |
+
sep2="<|endoftext|>",
|
297 |
+
)
|
298 |
+
|
299 |
+
conv_phi = Conversation(
|
300 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
301 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
302 |
+
roles=("USER", "ASSISTANT"),
|
303 |
+
version="v2",
|
304 |
+
messages=(),
|
305 |
+
offset=0,
|
306 |
+
sep_style=SeparatorStyle.TWO,
|
307 |
+
sep="<|endoftext|>",
|
308 |
+
sep2="<|endoftext|>",
|
309 |
+
)
|
310 |
+
|
311 |
conv_llama_2 = Conversation(
|
312 |
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
313 |
|
|
|
334 |
sep2="</s>",
|
335 |
)
|
336 |
|
337 |
+
llava_phi = Conversation(
|
338 |
+
system="You are a helpful language and vision assistant. "
|
339 |
+
"You are able to understand the visual content that the user provides, "
|
340 |
+
"and assist the user with a variety of tasks using natural language.",
|
341 |
+
roles=("USER", "ASSISTANT"),
|
342 |
+
version="llava_phi",
|
343 |
+
messages=(),
|
344 |
+
offset=0,
|
345 |
+
sep_style=SeparatorStyle.TWO,
|
346 |
+
sep="<|endoftext|>",
|
347 |
+
sep2="<|endoftext|>",
|
348 |
+
)
|
349 |
+
|
350 |
conv_mpt = Conversation(
|
351 |
system="""<|im_start|>system
|
352 |
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
|
|
404 |
sep2="</s>",
|
405 |
)
|
406 |
|
407 |
+
conv_mistral_v1 = Conversation(
|
408 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
409 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
410 |
+
roles=("user", "assistant"),
|
411 |
+
version="v1",
|
412 |
+
messages=(),
|
413 |
+
offset=0,
|
414 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
415 |
+
sep="<s>",
|
416 |
+
sep2="</s>",
|
417 |
+
)
|
418 |
+
|
419 |
conv_llava_v1_mmtag = Conversation(
|
420 |
system="A chat between a curious user and an artificial intelligence assistant. "
|
421 |
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
|
|
436 |
"v1": conv_vicuna_v1,
|
437 |
"vicuna_v1": conv_vicuna_v1,
|
438 |
"llama_2": conv_llama_2,
|
439 |
+
"mistral": conv_llama_2,
|
440 |
"plain": conv_llava_plain,
|
441 |
"v0_plain": conv_llava_plain,
|
442 |
"llava_v0": conv_llava_v0,
|
|
|
444 |
"llava_v1": conv_llava_v1,
|
445 |
"v1_mmtag": conv_llava_v1_mmtag,
|
446 |
"llava_llama_2": conv_llava_llama_2,
|
447 |
+
"conv_mistral_v1": conv_mistral_v1,
|
448 |
+
"llava_phi": llava_phi,
|
449 |
+
"conv_phi": conv_phi,
|
450 |
+
"phi": phi,
|
451 |
"mpt": conv_mpt,
|
452 |
}
|
453 |
|
llava/eval/eval_science_qa.py
CHANGED
@@ -13,6 +13,8 @@ def get_args():
|
|
13 |
parser.add_argument('--output-result', type=str)
|
14 |
parser.add_argument('--split', type=str, default='test')
|
15 |
parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
|
|
|
|
|
16 |
return parser.parse_args()
|
17 |
|
18 |
|
@@ -39,8 +41,8 @@ if __name__ == "__main__":
|
|
39 |
args = get_args()
|
40 |
|
41 |
base_dir = args.base_dir
|
42 |
-
split_indices = json.load(open(
|
43 |
-
problems = json.load(open(
|
44 |
predictions = [json.loads(line) for line in open(args.result_file)]
|
45 |
predictions = {pred['question_id']: pred for pred in predictions}
|
46 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
@@ -54,18 +56,26 @@ if __name__ == "__main__":
|
|
54 |
sqa_results['outputs'] = {}
|
55 |
|
56 |
for prob_id, prob in split_problems.items():
|
|
|
57 |
if prob_id not in predictions:
|
58 |
-
|
59 |
-
|
60 |
-
pred_text = pred['text']
|
61 |
-
|
62 |
-
pattern = re.compile(r'The answer is ([A-Z]).')
|
63 |
-
res = pattern.findall(pred_text)
|
64 |
-
if len(res) == 1:
|
65 |
-
answer = res[0] # 'A', 'B', ...
|
66 |
else:
|
67 |
-
|
|
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
pred_idx = get_pred_idx(answer, prob['choices'], args.options)
|
70 |
|
71 |
analysis = {
|
@@ -85,9 +95,16 @@ if __name__ == "__main__":
|
|
85 |
else:
|
86 |
results['incorrect'].append(analysis)
|
87 |
|
|
|
88 |
correct = len(results['correct'])
|
89 |
total = len(results['correct']) + len(results['incorrect'])
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
sqa_results['acc'] = correct / total * 100
|
93 |
sqa_results['correct'] = correct
|
@@ -96,4 +113,4 @@ if __name__ == "__main__":
|
|
96 |
with open(args.output_file, 'w') as f:
|
97 |
json.dump(results, f, indent=2)
|
98 |
with open(args.output_result, 'w') as f:
|
99 |
-
json.dump(sqa_results, f, indent=2)
|
|
|
13 |
parser.add_argument('--output-result', type=str)
|
14 |
parser.add_argument('--split', type=str, default='test')
|
15 |
parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
|
16 |
+
parser.add_argument('--pid-splits-path')
|
17 |
+
parser.add_argument('--problems-path')
|
18 |
return parser.parse_args()
|
19 |
|
20 |
|
|
|
41 |
args = get_args()
|
42 |
|
43 |
base_dir = args.base_dir
|
44 |
+
split_indices = json.load(open(args.pid_splits_path))[args.split]
|
45 |
+
problems = json.load(open(args.problems_path))
|
46 |
predictions = [json.loads(line) for line in open(args.result_file)]
|
47 |
predictions = {pred['question_id']: pred for pred in predictions}
|
48 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
|
|
56 |
sqa_results['outputs'] = {}
|
57 |
|
58 |
for prob_id, prob in split_problems.items():
|
59 |
+
# prob_id = f'{args.split}/{prob_id}'
|
60 |
if prob_id not in predictions:
|
61 |
+
pred = {'text': 'FAILED', 'prompt': 'Unknown'}
|
62 |
+
pred_text = 'FAILED'
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
else:
|
64 |
+
pred = predictions[prob_id]
|
65 |
+
pred_text = pred['text']
|
66 |
|
67 |
+
if pred_text in args.options:
|
68 |
+
answer = pred_text
|
69 |
+
elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
|
70 |
+
answer = pred_text[0]
|
71 |
+
else:
|
72 |
+
pattern = re.compile(r'The answer is ([A-Z])')
|
73 |
+
res = pattern.findall(pred_text)
|
74 |
+
if len(res) == 1:
|
75 |
+
answer = res[0] # 'A', 'B', ...
|
76 |
+
else:
|
77 |
+
answer = "FAILED"
|
78 |
+
|
79 |
pred_idx = get_pred_idx(answer, prob['choices'], args.options)
|
80 |
|
81 |
analysis = {
|
|
|
95 |
else:
|
96 |
results['incorrect'].append(analysis)
|
97 |
|
98 |
+
|
99 |
correct = len(results['correct'])
|
100 |
total = len(results['correct']) + len(results['incorrect'])
|
101 |
+
###### IMG ######
|
102 |
+
multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
|
103 |
+
multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
|
104 |
+
multimodal_total = multimodal_correct + multimodal_incorrect
|
105 |
+
###### IMG ######
|
106 |
+
|
107 |
+
print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
|
108 |
|
109 |
sqa_results['acc'] = correct / total * 100
|
110 |
sqa_results['correct'] = correct
|
|
|
113 |
with open(args.output_file, 'w') as f:
|
114 |
json.dump(results, f, indent=2)
|
115 |
with open(args.output_result, 'w') as f:
|
116 |
+
json.dump(sqa_results, f, indent=2)
|
llava/eval/model_vqa_science.py
CHANGED
@@ -57,6 +57,10 @@ def eval_model(args):
|
|
57 |
else:
|
58 |
images = None
|
59 |
|
|
|
|
|
|
|
|
|
60 |
conv = conv_templates[args.conv_mode].copy()
|
61 |
conv.append_message(conv.roles[0], qs)
|
62 |
conv.append_message(conv.roles[1], None)
|
@@ -64,19 +68,22 @@ def eval_model(args):
|
|
64 |
|
65 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
66 |
|
67 |
-
stop_str = conv.
|
|
|
68 |
keywords = [stop_str]
|
69 |
-
stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)]
|
|
|
70 |
|
71 |
with torch.inference_mode():
|
72 |
output_ids = model.generate(
|
73 |
input_ids,
|
74 |
images=images,
|
75 |
-
do_sample=True,
|
76 |
-
temperature=
|
77 |
max_new_tokens=1024,
|
78 |
use_cache=True,
|
79 |
stopping_criteria=stopping_criteria,
|
|
|
80 |
)
|
81 |
|
82 |
input_token_len = input_ids.shape[1]
|
@@ -88,7 +95,9 @@ def eval_model(args):
|
|
88 |
if outputs.endswith(stop_str):
|
89 |
outputs = outputs[:-len(stop_str)]
|
90 |
outputs = outputs.strip()
|
91 |
-
|
|
|
|
|
92 |
# prompt for answer
|
93 |
if args.answer_prompter:
|
94 |
outputs_reasoning = outputs
|
@@ -98,11 +107,11 @@ def eval_model(args):
|
|
98 |
output_ids = model.generate(
|
99 |
input_ids,
|
100 |
images=images,
|
101 |
-
do_sample=True,
|
102 |
-
temperature=
|
103 |
max_new_tokens=64,
|
104 |
use_cache=True,
|
105 |
-
stopping_criteria=
|
106 |
|
107 |
input_token_len = input_ids.shape[1]
|
108 |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
@@ -135,7 +144,9 @@ if __name__ == "__main__":
|
|
135 |
parser.add_argument("--conv-mode", type=str, default="llava_v0")
|
136 |
parser.add_argument("--num-chunks", type=int, default=1)
|
137 |
parser.add_argument("--chunk-idx", type=int, default=0)
|
|
|
138 |
parser.add_argument("--answer-prompter", action="store_true")
|
|
|
139 |
args = parser.parse_args()
|
140 |
|
141 |
-
eval_model(args)
|
|
|
57 |
else:
|
58 |
images = None
|
59 |
|
60 |
+
if args.single_pred_prompt:
|
61 |
+
qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
|
62 |
+
cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
|
63 |
+
|
64 |
conv = conv_templates[args.conv_mode].copy()
|
65 |
conv.append_message(conv.roles[0], qs)
|
66 |
conv.append_message(conv.roles[1], None)
|
|
|
68 |
|
69 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
70 |
|
71 |
+
stop_str = conv.sep2
|
72 |
+
stop_str = "\n" if "phi" in model_name else stop_str
|
73 |
keywords = [stop_str]
|
74 |
+
stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)]
|
75 |
+
eos_token_id = tokenizer.eos_token_id
|
76 |
|
77 |
with torch.inference_mode():
|
78 |
output_ids = model.generate(
|
79 |
input_ids,
|
80 |
images=images,
|
81 |
+
do_sample=True if args.temperature > 0 else False,
|
82 |
+
temperature=args.temperature,
|
83 |
max_new_tokens=1024,
|
84 |
use_cache=True,
|
85 |
stopping_criteria=stopping_criteria,
|
86 |
+
# eos_token_id=eos_token_id
|
87 |
)
|
88 |
|
89 |
input_token_len = input_ids.shape[1]
|
|
|
95 |
if outputs.endswith(stop_str):
|
96 |
outputs = outputs[:-len(stop_str)]
|
97 |
outputs = outputs.strip()
|
98 |
+
# outputs = outputs.replace("\n<</SYS>>", "")
|
99 |
+
# print("question:\n", cur_prompt)
|
100 |
+
# print("answer:\n", outputs)
|
101 |
# prompt for answer
|
102 |
if args.answer_prompter:
|
103 |
outputs_reasoning = outputs
|
|
|
107 |
output_ids = model.generate(
|
108 |
input_ids,
|
109 |
images=images,
|
110 |
+
do_sample=True if args.temperature > 0 else False,
|
111 |
+
temperature=args.temperature,
|
112 |
max_new_tokens=64,
|
113 |
use_cache=True,
|
114 |
+
stopping_criteria=stopping_criteria)
|
115 |
|
116 |
input_token_len = input_ids.shape[1]
|
117 |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
|
|
144 |
parser.add_argument("--conv-mode", type=str, default="llava_v0")
|
145 |
parser.add_argument("--num-chunks", type=int, default=1)
|
146 |
parser.add_argument("--chunk-idx", type=int, default=0)
|
147 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
148 |
parser.add_argument("--answer-prompter", action="store_true")
|
149 |
+
parser.add_argument("--single-pred-prompt", action="store_true")
|
150 |
args = parser.parse_args()
|
151 |
|
152 |
+
eval_model(args)
|
llava/model/__init__.py
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
|
2 |
-
from .language_model.
|
|
|
|
|
|
1 |
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
|
2 |
+
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaConfig
|
3 |
+
from .language_model.llava_phi import LlavaPhiForCausalLM, LlavaConfig
|
4 |
+
# from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
|
llava/model/apply_delta.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
from tqdm import tqdm
|
9 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
from llava import LlavaLlamaForCausalLM
|
|
|
11 |
|
12 |
|
13 |
def apply_delta(base_model_path, target_model_path, delta_path):
|
@@ -16,7 +17,7 @@ def apply_delta(base_model_path, target_model_path, delta_path):
|
|
16 |
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
|
18 |
print("Loading delta")
|
19 |
-
delta =
|
20 |
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
|
22 |
print("Applying delta")
|
|
|
8 |
from tqdm import tqdm
|
9 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
from llava import LlavaLlamaForCausalLM
|
11 |
+
from llava import LlavaMistralForCausalLM
|
12 |
|
13 |
|
14 |
def apply_delta(base_model_path, target_model_path, delta_path):
|
|
|
17 |
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
18 |
|
19 |
print("Loading delta")
|
20 |
+
delta = LlavaMistralForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
21 |
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
22 |
|
23 |
print("Applying delta")
|
llava/model/builder.py
CHANGED
@@ -25,7 +25,6 @@ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, D
|
|
25 |
|
26 |
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
|
27 |
kwargs = {"device_map": device_map}
|
28 |
-
kwargs["offload_folder"] = "offload"
|
29 |
|
30 |
if load_8bit:
|
31 |
kwargs['load_in_8bit'] = True
|
@@ -48,7 +47,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
48 |
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
49 |
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
50 |
print('Loading LLaVA from base model...')
|
51 |
-
model =
|
52 |
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
53 |
if model.lm_head.weight.shape[0] != token_num:
|
54 |
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
@@ -90,18 +89,24 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
90 |
else:
|
91 |
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
92 |
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
93 |
-
model =
|
94 |
|
95 |
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
96 |
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
97 |
model.load_state_dict(mm_projector_weights, strict=False)
|
98 |
else:
|
99 |
-
if '
|
|
|
|
|
|
|
100 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
101 |
model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
|
|
|
|
|
|
102 |
else:
|
103 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
104 |
-
model =
|
105 |
else:
|
106 |
# Load language model
|
107 |
if model_base is not None:
|
@@ -138,14 +143,13 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
138 |
vision_tower = model.get_vision_tower()
|
139 |
if not vision_tower.is_loaded:
|
140 |
vision_tower.load_model()
|
141 |
-
|
142 |
-
|
143 |
-
vision_tower.to(device=model.device, dtype=torch.float16)
|
144 |
image_processor = vision_tower.image_processor
|
145 |
|
146 |
if hasattr(model.config, "max_sequence_length"):
|
147 |
context_len = model.config.max_sequence_length
|
148 |
else:
|
149 |
context_len = 2048
|
|
|
150 |
|
151 |
return tokenizer, model, image_processor, context_len
|
|
|
25 |
|
26 |
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
|
27 |
kwargs = {"device_map": device_map}
|
|
|
28 |
|
29 |
if load_8bit:
|
30 |
kwargs['load_in_8bit'] = True
|
|
|
47 |
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
48 |
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
49 |
print('Loading LLaVA from base model...')
|
50 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
|
51 |
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
52 |
if model.lm_head.weight.shape[0] != token_num:
|
53 |
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
|
|
89 |
else:
|
90 |
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
91 |
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
92 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
93 |
|
94 |
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
95 |
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
96 |
model.load_state_dict(mm_projector_weights, strict=False)
|
97 |
else:
|
98 |
+
if 'phi' in model_name.lower():
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
100 |
+
model = LlavaPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
101 |
+
elif 'mpt' in model_name.lower():
|
102 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
103 |
model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
104 |
+
elif 'phi' in model_name.lower():
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
106 |
+
model = LlavaPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
107 |
else:
|
108 |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
109 |
+
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
110 |
else:
|
111 |
# Load language model
|
112 |
if model_base is not None:
|
|
|
143 |
vision_tower = model.get_vision_tower()
|
144 |
if not vision_tower.is_loaded:
|
145 |
vision_tower.load_model()
|
146 |
+
vision_tower.to(device='cuda', dtype=torch.float16)
|
|
|
|
|
147 |
image_processor = vision_tower.image_processor
|
148 |
|
149 |
if hasattr(model.config, "max_sequence_length"):
|
150 |
context_len = model.config.max_sequence_length
|
151 |
else:
|
152 |
context_len = 2048
|
153 |
+
print("model loaded", model)
|
154 |
|
155 |
return tokenizer, model, image_processor, context_len
|
llava/model/language_model/llava_llama.py
CHANGED
@@ -28,7 +28,7 @@ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
|
28 |
|
29 |
|
30 |
class LlavaConfig(LlamaConfig):
|
31 |
-
model_type = "
|
32 |
|
33 |
|
34 |
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
@@ -136,5 +136,5 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
|
136 |
)
|
137 |
return model_inputs
|
138 |
|
139 |
-
AutoConfig.register("
|
140 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
|
|
28 |
|
29 |
|
30 |
class LlavaConfig(LlamaConfig):
|
31 |
+
model_type = "bakllava"
|
32 |
|
33 |
|
34 |
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
|
|
136 |
)
|
137 |
return model_inputs
|
138 |
|
139 |
+
AutoConfig.register("bakllava", LlavaConfig)
|
140 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
llava/model/multimodal_encoder/builder.py
CHANGED
@@ -5,7 +5,7 @@ from .clip_encoder import CLIPVisionTower
|
|
5 |
def build_vision_tower(vision_tower_cfg, **kwargs):
|
6 |
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
7 |
is_absolute_path_exists = os.path.exists(vision_tower)
|
8 |
-
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
|
9 |
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
10 |
|
11 |
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
|
|
5 |
def build_vision_tower(vision_tower_cfg, **kwargs):
|
6 |
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
7 |
is_absolute_path_exists = os.path.exists(vision_tower)
|
8 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or vision_tower.startswith("apple"):
|
9 |
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
10 |
|
11 |
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
llava/model/multimodal_encoder/clip_encoder.py
CHANGED
@@ -2,6 +2,30 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class CLIPVisionTower(nn.Module):
|
@@ -13,15 +37,18 @@ class CLIPVisionTower(nn.Module):
|
|
13 |
self.vision_tower_name = vision_tower
|
14 |
self.select_layer = args.mm_vision_select_layer
|
15 |
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
16 |
-
|
17 |
if not delay_load:
|
18 |
self.load_model()
|
19 |
else:
|
20 |
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
21 |
|
22 |
def load_model(self):
|
23 |
-
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
24 |
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
|
|
|
|
|
|
|
|
|
|
25 |
self.vision_tower.requires_grad_(False)
|
26 |
|
27 |
self.is_loaded = True
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
import json
|
7 |
+
|
8 |
+
def get_open_clip_image_processor(model_name):
|
9 |
+
config_path = hf_hub_download(model_name, filename="open_clip_config.json")
|
10 |
+
|
11 |
+
with open(config_path, 'r') as f:
|
12 |
+
config = json.load(f)
|
13 |
+
image_size = config['model_cfg']['vision_cfg']['image_size']
|
14 |
+
image_mean = config['preprocess_cfg']['mean']
|
15 |
+
image_std = config['preprocess_cfg']['std']
|
16 |
+
size = {"shortest_edge": image_size}
|
17 |
+
crop_size = {
|
18 |
+
"height": image_size,
|
19 |
+
"width": image_size
|
20 |
+
}
|
21 |
+
|
22 |
+
return CLIPImageProcessor(
|
23 |
+
image_size=image_size,
|
24 |
+
image_mean=image_mean,
|
25 |
+
image_std=image_std,
|
26 |
+
crop_size=crop_size,
|
27 |
+
size=size
|
28 |
+
)
|
29 |
|
30 |
|
31 |
class CLIPVisionTower(nn.Module):
|
|
|
37 |
self.vision_tower_name = vision_tower
|
38 |
self.select_layer = args.mm_vision_select_layer
|
39 |
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
|
|
40 |
if not delay_load:
|
41 |
self.load_model()
|
42 |
else:
|
43 |
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
44 |
|
45 |
def load_model(self):
|
|
|
46 |
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
47 |
+
if self.vision_tower_name.startswith("apple") or self.vision_tower_name.startswith("laion"):
|
48 |
+
self.image_processor = get_open_clip_image_processor(self.vision_tower_name)
|
49 |
+
else:
|
50 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
51 |
+
|
52 |
self.vision_tower.requires_grad_(False)
|
53 |
|
54 |
self.is_loaded = True
|
llava/train/train.py
CHANGED
@@ -20,7 +20,7 @@ from dataclasses import dataclass, field
|
|
20 |
import json
|
21 |
import logging
|
22 |
import pathlib
|
23 |
-
from typing import Dict, Optional, Sequence, List
|
24 |
|
25 |
import torch
|
26 |
|
@@ -35,7 +35,14 @@ from llava.model import *
|
|
35 |
from llava.mm_utils import tokenizer_image_token
|
36 |
|
37 |
from PIL import Image
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
local_rank = None
|
41 |
|
@@ -62,17 +69,30 @@ class ModelArguments:
|
|
62 |
|
63 |
@dataclass
|
64 |
class DataArguments:
|
65 |
-
|
|
|
|
|
|
|
66 |
metadata={"help": "Path to the training data."})
|
67 |
lazy_preprocess: bool = False
|
68 |
is_multimodal: bool = False
|
69 |
image_folder: Optional[str] = field(default=None)
|
70 |
image_aspect_ratio: str = 'square'
|
71 |
image_grid_pinpoints: Optional[str] = field(default=None)
|
|
|
|
|
|
|
72 |
|
73 |
|
74 |
@dataclass
|
75 |
class TrainingArguments(transformers.TrainingArguments):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
cache_dir: Optional[str] = field(default=None)
|
77 |
optim: str = field(default="adamw_torch")
|
78 |
remove_unused_columns: bool = field(default=False)
|
@@ -85,6 +105,11 @@ class TrainingArguments(transformers.TrainingArguments):
|
|
85 |
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
86 |
},
|
87 |
)
|
|
|
|
|
|
|
|
|
|
|
88 |
double_quant: bool = field(
|
89 |
default=True,
|
90 |
metadata={"help": "Compress the quantization statistics through double quantization."}
|
@@ -104,6 +129,8 @@ class TrainingArguments(transformers.TrainingArguments):
|
|
104 |
lora_weight_path: str = ""
|
105 |
lora_bias: str = "none"
|
106 |
group_by_modality_length: bool = field(default=False)
|
|
|
|
|
107 |
|
108 |
|
109 |
def maybe_zero_3(param, ignore_status=False, name=None):
|
@@ -617,7 +644,6 @@ def preprocess(
|
|
617 |
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
|
618 |
speakers = [sentence["from"] for sentence in source]
|
619 |
_mask_targets(target, tokenized_lens, speakers)
|
620 |
-
|
621 |
return dict(input_ids=input_ids, labels=targets)
|
622 |
|
623 |
|
@@ -634,6 +660,8 @@ class LazySupervisedDataset(Dataset):
|
|
634 |
self.tokenizer = tokenizer
|
635 |
self.list_data_dict = list_data_dict
|
636 |
self.data_args = data_args
|
|
|
|
|
637 |
|
638 |
def __len__(self):
|
639 |
return len(self.list_data_dict)
|
@@ -664,7 +692,12 @@ class LazySupervisedDataset(Dataset):
|
|
664 |
image_file = self.list_data_dict[i]['image']
|
665 |
image_folder = self.data_args.image_folder
|
666 |
processor = self.data_args.image_processor
|
667 |
-
|
|
|
|
|
|
|
|
|
|
|
668 |
if self.data_args.image_aspect_ratio == 'pad':
|
669 |
def expand2square(pil_img, background_color):
|
670 |
width, height = pil_img.size
|
@@ -721,8 +754,11 @@ class DataCollatorForSupervisedDataset(object):
|
|
721 |
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
722 |
batch_first=True,
|
723 |
padding_value=IGNORE_INDEX)
|
|
|
|
|
724 |
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
725 |
labels = labels[:, :self.tokenizer.model_max_length]
|
|
|
726 |
batch = dict(
|
727 |
input_ids=input_ids,
|
728 |
labels=labels,
|
@@ -738,6 +774,90 @@ class DataCollatorForSupervisedDataset(object):
|
|
738 |
|
739 |
return batch
|
740 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
741 |
|
742 |
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
743 |
data_args) -> Dict:
|
@@ -788,16 +908,30 @@ def train():
|
|
788 |
cache_dir=training_args.cache_dir,
|
789 |
**bnb_model_from_pretrained_args
|
790 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
791 |
else:
|
792 |
model = LlavaLlamaForCausalLM.from_pretrained(
|
793 |
model_args.model_name_or_path,
|
794 |
cache_dir=training_args.cache_dir,
|
|
|
795 |
**bnb_model_from_pretrained_args
|
796 |
)
|
797 |
else:
|
798 |
model = transformers.LlamaForCausalLM.from_pretrained(
|
799 |
model_args.model_name_or_path,
|
800 |
cache_dir=training_args.cache_dir,
|
|
|
801 |
**bnb_model_from_pretrained_args
|
802 |
)
|
803 |
model.config.use_cache = False
|
@@ -915,8 +1049,24 @@ def train():
|
|
915 |
if training_args.bf16 and module.weight.dtype == torch.float32:
|
916 |
module = module.to(torch.bfloat16)
|
917 |
|
918 |
-
|
919 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
920 |
trainer = LLaVATrainer(model=model,
|
921 |
tokenizer=tokenizer,
|
922 |
args=training_args,
|
|
|
20 |
import json
|
21 |
import logging
|
22 |
import pathlib
|
23 |
+
from typing import Dict, Optional, Sequence, List, Union
|
24 |
|
25 |
import torch
|
26 |
|
|
|
35 |
from llava.mm_utils import tokenizer_image_token
|
36 |
|
37 |
from PIL import Image
|
38 |
+
import webdataset as wds
|
39 |
+
import io
|
40 |
+
import deepspeed
|
41 |
+
import time
|
42 |
+
from deepspeed.accelerator import get_accelerator
|
43 |
+
import zipfile
|
44 |
+
from webdataset_utils import get_wds_data
|
45 |
+
import math
|
46 |
|
47 |
local_rank = None
|
48 |
|
|
|
69 |
|
70 |
@dataclass
|
71 |
class DataArguments:
|
72 |
+
dataset_type: str = "webdataset"
|
73 |
+
dataset_resampled: bool = False
|
74 |
+
lengths_path: Optional[str] = None
|
75 |
+
data_path: Union[List[str], str] = field(default=None,
|
76 |
metadata={"help": "Path to the training data."})
|
77 |
lazy_preprocess: bool = False
|
78 |
is_multimodal: bool = False
|
79 |
image_folder: Optional[str] = field(default=None)
|
80 |
image_aspect_ratio: str = 'square'
|
81 |
image_grid_pinpoints: Optional[str] = field(default=None)
|
82 |
+
train_data_weights: Optional[List[str]] = None
|
83 |
+
# dataloader_num_workers: Optional[int] = None
|
84 |
+
# seed: int = 0
|
85 |
|
86 |
|
87 |
@dataclass
|
88 |
class TrainingArguments(transformers.TrainingArguments):
|
89 |
+
num_training_samples:int = field(default=None)
|
90 |
+
resume_from_checkpoint:bool = False
|
91 |
+
deepspeed_config: str = field(default=None)
|
92 |
+
lr: float = field(default=1e-3)
|
93 |
+
beta1: float = field(default=0.5)
|
94 |
+
beta2: float = field(default=0.999)
|
95 |
+
num_train_epochs: int = field(default=1)
|
96 |
cache_dir: Optional[str] = field(default=None)
|
97 |
optim: str = field(default="adamw_torch")
|
98 |
remove_unused_columns: bool = field(default=False)
|
|
|
105 |
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
106 |
},
|
107 |
)
|
108 |
+
dispatch_batches: bool = field(default=None)
|
109 |
+
pin_memory: bool = field(default=False)
|
110 |
+
resume: Optional[str] = field(default=None)
|
111 |
+
|
112 |
+
# train_num_samples: int = field()
|
113 |
double_quant: bool = field(
|
114 |
default=True,
|
115 |
metadata={"help": "Compress the quantization statistics through double quantization."}
|
|
|
129 |
lora_weight_path: str = ""
|
130 |
lora_bias: str = "none"
|
131 |
group_by_modality_length: bool = field(default=False)
|
132 |
+
token: str = None
|
133 |
+
train_mode: str = "visual_instruction" # ["visual_instruction", "language_pretraining"]
|
134 |
|
135 |
|
136 |
def maybe_zero_3(param, ignore_status=False, name=None):
|
|
|
644 |
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
|
645 |
speakers = [sentence["from"] for sentence in source]
|
646 |
_mask_targets(target, tokenized_lens, speakers)
|
|
|
647 |
return dict(input_ids=input_ids, labels=targets)
|
648 |
|
649 |
|
|
|
660 |
self.tokenizer = tokenizer
|
661 |
self.list_data_dict = list_data_dict
|
662 |
self.data_args = data_args
|
663 |
+
self.zip_file = zipfile.ZipFile(self.data_args.image_folder, 'r')
|
664 |
+
|
665 |
|
666 |
def __len__(self):
|
667 |
return len(self.list_data_dict)
|
|
|
692 |
image_file = self.list_data_dict[i]['image']
|
693 |
image_folder = self.data_args.image_folder
|
694 |
processor = self.data_args.image_processor
|
695 |
+
while True:
|
696 |
+
try:
|
697 |
+
image = Image.open(io.BytesIO(self.zip_file.read(image_file)))
|
698 |
+
break
|
699 |
+
except:
|
700 |
+
pass
|
701 |
if self.data_args.image_aspect_ratio == 'pad':
|
702 |
def expand2square(pil_img, background_color):
|
703 |
width, height = pil_img.size
|
|
|
754 |
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
755 |
batch_first=True,
|
756 |
padding_value=IGNORE_INDEX)
|
757 |
+
|
758 |
+
|
759 |
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
760 |
labels = labels[:, :self.tokenizer.model_max_length]
|
761 |
+
|
762 |
batch = dict(
|
763 |
input_ids=input_ids,
|
764 |
labels=labels,
|
|
|
774 |
|
775 |
return batch
|
776 |
|
777 |
+
class WdsProcessor:
|
778 |
+
def __init__(self, tokenizer, data_args):
|
779 |
+
self.data_args = data_args
|
780 |
+
self.tokenizer = tokenizer
|
781 |
+
# processor = self.data_args.image_processor
|
782 |
+
def expand2square(self, pil_img, background_color):
|
783 |
+
width, height = pil_img.size
|
784 |
+
if width == height:
|
785 |
+
return pil_img
|
786 |
+
elif width > height:
|
787 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
788 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
789 |
+
return result
|
790 |
+
else:
|
791 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
792 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
793 |
+
return result
|
794 |
+
|
795 |
+
def preprocess_wds(self, data):
|
796 |
+
|
797 |
+
image, sources = data
|
798 |
+
has_image = 'image' in sources
|
799 |
+
sources = [sources]
|
800 |
+
image_processor = self.data_args.image_processor
|
801 |
+
if has_image:
|
802 |
+
if self.data_args.image_aspect_ratio == 'pad':
|
803 |
+
|
804 |
+
image = self.expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
805 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
806 |
+
else:
|
807 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
808 |
+
sources = preprocess_multimodal(
|
809 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
810 |
+
self.data_args)
|
811 |
+
else:
|
812 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
813 |
+
|
814 |
+
data_dict = preprocess(
|
815 |
+
sources,
|
816 |
+
self.tokenizer,
|
817 |
+
has_image=has_image)
|
818 |
+
|
819 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
820 |
+
labels=data_dict["labels"][0])
|
821 |
+
|
822 |
+
|
823 |
+
if has_image:
|
824 |
+
data_dict['image'] = image
|
825 |
+
elif self.data_args.is_multimodal:
|
826 |
+
# image does not exist in the data, but the model is multimodal
|
827 |
+
crop_size = self.data_args.image_processor.crop_size
|
828 |
+
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
|
829 |
+
|
830 |
+
return data_dict
|
831 |
+
|
832 |
+
def get_wds_dataset(tokenizer, data_args, training_args):
|
833 |
+
visual_instruction = training_args.train_mode == "visual_instruction"
|
834 |
+
|
835 |
+
round_fn = math.ceil
|
836 |
+
|
837 |
+
if data_args.lengths_path:
|
838 |
+
with open(data_args.lengths_path, 'r') as f:
|
839 |
+
lengths = json.load(f)['length_list']
|
840 |
+
num_samples = len(lengths)
|
841 |
+
elif training_args.num_training_samples:
|
842 |
+
num_samples = training_args.num_training_samples
|
843 |
+
global_batch_size = training_args.per_device_train_batch_size * training_args.world_size
|
844 |
+
num_batches = round_fn(num_samples / global_batch_size)
|
845 |
+
num_workers = max(1, training_args.dataloader_num_workers)
|
846 |
+
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
|
847 |
+
num_batches = num_worker_batches * num_workers
|
848 |
+
training_args.max_steps = num_batches
|
849 |
+
data_args.train_num_samples = training_args.num_training_samples
|
850 |
+
|
851 |
+
data_args.tokenizer = tokenizer
|
852 |
+
data_args.dataloader_num_workers = training_args.dataloader_num_workers
|
853 |
+
data_args.batch_size = training_args.per_device_train_batch_size
|
854 |
+
data_args.world_size = training_args.world_size
|
855 |
+
wds_processor = WdsProcessor(tokenizer, data_args)
|
856 |
+
train_data = get_wds_data(data_args, is_train=True, wds_processor=wds_processor.preprocess_wds)
|
857 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) if visual_instruction else None
|
858 |
+
return dict(train_dataset=train_data,
|
859 |
+
eval_dataset=None,
|
860 |
+
data_collator=data_collator)
|
861 |
|
862 |
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
863 |
data_args) -> Dict:
|
|
|
908 |
cache_dir=training_args.cache_dir,
|
909 |
**bnb_model_from_pretrained_args
|
910 |
)
|
911 |
+
elif 'mistral' in model_args.model_name_or_path.lower():
|
912 |
+
model = LlavaMistralForCausalLM.from_pretrained(
|
913 |
+
model_args.model_name_or_path,
|
914 |
+
cache_dir=training_args.cache_dir,
|
915 |
+
**bnb_model_from_pretrained_args
|
916 |
+
)
|
917 |
+
elif 'phi' in model_args.model_name_or_path:
|
918 |
+
model = LlavaPhiForCausalLM.from_pretrained(
|
919 |
+
model_args.model_name_or_path,
|
920 |
+
cache_dir=training_args.cache_dir,
|
921 |
+
**bnb_model_from_pretrained_args
|
922 |
+
)
|
923 |
else:
|
924 |
model = LlavaLlamaForCausalLM.from_pretrained(
|
925 |
model_args.model_name_or_path,
|
926 |
cache_dir=training_args.cache_dir,
|
927 |
+
token=training_args.token,
|
928 |
**bnb_model_from_pretrained_args
|
929 |
)
|
930 |
else:
|
931 |
model = transformers.LlamaForCausalLM.from_pretrained(
|
932 |
model_args.model_name_or_path,
|
933 |
cache_dir=training_args.cache_dir,
|
934 |
+
token=training_args.token,
|
935 |
**bnb_model_from_pretrained_args
|
936 |
)
|
937 |
model.config.use_cache = False
|
|
|
1049 |
if training_args.bf16 and module.weight.dtype == torch.float32:
|
1050 |
module = module.to(torch.bfloat16)
|
1051 |
|
1052 |
+
if data_args.dataset_type == "webdataset":
|
1053 |
+
|
1054 |
+
training_args.group_by_length = False
|
1055 |
+
data_module = get_wds_dataset(
|
1056 |
+
tokenizer=tokenizer,
|
1057 |
+
data_args=data_args,
|
1058 |
+
training_args=training_args
|
1059 |
+
)
|
1060 |
+
|
1061 |
+
|
1062 |
+
elif data_args.dataset_type == "files":
|
1063 |
+
data_module = make_supervised_data_module(
|
1064 |
+
tokenizer=tokenizer,
|
1065 |
+
data_args=data_args
|
1066 |
+
)
|
1067 |
+
else:
|
1068 |
+
ValueError(f"Unknown dataset type {data_args.dataset_type}! Dataset type should be euther `webdataset` or `files`")
|
1069 |
+
|
1070 |
trainer = LLaVATrainer(model=model,
|
1071 |
tokenizer=tokenizer,
|
1072 |
args=training_args,
|
llava/train/train_mem.py
CHANGED
@@ -3,9 +3,9 @@
|
|
3 |
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
4 |
|
5 |
# Need to call this before importing transformers.
|
6 |
-
from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
7 |
|
8 |
-
replace_llama_attn_with_flash_attn()
|
9 |
|
10 |
from llava.train.train import train
|
11 |
|
|
|
3 |
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
4 |
|
5 |
# Need to call this before importing transformers.
|
6 |
+
# from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
7 |
|
8 |
+
# replace_llama_attn_with_flash_attn()
|
9 |
|
10 |
from llava.train.train import train
|
11 |
|
requirements.txt
CHANGED
@@ -16,7 +16,7 @@ shortuuid
|
|
16 |
httpx==0.24.0
|
17 |
deepspeed==0.9.5
|
18 |
peft==0.4.0
|
19 |
-
transformers==4.
|
20 |
accelerate==0.21.0
|
21 |
bitsandbytes==0.41.0
|
22 |
scikit-learn==1.2.2
|
|
|
16 |
httpx==0.24.0
|
17 |
deepspeed==0.9.5
|
18 |
peft==0.4.0
|
19 |
+
transformers==4.36.0
|
20 |
accelerate==0.21.0
|
21 |
bitsandbytes==0.41.0
|
22 |
scikit-learn==1.2.2
|