embedding / scripts /hf_model_download.py
nam pham
feat: first update
4e5c5cb
import os
import argparse
from huggingface_hub import login, snapshot_download
from typing import Literal
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
DATA_PATH = os.getenv('DATA_PATH')
HUGGINGFACE_ACCESS_TOKEN = os.getenv('HF_TOKEN')
login(token=HUGGINGFACE_ACCESS_TOKEN)
class ModelConfig(BaseModel):
model_id : str
mode: Literal['snapshot', 'model'] = 'model'
class Config:
protected_namespaces = ()
def download(config: ModelConfig):
try:
if config.mode == 'snapshot':
snapshot_download(
config.model_id,
revision='main',
ignore_patterns=['*.git*', '*README.md'],
local_dir=os.path.join(DATA_PATH, config.model_id)
)
else:
model = SentenceTransformer(
config.model_id,
trust_remote_code=True,
)
model.save(os.path.join(DATA_PATH, config.model_id))
except Exception as e:
raise e
def run():
parser = argparse.ArgumentParser()
parser.add_argument(
'-i', '--input',
help='model id to download',
required=True,
)
parser.add_argument(
'-m', '--mode',
help='mode to download',
default='model',
)
args = parser.parse_args()
config = ModelConfig(
model_id=args.input,
mode=args.mode
)
download(config)
if __name__ == '__main__':
run()