import streamlit as st from transformers import AutoModel, AutoTokenizer import mdtex2html from utils import load_model_on_gpus st.set_page_config(page_title="ChatGLM2-6B", page_icon=":robot:") st.header("ChatGLM2-6B") tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() # Load model on multiple GPUs #model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) model = model.eval() def postprocess(chat): for i, (user, response) in enumerate(chat): chat[i] = (mdtex2html.convert(user), mdtex2html.convert(response)) return chat user_input = st.text_area("Input:", height=200, placeholder="Ask me anything!") if user_input: history = st.session_state.get('history', []) max_length = st.slider("Max Length:", 0, 32768, 8192, 1) top_p = st.slider("Top P:", 0.0, 1.0, 0.8, 0.01) temperature = st.slider("Temperature:", 0.0, 1.0, 0.95, 0.01) if 'past_key_values' not in st.session_state: st.session_state['past_key_values'] = None with st.spinner("Thinking..."): response = model.generate(tokenizer.encode(user_input), max_length=max_length, top_p=top_p, temperature=temperature, return_dict_in_generate=True, output_scores=True, return_past_key_values=True, past_key_values=st.session_state.past_key_values) st.session_state.past_key_values = response.past_key_values history.append((user_input, response.sequences[0])) history = postprocess(history) for user, chatbot in history: message = f"**Human:** {user}" if user else "" response = f"**AI:** {chatbot}" if chatbot else "" st.markdown(message + response, unsafe_allow_html=True) if st.button("Clear History"): st.session_state['history'] = [] st.session_state['past_key_values'] = None