WYBar commited on
Commit
08acebb
·
1 Parent(s): 8fe62ee

fix crello pretrained

Browse files
Files changed (3) hide show
  1. app.py +17 -15
  2. app_test.py +21 -18
  3. modeling_crello.py +21 -8
app.py CHANGED
@@ -50,10 +50,11 @@ def generate_unique_filename():
50
  unique_filename = f"{timestamp}"
51
  return unique_filename
52
 
 
53
  def upload_to_github(file_path,
54
  repo='WYBar/gradiodemo_svg',
55
  branch='main',
56
- token='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'):
57
  if not os.path.isfile(file_path):
58
  print(f"File not found: {file_path}")
59
  return None
@@ -274,26 +275,21 @@ def buildmodel(**kwargs):
274
  pad_token_id=tokenizer.pad_token_id,
275
  ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
276
  )
277
- model_args.freeze_lm = True
278
- model_args.opt_version = "WYBar/LLM_For_Layout_Planning"
279
  model_args.use_lora = False
280
  model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
281
  # model = CrelloModel.from_pretrained(
282
  # resume,
283
  # config=model_args
284
  # ).to(device)
285
- # model = CrelloModel.from_pretrained(
286
- # "WYBar/LLM_For_Layout_Planning",
287
- # subfolder="checkpoint-26000", # 加载检查点目录
288
- # config=model_args,
289
- # # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
290
- # )
291
- model = CrelloModel(config=model_args)
292
- print("before .to(device)")
293
- model = model.to(device)
294
- print("after .to(device)")
295
- model = model.bfloat16()
296
- model.eval()
297
 
298
  tokenizer.add_special_tokens({"mask_token": "<mask>"})
299
  quantizer.additional_special_tokens.add("<mask>")
@@ -328,6 +324,12 @@ def construction_layout():
328
  model.lm.resize_token_embeddings(129423)
329
  model.input_embeddings = model.lm.get_input_embeddings()
330
  print('after token embeddings to match the tokenizer', 129423)
 
 
 
 
 
 
331
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
332
 
333
  @torch.no_grad()
 
50
  unique_filename = f"{timestamp}"
51
  return unique_filename
52
 
53
+ git_token = os.environ.get("GIT_TOKEN")
54
  def upload_to_github(file_path,
55
  repo='WYBar/gradiodemo_svg',
56
  branch='main',
57
+ token=git_token):
58
  if not os.path.isfile(file_path):
59
  print(f"File not found: {file_path}")
60
  return None
 
275
  pad_token_id=tokenizer.pad_token_id,
276
  ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
277
  )
278
+ model_args.freeze_lm = False
279
+ model_args.opt_version = input_model
280
  model_args.use_lora = False
281
  model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
282
  # model = CrelloModel.from_pretrained(
283
  # resume,
284
  # config=model_args
285
  # ).to(device)
286
+ model = CrelloModel.from_pretrained(
287
+ "WYBar/LLM_For_Layout_Planning",
288
+ subfolder="checkpoint-26000", # 加载检查点目录
289
+ config=model_args,
290
+ # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
291
+ )
292
+ # model = CrelloModel(config=model_args)
 
 
 
 
 
293
 
294
  tokenizer.add_special_tokens({"mask_token": "<mask>"})
295
  quantizer.additional_special_tokens.add("<mask>")
 
324
  model.lm.resize_token_embeddings(129423)
325
  model.input_embeddings = model.lm.get_input_embeddings()
326
  print('after token embeddings to match the tokenizer', 129423)
327
+
328
+ print("before .to(device)")
329
+ model = model.to(device)
330
+ print("after .to(device)")
331
+ model = model.bfloat16()
332
+ model.eval()
333
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
334
 
335
  @torch.no_grad()
app_test.py CHANGED
@@ -50,10 +50,11 @@ def generate_unique_filename():
50
  unique_filename = f"{timestamp}"
51
  return unique_filename
52
 
 
53
  def upload_to_github(file_path,
54
  repo='WYBar/gradiodemo_svg',
55
  branch='main',
56
- token='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'):
57
  if not os.path.isfile(file_path):
58
  print(f"File not found: {file_path}")
59
  return None
@@ -274,26 +275,22 @@ def buildmodel(**kwargs):
274
  pad_token_id=tokenizer.pad_token_id,
275
  ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
276
  )
277
- model_args.freeze_lm = True
278
- model_args.opt_version = "WYBar/LLM_For_Layout_Planning"
279
  model_args.use_lora = False
280
  model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
281
  # model = CrelloModel.from_pretrained(
282
  # resume,
283
  # config=model_args
284
  # ).to(device)
