Not-Grim-Refer's picture
Update app.py
bb97cbe
raw
history blame
2.12 kB
import streamlit as st
from transformers import AutoModel, AutoTokenizer
import mdtex2html
from utils import load_model_on_gpus
st.set_page_config(page_title="ChatGLM2-6B", page_icon=":robot:")
st.header("ChatGLM2-6B")
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
# Load model on multiple GPUs
#model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model = model.eval()
def postprocess(chat):
for i, (user, response) in enumerate(chat):
chat[i] = (mdtex2html.convert(user), mdtex2html.convert(response))
return chat
user_input = st.text_area("Input:", height=200, placeholder="Ask me anything!")
if user_input:
history = st.session_state.get('history', [])
max_length = st.slider("Max Length:", 0, 32768, 8192, 1)
top_p = st.slider("Top P:", 0.0, 1.0, 0.8, 0.01)
temperature = st.slider("Temperature:", 0.0, 1.0, 0.95, 0.01)
if 'past_key_values' not in st.session_state:
st.session_state['past_key_values'] = None
with st.spinner("Thinking..."):
response = model.generate(tokenizer.encode(user_input),
max_length=max_length,
top_p=top_p,
temperature=temperature,
return_dict_in_generate=True,
output_scores=True,
return_past_key_values=True,
past_key_values=st.session_state.past_key_values)
st.session_state.past_key_values = response.past_key_values
history.append((user_input, response.sequences[0]))
history = postprocess(history)
for user, chatbot in history:
message = f"**Human:** {user}" if user else ""
response = f"**AI:** {chatbot}" if chatbot else ""
st.markdown(message + response, unsafe_allow_html=True)
if st.button("Clear History"):
st.session_state['history'] = []
st.session_state['past_key_values'] = None