marianna13 commited on
Commit
ef52c23
Β·
1 Parent(s): 230a504
app.py CHANGED
@@ -215,6 +215,8 @@ def http_bot(
215
  template_name = "mpt_text"
216
  elif "llama-2" in model_name:
217
  template_name = "llama_2"
 
 
218
  else:
219
  template_name = "vicuna_v1"
220
  new_state = conv_templates[template_name].copy()
@@ -604,7 +606,7 @@ if __name__ == "__main__":
604
  args = get_args()
605
  logger.info(f"args: {args}")
606
 
607
- model_path = "liuhaotian/llava-v1.5-13b"
608
  bits = int(os.getenv("bits", 8))
609
 
610
  controller_proc = start_controller()
 
215
  template_name = "mpt_text"
216
  elif "llama-2" in model_name:
217
  template_name = "llama_2"
218
+ elif "phi" in model_name:
219
+ template_name = "phi"
220
  else:
221
  template_name = "vicuna_v1"
222
  new_state = conv_templates[template_name].copy()
 
606
  args = get_args()
607
  logger.info(f"args: {args}")
608
 
609
+ model_path = "marianna13/llava-phi-2-3b"
610
  bits = int(os.getenv("bits", 8))
611
 
612
  controller_proc = start_controller()
llava/__init__.py CHANGED
@@ -1 +1,2 @@
1
  from .model import LlavaLlamaForCausalLM
 
 
1
  from .model import LlavaLlamaForCausalLM
2
+ from .model import LlavaMistralForCausalLM
llava/conversation.py CHANGED
@@ -10,6 +10,7 @@ class SeparatorStyle(Enum):
10
  MPT = auto()
11
  PLAIN = auto()
12
  LLAMA_2 = auto()
 
13
 
14
 
15
  @dataclasses.dataclass
@@ -72,6 +73,28 @@ class Conversation:
72
  wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
73
  ret = ""
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  for i, (role, message) in enumerate(messages):
76
  if i == 0:
77
  assert message, "first message should not be none"
@@ -261,6 +284,30 @@ conv_vicuna_v1 = Conversation(
261
  sep2="</s>",
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  conv_llama_2 = Conversation(
265
  system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
266
 
@@ -287,6 +334,19 @@ conv_llava_llama_2 = Conversation(
287
  sep2="</s>",
288
  )
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  conv_mpt = Conversation(
291
  system="""<|im_start|>system
292
  A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
@@ -344,6 +404,18 @@ conv_llava_v1 = Conversation(
344
  sep2="</s>",
345
  )
346
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  conv_llava_v1_mmtag = Conversation(
348
  system="A chat between a curious user and an artificial intelligence assistant. "
349
  "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
@@ -364,7 +436,7 @@ conv_templates = {
364
  "v1": conv_vicuna_v1,
365
  "vicuna_v1": conv_vicuna_v1,
366
  "llama_2": conv_llama_2,
367
-
368
  "plain": conv_llava_plain,
369
  "v0_plain": conv_llava_plain,
370
  "llava_v0": conv_llava_v0,
@@ -372,7 +444,10 @@ conv_templates = {
372
  "llava_v1": conv_llava_v1,
373
  "v1_mmtag": conv_llava_v1_mmtag,
374
  "llava_llama_2": conv_llava_llama_2,
375
-
 
 
 
376
  "mpt": conv_mpt,
377
  }
378
 
 
10
  MPT = auto()
11
  PLAIN = auto()
12
  LLAMA_2 = auto()
13
+ PHI = auto()
14
 
15
 
16
  @dataclasses.dataclass
 
73
  wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
74
  ret = ""
75
 
76
+ for i, (role, message) in enumerate(messages):
77
+ if i == 0:
78
+ assert message, "first message should not be none"
79
+ assert role == self.roles[0], "first message should come from user"
80
+ if message:
81
+ if type(message) is tuple:
82
+ message, _, _ = message
83
+ if i == 0: message = wrap_sys(self.system) + message
84
+ if i % 2 == 0:
85
+ message = wrap_inst(message)
86
+ ret += self.sep + message
87
+ else:
88
+ ret += " " + message + " " + self.sep2
89
+ else:
90
+ ret += ""
91
+ ret = ret.lstrip(self.sep)
92
+
93
+ elif self.sep_style == SeparatorStyle.PHI:
94
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
95
+ wrap_inst = lambda msg: f"Instruct: {msg} \nOutput:"
96
+ ret = ""
97
+
98
  for i, (role, message) in enumerate(messages):
99
  if i == 0:
100
  assert message, "first message should not be none"
 
284
  sep2="</s>",
285
  )
286
 
287
+ phi = Conversation(
288
+ system="A chat between a curious user and an artificial intelligence assistant. "
289
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
290
+ roles=("USER", "ASSISTANT"),
291
+ version="v0",
292
+ messages=(),
293
+ offset=0,
294
+ sep_style=SeparatorStyle.PHI,
295
+ sep="<|endoftext|>",
296
+ sep2="<|endoftext|>",
297
+ )
298
+
299
+ conv_phi = Conversation(
300
+ system="A chat between a curious user and an artificial intelligence assistant. "
301
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
302
+ roles=("USER", "ASSISTANT"),
303
+ version="v2",
304
+ messages=(),
305
+ offset=0,
306
+ sep_style=SeparatorStyle.TWO,
307
+ sep="<|endoftext|>",
308
+ sep2="<|endoftext|>",
309
+ )
310
+
311
  conv_llama_2 = Conversation(
312
  system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
313
 
 
334
  sep2="</s>",
335
  )
336
 
337
+ llava_phi = Conversation(
338
+ system="You are a helpful language and vision assistant. "
339
+ "You are able to understand the visual content that the user provides, "
340
+ "and assist the user with a variety of tasks using natural language.",
341
+ roles=("USER", "ASSISTANT"),
342
+ version="llava_phi",
343
+ messages=(),
344
+ offset=0,
345
+ sep_style=SeparatorStyle.TWO,
346
+ sep="<|endoftext|>",
347
+ sep2="<|endoftext|>",
348
+ )
349
+
350
  conv_mpt = Conversation(
351
  system="""<|im_start|>system
352
  A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
 
404
  sep2="</s>",
405
  )
406
 
407
+ conv_mistral_v1 = Conversation(
408
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
409
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
410
+ roles=("user", "assistant"),
411
+ version="v1",
412
+ messages=(),
413
+ offset=0,
414
+ sep_style=SeparatorStyle.LLAMA_2,
415
+ sep="<s>",
416
+ sep2="</s>",
417
+ )
418
+
419
  conv_llava_v1_mmtag = Conversation(
420
  system="A chat between a curious user and an artificial intelligence assistant. "
421
  "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
 
436
  "v1": conv_vicuna_v1,
437
  "vicuna_v1": conv_vicuna_v1,
438
  "llama_2": conv_llama_2,
439
+ "mistral": conv_llama_2,
440
  "plain": conv_llava_plain,
441
  "v0_plain": conv_llava_plain,
442
  "llava_v0": conv_llava_v0,
 
444
  "llava_v1": conv_llava_v1,
445
  "v1_mmtag": conv_llava_v1_mmtag,
446
  "llava_llama_2": conv_llava_llama_2,
447
+ "conv_mistral_v1": conv_mistral_v1,
448
+ "llava_phi": llava_phi,
449
+ "conv_phi": conv_phi,
450
+ "phi": phi,
451
  "mpt": conv_mpt,
452
  }
453
 
llava/eval/eval_science_qa.py CHANGED
@@ -13,6 +13,8 @@ def get_args():
13
  parser.add_argument('--output-result', type=str)
14
  parser.add_argument('--split', type=str, default='test')
15
  parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
 
 
16
  return parser.parse_args()
17
 
18
 
@@ -39,8 +41,8 @@ if __name__ == "__main__":
39
  args = get_args()
40
 
41
  base_dir = args.base_dir
42
- split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43
- problems = json.load(open(os.path.join(base_dir, "problems.json")))
44
  predictions = [json.loads(line) for line in open(args.result_file)]
45
  predictions = {pred['question_id']: pred for pred in predictions}
46
  split_problems = {idx: problems[idx] for idx in split_indices}
@@ -54,18 +56,26 @@ if __name__ == "__main__":
54
  sqa_results['outputs'] = {}
55
 
56
  for prob_id, prob in split_problems.items():
 
57
  if prob_id not in predictions:
58
- continue
59
- pred = predictions[prob_id]
60
- pred_text = pred['text']
61
-
62
- pattern = re.compile(r'The answer is ([A-Z]).')
63
- res = pattern.findall(pred_text)
64
- if len(res) == 1:
65
- answer = res[0] # 'A', 'B', ...
66
  else:
67
- answer = "FAILED"
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  pred_idx = get_pred_idx(answer, prob['choices'], args.options)
70
 
71
  analysis = {
@@ -85,9 +95,16 @@ if __name__ == "__main__":
85
  else:
86
  results['incorrect'].append(analysis)
87
 
 
88
  correct = len(results['correct'])
89
  total = len(results['correct']) + len(results['incorrect'])
90
- print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
 
 
 
 
 
 
91
 
92
  sqa_results['acc'] = correct / total * 100
93
  sqa_results['correct'] = correct
@@ -96,4 +113,4 @@ if __name__ == "__main__":
96
  with open(args.output_file, 'w') as f:
97
  json.dump(results, f, indent=2)
98
  with open(args.output_result, 'w') as f:
99
- json.dump(sqa_results, f, indent=2)
 
13
  parser.add_argument('--output-result', type=str)
14
  parser.add_argument('--split', type=str, default='test')
15
  parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
+ parser.add_argument('--pid-splits-path')
17
+ parser.add_argument('--problems-path')
18
  return parser.parse_args()
19
 
20
 
 
41
  args = get_args()
42
 
43
  base_dir = args.base_dir
44
+ split_indices = json.load(open(args.pid_splits_path))[args.split]
45
+ problems = json.load(open(args.problems_path))
46
  predictions = [json.loads(line) for line in open(args.result_file)]
47
  predictions = {pred['question_id']: pred for pred in predictions}
48
  split_problems = {idx: problems[idx] for idx in split_indices}
 
56
  sqa_results['outputs'] = {}
57
 
58
  for prob_id, prob in split_problems.items():
59
+ # prob_id = f'{args.split}/{prob_id}'
60
  if prob_id not in predictions:
61
+ pred = {'text': 'FAILED', 'prompt': 'Unknown'}
62
+ pred_text = 'FAILED'
 
 
 
 
 
 
63
  else:
64
+ pred = predictions[prob_id]
65
+ pred_text = pred['text']
66
 
67
+ if pred_text in args.options:
68
+ answer = pred_text
69
+ elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
70
+ answer = pred_text[0]
71
+ else:
72
+ pattern = re.compile(r'The answer is ([A-Z])')
73
+ res = pattern.findall(pred_text)
74
+ if len(res) == 1:
75
+ answer = res[0] # 'A', 'B', ...
76
+ else:
77
+ answer = "FAILED"
78
+
79
  pred_idx = get_pred_idx(answer, prob['choices'], args.options)
80
 
81
  analysis = {
 
95
  else:
96
  results['incorrect'].append(analysis)
97
 
98
+
99
  correct = len(results['correct'])
100
  total = len(results['correct']) + len(results['incorrect'])
101
+ ###### IMG ######
102
+ multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
103
+ multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
104
+ multimodal_total = multimodal_correct + multimodal_incorrect
105
+ ###### IMG ######
106
+
107
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
108
 
109
  sqa_results['acc'] = correct / total * 100
110
  sqa_results['correct'] = correct
 
113
  with open(args.output_file, 'w') as f:
114
  json.dump(results, f, indent=2)
115
  with open(args.output_result, 'w') as f:
116
+ json.dump(sqa_results, f, indent=2)
llava/eval/model_vqa_science.py CHANGED
@@ -57,6 +57,10 @@ def eval_model(args):
57
  else:
58
  images = None
59
 
 
 
 
 
60
  conv = conv_templates[args.conv_mode].copy()
61
  conv.append_message(conv.roles[0], qs)
62
  conv.append_message(conv.roles[1], None)
@@ -64,19 +68,22 @@ def eval_model(args):
64
 
65
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
66
 
67
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
 
68
  keywords = [stop_str]
69
- stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None
 
70
 
71
  with torch.inference_mode():
72
  output_ids = model.generate(
73
  input_ids,
74
  images=images,
75
- do_sample=True,
76
- temperature=0.2,
77
  max_new_tokens=1024,
78
  use_cache=True,
79
  stopping_criteria=stopping_criteria,
 
80
  )
81
 
82
  input_token_len = input_ids.shape[1]
@@ -88,7 +95,9 @@ def eval_model(args):
88
  if outputs.endswith(stop_str):
89
  outputs = outputs[:-len(stop_str)]
90
  outputs = outputs.strip()
91
-
 
 
92
  # prompt for answer
93
  if args.answer_prompter:
94
  outputs_reasoning = outputs
@@ -98,11 +107,11 @@ def eval_model(args):
98
  output_ids = model.generate(
99
  input_ids,
100
  images=images,
101
- do_sample=True,
102
- temperature=0.2,
103
  max_new_tokens=64,
104
  use_cache=True,
105
- stopping_criteria=[stopping_criteria])
106
 
107
  input_token_len = input_ids.shape[1]
108
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
@@ -135,7 +144,9 @@ if __name__ == "__main__":
135
  parser.add_argument("--conv-mode", type=str, default="llava_v0")
136
  parser.add_argument("--num-chunks", type=int, default=1)
137
  parser.add_argument("--chunk-idx", type=int, default=0)
 
138
  parser.add_argument("--answer-prompter", action="store_true")
 
139
  args = parser.parse_args()
140
 
141
- eval_model(args)
 
57
  else:
58
  images = None
59
 
60
+ if args.single_pred_prompt:
61
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
62
+ cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
63
+
64
  conv = conv_templates[args.conv_mode].copy()
65
  conv.append_message(conv.roles[0], qs)
66
  conv.append_message(conv.roles[1], None)
 
68
 
69
  input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
70
 
71
+ stop_str = conv.sep2
72
+ stop_str = "\n" if "phi" in model_name else stop_str
73
  keywords = [stop_str]
74
+ stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)]
75
+ eos_token_id = tokenizer.eos_token_id
76
 
77
  with torch.inference_mode():
78
  output_ids = model.generate(
79
  input_ids,
80
  images=images,
81
+ do_sample=True if args.temperature > 0 else False,
82
+ temperature=args.temperature,
83
  max_new_tokens=1024,
84
  use_cache=True,
85
  stopping_criteria=stopping_criteria,
86
+ # eos_token_id=eos_token_id
87
  )
88
 
89
  input_token_len = input_ids.shape[1]
 
95
  if outputs.endswith(stop_str):
96
  outputs = outputs[:-len(stop_str)]
97
  outputs = outputs.strip()
98
+ # outputs = outputs.replace("\n<</SYS>>", "")
99
+ # print("question:\n", cur_prompt)
100
+ # print("answer:\n", outputs)
101
  # prompt for answer
102
  if args.answer_prompter:
103
  outputs_reasoning = outputs
 
107
  output_ids = model.generate(
108
  input_ids,
109
  images=images,
110
+ do_sample=True if args.temperature > 0 else False,
111
+ temperature=args.temperature,
112
  max_new_tokens=64,
113
  use_cache=True,
114
+ stopping_criteria=stopping_criteria)
115
 
116
  input_token_len = input_ids.shape[1]
117
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
 
144
  parser.add_argument("--conv-mode", type=str, default="llava_v0")
145
  parser.add_argument("--num-chunks", type=int, default=1)
146
  parser.add_argument("--chunk-idx", type=int, default=0)
147
+ parser.add_argument("--temperature", type=float, default=0.2)
148
  parser.add_argument("--answer-prompter", action="store_true")
149
+ parser.add_argument("--single-pred-prompt", action="store_true")
150
  args = parser.parse_args()
151
 
152
+ eval_model(args)
llava/model/__init__.py CHANGED
@@ -1,2 +1,4 @@
1
  from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2
- from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
 
 
 
1
  from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2
+ from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaConfig
3
+ from .language_model.llava_phi import LlavaPhiForCausalLM, LlavaConfig
4
+ # from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
llava/model/apply_delta.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  from tqdm import tqdm
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from llava import LlavaLlamaForCausalLM
 
11
 
12
 
13
  def apply_delta(base_model_path, target_model_path, delta_path):
@@ -16,7 +17,7 @@ def apply_delta(base_model_path, target_model_path, delta_path):
16
  base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
 
18
  print("Loading delta")
19
- delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
  delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
 
22
  print("Applying delta")
 
8
  from tqdm import tqdm
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from llava import LlavaLlamaForCausalLM
11
+ from llava import LlavaMistralForCausalLM
12
 
13
 
14
  def apply_delta(base_model_path, target_model_path, delta_path):
 
17
  base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
 
19
  print("Loading delta")
20
+ delta = LlavaMistralForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
  delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
22
 
23
  print("Applying delta")
llava/model/builder.py CHANGED
@@ -25,7 +25,6 @@ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, D
25
 
26
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
27
  kwargs = {"device_map": device_map}
28
- kwargs["offload_folder"] = "offload"
29
 
30
  if load_8bit:
31
  kwargs['load_in_8bit'] = True
@@ -48,7 +47,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
48
  lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
49
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
50
  print('Loading LLaVA from base model...')
51
- model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
52
  token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
53
  if model.lm_head.weight.shape[0] != token_num:
54
  model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
@@ -90,18 +89,24 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
90
  else:
91
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
92
  cfg_pretrained = AutoConfig.from_pretrained(model_path)
93
- model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
94
 
95
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
96
  mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
97
  model.load_state_dict(mm_projector_weights, strict=False)
98
  else:
99
- if 'mpt' in model_name.lower():
 
 
 
100
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
101
  model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
 
 
 
102
  else:
103
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
104
- model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
105
  else:
106
  # Load language model
107
  if model_base is not None:
@@ -138,14 +143,13 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
138
  vision_tower = model.get_vision_tower()
139
  if not vision_tower.is_loaded:
140
  vision_tower.load_model()
141
-
142
-
143
- vision_tower.to(device=model.device, dtype=torch.float16)
144
  image_processor = vision_tower.image_processor
145
 
146
  if hasattr(model.config, "max_sequence_length"):
147
  context_len = model.config.max_sequence_length
148
  else:
149
  context_len = 2048
 
150
 
151
  return tokenizer, model, image_processor, context_len
 
25
 
26
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto"):
27
  kwargs = {"device_map": device_map}
 
28
 
29
  if load_8bit:
30
  kwargs['load_in_8bit'] = True
 
47
  lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
48
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
49
  print('Loading LLaVA from base model...')
50
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
51
  token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
52
  if model.lm_head.weight.shape[0] != token_num:
53
  model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
 
89
  else:
90
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
  cfg_pretrained = AutoConfig.from_pretrained(model_path)
92
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93
 
94
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95
  mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96
  model.load_state_dict(mm_projector_weights, strict=False)
97
  else:
98
+ if 'phi' in model_name.lower():
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
100
+ model = LlavaPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
101
+ elif 'mpt' in model_name.lower():
102
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
103
  model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
104
+ elif 'phi' in model_name.lower():
105
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
106
+ model = LlavaPhiForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
107
  else:
108
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
109
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
110
  else:
111
  # Load language model
112
  if model_base is not None:
 
143
  vision_tower = model.get_vision_tower()
144
  if not vision_tower.is_loaded:
145
  vision_tower.load_model()
146
+ vision_tower.to(device='cuda', dtype=torch.float16)
 
 
147
  image_processor = vision_tower.image_processor
148
 
149
  if hasattr(model.config, "max_sequence_length"):
150
  context_len = model.config.max_sequence_length
151
  else:
152
  context_len = 2048
153
+ print("model loaded", model)
154
 
155
  return tokenizer, model, image_processor, context_len
llava/model/language_model/llava_llama.py CHANGED
@@ -28,7 +28,7 @@ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
 
29
 
30
  class LlavaConfig(LlamaConfig):
31
- model_type = "llava"
32
 
33
 
34
  class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
@@ -136,5 +136,5 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
136
  )
137
  return model_inputs
138
 
139
- AutoConfig.register("llava", LlavaConfig)
140
  AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
 
28
 
29
 
30
  class LlavaConfig(LlamaConfig):
31
+ model_type = "bakllava"
32
 
33
 
34
  class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
 
136
  )
137
  return model_inputs
138
 
139
+ AutoConfig.register("bakllava", LlavaConfig)
140
  AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/multimodal_encoder/builder.py CHANGED
@@ -5,7 +5,7 @@ from .clip_encoder import CLIPVisionTower
5
  def build_vision_tower(vision_tower_cfg, **kwargs):
6
  vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
  is_absolute_path_exists = os.path.exists(vision_tower)
8
- if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9
  return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10
 
11
  raise ValueError(f'Unknown vision tower: {vision_tower}')
 
5
  def build_vision_tower(vision_tower_cfg, **kwargs):
6
  vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
  is_absolute_path_exists = os.path.exists(vision_tower)
8
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or vision_tower.startswith("apple"):
9
  return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10
 
11
  raise ValueError(f'Unknown vision tower: {vision_tower}')
llava/model/multimodal_encoder/clip_encoder.py CHANGED
@@ -2,6 +2,30 @@ import torch
2
  import torch.nn as nn
3
 
4
  from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  class CLIPVisionTower(nn.Module):
@@ -13,15 +37,18 @@ class CLIPVisionTower(nn.Module):
13
  self.vision_tower_name = vision_tower
14
  self.select_layer = args.mm_vision_select_layer
15
  self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
-
17
  if not delay_load:
18
  self.load_model()
19
  else:
20
  self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21
 
22
  def load_model(self):
23
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24
  self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
 
 
 
 
 
25
  self.vision_tower.requires_grad_(False)
26
 
27
  self.is_loaded = True
 
2
  import torch.nn as nn
3
 
4
  from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+ from huggingface_hub import hf_hub_download
6
+ import json
7
+
8
+ def get_open_clip_image_processor(model_name):
9
+ config_path = hf_hub_download(model_name, filename="open_clip_config.json")
10
+
11
+ with open(config_path, 'r') as f:
12
+ config = json.load(f)
13
+ image_size = config['model_cfg']['vision_cfg']['image_size']
14
+ image_mean = config['preprocess_cfg']['mean']
15
+ image_std = config['preprocess_cfg']['std']
16
+ size = {"shortest_edge": image_size}
17
+ crop_size = {
18
+ "height": image_size,
19
+ "width": image_size
20
+ }
21
+
22
+ return CLIPImageProcessor(
23
+ image_size=image_size,
24
+ image_mean=image_mean,
25
+ image_std=image_std,
26
+ crop_size=crop_size,
27
+ size=size
28
+ )
29
 
30
 
31
  class CLIPVisionTower(nn.Module):
 
37
  self.vision_tower_name = vision_tower
38
  self.select_layer = args.mm_vision_select_layer
39
  self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
 
40
  if not delay_load:
41
  self.load_model()
42
  else:
43
  self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
44
 
45
  def load_model(self):
 
46
  self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
47
+ if self.vision_tower_name.startswith("apple") or self.vision_tower_name.startswith("laion"):
48
+ self.image_processor = get_open_clip_image_processor(self.vision_tower_name)
49
+ else:
50
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
51
+
52
  self.vision_tower.requires_grad_(False)
53
 
54
  self.is_loaded = True
llava/train/train.py CHANGED
@@ -20,7 +20,7 @@ from dataclasses import dataclass, field
20
  import json
21
  import logging
22
  import pathlib
23
- from typing import Dict, Optional, Sequence, List
24
 
25
  import torch
26
 
@@ -35,7 +35,14 @@ from llava.model import *
35
  from llava.mm_utils import tokenizer_image_token
36
 
37
  from PIL import Image
38
-
 
 
 
 
 
 
 
39
 
40
  local_rank = None
41
 
@@ -62,17 +69,30 @@ class ModelArguments:
62
 
63
  @dataclass
64
  class DataArguments:
65
- data_path: str = field(default=None,
 
 
 
66
  metadata={"help": "Path to the training data."})
67
  lazy_preprocess: bool = False
68
  is_multimodal: bool = False
69
  image_folder: Optional[str] = field(default=None)
70
  image_aspect_ratio: str = 'square'
71
  image_grid_pinpoints: Optional[str] = field(default=None)
 
 
 
72
 
73
 
74
  @dataclass
75
  class TrainingArguments(transformers.TrainingArguments):
 
 
 
 
 
 
 
76
  cache_dir: Optional[str] = field(default=None)
77
  optim: str = field(default="adamw_torch")
78
  remove_unused_columns: bool = field(default=False)
@@ -85,6 +105,11 @@ class TrainingArguments(transformers.TrainingArguments):
85
  "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
86
  },
