Update app.py
Browse files
app.py
CHANGED
@@ -74,9 +74,9 @@ class S3ModelLoader:
|
|
74 |
s3_uri = self._get_s3_uri(model_name)
|
75 |
try:
|
76 |
logging.info(f"Trying to load {model_name} from S3...")
|
77 |
-
config = AutoConfig.from_pretrained(s3_uri, local_files_only=
|
78 |
-
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=
|
79 |
-
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=
|
80 |
|
81 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
82 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
@@ -87,10 +87,10 @@ class S3ModelLoader:
|
|
87 |
logging.info(f"Model {model_name} not found in S3. Downloading...")
|
88 |
try:
|
89 |
model_info = self.api.model_info(model_name)
|
90 |
-
files_to_download = [f.
|
91 |
|
92 |
temp_dir = "temp_model"
|
93 |
-
os.makedirs(temp_dir, exist_ok
|
94 |
|
95 |
for file_name in files_to_download:
|
96 |
hf_hub_download(repo_id=model_name, filename=file_name, local_dir=temp_dir, token=HUGGINGFACE_HUB_TOKEN)
|
|
|
74 |
s3_uri = self._get_s3_uri(model_name)
|
75 |
try:
|
76 |
logging.info(f"Trying to load {model_name} from S3...")
|
77 |
+
config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
|
78 |
+
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
|
80 |
|
81 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
82 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
|
|
87 |
logging.info(f"Model {model_name} not found in S3. Downloading...")
|
88 |
try:
|
89 |
model_info = self.api.model_info(model_name)
|
90 |
+
files_to_download = [f.filename for f in self.api.list_repo_files(model_name)]
|
91 |
|
92 |
temp_dir = "temp_model"
|
93 |
+
os.makedirs(temp_dir, exist_ok=True)
|
94 |
|
95 |
for file_name in files_to_download:
|
96 |
hf_hub_download(repo_id=model_name, filename=file_name, local_dir=temp_dir, token=HUGGINGFACE_HUB_TOKEN)
|