research14 commited on
Commit
85bd1c9
·
1 Parent(s): f2d6f20
Files changed (1) hide show
  1. app.py +33 -52
app.py CHANGED
@@ -7,72 +7,53 @@ model_name = "lmsys/vicuna-7b-v1.3"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- with gr.Blocks() as demo:
11
- gr.Markdown("# LLM Evaluator With Linguistic Scrutiny")
12
 
13
- with gr.Tab("POS"):
14
- gr.Markdown(" Description ")
15
 
16
- prompt_POS = gr.Textbox(show_label=False, placeholder="Write a prompt and press enter")
17
-
18
- gr.Markdown("Strategy 1 QA")
19
- with gr.Row():
20
- vicuna_S1_chatbot_POS = gr.Chatbot(label="vicuna-7b")
21
- llama_S1_chatbot_POS = gr.Chatbot(label="llama-7b")
22
- gpt_S1_chatbot_POS = gr.Chatbot(label="gpt-3.5")
23
- clear = gr.ClearButton([prompt_POS, vicuna_S1_chatbot_POS])
24
- gr.Markdown("Strategy 2 Instruction")
25
- with gr.Row():
26
- vicuna_S2_chatbot_POS = gr.Chatbot(label="vicuna-7b")
27
- llama_S2_chatbot_POS = gr.Chatbot(label="llama-7b")
28
- gpt_S2_chatbot_POS = gr.Chatbot(label="gpt-3.5")
29
- clear = gr.ClearButton([prompt_POS, vicuna_S2_chatbot_POS])
30
- gr.Markdown("Strategy 3 Structured Prompting")
31
  with gr.Row():
32
- vicuna_S3_chatbot_POS = gr.Chatbot(label="vicuna-7b")
33
- llama_S3_chatbot_POS = gr.Chatbot(label="llama-7b")
34
- gpt_S3_chatbot_POS = gr.Chatbot(label="gpt-3.5")
35
- clear = gr.ClearButton([prompt_POS, vicuna_S3_chatbot_POS])
36
-
37
- with gr.Tab("Chunk"):
38
- gr.Markdown(" Description ")
39
 
40
- prompt_CHUNK = gr.Textbox(show_label=False, placeholder="Write a prompt and press enter")
41
 
42
- gr.Markdown("Strategy 1 QA")
43
- with gr.Row():
44
- vicuna_S1_chatbot_CHUNK = gr.Chatbot(label="vicuna-7b")
45
- llama_S1_chatbot_CHUNK = gr.Chatbot(label="llama-7b")
46
- gpt_S1_chatbot_CHUNK = gr.Chatbot(label="gpt-3.5")
47
- clear = gr.ClearButton([prompt_CHUNK, vicuna_S1_chatbot_CHUNK])
48
- gr.Markdown("Strategy 2 Instruction")
49
- with gr.Row():
50
- vicuna_S2_chatbot_CHUNK = gr.Chatbot(label="vicuna-7b")
51
- llama_S2_chatbot_CHUNK = gr.Chatbot(label="llama-7b")
52
- gpt_S2_chatbot_CHUNK = gr.Chatbot(label="gpt-3.5")
53
- clear = gr.ClearButton([prompt_CHUNK, vicuna_S2_chatbot_CHUNK])
54
- gr.Markdown("Strategy 3 Structured Prompting")
55
- with gr.Row():
56
- vicuna_S3_chatbot_CHUNK = gr.Chatbot(label="vicuna-7b")
57
- llama_S3_chatbot_CHUNK = gr.Chatbot(label="llama-7b")
58
- gpt_S3_chatbot_CHUNK = gr.Chatbot(label="gpt-3.5")
59
- clear = gr.ClearButton([prompt_CHUNK, vicuna_S3_chatbot_CHUNK])
60
 
 
61
  def respond(message, chat_history):
62
  input_ids = tokenizer.encode(message, return_tensors="pt")
63
  output_ids = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