285
- # model = CrelloModel.from_pretrained(
286
- # "WYBar/LLM_For_Layout_Planning",
287
- # subfolder="checkpoint-26000", # 加载检查点目录
288
- # config=model_args,
289
- # # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
290
- # )
291
- model = CrelloModel(config=model_args)
292
- print("before .to(device)")
293
- model = model.to(device)
294
- print("after .to(device)")
295
- model = model.bfloat16()
296
- model.eval()
297
 
298
  tokenizer.add_special_tokens({"mask_token": "<mask>"})
299
  quantizer.additional_special_tokens.add("<mask>")
@@ -307,8 +304,8 @@ def buildmodel(**kwargs):
307
  def construction_layout():
308
  params_dict = {
309
  # 需要修改
310
- "input_model": "WYBar/LLM_For_Layout_Planning",
311
- "resume": "WYBar/LLM_For_Layout_Planning",
312
 
313
  "seed": 0,
314
  "mask_values": False,
@@ -328,6 +325,12 @@ def construction_layout():
328
  model.lm.resize_token_embeddings(129423)
329
  model.input_embeddings = model.lm.get_input_embeddings()
330
  print('after token embeddings to match the tokenizer', 129423)
 
 
 
 
 
 
331
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
332
 
333
  @torch.no_grad()
@@ -678,7 +681,7 @@ def main():
678
  inputs=[intention_input, temperature_input, top_p_input, seed_input, true_gs_input, inference_steps_input],
679
  outputs=[list_box_output, result_images, svg_file, svg_editor, text_input, tuple_input]
680
  )
681
- demo.launch()
682
 
683
  if __name__ == "__main__":
684
  main()
 
50
  unique_filename = f"{timestamp}"
51
  return unique_filename
52
 
53
+ git_token = os.environ.get("GIT_TOKEN")
54
  def upload_to_github(file_path,
55
  repo='WYBar/gradiodemo_svg',
56
  branch='main',
57
+ token=git_token):
58
  if not os.path.isfile(file_path):
59
  print(f"File not found: {file_path}")
60
  return None
 
275
  pad_token_id=tokenizer.pad_token_id,
276
  ignore_ids=tokenizer.convert_tokens_to_ids(quantizer.ignore_tokens),
277
  )
278
+ model_args.freeze_lm = False
279
+ model_args.opt_version = input_model
280
  model_args.use_lora = False
281
  model_args.load_in_4bit = kwargs.get('load_in_4bit', False)
282
  # model = CrelloModel.from_pretrained(
283
  # resume,
284
  # config=model_args
285
  # ).to(device)
286
+
287
+ model = CrelloModel.from_pretrained(
288
+ "WYBar/LLM_For_Layout_Planning",
289
+ subfolder="checkpoint-26000", # 加载检查点目录
290
+ config=model_args,
291
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
292
+ )
293
+ # model = CrelloModel(config=model_args)
 
 
 
 
294
 
295
  tokenizer.add_special_tokens({"mask_token": "<mask>"})
296
  quantizer.additional_special_tokens.add("<mask>")
 
304
  def construction_layout():
305
  params_dict = {
306
  # 需要修改
307
+ "input_model": "/openseg_blob/v-sirui/temporary/2024-02-21/Layout_train/COLEv2/Design_LLM/checkpoint/Meta-Llama-3-8B",
308
+ "resume": "/openseg_blob/v-sirui/temporary/2024-02-21/SVD/Int2lay_1016/checkpoint/int2lay_1031/1031_test/checkpoint-26000/",
309
 
310
  "seed": 0,
311
  "mask_values": False,
 
325
  model.lm.resize_token_embeddings(129423)
326
  model.input_embeddings = model.lm.get_input_embeddings()
327
  print('after token embeddings to match the tokenizer', 129423)
328
+
329
+ print("before .to(device)")
330
+ model = model.to(device)
331
+ print("after .to(device)")
332
+ model = model.bfloat16()
333
+ model.eval()
334
  return model, quantizer, tokenizer, params_dict["width"], params_dict["height"], device
335
 
336
  @torch.no_grad()
 
681
  inputs=[intention_input, temperature_input, top_p_input, seed_input, true_gs_input, inference_steps_input],
682
  outputs=[list_box_output, result_images, svg_file, svg_editor, text_input, tuple_input]
683
  )
684
+ demo.launch(server_name='0.0.0.0', server_port=7860)
685
 
686
  if __name__ == "__main__":
687
  main()
modeling_crello.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
- from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM
3
- # from transformers import BitsAndBytesConfig
4
  from torch import nn
5
  import os
6
  from typing import Optional, List
@@ -117,12 +116,13 @@ class CrelloModel(PreTrainedModel):
117
 