87
  )
 
 
 
 
 
88
  double_quant: bool = field(
89
  default=True,
90
  metadata={"help": "Compress the quantization statistics through double quantization."}
@@ -104,6 +129,8 @@ class TrainingArguments(transformers.TrainingArguments):
104
  lora_weight_path: str = ""
105
  lora_bias: str = "none"
106
  group_by_modality_length: bool = field(default=False)
 
 
107
 
108
 
109
  def maybe_zero_3(param, ignore_status=False, name=None):
@@ -617,7 +644,6 @@ def preprocess(
617
  tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
618
  speakers = [sentence["from"] for sentence in source]
619
  _mask_targets(target, tokenized_lens, speakers)
620
-
621
  return dict(input_ids=input_ids, labels=targets)
622
 
623
 
@@ -634,6 +660,8 @@ class LazySupervisedDataset(Dataset):
634
  self.tokenizer = tokenizer
635
  self.list_data_dict = list_data_dict
636
  self.data_args = data_args
 
 
637
 
638
  def __len__(self):
639
  return len(self.list_data_dict)
@@ -664,7 +692,12 @@ class LazySupervisedDataset(Dataset):
664
  image_file = self.list_data_dict[i]['image']
665
  image_folder = self.data_args.image_folder
666
  processor = self.data_args.image_processor
667
- image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
 
 
 
 
 
668
  if self.data_args.image_aspect_ratio == 'pad':
669
  def expand2square(pil_img, background_color):
670
  width, height = pil_img.size
@@ -721,8 +754,11 @@ class DataCollatorForSupervisedDataset(object):
721
  labels = torch.nn.utils.rnn.pad_sequence(labels,
722
  batch_first=True,
723
  padding_value=IGNORE_INDEX)
 
 
724
  input_ids = input_ids[:, :self.tokenizer.model_max_length]
725
  labels = labels[:, :self.tokenizer.model_max_length]
 
726
  batch = dict(
727
  input_ids=input_ids,
728
  labels=labels,
@@ -738,6 +774,90 @@ class DataCollatorForSupervisedDataset(object):
738
 
739
  return batch
740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
 
742
  def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
743
  data_args) -> Dict:
@@ -788,16 +908,30 @@ def train():
788
  cache_dir=training_args.cache_dir,
789
  **bnb_model_from_pretrained_args
790
  )
 
 
 
 
 
 
 
 
 
 
 
 
791
  else:
792
  model = LlavaLlamaForCausalLM.from_pretrained(
793
  model_args.model_name_or_path,
794
  cache_dir=training_args.cache_dir,
 
795
  **bnb_model_from_pretrained_args
796
  )
797
  else:
798
  model = transformers.LlamaForCausalLM.from_pretrained(
799
  model_args.model_name_or_path,
800
  cache_dir=training_args.cache_dir,
 
801
  **bnb_model_from_pretrained_args
802
  )
803
  model.config.use_cache = False
@@ -915,8 +1049,24 @@ def train():
915
  if training_args.bf16 and module.weight.dtype == torch.float32:
916
  module = module.to(torch.bfloat16)
917
 
918
- data_module = make_supervised_data_module(tokenizer=tokenizer,
919
- data_args=data_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  trainer = LLaVATrainer(model=model,
921
  tokenizer=tokenizer,
922
  args=training_args,
 
20
  import json
21
  import logging
22
  import pathlib
23
+ from typing import Dict, Optional, Sequence, List, Union
24
 
25
  import torch
26
 
 
35
  from llava.mm_utils import tokenizer_image_token
36
 
37
  from PIL import Image
38
+ import webdataset as wds
39
+ import io
40
+ import deepspeed
41
+ import time
42
+ from deepspeed.accelerator import get_accelerator
43
+ import zipfile
44
+ from webdataset_utils import get_wds_data
45
+ import math
46
 
47
  local_rank = None
48
 
 
69
 
70
  @dataclass
71
  class DataArguments:
72
+ dataset_type: str = "webdataset"
73
+ dataset_resampled: bool = False
74
+ lengths_path: Optional[str] = None
75
+ data_path: Union[List[str], str] = field(default=None,
76
  metadata={"help": "Path to the training data."})
77
  lazy_preprocess: bool = False
78
  is_multimodal: bool = False
79
  image_folder: Optional[str] = field(default=None)
80
  image_aspect_ratio: str = 'square'
81
  image_grid_pinpoints: Optional[str] = field(default=None)
82
+ train_data_weights: Optional[List[str]] = None
83
+ # dataloader_num_workers: Optional[int] = None
84
+ # seed: int = 0
85
 
86
 
87
  @dataclass
88
  class TrainingArguments(transformers.TrainingArguments):
89
+ num_training_samples:int = field(default=None)
90
+ resume_from_checkpoint:bool = False
91
+ deepspeed_config: str = field(default=None)
92
+ lr: float = field(default=1e-3)
93
+ beta1: float = field(default=0.5)
94
+ beta2: float = field(default=0.999)
95
+ num_train_epochs: int = field(default=1)
96
  cache_dir: Optional[str] = field(default=None)
97
  optim: str = field(default="adamw_torch")
98
  remove_unused_columns: bool = field(default=False)
 
105
  "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
106
  },
107
  )
108
+ dispatch_batches: bool = field(default=None)
109
+ pin_memory: bool = field(default=False)
110
+ resume: Optional[str] = field(default=None)
111
+
112
+ # train_num_samples: int = field()
113
  double_quant: bool = field(
114
  default=True,
115
  metadata={"help": "Compress the quantization statistics through double quantization."}
 
129
  lora_weight_path: str = ""
130
  lora_bias: str = "none"
131
  group_by_modality_length: bool = field(default=False)
132
+ token: str = None
133
+ train_mode: str = "visual_instruction" # ["visual_instruction", "language_pretraining"]
134
 
135
 
136
  def maybe_zero_3(param, ignore_status=False, name=None):
 
644
  tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
645
  speakers = [sentence["from"] for sentence in source]
646
  _mask_targets(target, tokenized_lens, speakers)
 
647
  return dict(input_ids=input_ids, labels=targets)
648
 
649
 
 
660
  self.tokenizer = tokenizer
661
  self.list_data_dict = list_data_dict
662
  self.data_args = data_args
663
+ self.zip_file = zipfile.ZipFile(self.data_args.image_folder, 'r')
664
+
665
 
666
  def __len__(self):
667
  return len(self.list_data_dict)
 
692
  image_file = self.list_data_dict[i]['image']
693
  image_folder = self.data_args.image_folder
694
  processor = self.data_args.image_processor
695
+ while True:
696
+ try:
697
+ image = Image.open(io.BytesIO(self.zip_file.read(image_file)))
698
+ break
699
+ except:
700
+ pass
701
  if self.data_args.image_aspect_ratio == 'pad':
702
  def expand2square(pil_img, background_color):
703
  width, height = pil_img.size
 
754
  labels = torch.nn.utils.rnn.pad_sequence(labels,
755
  batch_first=True,
756
  padding_value=IGNORE_INDEX)
757
+
758
+
759
  input_ids = input_ids[:, :self.tokenizer.model_max_length]
760
  labels = labels[:, :self.tokenizer.model_max_length]
761
+
762
  batch = dict(
763
  input_ids=input_ids,
764
  labels=labels,
 
774
 
775
  return batch
776
 
777
+ class WdsProcessor:
778
+ def __init__(self, tokenizer, data_args):
779
+ self.data_args = data_args
780
+ self.tokenizer = tokenizer
781
+ # processor = self.data_args.image_processor
782
+ def expand2square(self, pil_img, background_color):
783
+ width, height = pil_img.size
784
+ if width == height:
785
+ return pil_img
786
+ elif width > height:
787
+ result = Image.new(pil_img.mode, (width, width), background_color)
788
+ result.paste(pil_img, (0, (width - height) // 2))
789
+ return result
790
+ else:
791
+ result = Image.new(pil_img.mode, (height, height), background_color)
792
+ result.paste(pil_img, ((height - width) // 2, 0))
793
+ return result
794
+
795
+ def preprocess_wds(self, data):
796
+
797
+ image, sources = data
798
+ has_image = 'image' in sources
799
+ sources = [sources]
800
+ image_processor = self.data_args.image_processor
801
+ if has_image:
802
+ if self.data_args.image_aspect_ratio == 'pad':
803
+
804
+ image = self.expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
805
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
806
+ else:
807
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
808
+ sources = preprocess_multimodal(
809
+ copy.deepcopy([e["conversations"] for e in sources]),
810
+ self.data_args)
811
+ else:
812
+ sources = copy.deepcopy([e["conversations"] for e in sources])
813
+
814
+ data_dict = preprocess(
815
+ sources,
816
+ self.tokenizer,
817
+ has_image=has_image)
818
+
819
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
820
+ labels=data_dict["labels"][0])
821
+
822
+
823
+ if has_image:
824
+ data_dict['image'] = image
825
+ elif self.data_args.is_multimodal:
826
+ # image does not exist in the data, but the model is multimodal
827
+ crop_size = self.data_args.image_processor.crop_size
828
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
829
+
830
+ return data_dict
831
+
832
+ def get_wds_dataset(tokenizer, data_args, training_args):
833
+ visual_instruction = training_args.train_mode == "visual_instruction"
834
+
835
+ round_fn = math.ceil
836
+
837
+ if data_args.lengths_path:
838
+ with open(data_args.lengths_path, 'r') as f:
839
+ lengths = json.load(f)['length_list']
840
+ num_samples = len(lengths)
841
+ elif training_args.num_training_samples:
842
+ num_samples = training_args.num_training_samples
843
+ global_batch_size = training_args.per_device_train_batch_size * training_args.world_size
844
+ num_batches = round_fn(num_samples / global_batch_size)
845
+ num_workers = max(1, training_args.dataloader_num_workers)
846
+ num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
847
+ num_batches = num_worker_batches * num_workers
848
+ training_args.max_steps = num_batches
849
+ data_args.train_num_samples = training_args.num_training_samples
850
+
851
+ data_args.tokenizer = tokenizer
852
+ data_args.dataloader_num_workers = training_args.dataloader_num_workers
853
+ data_args.batch_size = training_args.per_device_train_batch_size
854
+ data_args.world_size = training_args.world_size
855
+ wds_processor = WdsProcessor(tokenizer, data_args)
856
+ train_data = get_wds_data(data_args, is_train=True, wds_processor=wds_processor.preprocess_wds)
857
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) if visual_instruction else None
858
+ return dict(train_dataset=train_data,
859
+ eval_dataset=None,
860
+ data_collator=data_collator)
861
 
862
  def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
863
  data_args) -> Dict:
 
908
  cache_dir=training_args.cache_dir,
909
  **bnb_model_from_pretrained_args
910
  )
911
+ elif 'mistral' in model_args.model_name_or_path.lower():
912
+ model = LlavaMistralForCausalLM.from_pretrained(
913
+ model_args.model_name_or_path,
914
+ cache_dir=training_args.cache_dir,
915
+ **bnb_model_from_pretrained_args
916
+ )
917
+ elif 'phi' in model_args.model_name_or_path:
918
+ model = LlavaPhiForCausalLM.from_pretrained(
919
+ model_args.model_name_or_path,
920
+ cache_dir=training_args.cache_dir,
921
+ **bnb_model_from_pretrained_args
922
+ )
923
  else:
924
  model = LlavaLlamaForCausalLM.from_pretrained(
925
  model_args.model_name_or_path,
926
  cache_dir=training_args.cache_dir,
927
+ token=training_args.token,
928
  **bnb_model_from_pretrained_args
929
  )
930
  else:
931
  model = transformers.LlamaForCausalLM.from_pretrained(
932
  model_args.model_name_or_path,
933
  cache_dir=training_args.cache_dir,
934
+ token=training_args.token,
935
  **bnb_model_from_pretrained_args
936
  )
937
  model.config.use_cache = False
 
1049
  if training_args.bf16 and module.weight.dtype == torch.float32:
1050
  module = module.to(torch.bfloat16)
1051
 
1052
+ if data_args.dataset_type == "webdataset":
1053
+
1054
+ training_args.group_by_length = False
1055
+ data_module = get_wds_dataset(
1056
+ tokenizer=tokenizer,
1057
+ data_args=data_args,
1058
+ training_args=training_args
1059
+ )
1060
+
1061
+
1062
+ elif data_args.dataset_type == "files":
1063
+ data_module = make_supervised_data_module(
1064
+ tokenizer=tokenizer,
1065
+ data_args=data_args
1066
+ )
1067
+ else:
1068
+ ValueError(f"Unknown dataset type {data_args.dataset_type}! Dataset type should be euther `webdataset` or `files`")
1069
+
1070
  trainer = LLaVATrainer(model=model,
1071
  tokenizer=tokenizer,
1072
  args=training_args,
llava/train/train_mem.py CHANGED
@@ -3,9 +3,9 @@
3
  # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
 
5
  # Need to call this before importing transformers.
6
- from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7
 
8
- replace_llama_attn_with_flash_attn()
9
 
10
  from llava.train.train import train
11
 
 
3
  # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
 
5
  # Need to call this before importing transformers.
6
+ # from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7
 
8
+ # replace_llama_attn_with_flash_attn()
9
 
10
  from llava.train.train import train
11
 
requirements.txt CHANGED
@@ -16,7 +16,7 @@ shortuuid
16
  httpx==0.24.0
17
  deepspeed==0.9.5
18
  peft==0.4.0
19
- transformers==4.31.0
20
  accelerate==0.21.0
21
  bitsandbytes==0.41.0
22
  scikit-learn==1.2.2
 
16
  httpx==0.24.0
17
  deepspeed==0.9.5
18
  peft==0.4.0
19
+ transformers==4.36.0
20
  accelerate==0.21.0
21
  bitsandbytes==0.41.0
22
  scikit-learn==1.2.2