sci-m-wang commited on
Commit
7d9d04e
·
verified ·
1 Parent(s): e07b128

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +45 -50
src/streamlit_app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  from openai import OpenAI
3
  import time
4
  import json # 导入json库
 
5
 
6
  # =======================================================================
7
  # 1. 导入您的类和数据
@@ -11,7 +12,7 @@ import json # 导入json库
11
  from ms_patient import MsPatient
12
 
13
  # 从本地JSON文件加载数据集的函数
14
- def load_data_from_json(filepath="/app/src/CPsyCounS-3134.json"):
15
  """
16
  从本地的JSON文件加载数据集。
17
  请确保您已将数据集文件上传到与此应用相同的目录中。
@@ -31,7 +32,6 @@ def load_data_from_json(filepath="/app/src/CPsyCounS-3134.json"):
31
  return []
32
 
33
  # 加载数据
34
- # 注意:现在每次脚本重新运行时都会从本地JSON文件加载
35
  ALL_PATIENTS = load_data_from_json()
36
 
37
 
@@ -79,49 +79,40 @@ st.markdown("""
79
  """, unsafe_allow_html=True)
80
 
81
 
82
- # --- 初始化 Session State ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if "patient_agent" not in st.session_state:
84
  st.session_state.patient_agent = None
85
  if "messages" not in st.session_state:
86
  st.session_state.messages = []
87
  if "selected_patient_id" not in st.session_state:
88
  st.session_state.selected_patient_id = None
89
- if "openai_client" not in st.session_state:
90
- st.session_state.openai_client = None
91
- if "model_name" not in st.session_state:
92
- st.session_state.model_name = "gpt-4o-mini" # 默认模型
93
 
94
  # --- 侧边栏 ---
95
  with st.sidebar:
96
  st.title("👩 AnnaAgent 设置")
97
  st.markdown("---")
98
 
99
- # API Key 输入
100
- with st.expander("🔑 API 设置", expanded=True):
101
- api_key = st.text_input("输入您的 OpenAI API Key", type="password", help="您的API Key将仅用于本次会话,不会被储存。")
102
- base_url = st.text_input("API Base URL (可选)", value="https://api.openai.com/v1")
103
- model_name = st.text_input("模型名称", value=st.session_state.model_name)
104
-
105
- if st.button("连接模型"):
106
- if api_key:
107
- try:
108
- st.session_state.openai_client = OpenAI(api_key=api_key, base_url=base_url)
109
- st.session_state.model_name = model_name
110
- st.success("连接成功!")
111
- if st.session_state.patient_agent:
112
- st.session_state.patient_agent.client = st.session_state.openai_client
113
- except Exception as e:
114
- st.error(f"连接失败: {e}")
115
- else:
116
- st.warning("请输入API Key。")
117
-
118
- st.markdown("---")
119
-
120
  # 病人选择
121
  if not ALL_PATIENTS:
122
  st.error("无法加载病人数据。请检查JSON文件是否已上传且格式正确。")
123
  else:
124
- patient_options = {p["id"]: f"{p['portrait']['gender']},{p['portrait']['age']}岁 - {p['portrait']['symptoms']}" for p in ALL_PATIENTS}
125
  selected_id = st.selectbox(
126
  "选择一位病人进行对话",
127
  options=list(patient_options.keys()),
@@ -162,24 +153,28 @@ with st.sidebar:
162
  st.title("💬 与 Anna 对话")
163
  st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。")
164
 
165
- # 显示聊天记录
166
- for message in st.session_state.messages:
167
- avatar = "👩" if message["role"] == "assistant" else "🧑‍⚕️"
168
- with st.chat_message(message["role"], avatar=avatar):
169
- st.markdown(message["content"])
170
-
171
- # 获取用户输入
172
- if prompt := st.chat_input("请输入您想说的话..."):
173
- st.session_state.messages.append({"role": "user", "content": prompt})
174
- with st.chat_message("user", avatar="🧑‍⚕️"):
175
- st.markdown(prompt)
176
-
177
- if st.session_state.patient_agent:
178
- with st.chat_message("assistant", avatar="👩"):
179
- with st.spinner("Anna正在思考..."):
180
- response = st.session_state.patient_agent.chat(prompt)
181
- st.markdown(response)
182
-
183
- st.session_state.messages.append({"role": "assistant", "content": response})
184
- else:
185
- st.warning("请先在左侧选择一位病人并配置API Key。")
 
 
 
 
 
2
  from openai import OpenAI
3
  import time
4
  import json # 导入json库
5
+ import os # 导入os库用于读取环境变量
6
 
7
  # =======================================================================
8
  # 1. 导入您的类和数据
 
12
  from ms_patient import MsPatient
13
 
14
  # 从本地JSON文件加载数据集的函数
15
+ def load_data_from_json(filepath="Anna-CPsyCounD.json"):
16
  """
