Spaces:
Sleeping
Sleeping
File size: 1,497 Bytes
4e5c5cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import argparse
from huggingface_hub import login, snapshot_download
from typing import Literal
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
DATA_PATH = os.getenv('DATA_PATH')
HUGGINGFACE_ACCESS_TOKEN = os.getenv('HF_TOKEN')
login(token=HUGGINGFACE_ACCESS_TOKEN)
class ModelConfig(BaseModel):
model_id : str
mode: Literal['snapshot', 'model'] = 'model'
class Config:
protected_namespaces = ()
def download(config: ModelConfig):
try:
if config.mode == 'snapshot':
snapshot_download(
config.model_id,
revision='main',
ignore_patterns=['*.git*', '*README.md'],
local_dir=os.path.join(DATA_PATH, config.model_id)
)
else:
model = SentenceTransformer(
config.model_id,
trust_remote_code=True,
)
model.save(os.path.join(DATA_PATH, config.model_id))
except Exception as e:
raise e
def run():
parser = argparse.ArgumentParser()
parser.add_argument(
'-i', '--input',
help='model id to download',
required=True,
)
parser.add_argument(
'-m', '--mode',
help='mode to download',
default='model',
)
args = parser.parse_args()
config = ModelConfig(
model_id=args.input,
mode=args.mode
)
download(config)
if __name__ == '__main__':
run() |