64
  bot_message = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
-
66
  chat_history.append((message, bot_message))
67
  time.sleep(2)
68
  return "", chat_history
69
 
70
- prompt_POS.submit(respond, [prompt_POS, vicuna_S1_chatbot_POS], [prompt_POS, vicuna_S1_chatbot_POS])
71
- prompt_POS.submit(respond, [prompt_POS, vicuna_S2_chatbot_POS], [prompt_POS, vicuna_S2_chatbot_POS])
72
- prompt_POS.submit(respond, [prompt_POS, vicuna_S3_chatbot_POS], [prompt_POS, vicuna_S3_chatbot_POS])
 
 
 
73
 
74
- prompt_CHUNK.submit(respond, [prompt_CHUNK, vicuna_S1_chatbot_CHUNK], [prompt_CHUNK, vicuna_S1_chatbot_CHUNK])
75
- prompt_CHUNK.submit(respond, [prompt_CHUNK, vicuna_S2_chatbot_CHUNK], [prompt_CHUNK, vicuna_S2_chatbot_CHUNK])
76
- prompt_CHUNK.submit(respond, [prompt_CHUNK, vicuna_S3_chatbot_CHUNK], [prompt_CHUNK, vicuna_S3_chatbot_CHUNK])
77
 
 
 
 
 
78
  demo.launch()
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ def create_chatbot_tab(description, prompt_placeholder, strategy_labels):
11
+ tab = gr.Tab(description)
12
 
13
+ prompt_textbox = gr.Textbox(show_label=False, placeholder=prompt_placeholder)
 
14
 
15
+ chatbots = []
16
+ for strategy_label in strategy_labels:
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  with gr.Row():
18
+ vicuna_chatbot = gr.Chatbot(label="vicuna-7b")
19
+ llama_chatbot = gr.Chatbot(label="llama-7b")
20
+ gpt_chatbot = gr.Chatbot(label="gpt-3.5")
21
+ chatbots.append(vicuna_chatbot)
 
 
 
22
 
23
+ clear_button = gr.ClearButton([prompt_textbox] + chatbots)
24
 
25
+ # Add components within the gr.Blocks context
26
+ with tab:
27
+ gr.Col(prompt_textbox)
28
+ for chatbot in chatbots:
29
+ gr.Col(chatbot)
30
+ gr.Col(clear_button)
31
+
32
+ return tab, prompt_textbox, chatbots
 
 
 
 
 
 
 
 
 
 
33
 
34
+ def create_submit_function(prompt_textbox, chatbots):
35
  def respond(message, chat_history):
36
  input_ids = tokenizer.encode(message, return_tensors="pt")
37
  output_ids = model.generate(input_ids, max_length=50, num_beams=5, no_repeat_ngram_size=2)
38
  bot_message = tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
+
40
  chat_history.append((message, bot_message))
41
  time.sleep(2)
42
  return "", chat_history
43
 
44
+ for chatbot in chatbots:
45
+ prompt_textbox.submit(respond, [prompt_textbox, chatbot], [prompt_textbox, chatbot])
46
+
47
+ # Create POS and Chunk tabs
48
+ pos_tab, pos_prompt_textbox, pos_chatbots = create_chatbot_tab("POS", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"])
49
+ chunk_tab, chunk_prompt_textbox, chunk_chatbots = create_chatbot_tab("Chunk", "Write a prompt and press enter", ["Strategy 1 QA", "Strategy 2 Instruction", "Strategy 3 Structured Prompting"])
50
 
51
+ # Create submit functions for POS and Chunk tabs
52
+ create_submit_function(pos_prompt_textbox, pos_chatbots)
53
+ create_submit_function(chunk_prompt_textbox, chunk_chatbots)
54
 
55
+ # Launch the demo with POS and Chunk tabs
56
+ demo = gr.Blocks()
57
+ demo.append(pos_tab)
58
+ demo.append(chunk_tab)
59
  demo.launch()