import os import argparse from huggingface_hub import login, snapshot_download from typing import Literal from sentence_transformers import SentenceTransformer from pydantic import BaseModel DATA_PATH = os.getenv('DATA_PATH') HUGGINGFACE_ACCESS_TOKEN = os.getenv('HF_TOKEN') login(token=HUGGINGFACE_ACCESS_TOKEN) class ModelConfig(BaseModel): model_id : str mode: Literal['snapshot', 'model'] = 'model' class Config: protected_namespaces = () def download(config: ModelConfig): try: if config.mode == 'snapshot': snapshot_download( config.model_id, revision='main', ignore_patterns=['*.git*', '*README.md'], local_dir=os.path.join(DATA_PATH, config.model_id) ) else: model = SentenceTransformer( config.model_id, trust_remote_code=True, ) model.save(os.path.join(DATA_PATH, config.model_id)) except Exception as e: raise e def run(): parser = argparse.ArgumentParser() parser.add_argument( '-i', '--input', help='model id to download', required=True, ) parser.add_argument( '-m', '--mode', help='mode to download', default='model', ) args = parser.parse_args() config = ModelConfig( model_id=args.input, mode=args.mode ) download(config) if __name__ == '__main__': run()