Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
# 设置页面标题、图标和布局 | |
st.set_page_config( | |
page_title="ChatGLM3-6B 演示", | |
page_icon=":robot:", | |
layout="wide" | |
) | |
# 设置为模型ID或本地文件夹路径 | |
model_path = "THUDM/chatglm3-6b" | |
def get_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda() | |
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 | |
# from utils import load_model_on_gpus | |
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) | |
model = model.eval() | |
return tokenizer, model | |
# 加载Chatglm3的model和tokenizer | |
tokenizer, model = get_model() | |
# 初始化历史记录和past key values | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "past_key_values" not in st.session_state: | |
st.session_state.past_key_values = None | |
# 设置max_length、top_p和temperature | |
max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1) | |
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) | |
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01) | |
# 清理会话历史 | |
buttonClean = st.sidebar.button("清理会话历史", key="clean") | |
if buttonClean: | |
st.session_state.history = [] | |
st.session_state.past_key_values = None | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
st.rerun() | |
# 渲染聊天历史记录 | |
for i, message in enumerate(st.session_state.history): | |
if message["role"] == "user": | |
with st.chat_message(name="user", avatar="user"): | |
st.markdown(message["content"]) | |
else: | |
with st.chat_message(name="assistant", avatar="assistant"): | |
st.markdown(message["content"]) | |
# 输入框和输出框 | |
with st.chat_message(name="user", avatar="user"): | |
input_placeholder = st.empty() | |
with st.chat_message(name="assistant", avatar="assistant"): | |
message_placeholder = st.empty() | |
# 获取用户输入 | |
prompt_text = st.chat_input("请输入您的问题") | |
# 如果用户输入了内容,则生成回复 | |
if prompt_text: | |
input_placeholder.markdown(prompt_text) | |
history = st.session_state.history | |
past_key_values = st.session_state.past_key_values | |
for response, history, past_key_values in model.stream_chat( | |
tokenizer, | |
prompt_text, | |
history, | |
past_key_values=past_key_values, | |
max_length=max_length, | |
top_p=top_p, | |
temperature=temperature, | |
return_past_key_values=True, | |
): | |
message_placeholder.markdown(response) | |
# 更新历史记录和past key values | |
st.session_state.history = history | |
st.session_state.past_key_values = past_key_values | |