gururise commited on
Commit
47dc020
·
1 Parent(s): a2f6401

complete chatbot

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -116,7 +116,9 @@ def chat(
116
  ):
117
  global model
118
  history = history or []
119
-
 
 
120
  if model == None:
121
  gc.collect()
122
  if (DEVICE == "cuda"):
@@ -126,6 +128,22 @@ def chat(
126
  if len(history) == 0:
127
  # no history, so lets reset chat state
128
  model.resetState()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  max_new_tokens = int(max_new_tokens)
131
  temperature = float(temperature)
@@ -139,15 +157,13 @@ def chat(
139
 
140
  if temperature == 0.0:
141
  temperature = 0.01
142
- if prompt == "":
143
- prompt = " "
144
 
145
  print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
146
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
147
  # Load prompt
148
- model.loadContext(newctx=prompt)
149
- generated_text = ""
150
- done = False
151
  generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
152
 
153
  generated_text = generated_text.lstrip("\n ")
@@ -230,7 +246,7 @@ chatiface = gr.Interface(
230
  gr.Slider(1, 256, value=60), # max_tokens
231
  gr.Slider(0.0, 1.0, value=0.8), # temperature
232
  gr.Slider(0.0, 1.0, value=0.85), # top_p
233
- gr.Textbox(lines=1, value="<|endoftext|>") # stop
234
  ],
235
  outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
236
  ).queue()
 
116
  ):
117
  global model
118
  history = history or []
119
+
120
+ intro = ""
121
+
122
  if model == None:
123
  gc.collect()
124
  if (DEVICE == "cuda"):
 
128
  if len(history) == 0:
129
  # no history, so lets reset chat state
130
  model.resetState()
131
+ print("reset chat state")
132
+ intro = '''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
133
+
134
+ USER: What year was the french revolution?
135
+ FRITZ: The French Revolution started in 1789, and lasted 10 years until 1799.
136
+ USER: 3+5=?
137
+ FRITZ: The answer is 8.
138
+ USER: What year did the Berlin Wall fall?
139
+ FRITZ: The Berlin wall fell in 1989 and was the start of the collapse of the iron curtain.
140
+ USER: solve for a: 9-a=2
141
+ FRITZ: The answer is a=7, because 9-7 = 2.
142
+ USER: wat is lhc
143
+ FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
144
+ USER: Do you know who I am?
145
+ FRITZ: Only if you tell me more about yourself.. what are your interests?
146
+ '''
147
 
148
  max_new_tokens = int(max_new_tokens)
149
  temperature = float(temperature)
 
157
 
158
  if temperature == 0.0:
159
  temperature = 0.01
 
 
160
 
161
  print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
162
  print(f"OUTPUT ({datetime.now()}):\n-------\n")
163
  # Load prompt
164
+ prompt = "USER: " + prompt + "\n"
165
+ model.loadContext(newctx=intro+prompt)
166
+
167
  generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
168
 
169
  generated_text = generated_text.lstrip("\n ")
 
246
  gr.Slider(1, 256, value=60), # max_tokens
247
  gr.Slider(0.0, 1.0, value=0.8), # temperature
248
  gr.Slider(0.0, 1.0, value=0.85), # top_p
249
+ gr.Textbox(lines=1, value="USER:,<|endoftext|>") # stop
250
  ],
251
  outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
252
  ).queue()