csuhan commited on
Commit
ae280e8
·
1 Parent(s): 2e4b1b9
Files changed (2) hide show
  1. app.py +21 -14
  2. requirements.txt +2 -1
app.py CHANGED
@@ -6,6 +6,7 @@ import time
6
  from pathlib import Path
7
  from typing import Tuple
8
 
 
9
  from PIL import Image
10
  import gradio as gr
11
  import torch
@@ -46,7 +47,8 @@ def setup_model_parallel() -> Tuple[int, int]:
46
 
47
 
48
  def load(
49
- ckpt_dir: str,
 
50
  tokenizer_path: str,
51
  instruct_adapter_path: str,
52
  caption_adapter_path: str,
@@ -56,18 +58,19 @@ def load(
56
  max_batch_size: int,
57
  ) -> LLaMA:
58
  start_time = time.time()
59
- checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
60
- assert world_size == len(
61
- checkpoints
62
- ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
63
- ckpt_path = checkpoints[local_rank]
64
  print("Loading")
65
  checkpoint = torch.load(ckpt_path, map_location="cpu")
66
  instruct_adapter_checkpoint = torch.load(
67
  instruct_adapter_path, map_location="cpu")
68
  caption_adapter_checkpoint = torch.load(
69
  caption_adapter_path, map_location="cpu")
70
- with open(Path(ckpt_dir) / "params.json", "r") as f:
 
71
  params = json.loads(f.read())
72
 
73
  model_args: ModelArgs = ModelArgs(
@@ -149,9 +152,10 @@ def download_llama_7b(ckpt_dir, tokenizer_path):
149
  # if not os.path.exists(tokenizer_path):
150
  # os.system(
151
  # f"wget -O {tokenizer_path} https://huggingface.co/nyanko7/LLaMA-7B/resolve/main/tokenizer.model")
152
- if not os.path.exists(ckpt_path):
153
- os.system("git lfs install")
154
- os.system("git clone https://huggingface.co/nyanko7/LLaMA-7B")
 
155
  print("LLaMA-7B downloaded")
156
 
157
  def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
@@ -164,15 +168,18 @@ def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
164
 
165
  # ckpt_dir = "/data1/llma/7B"
166
  # tokenizer_path = "/data1/llma/tokenizer.model"
167
- ckpt_dir = "LLaMA-7B/"
168
- tokenizer_path = "LLaMA-7B/tokenizer.model"
 
 
 
169
  instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
170
  caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
171
  max_seq_len = 512
172
  max_batch_size = 1
173
 
174
  # download models
175
- download_llama_7b(ckpt_dir, tokenizer_path)
176
  download_llama_adapter(instruct_adapter_path, caption_adapter_path)
177
 
178
  local_rank, world_size = setup_model_parallel()
@@ -180,7 +187,7 @@ if local_rank > 0:
180
  sys.stdout = open(os.devnull, "w")
181
 
182
  generator = load(
183
- ckpt_dir, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
184
  )
185
 
186
 
 
6
  from pathlib import Path
7
  from typing import Tuple
8
 
9
+ from huggingface_hub import hf_hub_download
10
  from PIL import Image
11
  import gradio as gr
12
  import torch
 
47
 
48
 
49
  def load(
50
+ ckpt_path: str,
51
+ param_path: str,
52
  tokenizer_path: str,
53
  instruct_adapter_path: str,
54
  caption_adapter_path: str,
 
58
  max_batch_size: int,
59
  ) -> LLaMA:
60
  start_time = time.time()
61
+ # checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
62
+ # assert world_size == len(
63
+ # checkpoints
64
+ # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
65
+ # ckpt_path = checkpoints[local_rank]
66
  print("Loading")
67
  checkpoint = torch.load(ckpt_path, map_location="cpu")
68
  instruct_adapter_checkpoint = torch.load(
69
  instruct_adapter_path, map_location="cpu")
70
  caption_adapter_checkpoint = torch.load(
71
  caption_adapter_path, map_location="cpu")
72
+ # with open(Path(ckpt_dir) / "params.json", "r") as f:
73
+ with open(param_path, "r") as f:
74
  params = json.loads(f.read())
75
 
76
  model_args: ModelArgs = ModelArgs(
 
152
  # if not os.path.exists(tokenizer_path):
153
  # os.system(
154
  # f"wget -O {tokenizer_path} https://huggingface.co/nyanko7/LLaMA-7B/resolve/main/tokenizer.model")
155
+ # if not os.path.exists(ckpt_path):
156
+ # os.system("git lfs install")
157
+ # os.system("git clone https://huggingface.co/nyanko7/LLaMA-7B")
158
+
159
  print("LLaMA-7B downloaded")
160
 
161
  def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
 
168
 
169
  # ckpt_dir = "/data1/llma/7B"
170
  # tokenizer_path = "/data1/llma/tokenizer.model"
171
+ # ckpt_dir = "LLaMA-7B/"
172
+ # tokenizer_path = "LLaMA-7B/tokenizer.model"
173
+ ckpt_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="consolidated.00.pth")
174
+ param_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="params.json")
175
+ tokenizer_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
176
  instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
177
  caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
178
  max_seq_len = 512
179
  max_batch_size = 1
180
 
181
  # download models
182
+ # download_llama_7b(ckpt_dir, tokenizer_path)
183
  download_llama_adapter(instruct_adapter_path, caption_adapter_path)
184
 
185
  local_rank, world_size = setup_model_parallel()
 
187
  sys.stdout = open(os.devnull, "w")
188
 
189
  generator = load(
190
+ ckpt_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
191
  )
192
 
193
 
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- torch==1.12.0
2
  fairscale
3
  sentencepiece
4
  Pillow
 
5
  git+https://github.com/csuhan/timm_0_3_2.git
6
  git+https://github.com/openai/CLIP.git
 
1
+ torch
2
  fairscale
3
  sentencepiece
4
  Pillow
5
+ huggingface_hub
6
  git+https://github.com/csuhan/timm_0_3_2.git
7
  git+https://github.com/openai/CLIP.git