|
import streamlit as st |
|
from transformers import AutoModel, AutoTokenizer |
|
import mdtex2html |
|
from utils import load_model_on_gpus |
|
|
|
st.set_page_config(page_title="ChatGLM2-6B", page_icon=":robot:") |
|
|
|
st.header("ChatGLM2-6B") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() |
|
|
|
|
|
model = model.eval() |
|
|
|
def postprocess(chat): |
|
for i, (user, response) in enumerate(chat): |
|
chat[i] = (mdtex2html.convert(user), mdtex2html.convert(response)) |
|
return chat |
|
|
|
user_input = st.text_area("Input:", height=200, placeholder="Ask me anything!") |
|
if user_input: |
|
history = st.session_state.get('history', []) |
|
|
|
max_length = st.slider("Max Length:", 0, 32768, 8192, 1) |
|
top_p = st.slider("Top P:", 0.0, 1.0, 0.8, 0.01) |
|
temperature = st.slider("Temperature:", 0.0, 1.0, 0.95, 0.01) |
|
|
|
if 'past_key_values' not in st.session_state: |
|
st.session_state['past_key_values'] = None |
|
|
|
with st.spinner("Thinking..."): |
|
response = model.generate(tokenizer.encode(user_input), |
|
max_length=max_length, |
|
top_p=top_p, |
|
temperature=temperature, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
return_past_key_values=True, |
|
past_key_values=st.session_state.past_key_values) |
|
|
|
st.session_state.past_key_values = response.past_key_values |
|
|
|
history.append((user_input, response.sequences[0])) |
|
history = postprocess(history) |
|
|
|
for user, chatbot in history: |
|
message = f"**Human:** {user}" if user else "" |
|
response = f"**AI:** {chatbot}" if chatbot else "" |
|
st.markdown(message + response, unsafe_allow_html=True) |
|
|
|
if st.button("Clear History"): |
|
st.session_state['history'] = [] |
|
st.session_state['past_key_values'] = None |