Not-Grim-Refer commited on
Commit
bb97cbe
1 Parent(s): 81bc518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -73
app.py CHANGED
@@ -1,74 +1,54 @@
1
  import streamlit as st
2
- from queue import Queue
3
- from langchain import HuggingFaceHub, PromptTemplate, LLMChain
4
-
5
- # Set the title of the Streamlit app
6
- st.title("Falcon QA Bot")
7
-
8
- # Get the Hugging Face Hub API token from Streamlit secrets
9
- huggingfacehub_api_token = st.secrets["hf_token"]
10
-
11
- # Set the repository ID for the Falcon model
12
- repo_id = "tiiuae/falcon-7b-instruct"
13
-
14
- # Initialize the Hugging Face Hub and LLMChain
15
- llm = HuggingFaceHub(
16
- huggingfacehub_api_token=huggingfacehub_api_token,
17
- repo_id=repo_id,
18
- model_kwargs={"temperature": 0.2, "max_new_tokens": 2000}
19
- )
20
-
21
- # Define the template for the assistant's response
22
- template = """
23
- You are an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
24
-
25
- {question}
26
- """
27
-
28
- # Create a queue to store user questions
29
- queue = Queue()
30
-
31
- def chat(query):
32
- """
33
- Generates a response to the user's question using the LLMChain model.
34
-
35
- :param query: User's question.
36
- :type query: str
37
- :return: Response to the user's question.
38
- :rtype: str
39
- """
40
- # Create a prompt template with the question variable
41
- prompt = PromptTemplate(template=template, input_variables=["question"])
42
-
43
- # Create an LLMChain instance with the prompt and the Falcon model
44
- llm_chain = LLMChain(prompt=prompt, verbose=True, llm=llm)
45
-
46
- # Generate a response to the user's question
47
- result = llm_chain.predict(question=query)
48
-
49
- return result
50
-
51
- def main():
52
- """
53
- Main function for the Streamlit app.
54
- """
55
- # Get the user's question from the input text box
56
- user_question = st.text_input("What do you want to ask about", placeholder="Input your question here")
57
-
58
- if user_question:
59
- # Add the user's question to the queue
60
- queue.put(user_question)
61
-
62
- # Check if there are any waiting users
63
- if not queue.empty():
64
- # Get the next user's question from the queue
65
- query = queue.get()
66
-
67
- # Generate a response to the user's question
68
- response = chat(query)
69
-
70
- # Display the response to the user
71
- st.write(response, unsafe_allow_html=True)
72
-
73
- if __name__ == '__main__':
74
- main()
 
1
  import streamlit as st
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import mdtex2html
4
+ from utils import load_model_on_gpus
5
+
6
+ st.set_page_config(page_title="ChatGLM2-6B", page_icon=":robot:")
7
+
8
+ st.header("ChatGLM2-6B")
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
11
+ model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
12
+ # Load model on multiple GPUs
13
+ #model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
14
+ model = model.eval()
15
+
16
+ def postprocess(chat):
17
+ for i, (user, response) in enumerate(chat):
18
+ chat[i] = (mdtex2html.convert(user), mdtex2html.convert(response))
19
+ return chat
20
+
21
+ user_input = st.text_area("Input:", height=200, placeholder="Ask me anything!")
22
+ if user_input:
23
+ history = st.session_state.get('history', [])
24
+
25
+ max_length = st.slider("Max Length:", 0, 32768, 8192, 1)
26
+ top_p = st.slider("Top P:", 0.0, 1.0, 0.8, 0.01)
27
+ temperature = st.slider("Temperature:", 0.0, 1.0, 0.95, 0.01)
28
+
29
+ if 'past_key_values' not in st.session_state:
30
+ st.session_state['past_key_values'] = None
31
+
32
+ with st.spinner("Thinking..."):
33
+ response = model.generate(tokenizer.encode(user_input),
34
+ max_length=max_length,
35
+ top_p=top_p,
36
+ temperature=temperature,
37
+ return_dict_in_generate=True,
38
+ output_scores=True,
39
+ return_past_key_values=True,
40
+ past_key_values=st.session_state.past_key_values)
41
+
42
+ st.session_state.past_key_values = response.past_key_values
43
+
44
+ history.append((user_input, response.sequences[0]))
45
+ history = postprocess(history)
46
+
47
+ for user, chatbot in history:
48
+ message = f"**Human:** {user}" if user else ""
49
+ response = f"**AI:** {chatbot}" if chatbot else ""
50
+ st.markdown(message + response, unsafe_allow_html=True)
51
+
52
+ if st.button("Clear History"):
53
+ st.session_state['history'] = []
54
+ st.session_state['past_key_values'] = None