Update README.md
Browse files
README.md
CHANGED
@@ -16,7 +16,7 @@ We release ChatQA1.5, which excels at RAG-based conversational question answerin
|
|
16 |
|
17 |
|
18 |
## Benchmark Results
|
19 |
-
Results in ConvRAG are as follows:
|
20 |
|
21 |
| | ChatQA-1.0-7B | Command-R-Plus | Llama-3-instruct-70b | GPT-4-0613 | ChatQA-1.0-70B | ChatQA-1.5-8B | ChatQA-1.5-70B |
|
22 |
| -- |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|
@@ -33,7 +33,7 @@ Results in ConvRAG are as follows:
|
|
33 |
| Average (all) | 47.71 | 50.93 | 52.52 | 53.90 | 54.14 | 55.17 | 58.25 |
|
34 |
| Average (exclude HybriDial) | 46.96 | 51.40 | 52.95 | 54.35 | 53.89 | 53.99 | 57.14 |
|
35 |
|
36 |
-
Note that ChatQA-1.5 used some samples from the HybriDial training dataset. To ensure fair comparison, we also compare average scores excluding HybriDial. The data and evaluation scripts for ConvRAG can be found here.
|
37 |
|
38 |
|
39 |
## Prompt Format
|
@@ -72,14 +72,14 @@ def get_formatted_input(messages, context):
|
|
72 |
system = "System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context."
|
73 |
instruction = "Please give a full and complete answer for the question."
|
74 |
|
75 |
-
for item in
|
76 |
if item['role'] == "user":
|
77 |
## only apply this instruction for the first user turn
|
78 |
item['content'] = instruction + " " + item['content']
|
79 |
break
|
80 |
|
81 |
conversation = ""
|
82 |
-
for item in
|
83 |
if item["role"] == "user":
|
84 |
conversation += "User: " + item["content"] + "\n\n"
|
85 |
else:
|
@@ -90,17 +90,14 @@ def get_formatted_input(messages, context):
|
|
90 |
return formatted_input
|
91 |
|
92 |
formatted_input = get_formatted_input(messages, context)
|
93 |
-
|
94 |
|
95 |
terminators = [
|
96 |
tokenizer.eos_token_id,
|
97 |
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
98 |
]
|
99 |
|
100 |
-
outputs = model.generate(
|
101 |
-
input_ids,
|
102 |
-
max_new_tokens=128,
|
103 |
-
eos_token_id=terminators)
|
104 |
|
105 |
response = outputs[0][input_ids.shape[-1]:]
|
106 |
print(tokenizer.decode(response, skip_special_tokens=True))
|
|
|
16 |
|
17 |
|
18 |
## Benchmark Results
|
19 |
+
Results in ConvRAG Bench are as follows:
|
20 |
|
21 |
| | ChatQA-1.0-7B | Command-R-Plus | Llama-3-instruct-70b | GPT-4-0613 | ChatQA-1.0-70B | ChatQA-1.5-8B | ChatQA-1.5-70B |
|
22 |
| -- |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|
|
|
33 |
| Average (all) | 47.71 | 50.93 | 52.52 | 53.90 | 54.14 | 55.17 | 58.25 |
|
34 |
| Average (exclude HybriDial) | 46.96 | 51.40 | 52.95 | 54.35 | 53.89 | 53.99 | 57.14 |
|
35 |
|
36 |
+
Note that ChatQA-1.5 used some samples from the HybriDial training dataset. To ensure fair comparison, we also compare average scores excluding HybriDial. The data and evaluation scripts for ConvRAG can be found [here](https://huggingface.co/datasets/nvidia/ConvRAG-Bench).
|
37 |
|
38 |
|
39 |
## Prompt Format
|
|
|
72 |
system = "System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context."
|
73 |
instruction = "Please give a full and complete answer for the question."
|
74 |
|
75 |
+
for item in messages:
|
76 |
if item['role'] == "user":
|
77 |
## only apply this instruction for the first user turn
|
78 |
item['content'] = instruction + " " + item['content']
|
79 |
break
|
80 |
|
81 |
conversation = ""
|
82 |
+
for item in messages:
|
83 |
if item["role"] == "user":
|
84 |
conversation += "User: " + item["content"] + "\n\n"
|
85 |
else:
|
|
|
90 |
return formatted_input
|
91 |
|
92 |
formatted_input = get_formatted_input(messages, context)
|
93 |
+
tokenized_prompt = tokenizer(tokenizer.bos_token + formatted_input, return_tensors="pt").to(model.device)
|
94 |
|
95 |
terminators = [
|
96 |
tokenizer.eos_token_id,
|
97 |
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
98 |
]
|
99 |
|
100 |
+
outputs = model.generate(input_ids=tokenized_prompt.input_ids, attention_mask=tokenized_prompt.attention_mask, max_new_tokens=128, eos_token_id=terminators)
|
|
|
|
|
|
|
101 |
|
102 |
response = outputs[0][input_ids.shape[-1]:]
|
103 |
print(tokenizer.decode(response, skip_special_tokens=True))
|