Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
432eb42
1
Parent(s):
b5ddb7e
feat: 支持添加训练好的模型到配置文件里
Browse files- ChuanhuChatbot.py +1 -1
- modules/config.py +4 -0
- modules/train_func.py +13 -1
ChuanhuChatbot.py
CHANGED
@@ -15,8 +15,8 @@ from modules.presets import *
|
|
15 |
from modules.overwrites import *
|
16 |
from modules.webui import *
|
17 |
from modules.repo import *
|
|
|
18 |
from modules.models.models import get_model
|
19 |
-
from modules.train_func import handle_dataset_selection, handle_dataset_clear, upload_to_openai, start_training, get_training_status, add_to_models, cancel_all_jobs
|
20 |
|
21 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
22 |
|
|
|
15 |
from modules.overwrites import *
|
16 |
from modules.webui import *
|
17 |
from modules.repo import *
|
18 |
+
from modules.train_func import *
|
19 |
from modules.models.models import get_model
|
|
|
20 |
|
21 |
logging.getLogger("httpx").setLevel(logging.WARNING)
|
22 |
|
modules/config.py
CHANGED
@@ -96,6 +96,10 @@ else:
|
|
96 |
sensitive_id = config.get("sensitive_id", "")
|
97 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
98 |
|
|
|
|
|
|
|
|
|
99 |
|
100 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
101 |
google_palm_api_key = os.environ.get(
|
|
|
96 |
sensitive_id = config.get("sensitive_id", "")
|
97 |
sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
|
98 |
|
99 |
+
# 模型配置
|
100 |
+
if "extra_models" in config:
|
101 |
+
presets.MODELS.extend(config["extra_models"])
|
102 |
+
logging.info(f"已添加额外的模型:{config['extra_models']}")
|
103 |
|
104 |
google_palm_api_key = config.get("google_palm_api_key", "")
|
105 |
google_palm_api_key = os.environ.get(
|
modules/train_func.py
CHANGED
@@ -5,6 +5,7 @@ import traceback
|
|
5 |
import openai
|
6 |
import gradio as gr
|
7 |
import ujson as json
|
|
|
8 |
|
9 |
import modules.presets as presets
|
10 |
from modules.utils import get_file_hash
|
@@ -112,7 +113,18 @@ def handle_dataset_clear():
|
|
112 |
def add_to_models():
|
113 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
114 |
succeeded_jobs = [job for job in openai.FineTuningJob.list()["data"] if job["status"] == "succeeded"]
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
|
117 |
|
118 |
def cancel_all_jobs():
|
|
|
5 |
import openai
|
6 |
import gradio as gr
|
7 |
import ujson as json
|
8 |
+
import commentjson
|
9 |
|
10 |
import modules.presets as presets
|
11 |
from modules.utils import get_file_hash
|
|
|
113 |
def add_to_models():
|
114 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
115 |
succeeded_jobs = [job for job in openai.FineTuningJob.list()["data"] if job["status"] == "succeeded"]
|
116 |
+
extra_models = [job["fine_tuned_model"] for job in succeeded_jobs]
|
117 |
+
presets.MODELS.extend(extra_models)
|
118 |
+
|
119 |
+
with open('config.json', 'r') as f:
|
120 |
+
data = commentjson.load(f)
|
121 |
+
if 'extra_models' in data:
|
122 |
+
data['extra_models'].extend(extra_models)
|
123 |
+
else:
|
124 |
+
data['extra_models'] = extra_models
|
125 |
+
with open('config.json', 'w') as f:
|
126 |
+
commentjson.dump(data, f, indent=4)
|
127 |
+
|
128 |
return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
|
129 |
|
130 |
def cancel_all_jobs():
|