carolanderson commited on
Commit
f2f3156
·
1 Parent(s): ee12bcf

edit to handle mistral prompt format

Browse files
Files changed (1) hide show
  1. app.py +92 -24
app.py CHANGED
@@ -5,6 +5,7 @@ from langchain.chains import LLMChain
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.llms import HuggingFaceHub
7
  from langchain.prompts.chat import (
 
8
  ChatPromptTemplate,
9
  MessagesPlaceholder,
10
  SystemMessagePromptTemplate,
@@ -16,6 +17,9 @@ from langchain.schema import AIMessage, HumanMessage
16
  from openai.error import AuthenticationError
17
  import streamlit as st
18
 
 
 
 
19
 
20
  def setup_memory():
21
  msgs = StreamlitChatMessageHistory(key="basic_chat_app")
@@ -26,42 +30,55 @@ def setup_memory():
26
  return memory
27
 
28
 
29
- def use_existing_chain(model, provider, temp):
 
 
 
30
  if "current_chain" in st.session_state:
31
  current_chain = st.session_state.current_chain
32
  if (current_chain.model == model) \
33
  and (current_chain.provider == provider) \
34
- and (current_chain.temp == temp):
 
35
  return True
36
  return False
37
 
38
 
39
  class CurrentChain():
40
- def __init__(self, model, provider, memory, temp):
41
  self.model = model
42
  self.provider = provider
43
  self.temp = temp
 
 
44
  logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
45
  if provider == "OpenAI":
46
  llm = ChatOpenAI(model_name=model, temperature=temp)
47
  elif provider == "HuggingFace":
 
48
  llm = HuggingFaceHub(repo_id=model,
49
- model_kwargs={"temperature": temp, "max_length": 64})
50
- prompt = ChatPromptTemplate(
51
- messages=[
52
- SystemMessagePromptTemplate.from_template(
53
- "You are a nice chatbot having a conversation with a human."
54
- ),
55
- MessagesPlaceholder(variable_name="chat_history"),
56
- HumanMessagePromptTemplate.from_template("{input}")
57
- ]
58
- )
59
  self.conversation = LLMChain(
60
  llm=llm,
61
  prompt=prompt,
62
  verbose=True,
63
  memory=memory
64
  )
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  if __name__ == "__main__":
@@ -79,7 +96,10 @@ if __name__ == "__main__":
79
  model_name = st.sidebar.selectbox(
80
  label = "Choose a model",
81
  options = ["gpt-3.5-turbo (OpenAI)",
82
- "bigscience/bloom (HuggingFace)"
 
 
 
83
  ],
84
  help="Which LLM to use",
85
  )
@@ -92,39 +112,87 @@ if __name__ == "__main__":
92
  min_value=float(0),
93
  max_value=1.0,
94
  step=0.1,
95
- value=0.9,
96
  help="Set the decoding temperature"
97
  )
 
 
 
 
 
 
 
 
 
98
  ##########################
 
99
  model = model_name.split("(")[0].rstrip() # remove name of model provider
100
  provider = model_name.split("(")[-1].split(")")[0]
 
101
  if "session_memory" not in st.session_state:
102
- st.session_state.session_memory = setup_memory()
103
-
104
- if use_existing_chain(model, provider, temp):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  chain = st.session_state.current_chain
106
  else:
107
  chain = CurrentChain(model,
108
  provider,
 
109
  st.session_state.session_memory,
110
- temp)
 
111
  st.session_state.current_chain = chain
112
 
113
  conversation = chain.conversation
 
114
  if st.button("Clear history"):
115
- conversation.memory.clear()
116
- for message in conversation.memory.buffer: # display chat history
117
- st.chat_message(message.type).write(message.content)
 
 
 
 
 
 
 
118
  text = st.chat_input()
119
  if text:
120
  with st.chat_message("user"):
121
  st.write(text)
 
122
  try:
123
- result = conversation.predict(input=text)
 
 
 
 
 
 
 
124
  with st.chat_message("assistant"):
125
  st.write(result)
126
  except (AuthenticationError, ValueError):
127
- st.warning("Enter a valid API key", icon="⚠️")
 
128
 
129
 
