File size: 2,673 Bytes
de4ade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import tempfile
from datetime import datetime
from typing import List
import torch
import transformers
from llmfoundry import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility
def main(hf_repos_for_upload: List[str]):
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime('%B %d, %Y %H:%M:%S')
from huggingface_hub import HfApi
api = HfApi()
# register config auto class
MPTConfig.register_for_auto_class()
# register model auto class
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
config = MPTConfig()
config.attn_config['attn_impl'] = 'torch'
loaded_hf_model = MPTForCausalLM(config)
with tempfile.TemporaryDirectory() as _tmp_dir:
original_save_dir = os.path.join(_tmp_dir, 'model_current')
loaded_hf_model.save_pretrained(original_save_dir)
edit_files_for_hf_compatibility(original_save_dir)
for repo in hf_repos_for_upload:
print(f'Testing code changes for {repo}')
pr_model = transformers.AutoModelForCausalLM.from_pretrained(
original_save_dir, trust_remote_code=True, device_map='auto')
pr_tokenizer = transformers.AutoTokenizer.from_pretrained(
repo, trust_remote_code=True)
generation = pr_model.generate(pr_tokenizer(
'MosaicML is', return_tensors='pt').input_ids.to(
'cuda' if torch.cuda.is_available() else 'cpu'),
max_new_tokens=2)
_ = pr_tokenizer.batch_decode(generation)
print(f'Opening PR against {repo}')
result = api.upload_folder(
folder_path=original_save_dir,
repo_id=repo,
use_auth_token=True,
repo_type='model',
allow_patterns=['*.py'],
commit_message=f'LLM-foundry update {formatted_datetime}',
create_pr=True,
)
print(f'PR opened: {result}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=
'Update MPT code in HuggingFace Hub repos to be in sync with the local codebase'
)
parser.add_argument('--hf_repos_for_upload',
help='List of repos to open PRs against',
nargs='+',
required=True)
args = parser.parse_args()
main(hf_repos_for_upload=args.hf_repos_for_upload,)
|