|
|
|
|
|
|
|
import argparse |
|
import os |
|
import tempfile |
|
from datetime import datetime |
|
from typing import List |
|
|
|
import torch |
|
import transformers |
|
|
|
from llmfoundry import MPTConfig, MPTForCausalLM |
|
from llmfoundry.utils.huggingface_hub_utils import \ |
|
edit_files_for_hf_compatibility |
|
|
|
|
|
def main(hf_repos_for_upload: List[str]): |
|
current_datetime = datetime.now() |
|
formatted_datetime = current_datetime.strftime('%B %d, %Y %H:%M:%S') |
|
|
|
from huggingface_hub import HfApi |
|
api = HfApi() |
|
|
|
|
|
MPTConfig.register_for_auto_class() |
|
|
|
|
|
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') |
|
|
|
config = MPTConfig() |
|
config.attn_config['attn_impl'] = 'torch' |
|
loaded_hf_model = MPTForCausalLM(config) |
|
with tempfile.TemporaryDirectory() as _tmp_dir: |
|
original_save_dir = os.path.join(_tmp_dir, 'model_current') |
|
loaded_hf_model.save_pretrained(original_save_dir) |
|
|
|
edit_files_for_hf_compatibility(original_save_dir) |
|
|
|
for repo in hf_repos_for_upload: |
|
print(f'Testing code changes for {repo}') |
|
pr_model = transformers.AutoModelForCausalLM.from_pretrained( |
|
original_save_dir, trust_remote_code=True, device_map='auto') |
|
pr_tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
repo, trust_remote_code=True) |
|
|
|
generation = pr_model.generate(pr_tokenizer( |
|
'MosaicML is', return_tensors='pt').input_ids.to( |
|
'cuda' if torch.cuda.is_available() else 'cpu'), |
|
max_new_tokens=2) |
|
_ = pr_tokenizer.batch_decode(generation) |
|
|
|
print(f'Opening PR against {repo}') |
|
result = api.upload_folder( |
|
folder_path=original_save_dir, |
|
repo_id=repo, |
|
use_auth_token=True, |
|
repo_type='model', |
|
allow_patterns=['*.py'], |
|
commit_message=f'LLM-foundry update {formatted_datetime}', |
|
create_pr=True, |
|
) |
|
|
|
print(f'PR opened: {result}') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser( |
|
description= |
|
'Update MPT code in HuggingFace Hub repos to be in sync with the local codebase' |
|
) |
|
parser.add_argument('--hf_repos_for_upload', |
|
help='List of repos to open PRs against', |
|
nargs='+', |
|
required=True) |
|
|
|
args = parser.parse_args() |
|
|
|
main(hf_repos_for_upload=args.hf_repos_for_upload,) |
|
|