Tuchuanhuhuhu commited on
Commit
432eb42
·
1 Parent(s): b5ddb7e

feat: 支持添加训练好的模型到配置文件里

Browse files
Files changed (3) hide show
  1. ChuanhuChatbot.py +1 -1
  2. modules/config.py +4 -0
  3. modules/train_func.py +13 -1
ChuanhuChatbot.py CHANGED
@@ -15,8 +15,8 @@ from modules.presets import *
15
  from modules.overwrites import *
16
  from modules.webui import *
17
  from modules.repo import *
 
18
  from modules.models.models import get_model
19
- from modules.train_func import handle_dataset_selection, handle_dataset_clear, upload_to_openai, start_training, get_training_status, add_to_models, cancel_all_jobs
20
 
21
  logging.getLogger("httpx").setLevel(logging.WARNING)
22
 
 
15
  from modules.overwrites import *
16
  from modules.webui import *
17
  from modules.repo import *
18
+ from modules.train_func import *
19
  from modules.models.models import get_model
 
20
 
21
  logging.getLogger("httpx").setLevel(logging.WARNING)
22
 
modules/config.py CHANGED
@@ -96,6 +96,10 @@ else:
96
  sensitive_id = config.get("sensitive_id", "")
97
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
98
 
 
 
 
 
99
 
100
  google_palm_api_key = config.get("google_palm_api_key", "")
101
  google_palm_api_key = os.environ.get(
 
96
  sensitive_id = config.get("sensitive_id", "")
97
  sensitive_id = os.environ.get("SENSITIVE_ID", sensitive_id)
98
 
99
+ # 模型配置
100
+ if "extra_models" in config:
101
+ presets.MODELS.extend(config["extra_models"])
102
+ logging.info(f"已添加额外的模型:{config['extra_models']}")
103
 
104
  google_palm_api_key = config.get("google_palm_api_key", "")
105
  google_palm_api_key = os.environ.get(
modules/train_func.py CHANGED
@@ -5,6 +5,7 @@ import traceback
5
  import openai
6
  import gradio as gr
7
  import ujson as json
 
8
 
9
  import modules.presets as presets
10
  from modules.utils import get_file_hash
@@ -112,7 +113,18 @@ def handle_dataset_clear():
112
  def add_to_models():
113
  openai.api_key = os.getenv("OPENAI_API_KEY")
114
  succeeded_jobs = [job for job in openai.FineTuningJob.list()["data"] if job["status"] == "succeeded"]
115
- presets.MODELS.extend([job["fine_tuned_model"] for job in succeeded_jobs])
 
 
 
 
 
 
 
 
 
 
 
116
  return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
117
 
118
  def cancel_all_jobs():
 
5
  import openai
6
  import gradio as gr
7
  import ujson as json
8
+ import commentjson
9
 
10
  import modules.presets as presets
11
  from modules.utils import get_file_hash
 
113
  def add_to_models():
114
  openai.api_key = os.getenv("OPENAI_API_KEY")
115
  succeeded_jobs = [job for job in openai.FineTuningJob.list()["data"] if job["status"] == "succeeded"]
116
+ extra_models = [job["fine_tuned_model"] for job in succeeded_jobs]
117
+ presets.MODELS.extend(extra_models)
118
+
119
+ with open('config.json', 'r') as f:
120
+ data = commentjson.load(f)
121
+ if 'extra_models' in data:
122
+ data['extra_models'].extend(extra_models)
123
+ else:
124
+ data['extra_models'] = extra_models
125
+ with open('config.json', 'w') as f:
126
+ commentjson.dump(data, f, indent=4)
127
+
128
  return gr.update(choices=presets.MODELS), f"成功添加了 {len(succeeded_jobs)} 个模型。"
129
 
130
  def cancel_all_jobs():