先从本地加载
Browse files- handler.py +8 -1
handler.py
CHANGED
@@ -11,7 +11,14 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, path=""):
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, flash_atten=True)
|
16 |
self.model = AutoModelForCausalLM.from_pretrained(
|
17 |
self.model_name_or_path, torch_dtype=torch.bfloat16,
|
|
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, path=""):
|
14 |
+
local_config_path = "./config.json"
|
15 |
+
remote_model_name = "threadshare/Peach-9B-8k-Roleplay"
|
16 |
+
|
17 |
+
# Check if local config file exists
|
18 |
+
if os.path.exists(local_config_path):
|
19 |
+
self.model_name_or_path = "."
|
20 |
+
else:
|
21 |
+
self.model_name_or_path = remote_model_name
|
22 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True, flash_atten=True)
|
23 |
self.model = AutoModelForCausalLM.from_pretrained(
|
24 |
self.model_name_or_path, torch_dtype=torch.bfloat16,
|