Spaces:
Sleeping
Sleeping
kz209
commited on
Commit
·
031841d
1
Parent(s):
8e22bd4
update format
Browse files- app.py +2 -5
- pages/__init__.py +1 -1
- pages/arena.py +46 -22
- pages/batch_evaluation.py +44 -30
- pages/leaderboard.py +39 -21
- pages/summarization_playground.py +103 -26
- utils/__init__.py +1 -1
- utils/data.py +3 -3
- utils/metric.py +3 -2
- utils/model.py +46 -29
- utils/multiple_stream.py +19 -14
app.py
CHANGED
@@ -13,13 +13,10 @@ This application is for **display** and is designed to facilitate **fast prototy
|
|
13 |
|
14 |
Select a demo from the sidebar below to begin experimentation."""
|
15 |
|
|
|
16 |
with gr.Blocks() as demo:
|
17 |
with gr.Column(scale=4):
|
18 |
-
content = content = gr.Blocks(
|
19 |
-
gr.Markdown(
|
20 |
-
welcome_message()
|
21 |
-
)
|
22 |
-
)
|
23 |
|
24 |
with gr.Tabs() as tabs:
|
25 |
with gr.TabItem("Summarization"):
|
|
|
13 |
|
14 |
Select a demo from the sidebar below to begin experimentation."""
|
15 |
|
16 |
+
|
17 |
with gr.Blocks() as demo:
|
18 |
with gr.Column(scale=4):
|
19 |
+
content = content = gr.Blocks(gr.Markdown(welcome_message()))
|
|
|
|
|
|
|
|
|
20 |
|
21 |
with gr.Tabs() as tabs:
|
22 |
with gr.TabItem("Summarization"):
|
pages/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
# This is the __init__.py file for the utils package
|
2 |
# You can add any initialization code or import statements here
|
3 |
|
4 |
-
__all__ = [
|
|
|
1 |
# This is the __init__.py file for the utils package
|
2 |
# You can add any initialization code or import statements here
|
3 |
|
4 |
+
__all__ = ["arena", "batch_evaluation", "leaderboard", "summarization_playground"]
|
pages/arena.py
CHANGED
@@ -10,9 +10,10 @@ from utils.multiple_stream import stream_data
|
|
10 |
|
11 |
def random_data_selection():
|
12 |
datapoint = random.choice(dataset)
|
13 |
-
datapoint = datapoint[
|
14 |
return datapoint
|
15 |
|
|
|
16 |
def create_arena():
|
17 |
with open("prompt/prompt.json", "r") as file:
|
18 |
json_data = file.read()
|
@@ -21,19 +22,24 @@ def create_arena():
|
|
21 |
with gr.Blocks(css=custom_css) as demo:
|
22 |
with gr.Group():
|
23 |
datapoint = random_data_selection()
|
24 |
-
gr.Markdown(
|
|
|
25 |
|
26 |
-
Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
|
|
|
27 |
|
28 |
-
data_textbox = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
29 |
with gr.Row():
|
30 |
random_selection_button = gr.Button("Change Data")
|
31 |
stream_button = gr.Button("✨ Click to Streaming ✨")
|
32 |
|
33 |
random_selection_button.click(
|
34 |
-
fn=random_data_selection,
|
35 |
-
inputs=[],
|
36 |
-
outputs=[data_textbox]
|
37 |
)
|
38 |
|
39 |
random.shuffle(prompts)
|
@@ -42,43 +48,56 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
|
|
42 |
# Store prompts in state components
|
43 |
state_prompts = gr.State(value=prompts)
|
44 |
state_random_selected_prompts = gr.State(value=random_selected_prompts)
|
45 |
-
|
46 |
with gr.Row():
|
47 |
-
columns = [
|
48 |
-
|
|
|
|
|
|
|
49 |
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
|
50 |
|
51 |
def start_streaming(data, random_selected_prompts):
|
52 |
-
content_list = [
|
|
|
|
|
|
|
53 |
for response_data in stream_data(content_list, model):
|
54 |
-
updates = [
|
|
|
|
|
55 |
yield tuple(updates)
|
56 |
-
|
57 |
stream_button.click(
|
58 |
fn=start_streaming,
|
59 |
inputs=[data_textbox, state_random_selected_prompts],
|
60 |
outputs=columns,
|
61 |
-
show_progress=False
|
|
|
|
|
|
|
|
|
|
|
62 |
)
|
63 |
|
64 |
-
choice = gr.Radio(label="Choose the best response:", choices=["Response 1", "Response 2", "Response 3"])
|
65 |
-
|
66 |
submit_button = gr.Button("Submit")
|
67 |
|
68 |
output = gr.Textbox(label="You selected:", visible=False)
|
69 |
|
70 |
-
def update_prompt_metrics(
|
|
|
|
|
71 |
if selected_choice == "Response 1":
|
72 |
-
prompt_id = random_selected_prompts[0][
|
73 |
elif selected_choice == "Response 2":
|
74 |
-
prompt_id = random_selected_prompts[1][
|
75 |
elif selected_choice == "Response 3":
|
76 |
-
prompt_id = random_selected_prompts[2][
|
77 |
else:
|
78 |
raise ValueError(f"No corresponding response of {selected_choice}")
|
79 |
|
80 |
for prompt in prompts:
|
81 |
-
if prompt[
|
82 |
prompt["metric"]["winning_number"] += 1
|
83 |
break
|
84 |
else:
|
@@ -87,7 +106,11 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
|
|
87 |
with open("prompt/prompt.json", "w") as f:
|
88 |
json.dump(prompts, f)
|
89 |
|
90 |
-
return
|
|
|
|
|
|
|
|
|
91 |
|
92 |
submit_button.click(
|
93 |
fn=update_prompt_metrics,
|
@@ -97,6 +120,7 @@ Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
|
|
97 |
|
98 |
return demo
|
99 |
|
|
|
100 |
if __name__ == "__main__":
|
101 |
demo = create_arena()
|
102 |
demo.queue()
|
|
|
10 |
|
11 |
def random_data_selection():
|
12 |
datapoint = random.choice(dataset)
|
13 |
+
datapoint = datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"]
|
14 |
return datapoint
|
15 |
|
16 |
+
|
17 |
def create_arena():
|
18 |
with open("prompt/prompt.json", "r") as file:
|
19 |
json_data = file.read()
|
|
|
22 |
with gr.Blocks(css=custom_css) as demo:
|
23 |
with gr.Group():
|
24 |
datapoint = random_data_selection()
|
25 |
+
gr.Markdown(
|
26 |
+
"""This arena is designed to compare different prompts. Click the button to stream responses from randomly shuffled prompts. Each column represents a response generated from one randomly selected prompt.
|
27 |
|
28 |
+
Once the streaming is complete, you can choose the best response.\u2764\ufe0f"""
|
29 |
+
)
|
30 |
|
31 |
+
data_textbox = gr.Textbox(
|
32 |
+
label="Data",
|
33 |
+
lines=10,
|
34 |
+
placeholder="Datapoints to test...",
|
35 |
+
value=datapoint,
|
36 |
+
)
|
37 |
with gr.Row():
|
38 |
random_selection_button = gr.Button("Change Data")
|
39 |
stream_button = gr.Button("✨ Click to Streaming ✨")
|
40 |
|
41 |
random_selection_button.click(
|
42 |
+
fn=random_data_selection, inputs=[], outputs=[data_textbox]
|
|
|
|
|
43 |
)
|
44 |
|
45 |
random.shuffle(prompts)
|
|
|
48 |
# Store prompts in state components
|
49 |
state_prompts = gr.State(value=prompts)
|
50 |
state_random_selected_prompts = gr.State(value=random_selected_prompts)
|
51 |
+
|
52 |
with gr.Row():
|
53 |
+
columns = [
|
54 |
+
gr.Textbox(label=f"Prompt {i+1}", lines=10)
|
55 |
+
for i in range(len(random_selected_prompts))
|
56 |
+
]
|
57 |
+
|
58 |
model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
|
59 |
|
60 |
def start_streaming(data, random_selected_prompts):
|
61 |
+
content_list = [
|
62 |
+
prompt["prompt"] + "\n{" + data + "}\n\nsummary:"
|
63 |
+
for prompt in random_selected_prompts
|
64 |
+
]
|
65 |
for response_data in stream_data(content_list, model):
|
66 |
+
updates = [
|
67 |
+
gr.update(value=response_data[i]) for i in range(len(columns))
|
68 |
+
]
|
69 |
yield tuple(updates)
|
70 |
+
|
71 |
stream_button.click(
|
72 |
fn=start_streaming,
|
73 |
inputs=[data_textbox, state_random_selected_prompts],
|
74 |
outputs=columns,
|
75 |
+
show_progress=False,
|
76 |
+
)
|
77 |
+
|
78 |
+
choice = gr.Radio(
|
79 |
+
label="Choose the best response:",
|
80 |
+
choices=["Response 1", "Response 2", "Response 3"],
|
81 |
)
|
82 |
|
|
|
|
|
83 |
submit_button = gr.Button("Submit")
|
84 |
|
85 |
output = gr.Textbox(label="You selected:", visible=False)
|
86 |
|
87 |
+
def update_prompt_metrics(
|
88 |
+
selected_choice, prompts, random_selected_prompts
|
89 |
+
):
|
90 |
if selected_choice == "Response 1":
|
91 |
+
prompt_id = random_selected_prompts[0]["id"]
|
92 |
elif selected_choice == "Response 2":
|
93 |
+
prompt_id = random_selected_prompts[1]["id"]
|
94 |
elif selected_choice == "Response 3":
|
95 |
+
prompt_id = random_selected_prompts[2]["id"]
|
96 |
else:
|
97 |
raise ValueError(f"No corresponding response of {selected_choice}")
|
98 |
|
99 |
for prompt in prompts:
|
100 |
+
if prompt["id"] == prompt_id:
|
101 |
prompt["metric"]["winning_number"] += 1
|
102 |
break
|
103 |
else:
|
|
|
106 |
with open("prompt/prompt.json", "w") as f:
|
107 |
json.dump(prompts, f)
|
108 |
|
109 |
+
return (
|
110 |
+
gr.update(value=f"You selected: {selected_choice}", visible=True),
|
111 |
+
gr.update(interactive=False),
|
112 |
+
gr.update(interactive=False),
|
113 |
+
)
|
114 |
|
115 |
submit_button.click(
|
116 |
fn=update_prompt_metrics,
|
|
|
120 |
|
121 |
return demo
|
122 |
|
123 |
+
|
124 |
if __name__ == "__main__":
|
125 |
demo = create_arena()
|
126 |
demo.queue()
|
pages/batch_evaluation.py
CHANGED
@@ -12,21 +12,22 @@ from utils.model import Model
|
|
12 |
|
13 |
load_dotenv()
|
14 |
|
|
|
15 |
def display_results(response_list):
|
16 |
-
overall_score = np.mean([r[
|
17 |
-
|
18 |
html_output = f"<h2>Overall Score: {overall_score:.2f}</h2>"
|
19 |
-
|
20 |
for i, item in enumerate(response_list, 1):
|
21 |
-
dialogue = item[
|
22 |
-
summary = item[
|
23 |
-
response = item[
|
24 |
-
rouge_score = item[
|
25 |
-
|
26 |
-
dialogue = html.escape(item[
|
27 |
-
summary = html.escape(item[
|
28 |
-
response = html.escape(item[
|
29 |
-
|
30 |
html_output += f"""
|
31 |
<details>
|
32 |
<summary>Response {i} (Rouge Score: {rouge_score:.2f})</summary>
|
@@ -49,6 +50,7 @@ def display_results(response_list):
|
|
49 |
|
50 |
return html_output
|
51 |
|
|
|
52 |
def process(model_selection, prompt, num=10):
|
53 |
response_list = []
|
54 |
with open("test_samples/test_data.json", "r") as file:
|
@@ -57,21 +59,21 @@ def process(model_selection, prompt, num=10):
|
|
57 |
|
58 |
for i, data in enumerate(dataset):
|
59 |
logging.info(f"Start testing datapoint {i+1}")
|
60 |
-
dialogue = data[
|
61 |
-
format = data[
|
62 |
-
summary = data[
|
63 |
-
response = generate_answer(
|
|
|
|
|
64 |
|
65 |
rouge_score = metric_rouge_score(response, summary)
|
66 |
|
67 |
response_list.append(
|
68 |
{
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
'rouge_score': rouge_score
|
74 |
-
}
|
75 |
}
|
76 |
)
|
77 |
|
@@ -81,22 +83,34 @@ def process(model_selection, prompt, num=10):
|
|
81 |
|
82 |
|
83 |
def create_batch_evaluation_interface():
|
84 |
-
with gr.Blocks(
|
85 |
-
gr.
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
model_dropdown = gr.Dropdown(
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
submit_button = gr.Button("✨ Submit ✨")
|
90 |
output = gr.HTML(label="Results")
|
91 |
|
92 |
submit_button.click(
|
93 |
-
process,
|
94 |
-
inputs=[model_dropdown, Template_text],
|
95 |
-
outputs=output
|
96 |
)
|
97 |
|
98 |
return demo
|
99 |
|
|
|
100 |
if __name__ == "__main__":
|
101 |
demo = create_batch_evaluation_interface()
|
102 |
-
demo.launch()
|
|
|
12 |
|
13 |
load_dotenv()
|
14 |
|
15 |
+
|
16 |
def display_results(response_list):
|
17 |
+
overall_score = np.mean([r["metric_score"]["rouge_score"] for r in response_list])
|
18 |
+
|
19 |
html_output = f"<h2>Overall Score: {overall_score:.2f}</h2>"
|
20 |
+
|
21 |
for i, item in enumerate(response_list, 1):
|
22 |
+
dialogue = item["dialogue"]
|
23 |
+
summary = item["summary"]
|
24 |
+
response = item["response"]
|
25 |
+
rouge_score = item["metric_score"]["rouge_score"]
|
26 |
+
|
27 |
+
dialogue = html.escape(item["dialogue"]).replace("\n", "<br>")
|
28 |
+
summary = html.escape(item["summary"]).replace("\n", "<br>")
|
29 |
+
response = html.escape(item["response"]).replace("\n", "<br>")
|
30 |
+
|
31 |
html_output += f"""
|
32 |
<details>
|
33 |
<summary>Response {i} (Rouge Score: {rouge_score:.2f})</summary>
|
|
|
50 |
|
51 |
return html_output
|
52 |
|
53 |
+
|
54 |
def process(model_selection, prompt, num=10):
|
55 |
response_list = []
|
56 |
with open("test_samples/test_data.json", "r") as file:
|
|
|
59 |
|
60 |
for i, data in enumerate(dataset):
|
61 |
logging.info(f"Start testing datapoint {i+1}")
|
62 |
+
dialogue = data["dialogue"]
|
63 |
+
format = data["format"]
|
64 |
+
summary = data["summary"]
|
65 |
+
response = generate_answer(
|
66 |
+
dialogue, model_selection, prompt + f" Output following {format} format."
|
67 |
+
)
|
68 |
|
69 |
rouge_score = metric_rouge_score(response, summary)
|
70 |
|
71 |
response_list.append(
|
72 |
{
|
73 |
+
"dialogue": dialogue,
|
74 |
+
"summary": summary,
|
75 |
+
"response": response,
|
76 |
+
"metric_score": {"rouge_score": rouge_score},
|
|
|
|
|
77 |
}
|
78 |
)
|
79 |
|
|
|
83 |
|
84 |
|
85 |
def create_batch_evaluation_interface():
|
86 |
+
with gr.Blocks(
|
87 |
+
theme=gr.themes.Soft(spacing_size="sm", text_size="sm"), css=custom_css
|
88 |
+
) as demo:
|
89 |
+
gr.Markdown(
|
90 |
+
"## Here are evaluation setups. It will run though datapoints in test_data.josn to generate and evaluate. Show results once finished."
|
91 |
+
)
|
92 |
|
93 |
+
model_dropdown = gr.Dropdown(
|
94 |
+
choices=Model.__model_list__,
|
95 |
+
label="Choose a model",
|
96 |
+
value=Model.__model_list__[0],
|
97 |
+
)
|
98 |
+
Template_text = gr.Textbox(
|
99 |
+
value="""Summarize the following dialogue""",
|
100 |
+
label="Input Prompting Template",
|
101 |
+
lines=8,
|
102 |
+
placeholder="Input your prompts",
|
103 |
+
)
|
104 |
submit_button = gr.Button("✨ Submit ✨")
|
105 |
output = gr.HTML(label="Results")
|
106 |
|
107 |
submit_button.click(
|
108 |
+
process, inputs=[model_dropdown, Template_text], outputs=output
|
|
|
|
|
109 |
)
|
110 |
|
111 |
return demo
|
112 |
|
113 |
+
|
114 |
if __name__ == "__main__":
|
115 |
demo = create_batch_evaluation_interface()
|
116 |
+
demo.launch()
|
pages/leaderboard.py
CHANGED
@@ -9,72 +9,90 @@ import pandas as pd
|
|
9 |
def create_html_with_tooltip(id, base_url):
|
10 |
return f'<a href="{base_url}"target="_blank">{id}</a>'
|
11 |
|
|
|
12 |
# Load prompts from JSON
|
13 |
with open("prompt/prompt.json", "r") as file:
|
14 |
json_data = file.read()
|
15 |
prompts = json.loads(json_data)
|
16 |
|
17 |
# Prepare leaderboard data
|
18 |
-
winning_rate = [prompt[
|
19 |
-
winning_rate = [round(num / sum(winning_rate), 4)for num in winning_rate]
|
20 |
data = {
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
26 |
}
|
27 |
|
28 |
# Create DataFrame and sort by Rouge Score
|
29 |
df = pd.DataFrame(data)
|
30 |
-
df.sort_values(by=
|
31 |
-
df[
|
32 |
|
33 |
# Assign medals for top 3 authors
|
34 |
-
medals = [
|
35 |
for i in range(3):
|
36 |
-
df.loc[i,
|
|
|
37 |
|
38 |
# Function to update the leaderboard
|
39 |
def update_leaderboard(sort_by):
|
40 |
sorted_df = df.sort_values(by=sort_by, ascending=False, ignore_index=True)
|
41 |
-
sorted_df[
|
42 |
|
43 |
# Convert DataFrame to HTML with clickable headers for sorting
|
44 |
table_html = sorted_df.to_html(index=False, escape=False)
|
45 |
|
46 |
# Add sorting links to column headers
|
47 |
for column in sorted_df.columns:
|
48 |
-
table_html = table_html.replace(
|
49 |
-
|
|
|
|
|
50 |
|
51 |
return table_html
|
52 |
|
|
|
53 |
# Define Gradio interface
|
54 |
def create_leaderboard():
|
55 |
-
with gr.Blocks(
|
|
|
56 |
.tooltip { cursor: pointer; color: blue; text-decoration: underline; }
|
57 |
table { border-collapse: collapse; width: 100%; }
|
58 |
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
|
59 |
th { background-color: #f2f2f2; }
|
60 |
#prompt-display { display: none; }
|
61 |
-
"""
|
|
|
62 |
gr.Markdown("# 🏆 Summarization Arena Leaderboard")
|
63 |
with gr.Row():
|
64 |
-
gr.Markdown(
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
# Dropdown for sorting
|
68 |
sort_by = gr.Dropdown(list(df.columns), label="Sort by", value="Rouge Score")
|
69 |
|
70 |
# Display the leaderboard
|
71 |
leaderboard = gr.HTML(update_leaderboard("Rouge Score"), elem_id="leaderboard")
|
72 |
-
|
73 |
# Change sorting when dropdown is changed
|
74 |
-
sort_by.change(
|
|
|
|
|
|
|
|
|
75 |
|
76 |
return demo
|
77 |
|
|
|
78 |
# Launch Gradio interface
|
79 |
if __name__ == "__main__":
|
80 |
demo = create_leaderboard()
|
|
|
9 |
def create_html_with_tooltip(id, base_url):
|
10 |
return f'<a href="{base_url}"target="_blank">{id}</a>'
|
11 |
|
12 |
+
|
13 |
# Load prompts from JSON
|
14 |
with open("prompt/prompt.json", "r") as file:
|
15 |
json_data = file.read()
|
16 |
prompts = json.loads(json_data)
|
17 |
|
18 |
# Prepare leaderboard data
|
19 |
+
winning_rate = [prompt["metric"]["winning_number"] for prompt in prompts]
|
20 |
+
winning_rate = [round(num / sum(winning_rate), 4) for num in winning_rate]
|
21 |
data = {
|
22 |
+
"Rank": [i + 1 for i in range(len(prompts))],
|
23 |
+
"Methods": [
|
24 |
+
create_html_with_tooltip(prompt["id"], prompt["url"]) for prompt in prompts
|
25 |
+
],
|
26 |
+
"Rouge Score": [prompt["metric"]["Rouge"] for prompt in prompts],
|
27 |
+
"Winning Rate": winning_rate,
|
28 |
+
"Authors": [prompt["author"] for prompt in prompts],
|
29 |
}
|
30 |
|
31 |
# Create DataFrame and sort by Rouge Score
|
32 |
df = pd.DataFrame(data)
|
33 |
+
df.sort_values(by="Rouge Score", ascending=False, inplace=True, ignore_index=True)
|
34 |
+
df["Rank"] = range(1, len(df) + 1)
|
35 |
|
36 |
# Assign medals for top 3 authors
|
37 |
+
medals = ["🏅", "🥈", "🥉"]
|
38 |
for i in range(3):
|
39 |
+
df.loc[i, "Authors"] = f"{medals[i]} {df.loc[i, 'Authors']}"
|
40 |
+
|
41 |
|
42 |
# Function to update the leaderboard
|
43 |
def update_leaderboard(sort_by):
|
44 |
sorted_df = df.sort_values(by=sort_by, ascending=False, ignore_index=True)
|
45 |
+
sorted_df["Rank"] = range(1, len(sorted_df) + 1)
|
46 |
|
47 |
# Convert DataFrame to HTML with clickable headers for sorting
|
48 |
table_html = sorted_df.to_html(index=False, escape=False)
|
49 |
|
50 |
# Add sorting links to column headers
|
51 |
for column in sorted_df.columns:
|
52 |
+
table_html = table_html.replace(
|
53 |
+
f"<th>{column}</th>",
|
54 |
+
f'<th><a href="#" onclick="sortBy(\'{column}\'); return false;">{column}</a></th>',
|
55 |
+
)
|
56 |
|
57 |
return table_html
|
58 |
|
59 |
+
|
60 |
# Define Gradio interface
|
61 |
def create_leaderboard():
|
62 |
+
with gr.Blocks(
|
63 |
+
css="""
|
64 |
.tooltip { cursor: pointer; color: blue; text-decoration: underline; }
|
65 |
table { border-collapse: collapse; width: 100%; }
|
66 |
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
|
67 |
th { background-color: #f2f2f2; }
|
68 |
#prompt-display { display: none; }
|
69 |
+
"""
|
70 |
+
) as demo:
|
71 |
gr.Markdown("# 🏆 Summarization Arena Leaderboard")
|
72 |
with gr.Row():
|
73 |
+
gr.Markdown(
|
74 |
+
"[Blog](placeholder) | [GitHub](placeholder) | [Paper](placeholder) | [Dataset](placeholder) | [Twitter](placeholder) | [Discord](placeholder)"
|
75 |
+
)
|
76 |
+
gr.Markdown(
|
77 |
+
"Welcome to our open platform for evaluating LLM summarization capabilities."
|
78 |
+
)
|
79 |
+
|
80 |
# Dropdown for sorting
|
81 |
sort_by = gr.Dropdown(list(df.columns), label="Sort by", value="Rouge Score")
|
82 |
|
83 |
# Display the leaderboard
|
84 |
leaderboard = gr.HTML(update_leaderboard("Rouge Score"), elem_id="leaderboard")
|
85 |
+
|
86 |
# Change sorting when dropdown is changed
|
87 |
+
sort_by.change(
|
88 |
+
fn=lambda sort: update_leaderboard(sort),
|
89 |
+
inputs=sort_by,
|
90 |
+
outputs=leaderboard,
|
91 |
+
)
|
92 |
|
93 |
return demo
|
94 |
|
95 |
+
|
96 |
# Launch Gradio interface
|
97 |
if __name__ == "__main__":
|
98 |
demo = create_leaderboard()
|
pages/summarization_playground.py
CHANGED
@@ -65,27 +65,26 @@ input-label {
|
|
65 |
}
|
66 |
"""
|
67 |
|
68 |
-
__model_on_gpu__ =
|
69 |
model = {model_name: None for model_name in Model.__model_list__}
|
70 |
|
71 |
-
random_label =
|
72 |
examples = {
|
73 |
"example 1": """Boston's injury reporting for Kristaps Porziņģis has been fairly coy. He missed Game 3, but his coach told reporters just before Game 4 that was technically available, but with a catch.
|
74 |
Joe Mazzulla said Porziņģis would "only be used in specific instances, if necessary." That sounds like the team doesn't want to risk further injury to his dislocated Posterior Tibialis (or some other body part, due to overcompensation for the ankle), unless it's in a desperate situation.
|
75 |
Being up 3-1, with Game 5 at home, doesn't qualify as desperate. So, expect the Celtics to continue slow-playing KP's return.
|
76 |
It'd obviously be nice for Boston to have his rim protection and jump shooting back. It was missed in the Game 4 blowout, but the Celtics have also demonstrated they can win without the big man throughout this campaign.
|
77 |
On top of winning Game 3 of this series, Boston is plus-10.9 points per 100 possessions when Porziņģis has been off the floor this regular and postseason.""",
|
78 |
-
|
79 |
"example 2": """Prior to the Finals, we predicted that Dereck Lively II's minutes would swell over the course of the series, and that's starting to play out.
|
80 |
He averaged 18.8 minutes in Games 1 and 2 and was up to 26.2 in Games 3 and 4. That's with the regulars being pulled long before the final buzzer in Friday's game, too.
|
81 |
Expect the rookie's playing time to continue to climb in Game 5. It seems increasingly clear that coach Jason Kidd trusts him over the rest of Dallas' bigs, and it's not hard to see why.
|
82 |
Lively has been absolutely relentless on the offensive glass all postseason. He makes solid decisions as a passer when his rolls don't immediately lead to dunks. And he's not a liability when caught defending guards or wings outside.
|
83 |
All of that has led to postseason averages of 8.2 points, 7.6 rebounds, 1.4 assists and 1.0 blocks in just 21.9 minutes, as well as a double-double in 22 minutes of Game 4.
|
84 |
Back in Boston, Kidd is going to rely on Lively even more. He'll play close to 30 minutes and reach double-figures in both scoring and rebounding again.""",
|
85 |
-
|
86 |
-
random_label: ""
|
87 |
}
|
88 |
|
|
|
89 |
def model_device_check(model_name):
|
90 |
global __model_on_gpu__
|
91 |
|
@@ -106,56 +105,134 @@ def get_model_batch_generation(model_name):
|
|
106 |
return model[model_name]
|
107 |
|
108 |
|
109 |
-
def generate_answer(
|
|
|
|
|
110 |
model_device_check(model_name)
|
111 |
-
content = prompt +
|
112 |
-
answer =
|
|
|
|
|
|
|
|
|
113 |
|
114 |
return answer
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
if input_text:
|
118 |
logging.info("Start generation")
|
119 |
-
response = generate_answer(
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
else:
|
122 |
return "Please fill the input to generate outputs."
|
123 |
|
|
|
124 |
def update_input(example):
|
125 |
if example == random_label:
|
126 |
datapoint = random.choice(dataset)
|
127 |
-
return datapoint[
|
128 |
return examples[example]
|
129 |
|
|
|
130 |
def create_summarization_interface():
|
131 |
-
with gr.Blocks(
|
132 |
-
gr.
|
|
|
|
|
|
|
|
|
133 |
|
134 |
with gr.Row():
|
135 |
-
example_dropdown = gr.Dropdown(
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
datapoint = random.choice(dataset)
|
141 |
-
input_text = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
142 |
submit_button = gr.Button("✨ Submit ✨")
|
143 |
|
144 |
with gr.Row():
|
145 |
with gr.Column(scale=1):
|
146 |
-
gr.Markdown(
|
|
|
|
|
147 |
with gr.Column():
|
148 |
-
temperature = gr.Number(
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
with gr.Column(scale=3):
|
152 |
output = gr.Markdown(line_breaks=True)
|
153 |
|
154 |
-
example_dropdown.change(
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
return demo
|
158 |
|
|
|
159 |
if __name__ == "__main__":
|
160 |
demo = create_summarization_interface()
|
161 |
demo.launch()
|
|
|
65 |
}
|
66 |
"""
|
67 |
|
68 |
+
__model_on_gpu__ = ""
|
69 |
model = {model_name: None for model_name in Model.__model_list__}
|
70 |
|
71 |
+
random_label = "🔀 Random dialogue from dataset"
|
72 |
examples = {
|
73 |
"example 1": """Boston's injury reporting for Kristaps Porziņģis has been fairly coy. He missed Game 3, but his coach told reporters just before Game 4 that was technically available, but with a catch.
|
74 |
Joe Mazzulla said Porziņģis would "only be used in specific instances, if necessary." That sounds like the team doesn't want to risk further injury to his dislocated Posterior Tibialis (or some other body part, due to overcompensation for the ankle), unless it's in a desperate situation.
|
75 |
Being up 3-1, with Game 5 at home, doesn't qualify as desperate. So, expect the Celtics to continue slow-playing KP's return.
|
76 |
It'd obviously be nice for Boston to have his rim protection and jump shooting back. It was missed in the Game 4 blowout, but the Celtics have also demonstrated they can win without the big man throughout this campaign.
|
77 |
On top of winning Game 3 of this series, Boston is plus-10.9 points per 100 possessions when Porziņģis has been off the floor this regular and postseason.""",
|
|
|
78 |
"example 2": """Prior to the Finals, we predicted that Dereck Lively II's minutes would swell over the course of the series, and that's starting to play out.
|
79 |
He averaged 18.8 minutes in Games 1 and 2 and was up to 26.2 in Games 3 and 4. That's with the regulars being pulled long before the final buzzer in Friday's game, too.
|
80 |
Expect the rookie's playing time to continue to climb in Game 5. It seems increasingly clear that coach Jason Kidd trusts him over the rest of Dallas' bigs, and it's not hard to see why.
|
81 |
Lively has been absolutely relentless on the offensive glass all postseason. He makes solid decisions as a passer when his rolls don't immediately lead to dunks. And he's not a liability when caught defending guards or wings outside.
|
82 |
All of that has led to postseason averages of 8.2 points, 7.6 rebounds, 1.4 assists and 1.0 blocks in just 21.9 minutes, as well as a double-double in 22 minutes of Game 4.
|
83 |
Back in Boston, Kidd is going to rely on Lively even more. He'll play close to 30 minutes and reach double-figures in both scoring and rebounding again.""",
|
84 |
+
random_label: "",
|
|
|
85 |
}
|
86 |
|
87 |
+
|
88 |
def model_device_check(model_name):
|
89 |
global __model_on_gpu__
|
90 |
|
|
|
105 |
return model[model_name]
|
106 |
|
107 |
|
108 |
+
def generate_answer(
|
109 |
+
sources, model_name, prompt, temperature=0.0001, max_new_tokens=500, do_sample=True
|
110 |
+
):
|
111 |
model_device_check(model_name)
|
112 |
+
content = prompt + "\n{" + sources + "}\n\nsummary:"
|
113 |
+
answer = (
|
114 |
+
model[model_name]
|
115 |
+
.gen(content, temperature, max_new_tokens, do_sample)[0]
|
116 |
+
.strip()
|
117 |
+
)
|
118 |
|
119 |
return answer
|
120 |
|
121 |
+
|
122 |
+
def process_input(
|
123 |
+
input_text,
|
124 |
+
model_selection,
|
125 |
+
prompt,
|
126 |
+
temperature=0.0001,
|
127 |
+
max_new_tokens=500,
|
128 |
+
do_sample=True,
|
129 |
+
):
|
130 |
if input_text:
|
131 |
logging.info("Start generation")
|
132 |
+
response = generate_answer(
|
133 |
+
input_text, model_selection, prompt, temperature, max_new_tokens, do_sample
|
134 |
+
)
|
135 |
+
return (
|
136 |
+
f"## Original Dialogue:\n\n{input_text}\n\n## Summarization:\n\n{response}"
|
137 |
+
)
|
138 |
else:
|
139 |
return "Please fill the input to generate outputs."
|
140 |
|
141 |
+
|
142 |
def update_input(example):
|
143 |
if example == random_label:
|
144 |
datapoint = random.choice(dataset)
|
145 |
+
return datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"]
|
146 |
return examples[example]
|
147 |
|
148 |
+
|
149 |
def create_summarization_interface():
|
150 |
+
with gr.Blocks(
|
151 |
+
theme=gr.themes.Soft(spacing_size="sm", text_size="sm"), css=custom_css
|
152 |
+
) as demo:
|
153 |
+
gr.Markdown(
|
154 |
+
"## This is a playground to test prompts for clinical dialogue summarizations"
|
155 |
+
)
|
156 |
|
157 |
with gr.Row():
|
158 |
+
example_dropdown = gr.Dropdown(
|
159 |
+
choices=list(examples.keys()),
|
160 |
+
label="Choose an example",
|
161 |
+
value=random_label,
|
162 |
+
)
|
163 |
+
model_dropdown = gr.Dropdown(
|
164 |
+
choices=Model.__model_list__,
|
165 |
+
label="Choose a model",
|
166 |
+
value=Model.__model_list__[0],
|
167 |
+
)
|
168 |
+
|
169 |
+
gr.Markdown(
|
170 |
+
"<div style='border: 4px solid white; padding: 3px; border-radius: 5px;width:100px;padding-top: 0.5px;padding-bottom: 10px;'><h3>Prompt 👥</h3></center></div>"
|
171 |
+
)
|
172 |
+
Template_text = gr.Textbox(
|
173 |
+
value="""Summarize the following dialogue""",
|
174 |
+
label="Input Prompting Template",
|
175 |
+
lines=4,
|
176 |
+
placeholder="Input your prompts",
|
177 |
+
)
|
178 |
datapoint = random.choice(dataset)
|
179 |
+
input_text = gr.Textbox(
|
180 |
+
label="Input Dialogue",
|
181 |
+
lines=7,
|
182 |
+
placeholder="Enter text here...",
|
183 |
+
value=datapoint["section_text"] + "\n\nDialogue:\n" + datapoint["dialogue"],
|
184 |
+
)
|
185 |
submit_button = gr.Button("✨ Submit ✨")
|
186 |
|
187 |
with gr.Row():
|
188 |
with gr.Column(scale=1):
|
189 |
+
gr.Markdown(
|
190 |
+
"<div style='border: 4px solid white; padding: 2px; border-radius: 5px;width:130px;padding-bottom: 10px;'><b><h3>Parameters 📈</h3></center></b></div>"
|
191 |
+
)
|
192 |
with gr.Column():
|
193 |
+
temperature = gr.Number(
|
194 |
+
label="Temperature",
|
195 |
+
elem_classes="parameter-text",
|
196 |
+
value=0.0001,
|
197 |
+
minimum=0.000001,
|
198 |
+
maximum=1.0,
|
199 |
+
)
|
200 |
+
max_new_tokens = gr.Number(
|
201 |
+
label="Max New Tokens",
|
202 |
+
elem_classes="parameter-text",
|
203 |
+
value=500,
|
204 |
+
precision=0,
|
205 |
+
minimum=0,
|
206 |
+
maximum=500,
|
207 |
+
)
|
208 |
+
do_sample = gr.Dropdown(
|
209 |
+
[True, False],
|
210 |
+
label="Do Sample",
|
211 |
+
elem_classes="parameter-text",
|
212 |
+
value=True,
|
213 |
+
)
|
214 |
with gr.Column(scale=3):
|
215 |
output = gr.Markdown(line_breaks=True)
|
216 |
|
217 |
+
example_dropdown.change(
|
218 |
+
update_input, inputs=[example_dropdown], outputs=[input_text]
|
219 |
+
)
|
220 |
+
submit_button.click(
|
221 |
+
process_input,
|
222 |
+
inputs=[
|
223 |
+
input_text,
|
224 |
+
model_dropdown,
|
225 |
+
Template_text,
|
226 |
+
temperature,
|
227 |
+
max_new_tokens,
|
228 |
+
do_sample,
|
229 |
+
],
|
230 |
+
outputs=[output],
|
231 |
+
)
|
232 |
|
233 |
return demo
|
234 |
|
235 |
+
|
236 |
if __name__ == "__main__":
|
237 |
demo = create_summarization_interface()
|
238 |
demo.launch()
|
utils/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
# This is the __init__.py file for the utils package
|
2 |
# You can add any initialization code or import statements here
|
3 |
|
4 |
-
__all__ = [
|
|
|
1 |
# This is the __init__.py file for the utils package
|
2 |
# You can add any initialization code or import statements here
|
3 |
|
4 |
+
__all__ = ["multiple_stream", "model", "data", "metric"]
|
utils/data.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
-
dialogsum = load_dataset('har1/MTS_Dialogue-Clinical_Note')
|
3 |
-
dataset = list(dialogsum['train'])
|
4 |
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
|
|
|
|
2 |
|
3 |
+
dialogsum = load_dataset("har1/MTS_Dialogue-Clinical_Note")
|
4 |
+
dataset = list(dialogsum["train"])
|
utils/metric.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from rouge_score import rouge_scorer
|
2 |
|
3 |
-
scorer = rouge_scorer.RougeScorer([
|
|
|
4 |
|
5 |
def metric_rouge_score(pred, ref):
|
6 |
-
return scorer.score(pred, ref)[
|
|
|
1 |
from rouge_score import rouge_scorer
|
2 |
|
3 |
+
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
|
4 |
+
|
5 |
|
6 |
def metric_rouge_score(pred, ref):
|
7 |
+
return scorer.score(pred, ref)["rougeL"].fmeasure
|
utils/model.py
CHANGED
@@ -6,7 +6,8 @@ from huggingface_hub import login
|
|
6 |
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
7 |
from vllm import LLM, SamplingParams
|
8 |
|
9 |
-
login(token=os.getenv(
|
|
|
10 |
|
11 |
class Model(torch.nn.Module):
|
12 |
number_of_models = 0
|
@@ -15,17 +16,17 @@ class Model(torch.nn.Module):
|
|
15 |
"lmsys/vicuna-7b-v1.5",
|
16 |
"google-t5/t5-large",
|
17 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
18 |
-
"meta-llama/Meta-Llama-3.1-8B-Instruct"
|
19 |
]
|
20 |
|
21 |
def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
|
22 |
super(Model, self).__init__()
|
23 |
-
|
24 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
25 |
self.name = model_name
|
26 |
self.use_vllm = model_name != "google-t5/t5-large"
|
27 |
|
28 |
-
logging.info(f
|
29 |
|
30 |
if self.use_vllm:
|
31 |
# 使用vLLM加载模型
|
@@ -33,18 +34,16 @@ class Model(torch.nn.Module):
|
|
33 |
model=model_name,
|
34 |
dtype="half",
|
35 |
tokenizer=model_name,
|
36 |
-
trust_remote_code=True
|
37 |
)
|
38 |
else:
|
39 |
# 加载原始transformers模型
|
40 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
41 |
-
model_name,
|
42 |
-
torch_dtype=torch.bfloat16,
|
43 |
-
device_map="auto"
|
44 |
)
|
45 |
self.model.eval()
|
46 |
|
47 |
-
logging.info(f
|
48 |
self.update()
|
49 |
|
50 |
@classmethod
|
@@ -56,13 +55,15 @@ class Model(torch.nn.Module):
|
|
56 |
sampling_params = SamplingParams(
|
57 |
temperature=temp,
|
58 |
max_tokens=max_length,
|
59 |
-
#top_p=0.95 if do_sample else 1.0,
|
60 |
-
stop_token_ids=[self.tokenizer.eos_token_id]
|
61 |
)
|
62 |
outputs = self.llm.generate(content_list, sampling_params)
|
63 |
return [output.outputs[0].text for output in outputs]
|
64 |
else:
|
65 |
-
input_ids = self.tokenizer(
|
|
|
|
|
66 |
outputs = self.model.generate(
|
67 |
input_ids,
|
68 |
max_new_tokens=max_length,
|
@@ -70,7 +71,9 @@ class Model(torch.nn.Module):
|
|
70 |
temperature=temp,
|
71 |
eos_token_id=self.tokenizer.eos_token_id,
|
72 |
)
|
73 |
-
return self.tokenizer.batch_decode(
|
|
|
|
|
74 |
|
75 |
def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
|
76 |
if self.use_vllm:
|
@@ -78,24 +81,28 @@ class Model(torch.nn.Module):
|
|
78 |
temperature=temp,
|
79 |
max_tokens=max_length,
|
80 |
top_p=0.95 if do_sample else 1.0,
|
81 |
-
stop_token_ids=[self.tokenizer.eos_token_id]
|
82 |
)
|
83 |
outputs = self.llm.generate(content_list, sampling_params, stream=True)
|
84 |
-
|
85 |
prev_token_ids = [[] for _ in content_list]
|
86 |
-
|
87 |
for output in outputs:
|
88 |
for i, request_output in enumerate(output.outputs):
|
89 |
current_token_ids = request_output.token_ids
|
90 |
-
new_token_ids = current_token_ids[len(prev_token_ids[i]):]
|
91 |
prev_token_ids[i] = current_token_ids.copy()
|
92 |
-
|
93 |
for token_id in new_token_ids:
|
94 |
-
token_text = self.tokenizer.decode(
|
|
|
|
|
95 |
yield i, token_text
|
96 |
else:
|
97 |
-
input_ids = self.tokenizer(
|
98 |
-
|
|
|
|
|
99 |
gen_kwargs = {
|
100 |
"input_ids": input_ids,
|
101 |
"do_sample": do_sample,
|
@@ -103,7 +110,7 @@ class Model(torch.nn.Module):
|
|
103 |
"eos_token_id": self.tokenizer.eos_token_id,
|
104 |
"max_new_tokens": 1,
|
105 |
"return_dict_in_generate": True,
|
106 |
-
"output_scores": True
|
107 |
}
|
108 |
|
109 |
generated_tokens = 0
|
@@ -113,16 +120,26 @@ class Model(torch.nn.Module):
|
|
113 |
while generated_tokens < max_length and len(active_sequences) > 0:
|
114 |
with torch.no_grad():
|
115 |
output = self.model.generate(**gen_kwargs)
|
116 |
-
|
117 |
next_tokens = output.sequences[:, -1].unsqueeze(-1)
|
118 |
-
|
119 |
for i, token in zip(active_sequences, next_tokens):
|
120 |
-
yield i.item(), self.tokenizer.decode(
|
|
|
|
|
121 |
|
122 |
-
gen_kwargs["input_ids"] = torch.cat(
|
|
|
|
|
123 |
generated_tokens += 1
|
124 |
|
125 |
-
completed = (
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
if len(active_sequences) > 0:
|
128 |
-
gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
|
|
|
6 |
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
7 |
from vllm import LLM, SamplingParams
|
8 |
|
9 |
+
login(token=os.getenv("HF_TOKEN"))
|
10 |
+
|
11 |
|
12 |
class Model(torch.nn.Module):
|
13 |
number_of_models = 0
|
|
|
16 |
"lmsys/vicuna-7b-v1.5",
|
17 |
"google-t5/t5-large",
|
18 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
19 |
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
20 |
]
|
21 |
|
22 |
def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
|
23 |
super(Model, self).__init__()
|
24 |
+
|
25 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
26 |
self.name = model_name
|
27 |
self.use_vllm = model_name != "google-t5/t5-large"
|
28 |
|
29 |
+
logging.info(f"Start loading model {self.name}")
|
30 |
|
31 |
if self.use_vllm:
|
32 |
# 使用vLLM加载模型
|
|
|
34 |
model=model_name,
|
35 |
dtype="half",
|
36 |
tokenizer=model_name,
|
37 |
+
trust_remote_code=True,
|
38 |
)
|
39 |
else:
|
40 |
# 加载原始transformers模型
|
41 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
42 |
+
model_name, torch_dtype=torch.bfloat16, device_map="auto"
|
|
|
|
|
43 |
)
|
44 |
self.model.eval()
|
45 |
|
46 |
+
logging.info(f"Loaded model {self.name}")
|
47 |
self.update()
|
48 |
|
49 |
@classmethod
|
|
|
55 |
sampling_params = SamplingParams(
|
56 |
temperature=temp,
|
57 |
max_tokens=max_length,
|
58 |
+
# top_p=0.95 if do_sample else 1.0,
|
59 |
+
stop_token_ids=[self.tokenizer.eos_token_id],
|
60 |
)
|
61 |
outputs = self.llm.generate(content_list, sampling_params)
|
62 |
return [output.outputs[0].text for output in outputs]
|
63 |
else:
|
64 |
+
input_ids = self.tokenizer(
|
65 |
+
content_list, return_tensors="pt", padding=True, truncation=True
|
66 |
+
).input_ids.to(self.model.device)
|
67 |
outputs = self.model.generate(
|
68 |
input_ids,
|
69 |
max_new_tokens=max_length,
|
|
|
71 |
temperature=temp,
|
72 |
eos_token_id=self.tokenizer.eos_token_id,
|
73 |
)
|
74 |
+
return self.tokenizer.batch_decode(
|
75 |
+
outputs[:, input_ids.shape[1] :], skip_special_tokens=True
|
76 |
+
)
|
77 |
|
78 |
def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
|
79 |
if self.use_vllm:
|
|
|
81 |
temperature=temp,
|
82 |
max_tokens=max_length,
|
83 |
top_p=0.95 if do_sample else 1.0,
|
84 |
+
stop_token_ids=[self.tokenizer.eos_token_id],
|
85 |
)
|
86 |
outputs = self.llm.generate(content_list, sampling_params, stream=True)
|
87 |
+
|
88 |
prev_token_ids = [[] for _ in content_list]
|
89 |
+
|
90 |
for output in outputs:
|
91 |
for i, request_output in enumerate(output.outputs):
|
92 |
current_token_ids = request_output.token_ids
|
93 |
+
new_token_ids = current_token_ids[len(prev_token_ids[i]) :]
|
94 |
prev_token_ids[i] = current_token_ids.copy()
|
95 |
+
|
96 |
for token_id in new_token_ids:
|
97 |
+
token_text = self.tokenizer.decode(
|
98 |
+
token_id, skip_special_tokens=True
|
99 |
+
)
|
100 |
yield i, token_text
|
101 |
else:
|
102 |
+
input_ids = self.tokenizer(
|
103 |
+
content_list, return_tensors="pt", padding=True, truncation=True
|
104 |
+
).input_ids.to(self.model.device)
|
105 |
+
|
106 |
gen_kwargs = {
|
107 |
"input_ids": input_ids,
|
108 |
"do_sample": do_sample,
|
|
|
110 |
"eos_token_id": self.tokenizer.eos_token_id,
|
111 |
"max_new_tokens": 1,
|
112 |
"return_dict_in_generate": True,
|
113 |
+
"output_scores": True,
|
114 |
}
|
115 |
|
116 |
generated_tokens = 0
|
|
|
120 |
while generated_tokens < max_length and len(active_sequences) > 0:
|
121 |
with torch.no_grad():
|
122 |
output = self.model.generate(**gen_kwargs)
|
123 |
+
|
124 |
next_tokens = output.sequences[:, -1].unsqueeze(-1)
|
125 |
+
|
126 |
for i, token in zip(active_sequences, next_tokens):
|
127 |
+
yield i.item(), self.tokenizer.decode(
|
128 |
+
token[0], skip_special_tokens=True
|
129 |
+
)
|
130 |
|
131 |
+
gen_kwargs["input_ids"] = torch.cat(
|
132 |
+
[gen_kwargs["input_ids"], next_tokens], dim=-1
|
133 |
+
)
|
134 |
generated_tokens += 1
|
135 |
|
136 |
+
completed = (
|
137 |
+
(next_tokens.squeeze(-1) == self.tokenizer.eos_token_id)
|
138 |
+
.nonzero()
|
139 |
+
.squeeze(-1)
|
140 |
+
)
|
141 |
+
active_sequences = torch.tensor(
|
142 |
+
[i for i in active_sequences if i not in completed]
|
143 |
+
)
|
144 |
if len(active_sequences) > 0:
|
145 |
+
gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
|
utils/multiple_stream.py
CHANGED
@@ -7,32 +7,36 @@ TEST = """ Test of Time. A Benchmark for Evaluating LLMs on Temporal Reasoning.
|
|
7 |
showcased remarkable reasoning capabilities, yet they remain susceptible to errors, particularly in temporal
|
8 |
reasoning tasks involving complex temporal logic. """
|
9 |
|
|
|
10 |
def generate_data_test():
|
11 |
"""Generator to yield words"""
|
12 |
temp = copy.deepcopy(TEST)
|
13 |
l1 = temp.split()
|
14 |
random.shuffle(l1)
|
15 |
-
temp =
|
16 |
for word in temp.split(" "):
|
17 |
yield word + " "
|
18 |
|
|
|
19 |
def stream_data(content_list, model):
|
20 |
"""Stream data to three columns"""
|
21 |
outputs = ["" for _ in content_list]
|
22 |
|
23 |
# Use the gen method to handle batch generation
|
24 |
generator = model.streaming(content_list)
|
25 |
-
|
26 |
while True:
|
27 |
updated = False
|
28 |
|
29 |
try:
|
30 |
-
id, word = next(
|
|
|
|
|
31 |
outputs[id] += f"{word} "
|
32 |
updated = True
|
33 |
except StopIteration:
|
34 |
break
|
35 |
-
|
36 |
if updated:
|
37 |
yield tuple(outputs)
|
38 |
|
@@ -41,21 +45,22 @@ def create_interface():
|
|
41 |
with gr.Blocks() as demo:
|
42 |
with gr.Group():
|
43 |
with gr.Row():
|
44 |
-
columns = [
|
45 |
-
|
|
|
|
|
46 |
start_btn = gr.Button("Start Streaming")
|
47 |
-
|
48 |
def start_streaming():
|
49 |
-
content_list = [
|
|
|
|
|
50 |
for data in stream_data(content_list):
|
51 |
updates = [gr.update(value=data[i]) for i in range(len(columns))]
|
52 |
yield tuple(updates)
|
53 |
-
|
54 |
start_btn.click(
|
55 |
-
fn=start_streaming,
|
56 |
-
inputs=[],
|
57 |
-
outputs=columns,
|
58 |
-
show_progress=False
|
59 |
)
|
60 |
|
61 |
return demo
|
@@ -64,4 +69,4 @@ def create_interface():
|
|
64 |
if __name__ == "__main__":
|
65 |
demo = create_interface()
|
66 |
demo.queue()
|
67 |
-
demo.launch()
|
|
|
7 |
showcased remarkable reasoning capabilities, yet they remain susceptible to errors, particularly in temporal
|
8 |
reasoning tasks involving complex temporal logic. """
|
9 |
|
10 |
+
|
11 |
def generate_data_test():
|
12 |
"""Generator to yield words"""
|
13 |
temp = copy.deepcopy(TEST)
|
14 |
l1 = temp.split()
|
15 |
random.shuffle(l1)
|
16 |
+
temp = " ".join(l1)
|
17 |
for word in temp.split(" "):
|
18 |
yield word + " "
|
19 |
|
20 |
+
|
21 |
def stream_data(content_list, model):
|
22 |
"""Stream data to three columns"""
|
23 |
outputs = ["" for _ in content_list]
|
24 |
|
25 |
# Use the gen method to handle batch generation
|
26 |
generator = model.streaming(content_list)
|
27 |
+
|
28 |
while True:
|
29 |
updated = False
|
30 |
|
31 |
try:
|
32 |
+
id, word = next(
|
33 |
+
generator
|
34 |
+
) # Get the next generated word for the corresponding content
|
35 |
outputs[id] += f"{word} "
|
36 |
updated = True
|
37 |
except StopIteration:
|
38 |
break
|
39 |
+
|
40 |
if updated:
|
41 |
yield tuple(outputs)
|
42 |
|
|
|
45 |
with gr.Blocks() as demo:
|
46 |
with gr.Group():
|
47 |
with gr.Row():
|
48 |
+
columns = [
|
49 |
+
gr.Textbox(label=f"Column {i+1}", lines=10) for i in range(3)
|
50 |
+
]
|
51 |
+
|
52 |
start_btn = gr.Button("Start Streaming")
|
53 |
+
|
54 |
def start_streaming():
|
55 |
+
content_list = [
|
56 |
+
col.value for col in columns
|
57 |
+
] # Get input texts from text boxes
|
58 |
for data in stream_data(content_list):
|
59 |
updates = [gr.update(value=data[i]) for i in range(len(columns))]
|
60 |
yield tuple(updates)
|
61 |
+
|
62 |
start_btn.click(
|
63 |
+
fn=start_streaming, inputs=[], outputs=columns, show_progress=False
|
|
|
|
|
|
|
64 |
)
|
65 |
|
66 |
return demo
|
|
|
69 |
if __name__ == "__main__":
|
70 |
demo = create_interface()
|
71 |
demo.queue()
|
72 |
+
demo.launch()
|