TrungNQ commited on
Commit
da8f43a
·
verified ·
1 Parent(s): ea5770c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -2
app.py CHANGED
@@ -93,7 +93,7 @@ out=grad.Textbox(lines=1, label="Genereated Question")
93
  grad.Interface(text2text, inputs=[context,ans], outputs=out).launch()
94
  '''
95
 
96
- #5.21
97
  from transformers import AutoTokenizer, AutoModelWithLMHead
98
  import gradio as grad
99
  text2text_tkn = AutoTokenizer.from_pretrained("deep-learning-analytics/wikihow-t5-small")
@@ -118,4 +118,70 @@ def text2text_summary(para):
118
 
119
  para=grad.Textbox(lines=10, label="Paragraph", placeholder="Copy paragraph")
120
  out=grad.Textbox(lines=1, label="Summary")
121
- grad.Interface(text2text_summary, inputs=para, outputs=out).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  grad.Interface(text2text, inputs=[context,ans], outputs=out).launch()
94
  '''
95
 
96
+ '''5.21
97
  from transformers import AutoTokenizer, AutoModelWithLMHead
98
  import gradio as grad
99
  text2text_tkn = AutoTokenizer.from_pretrained("deep-learning-analytics/wikihow-t5-small")
 
118
 
119
  para=grad.Textbox(lines=10, label="Paragraph", placeholder="Copy paragraph")
120
  out=grad.Textbox(lines=1, label="Summary")
121
+ grad.Interface(text2text_summary, inputs=para, outputs=out).launch()
122
+ '''
123
+
124
+ #5.28
125
+ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
126
+ import torch
127
+
128
+
129
+ chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
130
+ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
131
+
132
+
133
+ #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
134
+ #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
135
+
136
+ def converse(user_input, chat_history=[]):
137
+
138
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
139
+
140
+ # keep history in the tensor
141
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
142
+
143
+ # get response
144
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
145
+ print (chat_history)
146
+
147
+
148
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
149
+
150
+ print("starting to print response")
151
+ print(response)
152
+
153
+ # html for display
154
+ html = "<div class='mybot'>"
155
+ for x, mesg in enumerate(response):
156
+ if x%2!=0 :
157
+ mesg="Alicia:"+mesg
158
+ clazz="alicia"
159
+ else :
160
+ clazz="user"
161
+
162
+
163
+ print("value of x")
164
+ print(x)
165
+ print("message")
166
+ print (mesg)
167
+
168
+ html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
169
+ html += "</div>"
170
+ print(html)
171
+ return html, chat_history
172
+
173
+ import gradio as grad
174
+
175
+ css = """
176
+ .mychat {display:flex;flex-direction:column}
177
+ .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
178
+ .mesg.user {background-color:lightblue;color:white}
179
+ .mesg.alicia {background-color:orange;color:white,align-self:self-end}
180
+ .footer {display:none !important}
181
+ """
182
+ text=grad.inputs.Textbox(placeholder="Lets chat")
183
+ grad.Interface(fn=converse,
184
+ theme="default",
185
+ inputs=[text, "state"],
186
+ outputs=["html", "state"],
187
+ css=css).launch()