Michael398 commited on
Commit
d36d097
·
verified ·
1 Parent(s): e5f5f24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -5,18 +5,34 @@ import gradio as gr
5
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
6
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
7
 
8
- def predict(input, history=[]):
9
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
10
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
11
- history = model.generate(bot_input_ids, max_length=4000, pad_token_id=tokenizer.eos_token_id).tolist()
12
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
13
- response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)]
14
- return response, history
15
-
16
- gr.Interface(fn=predict,
17
- inputs=["text", "state"],
18
- outputs=["chatbot", "state"]).launch()
19
- gr.Interface(..., allow_flagging="never").launch(share=True, inline=True)
20
-
21
- iface = gr.Interface(...) # your interface
22
- iface.launch(share=True, show_error=True, enable_queue=True, allowed_paths=None, inline=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
6
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
7
 
8
+ def predict(user_input, history=[]):
9
+ # Encode the user input + end-of-text token
10
+ new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
11
+
12
+ # Prepare chat history
13
+ if history:
14
+ bot_input_ids = torch.cat([torch.tensor(history), new_input_ids], dim=-1)
15
+ else:
16
+ bot_input_ids = new_input_ids
17
+
18
+ # Generate response
19
+ chat_history_ids = model.generate(
20
+ bot_input_ids,
21
+ max_length=1000,
22
+ pad_token_id=tokenizer.eos_token_id
23
+ )
24
+
25
+ # Decode the bot's reply
26
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
27
+
28
+ return response, chat_history_ids.tolist()
29
+
30
+ iface = gr.Interface(
31
+ fn=predict,
32
+ inputs=["text", "state"],
33
+ outputs=["text", "state"],
34
+ title="DialoGPT Chatbot",
35
+ description="Chat with DialoGPT-large. Your chat history is preserved.",
36
+ )
37
+
38
+ iface.launch(share=True, show_error=True, inline=True)