17
  从本地的JSON文件加载数据集。
18
  请确保您已将数据集文件上传到与此应用相同的目录中。
 
32
  return []
33
 
34
  # 加载数据
 
35
  ALL_PATIENTS = load_data_from_json()
36
 
37
 
 
79
  """, unsafe_allow_html=True)
80
 
81
 
82
+ # --- 初始化 Session State 和 OpenAI Client ---
83
+ # 仅在会话状态中不存在时,才从环境变量初始化客户端
84
+ if "openai_client" not in st.session_state:
85
+ api_key = os.getenv("OPENAI_API_KEY")
86
+ if not api_key:
87
+ st.session_state.openai_client = None
88
+ st.session_state.model_name = None
89
+ else:
90
+ try:
91
+ st.session_state.openai_client = OpenAI(api_key=api_key, base_url=os.getenv("OPENAI_BASE_URL"))
92
+ st.session_state.model_name = os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini")
93
+ except Exception as e:
94
+ st.error(f"初始化OpenAI客户端失败: {e}")
95
+ st.session_state.openai_client = None
96
+ st.session_state.model_name = None
97
+
98
  if "patient_agent" not in st.session_state:
99
  st.session_state.patient_agent = None
100
  if "messages" not in st.session_state:
101
  st.session_state.messages = []
102
  if "selected_patient_id" not in st.session_state:
103
  st.session_state.selected_patient_id = None
104
+
 
 
 
105
 
106
  # --- 侧边栏 ---
107
  with st.sidebar:
108
  st.title("👩 AnnaAgent 设置")
109
  st.markdown("---")
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # 病人选择
112
  if not ALL_PATIENTS:
113
  st.error("无法加载病人数据。请检查JSON文件是否已上传且格式正确。")
114
  else:
115
+ patient_options = {p["id"]: f"{p['portrait']['gender']},{p['portrait']['age']}岁 - {p['portrait']['symptom']}" for p in ALL_PATIENTS}
116
  selected_id = st.selectbox(
117
  "选择一位病人进行对话",
118
  options=list(patient_options.keys()),
 
153
  st.title("💬 与 Anna 对话")
154
  st.caption("这是一个模拟心理咨询的AI Agent。由 `MsPatient` 类驱动。")
155
 
156
+ # 检查API Key是否已在后台设置
157
+ if not st.session_state.openai_client:
158
+ st.error("后台未设置 OPENAI_API_KEY。请在Hugging Face Space的'Settings' -> 'Secrets'中进行设置后刷新页面。")
159
+ else:
160
+ # 显示聊天记录
161
+ for message in st.session_state.messages:
162
+ avatar = "👩" if message["role"] == "assistant" else "🧑‍⚕️"
163
+ with st.chat_message(message["role"], avatar=avatar):
164
+ st.markdown(message["content"])
165
+
166
+ # 获取用户输入
167
+ if prompt := st.chat_input("请输入您想说的话..."):
168
+ st.session_state.messages.append({"role": "user", "content": prompt})
169
+ with st.chat_message("user", avatar="🧑‍⚕️"):
170
+ st.markdown(prompt)
171
+
172
+ if st.session_state.patient_agent:
173
+ with st.chat_message("assistant", avatar="👩"):
174
+ with st.spinner("Anna正在思考..."):
175
+ response = st.session_state.patient_agent.chat(prompt)
176
+ st.markdown(response)
177
+
178
+ st.session_state.messages.append({"role": "assistant", "content": response})
179
+ else:
180
+ st.warning("请先在左侧选择一位病人。")