Spaces:
Build error
Build error
complete chatbot
Browse files
app.py
CHANGED
@@ -116,7 +116,9 @@ def chat(
|
|
116 |
):
|
117 |
global model
|
118 |
history = history or []
|
119 |
-
|
|
|
|
|
120 |
if model == None:
|
121 |
gc.collect()
|
122 |
if (DEVICE == "cuda"):
|
@@ -126,6 +128,22 @@ def chat(
|
|
126 |
if len(history) == 0:
|
127 |
# no history, so lets reset chat state
|
128 |
model.resetState()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
max_new_tokens = int(max_new_tokens)
|
131 |
temperature = float(temperature)
|
@@ -139,15 +157,13 @@ def chat(
|
|
139 |
|
140 |
if temperature == 0.0:
|
141 |
temperature = 0.01
|
142 |
-
if prompt == "":
|
143 |
-
prompt = " "
|
144 |
|
145 |
print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
|
146 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
147 |
# Load prompt
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
|
152 |
|
153 |
generated_text = generated_text.lstrip("\n ")
|
@@ -230,7 +246,7 @@ chatiface = gr.Interface(
|
|
230 |
gr.Slider(1, 256, value=60), # max_tokens
|
231 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
232 |
gr.Slider(0.0, 1.0, value=0.85), # top_p
|
233 |
-
gr.Textbox(lines=1, value="
|
234 |
],
|
235 |
outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
|
236 |
).queue()
|
|
|
116 |
):
|
117 |
global model
|
118 |
history = history or []
|
119 |
+
|
120 |
+
intro = ""
|
121 |
+
|
122 |
if model == None:
|
123 |
gc.collect()
|
124 |
if (DEVICE == "cuda"):
|
|
|
128 |
if len(history) == 0:
|
129 |
# no history, so lets reset chat state
|
130 |
model.resetState()
|
131 |
+
print("reset chat state")
|
132 |
+
intro = '''The following is a verbose and detailed conversation between an AI assistant called FRITZ, and a human user called USER. FRITZ is intelligent, knowledgeable, wise and polite.
|
133 |
+
|
134 |
+
USER: What year was the french revolution?
|
135 |
+
FRITZ: The French Revolution started in 1789, and lasted 10 years until 1799.
|
136 |
+
USER: 3+5=?
|
137 |
+
FRITZ: The answer is 8.
|
138 |
+
USER: What year did the Berlin Wall fall?
|
139 |
+
FRITZ: The Berlin wall fell in 1989 and was the start of the collapse of the iron curtain.
|
140 |
+
USER: solve for a: 9-a=2
|
141 |
+
FRITZ: The answer is a=7, because 9-7 = 2.
|
142 |
+
USER: wat is lhc
|
143 |
+
FRITZ: The Large Hadron Collider (LHC) is a high-energy particle collider, built by CERN, and completed in 2008. It was used to confirm the existence of the Higgs boson in 2012.
|
144 |
+
USER: Do you know who I am?
|
145 |
+
FRITZ: Only if you tell me more about yourself.. what are your interests?
|
146 |
+
'''
|
147 |
|
148 |
max_new_tokens = int(max_new_tokens)
|
149 |
temperature = float(temperature)
|
|
|
157 |
|
158 |
if temperature == 0.0:
|
159 |
temperature = 0.01
|
|
|
|
|
160 |
|
161 |
print(f"CHAT ({datetime.now()}):\n-------\n{prompt}")
|
162 |
print(f"OUTPUT ({datetime.now()}):\n-------\n")
|
163 |
# Load prompt
|
164 |
+
prompt = "USER: " + prompt + "\n"
|
165 |
+
model.loadContext(newctx=intro+prompt)
|
166 |
+
|
167 |
generated_text = model.forward(number=max_new_tokens, stopStrings=stop,temp=temperature,top_p_usual=top_p)["output"]
|
168 |
|
169 |
generated_text = generated_text.lstrip("\n ")
|
|
|
246 |
gr.Slider(1, 256, value=60), # max_tokens
|
247 |
gr.Slider(0.0, 1.0, value=0.8), # temperature
|
248 |
gr.Slider(0.0, 1.0, value=0.85), # top_p
|
249 |
+
gr.Textbox(lines=1, value="USER:,<|endoftext|>") # stop
|
250 |
],
|
251 |
outputs=[gr.Chatbot(color_map=("green", "pink")),"state"],
|
252 |
).queue()
|