Update app.py
Browse files
app.py
CHANGED
@@ -136,27 +136,77 @@
|
|
136 |
|
137 |
|
138 |
# new code
|
139 |
-
import os
|
140 |
-
from langchain_openai import ChatOpenAI
|
141 |
-
from langchain.schema import AIMessage, HumanMessage
|
142 |
-
import openai
|
143 |
-
import gradio as gr
|
|
|
|
|
|
|
144 |
|
|
|
145 |
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
def predict(message, history):
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
gr.ChatInterface(predict).launch()
|
162 |
|
|
|
136 |
|
137 |
|
138 |
# new code
|
139 |
+
# import os
|
140 |
+
# from langchain_openai import ChatOpenAI
|
141 |
+
# from langchain.schema import AIMessage, HumanMessage
|
142 |
+
# import openai
|
143 |
+
# import gradio as gr
|
144 |
+
|
145 |
+
|
146 |
+
# os.environ["OPENAI_API_KEY"] = "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" # Replace with your key
|
147 |
|
148 |
+
# llm = ChatOpenAI(temperature=1.0, model='gpt-3.5-turbo')
|
149 |
|
150 |
+
# def predict(message, history):
|
151 |
+
# history_langchain_format = []
|
152 |
+
# for msg in history:
|
153 |
+
# if msg['role'] == "user":
|
154 |
+
# history_langchain_format.append(HumanMessage(content=msg['content']))
|
155 |
+
# elif msg['role'] == "assistant":
|
156 |
+
# history_langchain_format.append(AIMessage(content=msg['content']))
|
157 |
+
# history_langchain_format.append(HumanMessage(content=message))
|
158 |
+
# gpt_response = llm(history_langchain_format)
|
159 |
+
# return gpt_response.content
|
160 |
|
161 |
+
# gr.ChatInterface(predict).launch()
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
import gradio as gr
|
166 |
+
import torch
|
167 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
168 |
+
from threading import Thread
|
169 |
+
|
170 |
+
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
|
171 |
+
model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
|
172 |
+
model = model.to('cuda:0')
|
173 |
+
|
174 |
+
class StopOnTokens(StoppingCriteria):
|
175 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
176 |
+
stop_ids = [29, 0]
|
177 |
+
for stop_id in stop_ids:
|
178 |
+
if input_ids[0][-1] == stop_id:
|
179 |
+
return True
|
180 |
+
return False
|
181 |
|
182 |
def predict(message, history):
|
183 |
+
history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
|
184 |
+
stop = StopOnTokens()
|
185 |
+
|
186 |
+
messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
|
187 |
+
for item in history_transformer_format])
|
188 |
+
|
189 |
+
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
|
190 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
191 |
+
generate_kwargs = dict(
|
192 |
+
model_inputs,
|
193 |
+
streamer=streamer,
|
194 |
+
max_new_tokens=1024,
|
195 |
+
do_sample=True,
|
196 |
+
top_p=0.95,
|
197 |
+
top_k=1000,
|
198 |
+
temperature=1.0,
|
199 |
+
num_beams=1,
|
200 |
+
stopping_criteria=StoppingCriteriaList([stop])
|
201 |
+
)
|
202 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
203 |
+
t.start()
|
204 |
+
|
205 |
+
partial_message = ""
|
206 |
+
for new_token in streamer:
|
207 |
+
if new_token != '<':
|
208 |
+
partial_message += new_token
|
209 |
+
yield partial_message
|
210 |
|
211 |
gr.ChatInterface(predict).launch()
|
212 |
|