burtenshaw commited on
Commit
dc616b0
·
1 Parent(s): aac30ac

implement generate from ratings for dpo

Browse files
Files changed (1) hide show
  1. data/generate_dpo.py +167 -0
data/generate_dpo.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import TYPE_CHECKING, List, Literal, Union
3
+
4
+ from datasets import Dataset, concatenate_datasets
5
+ from distilabel.llms.huggingface import InferenceEndpointsLLM
6
+ from distilabel.pipeline import Pipeline
7
+ from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput
8
+ from distilabel.steps.tasks import TextGeneration
9
+ from typing_extensions import override
10
+
11
+ CHOSEN_TEMPLATE = """
12
+ You are provide with a conversation between a human and an AI assistant.
13
+ The final message has been rated negatively. Your task is to regenerate the response.
14
+ {% for message in conversation %}
15
+ {{ message["role"] }}: {{ message["content"] }}
16
+ {% endfor %}
17
+ Replacement improved message:
18
+ """.rstrip()
19
+
20
+ CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to regenerate high quality responses to user queries, when other assistants go wrong."
21
+
22
+ REJECT_TEMPLATE = """
23
+ You are provide with a conversation between a human and an AI assistant.
24
+ The final message has been rated positively. Your task is to regenerate a POOR QUALITYresponse.
25
+ {% for message in conversation %}
26
+ {{ message["role"] }}: {{ message["content"] }}
27
+ {% endfor %}
28
+ Replacement improved message:
29
+ """.rstrip()
30
+
31
+ REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to regenerate high quality responses to user queries, when other assistants go wrong."
32
+
33
+
34
+ class FilterConversationRatings(Step):
35
+ """Filters conversations based on the rating of the last message."""
36
+
37
+ target_column: Union[Literal["chosen"], Literal["rejected"]]
38
+ batch_size: int = 5
39
+
40
+ @override
41
+ def process(self, dataset: StepInput) -> "GeneratorStepOutput":
42
+
43
+ column_rating_map = {
44
+ "chosen": 1,
45
+ "rejected": -1,
46
+ }
47
+
48
+ target_rating = column_rating_map[self.target_column]
49
+
50
+ for batch_start in range(0, len(dataset), self.batch_size):
51
+ batch = dataset[batch_start : batch_start + self.batch_size]
52
+ filtered_batch = []
53
+ for conversation in batch:
54
+ for row in batch:
55
+ _conversation = row["conversation"]
56
+ conversation = None
57
+ for idx, message in enumerate(_conversation, 1):
58
+ if not isinstance(message["rating"], int):
59
+ continue
60
+ if message["rating"] == target_rating:
61
+ conversation = _conversation[:idx]
62
+ break
63
+ if conversation:
64
+ filtered_batch.append({"conversation": conversation})
65
+ yield filtered_batch
66
+
67
+ @property
68
+ def outputs(self) -> "StepColumns":
69
+ return ["conversation"]
70
+
71
+
72
+ class AppendToConversationStep(Step):
73
+ """Appends a generated message to a conversation."""
74
+
75
+ @property
76
+ def inputs(self) -> "StepColumns":
77
+ return ["generation", "conversation"]
78
+
79
+ @property
80
+ def outputs(self) -> "StepColumns":
81
+ return ["generated_conversation", "conversation"]
82
+
83
+ def process(self, inputs: StepInput) -> "StepOutput":
84
+
85
+ for input in inputs:
86
+ if not input["generation"]:
87
+ continue
88
+ if not input["conversation"]:
89
+ continue
90
+ input["generated_conversation"] = [
91
+ {"role": message["role"], "content": message["content"]}
92
+ for message in input["conversation"][:-1]
93
+ ] + [{"role": "assistant", "content": input["generation"]}]
94
+ input["conversation"] = [
95
+ {"role": message["role"], "content": message["content"]}
96
+ for message in input["conversation"]
97
+ ]
98
+ yield inputs
99
+
100
+
101
+ with Pipeline(
102
+ name="conversation_rejection",
103
+ description="Generate a chosen response to a rejected conversation.",
104
+ ) as rejection_pipeline:
105
+
106
+ rejected_dataset = FilterConversationRatings(target_column="rejected")
107
+
108
+ chosen_text_gen = TextGeneration(
109
+ llm=InferenceEndpointsLLM(
110
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
111
+ ),
112
+ system_prompt=CHOSEN_SYSTEM_PROMPT,
113
+ template=CHOSEN_TEMPLATE,
114
+ columns=["conversation"],
115
+ )
116
+
117
+ append_chosen = AppendToConversationStep(
118
+ output_mappings={
119
+ "generated_conversation": "chosen",
120
+ "conversation": "rejected",
121
+ },
122
+ )
123
+
124
+ keep_columns = KeepColumns(
125
+ columns=["chosen", "rejected"],
126
+ )
127
+
128
+ rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns
129
+
130
+ with Pipeline(
131
+ name="conversation_chosen",
132
+ description="Generate a rejected response to a chosen conversation.",
133
+ ) as chosen_pipeline:
134
+
135
+ chosen_dataset = FilterConversationRatings(target_column="chosen")
136
+
137
+ rejected_text_gen = TextGeneration(
138
+ llm=InferenceEndpointsLLM(
139
+ model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
140
+ ),
141
+ system_prompt=REJECT_SYSTEM_PROMPT,
142
+ template=REJECT_TEMPLATE,
143
+ columns=["conversation"],
144
+ )
145
+ append_rejected = AppendToConversationStep(
146
+ output_mappings={
147
+ "generated_conversation": "rejected",
148
+ "conversation": "chosen",
149
+ },
150
+ )
151
+ keep_columns = KeepColumns(
152
+ columns=["chosen", "rejected"],
153
+ )
154
+ chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns
155
+
156
+ if __name__ == "__main__":
157
+
158
+ dataset_path = "example_data.json"
159
+ data = json.load(open(dataset_path))
160
+
161
+ dataset = Dataset.from_list(data)
162
+ rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False)
163
+ chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False)
164
+
165
+ dataset = concatenate_datasets(
166
+ dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]]
167
+ )