130
 
 
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.llms import HuggingFaceHub
7
  from langchain.prompts.chat import (
8
+ PromptTemplate,
9
  ChatPromptTemplate,
10
  MessagesPlaceholder,
11
  SystemMessagePromptTemplate,
 
17
  from openai.error import AuthenticationError
18
  import streamlit as st
19
 
20
+ from langchain import verbose
21
+ verbose = True
22
+
23
 
24
  def setup_memory():
25
  msgs = StreamlitChatMessageHistory(key="basic_chat_app")
 
30
  return memory
31
 
32
 
33
+ def use_existing_chain(model, provider, temp, max_tokens):
34
+ # TODO: consider whether prompt needs to be checked here
35
+ if "mistral" in model:
36
+ return False
37
  if "current_chain" in st.session_state:
38
  current_chain = st.session_state.current_chain
39
  if (current_chain.model == model) \
40
  and (current_chain.provider == provider) \
41
+ and (current_chain.temp == temp) \
42
+ and (current_chain.max_tokens == max_tokens):
43
  return True
44
  return False
45
 
46
 
47
  class CurrentChain():
48
+ def __init__(self, model, provider, prompt, memory, temp, max_tokens=64):
49
  self.model = model
50
  self.provider = provider
51
  self.temp = temp
52
+ self.max_tokens=max_tokens
53
+
54
  logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
55
  if provider == "OpenAI":
56
  llm = ChatOpenAI(model_name=model, temperature=temp)
57
  elif provider == "HuggingFace":
58
+ # TODO: expose the controls below as widgets and clean up init
59
  llm = HuggingFaceHub(repo_id=model,
60
+ model_kwargs={"temperature": temp,
61
+ "max_new_tokens": 256,
62
+ "top_p" : 0.95,
63
+ "repetition_penalty" : 1.0,
64
+ "do_sample" : True,
65
+ "seed" : 42})
66
+
 
 
 
67
  self.conversation = LLMChain(
68
  llm=llm,
69
  prompt=prompt,
70
  verbose=True,
71
  memory=memory
72
  )
73
+
74
+
75
+ def format_mistral_prompt(message, history):
76
+ prompt = "<s>"
77
+ for user_prompt, bot_response in history:
78
+ prompt += f"[INST] {user_prompt} [/INST]"
79
+ prompt += f" {bot_response}</s> "
80
+ prompt += f"[INST] {message} [/INST]"
81
+ return prompt
82
 
83
 
84
  if __name__ == "__main__":
 
96
  model_name = st.sidebar.selectbox(
97
  label = "Choose a model",
98
  options = ["gpt-3.5-turbo (OpenAI)",
99
+ # "bigscience/bloom (HuggingFace)", # runs
100
+ # "microsoft/DialoGPT-medium (HuggingFace)", # throws error
101
+ # "google/flan-t5-xxl (HuggingFace)", # runs
102
+ "mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)"
103
  ],
104
  help="Which LLM to use",
105
  )
 
112
  min_value=float(0),
113
  max_value=1.0,
114
  step=0.1,
115
+ value=0.4,
116
  help="Set the decoding temperature"
117
  )
118
+
119
+ max_tokens = st.sidebar.slider(
120
+ label="Max tokens",
121
+ min_value=32,
122
+ max_value=2048,
123
+ step=1,
124
+ value=1028,
125
+ help="Set the maximum number of tokens to generate"
126
+ ) # TODO: edit this, not currently using
127
  ##########################
128
+
129
  model = model_name.split("(")[0].rstrip() # remove name of model provider
130
  provider = model_name.split("(")[-1].split(")")[0]
131
+
132
  if "session_memory" not in st.session_state:
133
+ st.session_state.session_memory = setup_memory() # for openai
134
+
135
+ if "history" not in st.session_state:
136
+ st.session_state.history = [] # for mistral
137
+
138
+ if "mistral" in model:
139
+ prompt = PromptTemplate(input_variables=["input"],
140
+ template="{input}")
141
+ else:
142
+ prompt = ChatPromptTemplate(
143
+ messages=[
144
+ SystemMessagePromptTemplate.from_template(
145
+ "You are a nice chatbot having a conversation with a human."
146
+ ),
147
+ MessagesPlaceholder(variable_name="chat_history"),
148
+ HumanMessagePromptTemplate.from_template("{input}")
149
+ ],
150
+ verbose=True
151
+ )
152
+
153
+ if use_existing_chain(model, provider, temp, max_tokens):
154
  chain = st.session_state.current_chain
155
  else:
156
  chain = CurrentChain(model,
157
  provider,
158
+ prompt,
159
  st.session_state.session_memory,
160
+ temp,
161
+ max_tokens)
162
  st.session_state.current_chain = chain
163
 
164
  conversation = chain.conversation
165
+
166
  if st.button("Clear history"):
167
+ conversation.memory.clear() # for openai
168
+ st.session_state.history = [] # for mistral
169
+ logging.info("history cleared")
170
+
171
+ for user_msg, asst_msg in st.session_state.history:
172
+ with st.chat_message("user"):
173
+ st.write(user_msg)
174
+ with st.chat_message("assistant"):
175
+ st.write(asst_msg)
176
+
177
  text = st.chat_input()
178
  if text:
179
  with st.chat_message("user"):
180
  st.write(text)
181
+ logging.info(text)
182
  try:
183
+ if "mistral" in model:
184
+ full_prompt = format_mistral_prompt(text, st.session_state.history)
185
+ result = conversation.predict(input=full_prompt)
186
+ else:
187
+ result = conversation.predict(input=text)
188
+
189
+ st.session_state.history.append((text, result))
190
+ logging.info(repr(result))
191
  with st.chat_message("assistant"):
192
  st.write(result)
193
  except (AuthenticationError, ValueError):
194
+ st.warning("Supply a valid API key", icon="⚠️")
195
+
196
 
197
 
198