threadshare commited on
Commit
0c75406
·
verified ·
1 Parent(s): 1cc6189

先从本地加载

Browse files
Files changed (1) hide show
  1. handler.py +8 -1
handler.py CHANGED
@@ -11,7 +11,14 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
14
- self.model_name_or_path = "threadshare/Peach-9B-8k-Roleplay"
 
 
 
 
 
 
 
15
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, flash_atten=True)
16
  self.model = AutoModelForCausalLM.from_pretrained(
17
  self.model_name_or_path, torch_dtype=torch.bfloat16,
 
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
14
+ local_config_path = "./config.json"
15
+ remote_model_name = "threadshare/Peach-9B-8k-Roleplay"
16
+
17
+ # Check if local config file exists
18
+ if os.path.exists(local_config_path):
19
+ self.model_name_or_path = "."
20
+ else:
21
+ self.model_name_or_path = remote_model_name
22
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, flash_atten=True)
23
  self.model = AutoModelForCausalLM.from_pretrained(
24
  self.model_name_or_path, torch_dtype=torch.bfloat16,