StoryLlama / download_model_weight.py
YuvrajSingh9886's picture
Upload 12 files
5bb6ad4 verified
# import gdown
# import os
# import argparse
# def download_model(model_id, folder, filename):
# os.makedirs(folder, exist_ok=True)
# url = f"https://drive.google.com/uc?id={model_id}"
# output_path = os.path.join(folder, filename)
# print(f"Downloading model to {output_path}...")
# gdown.download(url, output_path, quiet=False)
# print("Download complete!")
# def main():
# parser = argparse.ArgumentParser(description="Download models using gdown and organize them into appropriate folders.")
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
# args = parser.parse_args()
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
# if args.pretrained:
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt")
# if args.sft:
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt")
# if args.dpo:
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt")
# if __name__ == "__main__":
# main()
# import os
# import argparse
# def download_model(model_id, folder, filename, access_token):
# os.makedirs(folder, exist_ok=True)
# output_path = os.path.join(folder, filename)
# url = f"https://www.googleapis.com/drive/v3/files/{model_id}?alt=media"
# command = f"curl -H \"Authorization: Bearer {access_token}\" {url} -o {output_path}"
# print(f"Downloading model to {output_path}...")
# os.system(command)
# print("Download complete!")
# def main():
# parser = argparse.ArgumentParser(description="Download models using Google Drive API and organize them into appropriate folders.")
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model")
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model")
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model")
# parser.add_argument("--token", type=str, required=True, help="Google Drive API Access Token")
# args = parser.parse_args()
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98"
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf"
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh"
# if args.pretrained:
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt", args.token)
# if args.sft:
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt", args.token)
# if args.dpo:
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt", args.token)
# if __name__ == "__main__":
# main()
# download_model_weight.py
import os
import argparse
from huggingface_hub import hf_hub_download, login
def download_model(repo_id, filename, cache_dir):
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir,
resume_download=True,
force_download=False,
token=os.getenv("HF_TOKEN")
)
if os.path.exists(model_path) and os.path.getsize(model_path) > 1024*1024:
return model_path
raise ValueError("Downloaded file is too small or invalid")
except Exception as e:
print(f"Download failed: {str(e)}")
raise
def main():
parser = argparse.ArgumentParser(description="Download models from Hugging Face Hub")
parser.add_argument("--model_type",
choices=["pretrained"],
required=True,
help="Type of model to download")
args = parser.parse_args()
model_config = {
"pretrained": {
"repo_id": "YuvrajSingh9886/StoryLlama",
"filename": "snapshot_4650.pt",
"cache_dir": "weights/pretrained"
}
}
config = model_config[args.model_type]
os.makedirs(config["cache_dir"], exist_ok=True)
print(f"Downloading {args.model_type} model...")
model_path = download_model(
config["repo_id"],
config["filename"],
config["cache_dir"]
)
print(f"Successfully downloaded to: {model_path}")
if __name__ == "__main__":
login(token=os.getenv("HF_TOKEN"))
main()