118
  def __init__(self, config: CrelloModelConfig): # 显示声明config类型
119
  super().__init__(config)
 
120
 
121
  self.pad_token_id = config.pad_token_id
122
 
123
  self.args = config
124
 
125
- opt_version = "WYBar/LLM_For_Layout_Planning"
126
 
127
  print(f"Using {opt_version} for the language model.")
128
 
@@ -132,7 +132,9 @@ class CrelloModel(PreTrainedModel):
132
  else:
133
  if config.load_in_4bit:
134
  print("\n would load_in_4bit")
135
- quantization_config = None
 
 
136
  # This means: fit the entire model on the GPU:0
137
  local_rank = int(os.environ.get("LOCAL_RANK", 0))
138
  device_map = {"": local_rank}
@@ -151,8 +153,20 @@ class CrelloModel(PreTrainedModel):
151
  # device_map=device_map,
152
  trust_remote_code=True,
153
  torch_dtype=torch.bfloat16,
154
- # cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
155
  )
 
 
 
 
 
 
 
 
 
 
 
 
156
  word_embed_proj_dim = self.lm.config.hidden_size
157
  self.config.hidden_size = self.lm.config.hidden_size
158
  self.opt_version = opt_version
@@ -160,8 +174,8 @@ class CrelloModel(PreTrainedModel):
160
  if self.args.freeze_lm:
161
  self.lm.eval()
162
  print("Freezing the LM.")
163
- # for param in self.lm.parameters():
164
- # param.requires_grad = False
165
  else:
166
  print("\n no freeze lm, so to train lm")
167
  self.lm.train()
@@ -170,7 +184,6 @@ class CrelloModel(PreTrainedModel):
170
  # print('resize token embeddings to match the tokenizer', config.vocab_size)
171
  # self.lm.resize_token_embeddings(config.vocab_size)
172
  # self.input_embeddings = self.lm.get_input_embeddings()
173
- # print('after token embeddings to match the tokenizer', config.vocab_size)
174
 
175
  def train(self, mode=True):
176
  super().train(mode=mode)
 
1
  import torch
2
+ from transformers import PreTrainedModel, PretrainedConfig, AutoModel, AutoModelForCausalLM, OPTForCausalLM, BitsAndBytesConfig
 
3
  from torch import nn
4
  import os
5
  from typing import Optional, List
 
116
 
117
  def __init__(self, config: CrelloModelConfig): # 显示声明config类型
118
  super().__init__(config)
119
+ use_auth_token = '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'
120
 
121
  self.pad_token_id = config.pad_token_id
122
 
123
  self.args = config
124
 
125
+ opt_version = config.opt_version
126
 
127
  print(f"Using {opt_version} for the language model.")
128
 
 
132
  else:
133
  if config.load_in_4bit:
134
  print("\n would load_in_4bit")
135
+ quantization_config = BitsAndBytesConfig(
136
+ load_in_4bit=config.load_in_4bit
137
+ )
138
  # This means: fit the entire model on the GPU:0
139
  local_rank = int(os.environ.get("LOCAL_RANK", 0))
140
  device_map = {"": local_rank}
 
153
  # device_map=device_map,
154
  trust_remote_code=True,
155
  torch_dtype=torch.bfloat16,
156
+ cache_dir="/openseg_blob/v-yanbin/GradioDemo/cache_dir",
157
  )
158
+ # self.lm = AutoModelForCausalLM.from_pretrained(
159
+ # opt_version,
160
+ # use_auth_token=use_auth_token,
161
+ # quantization_config=quantization_config,
162
+ # device_map=device_map,
163
+ # trust_remote_code=True,
164
+ # # attn_implementation="flash_attention_2",
165
+ # # flash_attn=True,
166
+ # # flash_rotary=True,
167
+ # # fused_dense=True,
168
+ # torch_dtype=torch.bfloat16,
169
+ # )
170
  word_embed_proj_dim = self.lm.config.hidden_size
171
  self.config.hidden_size = self.lm.config.hidden_size
172
  self.opt_version = opt_version
 
174
  if self.args.freeze_lm:
175
  self.lm.eval()
176
  print("Freezing the LM.")
177
+ for param in self.lm.parameters():
178
+ param.requires_grad = False
179
  else:
180
  print("\n no freeze lm, so to train lm")
181
  self.lm.train()
 
184
  # print('resize token embeddings to match the tokenizer', config.vocab_size)
185
  # self.lm.resize_token_embeddings(config.vocab_size)
186
  # self.input_embeddings = self.lm.get_input_embeddings()
 
187
 
188
  def train(self, mode=True):
189
  super().train(mode=mode)