Hjgugugjhuhjggg commited on
Commit
f2e20dd
·
verified ·
1 Parent(s): 99862b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -16,11 +16,12 @@ from transformers import (
16
  StoppingCriteriaList
17
  )
18
  import boto3
19
- from huggingface_hub import hf_hub_download
20
  import soundfile as sf
21
  import numpy as np
22
  import torch
23
  import uvicorn
 
24
 
25
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
26
 
@@ -64,6 +65,7 @@ class S3ModelLoader:
64
  def __init__(self, bucket_name, s3_client):
65
  self.bucket_name = bucket_name
66
  self.s3_client = s3_client
 
67
 
68
  def _get_s3_uri(self, model_name):
69
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
@@ -72,9 +74,9 @@ class S3ModelLoader:
72
  s3_uri = self._get_s3_uri(model_name)
73
  try:
74
  logging.info(f"Trying to load {model_name} from S3...")
75
- config = AutoConfig.from_pretrained(s3_uri)
76
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
77
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config)
78
 
79
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
@@ -84,9 +86,18 @@ class S3ModelLoader:
84
  except EnvironmentError:
85
  logging.info(f"Model {model_name} not found in S3. Downloading...")
86
  try:
87
- config = AutoConfig.from_pretrained(model_name)
88
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
89
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
 
 
 
 
 
 
 
 
 
90
 
91
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
92
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
@@ -96,6 +107,9 @@ class S3ModelLoader:
96
  model.save_pretrained(s3_uri)
97
  tokenizer.save_pretrained(s3_uri)
98
  logging.info(f"Saved {model_name} to S3 successfully.")
 
 
 
99
  return model, tokenizer
100
  except Exception as e:
101
  logging.exception(f"Error downloading/uploading model: {e}")
@@ -122,7 +136,7 @@ async def generate(request: Request, body: GenerateRequest):
122
  top_k=validated_body.top_k,
123
  repetition_penalty=validated_body.repetition_penalty,
124
  do_sample=validated_body.do_sample,
125
- num_return_sequences=validated_body.num_return_sequences
126
  )
127
 
128
  async def stream_text():
@@ -139,7 +153,6 @@ async def generate(request: Request, body: GenerateRequest):
139
  break
140
 
141
  generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
142
-
143
  stopping_criteria = StoppingCriteriaList(
144
  [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
145
  )
 
16
  StoppingCriteriaList
17
  )
18
  import boto3
19
+ from huggingface_hub import hf_hub_download, HfApi
20
  import soundfile as sf
21
  import numpy as np
22
  import torch
23
  import uvicorn
24
+ import shutil
25
 
26
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
27
 
 
65
  def __init__(self, bucket_name, s3_client):
66
  self.bucket_name = bucket_name
67
  self.s3_client = s3_client
68
+ self.api = HfApi()
69
 
70
  def _get_s3_uri(self, model_name):
71
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
 
74
  s3_uri = self._get_s3_uri(model_name)
75
  try:
76
  logging.info(f"Trying to load {model_name} from S3...")
77
+ config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
78
+ model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False)
79
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=False)
80
 
81
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
82
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
86
  except EnvironmentError:
87
  logging.info(f"Model {model_name} not found in S3. Downloading...")
88
  try:
89
+ model_info = self.api.model_info(model_name)
90
+ files_to_download = [f.rfilename for f in self.api.list_repo_files(model_name)]
91
+
92
+ temp_dir = "temp_model"
93
+ os.makedirs(temp_dir, exist_ok = True)
94
+
95
+ for file_name in files_to_download:
96
+ hf_hub_download(repo_id=model_name, filename=file_name, local_dir=temp_dir, token=HUGGINGFACE_HUB_TOKEN)
97
+
98
+ config = AutoConfig.from_pretrained(temp_dir)
99
+ tokenizer = AutoTokenizer.from_pretrained(temp_dir, config=config)
100
+ model = AutoModelForCausalLM.from_pretrained(temp_dir, config=config)
101
 
102
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
103
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
107
  model.save_pretrained(s3_uri)
108
  tokenizer.save_pretrained(s3_uri)
109
  logging.info(f"Saved {model_name} to S3 successfully.")
110
+
111
+ shutil.rmtree(temp_dir)
112
+
113
  return model, tokenizer
114
  except Exception as e:
115
  logging.exception(f"Error downloading/uploading model: {e}")
 
136
  top_k=validated_body.top_k,
137
  repetition_penalty=validated_body.repetition_penalty,
138
  do_sample=validated_body.do_sample,
139
+ num_return_sequences=validated_body.num_return_sequences,
140
  )
141
 
142
  async def stream_text():
 
153
  break
154
 
155
  generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
 
156
  stopping_criteria = StoppingCriteriaList(
157
  [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
158
  )