Tristan Thrush commited on
Commit
460dbe4
·
1 Parent(s): d94c767

initial example of rlhf

Browse files
Files changed (6) hide show
  1. README.md +51 -1
  2. app.py +203 -0
  3. collect.py +55 -0
  4. config.py.example +6 -0
  5. requirements.txt +6 -0
  6. utils.py +39 -0
README.md CHANGED
@@ -1 +1,51 @@
1
- # rlhf-interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RLHF
3
+ emoji: 🏢
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.0.17
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ A basic example of an RLHF interface with a Gradio app.
13
+
14
+ **Instructions for someone to use for their own project:**
15
+
16
+ *Setting up the Space*
17
+ 1. Clone this repo and deploy it on your own Hugging Face space.
18
+ 2. Add the following secrets to your space:
19
+ - `HF_TOKEN`: One of your Hugging Face tokens.
20
+ - `DATASET_REPO_URL`: The url to an empty dataset that you created the hub. It
21
+ can be a private or public dataset.
22
+ - `FORCE_PUSH`: "yes"
23
+ When you run this space on mturk and when people visit your space on
24
+ huggingface.co, the app will use your token to automatically store new HITs
25
+ in your dataset. Setting `FORCE_PUSH` to "yes" ensures that your repo will
26
+ force push changes to the dataset during data collection. Otherwise,
27
+ accidental manual changes to your dataset could result in your space gettin
28
+ merge conflicts as it automatically tries to push the dataset to the hub. For
29
+ local development, add these three keys to a `.env` file, and consider setting
30
+ `FORCE_PUSH` to "no".
31
+ *Running Data Collection*
32
+ 1. On your local repo that you pulled, create a copy of `config.py.example`,
33
+ just called `config.py`. Now, put keys from your AWS account in `config.py`.
34
+ These keys should be for an AWS account that has the
35
+ AmazonMechanicalTurkFullAccess permission. You also need to
36
+ create an mturk requestor account associated with your AWS account.
37
+ 2. Run `python collect.py` locally.
38
+
39
+ *Profit*
40
+ Now, you should be watching hits come into your Hugging Face dataset
41
+ automatically!
42
+
43
+ *Tips and Tricks*
44
+ - Use caution while doing local development of your space and
45
+ simultaneously running it on mturk. Consider setting `FORCE_PUSH` to "no" in
46
+ your local `.env` file.
47
+ - huggingface spaces have limited computational resources and memory. If you
48
+ run too many HITs and/or assignments at once, then you could encounter issues.
49
+ You could also encounter issues if you are trying to create a dataset that is
50
+ very large. Check the log of your space for any errors that could be happening.
51
+
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic example for doing model-in-the-loop dynamic adversarial data collection
2
+ # using Gradio Blocks.
3
+ import os
4
+ import random
5
+ import uuid
6
+ from urllib.parse import parse_qs
7
+ import gradio as gr
8
+ import requests
9
+ from transformers import pipeline, Conversation
10
+ from huggingface_hub import Repository
11
+ from dotenv import load_dotenv
12
+ from pathlib import Path
13
+ import json
14
+ from utils import force_git_push
15
+ import threading
16
+
17
+ # These variables are for storing the mturk HITs in a Hugging Face dataset.
18
+ if Path(".env").is_file():
19
+ load_dotenv(".env")
20
+ DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
21
+ FORCE_PUSH = os.getenv("FORCE_PUSH")
22
+ HF_TOKEN = os.getenv("HF_TOKEN")
23
+ DATA_FILENAME = "data.jsonl"
24
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
25
+ repo = Repository(
26
+ local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
27
+ )
28
+
29
+ TOTAL_CNT = 4 # How many user inputs per HIT
30
+
31
+ # This function pushes the HIT data written in data.jsonl to our Hugging Face
32
+ # dataset every minute. Adjust the frequency to suit your needs.
33
+ PUSH_FREQUENCY = 60
34
+ def asynchronous_push(f_stop):
35
+ if repo.is_repo_clean():
36
+ print("Repo currently clean. Ignoring push_to_hub")
37
+ else:
38
+ repo.git_add(auto_lfs_track=True)
39
+ repo.git_commit("Auto commit by space")
40
+ if FORCE_PUSH == "yes":
41
+ force_git_push(repo)
42
+ else:
43
+ repo.git_push()
44
+ if not f_stop.is_set():
45
+ # call again in 60 seconds
46
+ threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start()
47
+
48
+ f_stop = threading.Event()
49
+ asynchronous_push(f_stop)
50
+
51
+ # Now let's run the app!
52
+ chatbot = pipeline(model="microsoft/DialoGPT-medium")
53
+
54
+ demo = gr.Blocks()
55
+
56
+ with demo:
57
+ dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
58
+
59
+ # We keep track of state as a JSON
60
+ state_dict = {
61
+ "conversation_id": uuid.uuid64(),
62
+ "assignmentId": "",
63
+ "cnt": 0, "data": [],
64
+ "past_user_inputs": [],
65
+ "generated_responses": [],
66
+ "response_1": "",
67
+ "response_2": "",
68
+ }
69
+ state = gr.JSON(state_dict, visible=False)
70
+
71
+ gr.Markdown("# RLHF Interface")
72
+ gr.Markdown("Choose the best model output")
73
+
74
+ state_display = gr.Markdown(f"Your messages: 0/{TOTAL_CNT}")
75
+
76
+ # Generate model prediction
77
+ # Default model: distilbert-base-uncased-finetuned-sst-2-english
78
+ def _predict(txt, state):
79
+ conversation_1 = Conversation(past_user_inputs=state["past_user_inputs"].copy(), generated_responses=state["generated_responses"].copy())
80
+ conversation_2 = Conversation(past_user_inputs=state["past_user_inputs"].copy(), generated_responses=state["generated_responses"].copy())
81
+ conversation_1.add_user_input(txt)
82
+ conversation_2.add_user_input(txt)
83
+ conversation_1 = chatbot(conversation_1, do_sample=True, seed=420)
84
+ conversation_2 = chatbot(conversation_2, do_sample=True, seed=69)
85
+ response_1 = conversation_1.generated_responses[-1]
86
+ response_2 = conversation_2.generated_responses[-1]
87
+
88
+ state["cnt"] += 1
89
+
90
+ new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
91
+
92
+ state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2})
93
+ state["past_user_inputs"].append(txt)
94
+
95
+ if state["cnt"] == TOTAL_CNT:
96
+ # Write the HIT data to our local dataset because the worker has
97
+ # submitted everything now.
98
+ with open(DATA_FILE, "a") as jsonlfile:
99
+ json_data_with_assignment_id =\
100
+ [json.dumps(dict({"assignmentId": state["assignmentId"]}, **datum)) for datum in state["data"]]
101
+ jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
102
+
103
+ past_conversation_string = "<br />".join(["<br />".join(["😃: " + user_input, "🤖: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
104
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True, choices=[response_1, response_2], interactive=True, value=response_1), gr.update(value=past_conversation_string), state, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), new_state_md, dummy
105
+
106
+ def _select_response(selected_response, state, dummy):
107
+ done = state["cnt"] == TOTAL_CNT
108
+ toggle_example_submit = gr.update(visible=not done)
109
+ state["generated_responses"].append(selected_response)
110
+ state["data"][-1]["selected_response"] = selected_response
111
+ past_conversation_string = "<br />".join(["<br />".join(["😃: " + user_input, "🤖: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"])])
112
+ query = parse_qs(dummy[1:])
113
+ if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
114
+ # It seems that someone is using this app on mturk. We need to
115
+ # store the assignmentId in the state before submit_hit_button
116
+ # is clicked. We can do this here in _predict. We need to save the
117
+ # assignmentId so that the turker can get credit for their HIT.
118
+ state["assignmentId"] = query["assignmentId"][0]
119
+ toggle_final_submit = gr.update(visible=done)
120
+ toggle_final_submit_preview = gr.update(visible=False)
121
+ else:
122
+ toggle_final_submit_preview = gr.update(visible=done)
123
+ toggle_final_submit = gr.update(visible=False)
124
+ text_input = gr.update(visible=False) if done else gr.update(visible=True)
125
+ return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
126
+
127
+ # Input fields
128
+ past_conversation = gr.Markdown()
129
+ text_input = gr.Textbox(placeholder="Enter a statement", show_label=False)
130
+ select_response = gr.Radio(choices=[None, None], visible=False, label="Choose the best response")
131
+ select_response_button = gr.Button("Select Response", visible=False)
132
+ with gr.Column() as example_submit:
133
+ submit_ex_button = gr.Button("Submit")
134
+ with gr.Column(visible=False) as final_submit:
135
+ submit_hit_button = gr.Button("Submit HIT")
136
+ with gr.Column(visible=False) as final_submit_preview:
137
+ submit_hit_button_preview = gr.Button("Submit Work (preview mode; no mturk HIT credit, but your examples will still be stored)")
138
+
139
+ # Button event handlers
140
+ get_window_location_search_js = """
141
+ function(text_input, label_input, state, dummy) {
142
+ return [text_input, label_input, state, window.location.search];
143
+ }
144
+ """
145
+
146
+ select_response_button.click(
147
+ _select_response,
148
+ inputs=[select_response, state, dummy],
149
+ outputs=[select_response, example_submit, text_input, select_response_button, state, past_conversation, example_submit, final_submit, final_submit_preview],
150
+ _js=get_window_location_search_js,
151
+ )
152
+
153
+ submit_ex_button.click(
154
+ _predict,
155
+ inputs=[text_input, state],
156
+ outputs=[text_input, select_response_button, select_response, past_conversation, state, example_submit, final_submit, final_submit_preview, state_display, dummy],
157
+ _js=get_window_location_search_js,
158
+ )
159
+
160
+ post_hit_js = """
161
+ function(state) {
162
+ // If there is an assignmentId, then the submitter is on mturk
163
+ // and has accepted the HIT. So, we need to submit their HIT.
164
+ const form = document.createElement('form');
165
+ form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
166
+ form.method = 'post';
167
+ for (const key in state) {
168
+ const hiddenField = document.createElement('input');
169
+ hiddenField.type = 'hidden';
170
+ hiddenField.name = key;
171
+ hiddenField.value = state[key];
172
+ form.appendChild(hiddenField);
173
+ };
174
+ document.body.appendChild(form);
175
+ form.submit();
176
+ return state;
177
+ }
178
+ """
179
+
180
+ submit_hit_button.click(
181
+ lambda state: state,
182
+ inputs=[state],
183
+ outputs=[state],
184
+ _js=post_hit_js,
185
+ )
186
+
187
+ refresh_app_js = """
188
+ function(state) {
189
+ // The following line here loads the app again so the user can
190
+ // enter in another preview-mode "HIT".
191
+ window.location.href = window.location.href;
192
+ return state;
193
+ }
194
+ """
195
+
196
+ submit_hit_button_preview.click(
197
+ lambda state: state,
198
+ inputs=[state],
199
+ outputs=[state],
200
+ _js=refresh_app_js,
201
+ )
202
+
203
+ demo.launch()
collect.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic example for running MTurk data collection against a Space
2
+ # For more information see https://docs.aws.amazon.com/mturk/index.html
3
+
4
+ import boto3
5
+ from boto.mturk.question import ExternalQuestion
6
+
7
+ from config import MTURK_KEY, MTURK_SECRET
8
+ import argparse
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--mturk_region", default="us-east-1", help="The region for mturk (default: us-east-1)")
12
+ parser.add_argument("--space_name", default="Tristan/dadc", help="Name of the accompanying Hugging Face space (default: Tristan/dadc)")
13
+ parser.add_argument("--num_hits", type=int, default=5, help="The number of HITs.")
14
+ parser.add_argument("--num_assignments", type=int, default=1, help="The number of times that the HIT can be accepted and completed.")
15
+ parser.add_argument("--live_mode", action="store_true", help="""
16
+ Whether to run in live mode with real turkers. This will charge your account money.
17
+ If you don't use this flag, the HITs will be deployed on the sandbox version of mturk,
18
+ which will not charge your account money.
19
+ """
20
+ )
21
+
22
+ args = parser.parse_args()
23
+
24
+ MTURK_URL = f"https://mturk-requester{'' if args.live_mode else '-sandbox'}.{args.mturk_region}.amazonaws.com"
25
+
26
+ mturk = boto3.client(
27
+ "mturk",
28
+ aws_access_key_id=MTURK_KEY,
29
+ aws_secret_access_key=MTURK_SECRET,
30
+ region_name=args.mturk_region,
31
+ endpoint_url=MTURK_URL,
32
+ )
33
+
34
+ # This is the URL that makes the space embeddable in an mturk iframe
35
+ question = ExternalQuestion(f"https://hf.space/embed/{args.space_name}/+?__theme=light",
36
+ frame_height=600
37
+ )
38
+
39
+ for i in range(args.num_hits):
40
+ new_hit = mturk.create_hit(
41
+ Title="Beat the AI",
42
+ Description="Try to fool an AI by creating examples that it gets wrong",
43
+ Keywords="fool the model",
44
+ Reward="0.15",
45
+ MaxAssignments=args.num_assignments,
46
+ LifetimeInSeconds=172800,
47
+ AssignmentDurationInSeconds=600,
48
+ AutoApprovalDelayInSeconds=14400,
49
+ Question=question.get_as_xml(),
50
+ )
51
+
52
+ print(
53
+ f"HIT Group Link: https://worker{'' if args.live_mode else 'sandbox'}.mturk.com/mturk/preview?groupId="
54
+ + new_hit["HIT"]["HITGroupId"]
55
+ )
config.py.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Fill in the information and rename this file config.py
2
+ # You can obtain the key and secret in the AWS Identity
3
+ # and Access Management (IAM) panel.
4
+
5
+ MTURK_KEY = ''
6
+ MTURK_SECRET = '
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.12.0
2
+ transformers==4.20.1
3
+ gradio==3.0.26
4
+ boto3==1.24.32
5
+ huggingface_hub==0.8.1
6
+ python-dotenv==0.20.0
utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from huggingface_hub.repository import _lfs_log_progress
3
+
4
+ def force_git_push(
5
+ repo,
6
+ ):
7
+ """
8
+ force a simple git push
9
+ Blocking. Will return url to commit on remote
10
+ repo.
11
+ """
12
+ command = "git push --force"
13
+
14
+ try:
15
+ with _lfs_log_progress():
16
+ process = subprocess.Popen(
17
+ command.split(),
18
+ stderr=subprocess.PIPE,
19
+ stdout=subprocess.PIPE,
20
+ encoding="utf-8",
21
+ cwd=repo.local_dir,
22
+ )
23
+
24
+ stdout, stderr = process.communicate()
25
+ return_code = process.poll()
26
+ process.kill()
27
+
28
+ if len(stderr):
29
+ print(stderr)
30
+
31
+ if return_code:
32
+ raise subprocess.CalledProcessError(
33
+ return_code, process.args, output=stdout, stderr=stderr
34
+ )
35
+
36
+ except subprocess.CalledProcessError as exc:
37
+ raise EnvironmentError(exc.stderr)
38
+
39
+ return repo.git_head_commit_url()