File size: 1,497 Bytes
4e5c5cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import argparse
from huggingface_hub import login, snapshot_download
from typing import Literal
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel

DATA_PATH = os.getenv('DATA_PATH')
HUGGINGFACE_ACCESS_TOKEN = os.getenv('HF_TOKEN')

login(token=HUGGINGFACE_ACCESS_TOKEN)

class ModelConfig(BaseModel):
    model_id : str
    mode: Literal['snapshot', 'model'] = 'model'

    class Config:
        protected_namespaces = ()

def download(config: ModelConfig):
    try:

        if config.mode == 'snapshot':
            snapshot_download(
                config.model_id,
                revision='main',
                ignore_patterns=['*.git*', '*README.md'],
                local_dir=os.path.join(DATA_PATH, config.model_id)
            )
        else:
            model = SentenceTransformer(
                config.model_id,
                trust_remote_code=True,
            )
            model.save(os.path.join(DATA_PATH, config.model_id))
    except Exception as e:
        raise e
    
def run():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-i', '--input',
        help='model id to download',
        required=True,
    )
    parser.add_argument(
        '-m', '--mode',
        help='mode to download',
        default='model',
    )

    args = parser.parse_args()
    config = ModelConfig(
        model_id=args.input,
        mode=args.mode
    )


    download(config)

if __name__ == '__main__':
    run()