research14 commited on
Commit
7eaa7b0
·
1 Parent(s): da85b3a
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -3,26 +3,25 @@ import random
3
  import time
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
- # Load models and tokenizers
7
- model_names = ["lmsys/vicuna-7b-v1.3", "gpt2"]
8
- models = [AutoModelForCausalLM.from_pretrained(name) for name in model_names]
9
- tokenizers = [AutoTokenizer.from_pretrained(name) for name in model_names]
10
 
11
  with gr.Blocks() as demo:
12
  with gr.Row():
13
  vicuna_chatbot = gr.Chatbot(label="Vicuna", live=True)
14
- gpt2_chatbot = gr.Chatbot(label="GPT-2", live=True)
15
  msg = gr.Textbox()
16
- clear = gr.ClearButton([msg, vicuna_chatbot, gpt2_chatbot])
17
 
18
  def respond(message, chat_history, chatbot_idx):
19
- input_ids = tokenizers[chatbot_idx].encode(message, return_tensors="pt")
20
- output = models[chatbot_idx].generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
21
- bot_message = tokenizers[chatbot_idx].decode(output[0], skip_special_tokens=True)
22
  chat_history.append((message, bot_message))
23
  time.sleep(2)
24
  return "", chat_history
25
 
26
- msg.submit(respond, [msg, vicuna_chatbot, 0], [msg, gpt2_chatbot, 1])
27
 
28
  demo.launch()
 
3
  import time
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
+ # Load Vicuna 7B model and tokenizer
7
+ model_name = "lmsys/vicuna-7b-v1.3"
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
  with gr.Blocks() as demo:
12
  with gr.Row():
13
  vicuna_chatbot = gr.Chatbot(label="Vicuna", live=True)
 
14
  msg = gr.Textbox()
15
+ clear = gr.ClearButton([msg, vicuna_chatbot])
16
 
17
  def respond(message, chat_history, chatbot_idx):
18
+ input_ids = tokenizer.encode(message, return_tensors="pt")
19
+ output = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
20
+ bot_message = tokenizer.decode(output[0], skip_special_tokens=True)
21
  chat_history.append((message, bot_message))
22
  time.sleep(2)
23
  return "", chat_history
24
 
25
+ msg.submit(respond, [msg, vicuna_chatbot, 0])
26
 
27
  demo.launch()