nikshep01 commited on
Commit
7a73b12
·
verified ·
1 Parent(s): c19a8bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -16
app.py CHANGED
@@ -136,27 +136,77 @@
136
 
137
 
138
  # new code
139
- import os
140
- from langchain_openai import ChatOpenAI
141
- from langchain.schema import AIMessage, HumanMessage
142
- import openai
143
- import gradio as gr
 
 
 
144
 
 
145
 
146
- os.environ["OPENAI_API_KEY"] = "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" # Replace with your key
 
 
 
 
 
 
 
 
 
147
 
148
- llm = ChatOpenAI(temperature=1.0, model='gpt-3.5-turbo')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  def predict(message, history):
151
- history_langchain_format = []
152
- for msg in history:
153
- if msg['role'] == "user":
154
- history_langchain_format.append(HumanMessage(content=msg['content']))
155
- elif msg['role'] == "assistant":
156
- history_langchain_format.append(AIMessage(content=msg['content']))
157
- history_langchain_format.append(HumanMessage(content=message))
158
- gpt_response = llm(history_langchain_format)
159
- return gpt_response.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  gr.ChatInterface(predict).launch()
162
 
 
136
 
137
 
138
  # new code
139
+ # import os
140
+ # from langchain_openai import ChatOpenAI
141
+ # from langchain.schema import AIMessage, HumanMessage
142
+ # import openai
143
+ # import gradio as gr
144
+
145
+
146
+ # os.environ["OPENAI_API_KEY"] = "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" # Replace with your key
147
 
148
+ # llm = ChatOpenAI(temperature=1.0, model='gpt-3.5-turbo')
149
 
150
+ # def predict(message, history):
151
+ # history_langchain_format = []
152
+ # for msg in history:
153
+ # if msg['role'] == "user":
154
+ # history_langchain_format.append(HumanMessage(content=msg['content']))
155
+ # elif msg['role'] == "assistant":
156
+ # history_langchain_format.append(AIMessage(content=msg['content']))
157
+ # history_langchain_format.append(HumanMessage(content=message))
158
+ # gpt_response = llm(history_langchain_format)
159
+ # return gpt_response.content
160
 
161
+ # gr.ChatInterface(predict).launch()
162
+
163
+
164
+
165
+ import gradio as gr
166
+ import torch
167
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
168
+ from threading import Thread
169
+
170
+ tokenizer = AutoTokenizer.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1")
171
+ model = AutoModelForCausalLM.from_pretrained("togethercomputer/RedPajama-INCITE-Chat-3B-v1", torch_dtype=torch.float16)
172
+ model = model.to('cuda:0')
173
+
174
+ class StopOnTokens(StoppingCriteria):
175
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
176
+ stop_ids = [29, 0]
177
+ for stop_id in stop_ids:
178
+ if input_ids[0][-1] == stop_id:
179
+ return True
180
+ return False
181
 
182
  def predict(message, history):
183
+ history_transformer_format = list(zip(history[:-1], history[1:])) + [[message, ""]]
184
+ stop = StopOnTokens()
185
+
186
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
187
+ for item in history_transformer_format])
188
+
189
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
190
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
191
+ generate_kwargs = dict(
192
+ model_inputs,
193
+ streamer=streamer,
194
+ max_new_tokens=1024,
195
+ do_sample=True,
196
+ top_p=0.95,
197
+ top_k=1000,
198
+ temperature=1.0,
199
+ num_beams=1,
200
+ stopping_criteria=StoppingCriteriaList([stop])
201
+ )
202
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
203
+ t.start()
204
+
205
+ partial_message = ""
206
+ for new_token in streamer:
207
+ if new_token != '<':
208
+ partial_message += new_token
209
+ yield partial_message
210
 
211
  gr.ChatInterface(predict).launch()
212