csuhan commited on
Commit
e79860f
·
1 Parent(s): 3ea8b34

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -94,8 +94,8 @@ def model_worker(
94
  with default_tensor_type(dtype=target_dtype, device="cuda"):
95
  model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
96
  for ckpt_id in range(args.num_ckpts):
97
- # ckpt_path = hf_hub_download(repo_id=args.pretrained_path, filename=args.ckpt_format.format(str(ckpt_id)))
98
- ckpt_path = os.path.join(args.pretrained_path, args.ckpt_format.format(str(ckpt_id)))
99
  print(f"Loading pretrained weights {ckpt_path}")
100
  checkpoint = torch.load(ckpt_path, map_location='cpu')
101
  msg = model.load_state_dict(checkpoint, strict=False)
@@ -349,10 +349,8 @@ class DemoConfig:
349
  llama_type = "onellm"
350
  llama_config = "config/llama2/7B.json"
351
  model_max_seq_len = 2048
352
- # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
353
- # pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
354
- # pretrained_path = "csuhan/OneLLM-7B-hf"
355
- pretrained_path = "/home/pgao/jiaming/weights/7B_v20_splits/"
356
  ckpt_format = "consolidated.00-of-01.s{}.pth"
357
  num_ckpts = 10
358
  master_port = 23863
 
94
  with default_tensor_type(dtype=target_dtype, device="cuda"):
95
  model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
96
  for ckpt_id in range(args.num_ckpts):
97
+ ckpt_path = hf_hub_download(repo_id=args.pretrained_path, filename=args.ckpt_format.format(str(ckpt_id)))
98
+ # ckpt_path = os.path.join(args.pretrained_path, args.ckpt_format.format(str(ckpt_id)))
99
  print(f"Loading pretrained weights {ckpt_path}")
100
  checkpoint = torch.load(ckpt_path, map_location='cpu')
101
  msg = model.load_state_dict(checkpoint, strict=False)
 
349
  llama_type = "onellm"
350
  llama_config = "config/llama2/7B.json"
351
  model_max_seq_len = 2048
352
+ pretrained_path = "csuhan/OneLLM-7B-hf"
353
+ # pretrained_path = "/home/pgao/jiaming/weights/7B_v20_splits/"
 
 
354
  ckpt_format = "consolidated.00-of-01.s{}.pth"
355
  num_ckpts = 10
356
  master_port = 23863