Spaces:
Paused
Paused
File size: 4,745 Bytes
5bb6ad4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# import gdown
# import os
# import argparse
# def download_model(model_id, folder, filename):
# os.makedirs(folder, exist_ok=True)
# url = f"https://drive.google.com/uc?id={model_id}"
# output_path = os.path.join(folder, filename)
# print(f"Downloading model to {output_path}...")
# gdown.download(url, output_path, quiet=False)
# print("Download complete!")
# def main():
# parser = argparse.ArgumentParser(description="Download models using gdown and organize them into appropriate folders.")
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
# args = parser.parse_args()
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
# if args.pretrained:
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt")
# if args.sft:
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt")
# if args.dpo:
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt")
# if __name__ == "__main__":
# main()
# import os
# import argparse
# def download_model(model_id, folder, filename, access_token):
# os.makedirs(folder, exist_ok=True)
# output_path = os.path.join(folder, filename)
# url = f"https://www.googleapis.com/drive/v3/files/{model_id}?alt=media"
# command = f"curl -H \"Authorization: Bearer {access_token}\" {url} -o {output_path}"
# print(f"Downloading model to {output_path}...")
# os.system(command)
# print("Download complete!")
# def main():
# parser = argparse.ArgumentParser(description="Download models using Google Drive API and organize them into appropriate folders.")
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
# parser.add_argument("--token", type=str, required=True, help="Google Drive API Access Token")
# args = parser.parse_args()
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
# if args.pretrained:
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt", args.token)
# if args.sft:
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt", args.token)
# if args.dpo:
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt", args.token)
# if __name__ == "__main__":
# main()
# download_model_weight.py
import os
import argparse
from huggingface_hub import hf_hub_download, login
def download_model(repo_id, filename, cache_dir):
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir,
resume_download=True,
force_download=False,
token=os.getenv("HF_TOKEN")
)
if os.path.exists(model_path) and os.path.getsize(model_path) > 1024*1024:
return model_path
raise ValueError("Downloaded file is too small or invalid")
except Exception as e:
print(f"Download failed: {str(e)}")
raise
def main():
parser = argparse.ArgumentParser(description="Download models from Hugging Face Hub")
parser.add_argument("--model_type",
choices=["pretrained"],
required=True,
help="Type of model to download")
args = parser.parse_args()
model_config = {
"pretrained": {
"repo_id": "YuvrajSingh9886/StoryLlama",
"filename": "snapshot_4650.pt",
"cache_dir": "weights/pretrained"
}
}
config = model_config[args.model_type]
os.makedirs(config["cache_dir"], exist_ok=True)
print(f"Downloading {args.model_type} model...")
model_path = download_model(
config["repo_id"],
config["filename"],
config["cache_dir"]
)
print(f"Successfully downloaded to: {model_path}")
if __name__ == "__main__":
login(token=os.getenv("HF_TOKEN"))
main()
|