File size: 18,226 Bytes
5d92357
c5c065e
5d92357
 
8185fe8
 
 
295e94f
62bce9d
5d92357
 
 
11a61b4
 
 
 
8185fe8
 
2600399
8185fe8
b8a7cf9
8185fe8
 
 
 
5d92357
 
8185fe8
11a61b4
 
ea1fca9
c752f9e
 
b921db9
 
 
 
5d92357
 
11a61b4
 
05b7e51
 
11a61b4
 
 
 
 
481afa0
b921db9
 
 
 
 
 
 
 
5d92357
c752f9e
 
b921db9
 
 
 
5d92357
 
c752f9e
 
 
e150a4a
 
c752f9e
 
 
 
24bf24b
c752f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d92357
 
6a4226d
 
 
 
8224785
6a4226d
 
 
 
 
 
 
 
 
c752f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06cefb6
b206d70
c752f9e
 
 
 
b57bf75
62bce9d
c752f9e
 
b57bf75
c752f9e
 
 
62bce9d
 
c752f9e
 
 
 
3cd31d8
 
15a601f
c752f9e
 
8224785
06cefb6
62bce9d
c752f9e
 
 
7916def
c752f9e
 
 
62bce9d
 
c752f9e
3cd31d8
 
c752f9e
 
15a601f
c752f9e
 
 
 
3cd31d8
 
 
 
5d92357
62bce9d
8224785
62bce9d
 
 
32a3bbb
62bce9d
 
 
 
 
6a4226d
 
 
 
 
 
 
 
62bce9d
 
5d92357
7916def
 
 
90d2ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d92357
 
7128b84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d92357
7916def
 
 
 
3cd31d8
7916def
 
 
 
7128b84
 
 
f8a5c23
7128b84
8b35b55
 
7916def
b206d70
f8a5c23
7128b84
f8a5c23
6c09b42
7128b84
f8a5c23
3cd31d8
f8a5c23
3cd31d8
0ef49e6
3cd31d8
7128b84
 
 
 
0ef49e6
f8a5c23
0ef49e6
 
 
7128b84
3cd31d8
8b35b55
3cd31d8
 
 
 
5d92357
62bce9d
c752f9e
62bce9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import gradio as gr
import os
import torch
import transformers
import huggingface_hub
import datetime
import json
import shutil
import threading

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# To suppress the following warning:
# huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
os.environ["TOKENIZERS_PARALLELISM"] = "false"

HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD']
HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD']
MODE = os.environ['MODE'] # 'debug' or 'prod'

MODEL_NAME = 'liujch1998/vera'
DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset"
DATA_DIR = 'data'
DATA_PATH = os.path.join(DATA_DIR, 'data.jsonl')

class Interactive:
    def __init__(self):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
        if MODE == 'debug':
            return
        self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload')
        self.model.D = self.model.shared.embedding_dim
        self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device)
        self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
        self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
        self.model.eval()
        self.t = self.model.shared.weight[32097, 0].item()

    def run(self, statement):
        if MODE == 'debug':
            return {
                'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
                'statement': statement,
                'logit': 0.0,
                'logit_calibrated': 0.0,
                'score': 0.5,
                'score_calibrated': 0.5,
            }
        input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest', truncation='longest_first', max_length=128).input_ids.to(device)
        with torch.no_grad():
            output = self.model(input_ids)
            last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
            hidden = last_hidden_state[0, -1, :] # (D)
            logit = self.linear(hidden).squeeze(-1) # ()
            logit_calibrated = logit / self.t
            score = logit.sigmoid()
            score_calibrated = logit_calibrated.sigmoid()
        return {
            'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
            'statement': statement,
            'logit': logit.item(),
            'logit_calibrated': logit_calibrated.item(),
            'score': score.item(),
            'score_calibrated': score_calibrated.item(),
        }

    def runs(self, statements):
        if MODE == 'debug':
            return [{
                'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
                'statement': statement,
                'logit': 0.0,
                'logit_calibrated': 0.0,
                'score': 0.5,
                'score_calibrated': 0.5,
            } for statement in statements]
        tok = self.tokenizer.batch_encode_plus(statements, return_tensors='pt', padding='longest')
        input_ids = tok.input_ids.to(device)
        attention_mask = tok.attention_mask.to(device)
        with torch.no_grad():
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)
            last_indices = attention_mask.sum(dim=1, keepdim=True) - 1 # (B, 1)
            last_indices = last_indices.unsqueeze(-1).expand(-1, -1, self.model.D) # (B, 1, D)
            last_hidden_state = output.last_hidden_state.to(device) # (B, L, D)
            hidden = last_hidden_state.gather(dim=1, index=last_indices).squeeze(1) # (B, D)
            logits = self.linear(hidden).squeeze(-1) # (B)
            logits_calibrated = logits / self.t
            scores = logits.sigmoid()
            scores_calibrated = logits_calibrated.sigmoid()
        return [{
            'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
            'statement': statement,
            'logit': logit.item(),
            'logit_calibrated': logit_calibrated.item(),
            'score': score.item(),
            'score_calibrated': score_calibrated.item(),
        } for statement, logit, logit_calibrated, score, score_calibrated in zip(statements, logits, logits_calibrated, scores, scores_calibrated)]

interactive = Interactive()

try:
    shutil.rmtree(DATA_DIR)
except:
    pass
global repo, lock
repo = huggingface_hub.Repository(
    local_dir=DATA_DIR,
    clone_from=DATASET_REPO_URL,
    token=HF_TOKEN_UPLOAD,
    repo_type='dataset',
)
repo.git_pull()
lock = threading.Lock()

# def predict(statement, do_save=True):
#     output_raw = interactive.run(statement)
#     output = {
#         'True': output_raw['score_calibrated'],
#         'False': 1 - output_raw['score_calibrated'],
#     }
#     if do_save:
#         with open(DATA_PATH, 'a') as f:
#             json.dump(output_raw, f, ensure_ascii=False)
#             f.write('\n')
#         commit_url = repo.push_to_hub()
#         print('Logged statement to dataset:')
#         print('Commit URL:', commit_url)
#         print(output_raw)
#         print()
#     return output, output_raw, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value='Please provide your feedback before trying out another statement.')

# def record_feedback(output_raw, feedback, do_save=True):
#     if do_save:
#         output_raw.update({ 'feedback': feedback })
#         with open(DATA_PATH, 'a') as f:
#             json.dump(output_raw, f, ensure_ascii=False)
#             f.write('\n')
#         commit_url = repo.push_to_hub()
#         print('Logged feedback to dataset:')
#         print('Commit URL:', commit_url)
#         print(output_raw)
#         print()
#     return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value='Thanks for your feedback! Now you can enter another statement.')
# def record_feedback_agree(output_raw, do_save=True):
#     return record_feedback(output_raw, 'agree', do_save)
# def record_feedback_disagree(output_raw, do_save=True):
#     return record_feedback(output_raw, 'disagree', do_save)

def predict(statements, do_saves):
    global lock, interactive
    output_raws = interactive.runs(list(statements)) # statements is a tuple, but tokenizer takes a list
    outputs = [{
        'True': output_raw['score_calibrated'],
        'False': 1 - output_raw['score_calibrated'],
    } for output_raw in output_raws]
    print('Logging statements to dataset:')
    lock.acquire()
    for output_raw, do_save in zip(output_raws, do_saves):
        if do_save:
            print(output_raw)
            with open(DATA_PATH, 'a') as f:
                json.dump(output_raw, f, ensure_ascii=False)
                f.write('\n')
    print()
    lock.release()
    return outputs, output_raws, \
        [gr.update(visible=False) for _ in statements], \
        [gr.update(visible=True) for _ in statements], \
        [gr.update(visible=True) for _ in statements], \
        [gr.update(visible=True) for _ in statements], \
        [gr.update(visible=True) for _ in statements], \
        [gr.update(value='Please share your feedback before trying out another statement.') for _ in statements]

def record_feedback(output_raws, feedback, do_saves):
    global lock
    print('Logging feedbacks to dataset:')
    lock.acquire()
    for output_raw, do_save in zip(output_raws, do_saves):
        if do_save:
            output_raw.update({ 'feedback': feedback })
            print(output_raw)
            with open(DATA_PATH, 'a') as f:
                json.dump(output_raw, f, ensure_ascii=False)
                f.write('\n')
    print()
    lock.release()
    return [gr.update(visible=True) for _ in output_raws], \
        [gr.update(visible=False) for _ in output_raws], \
        [gr.update(visible=False) for _ in output_raws], \
        [gr.update(visible=False) for _ in output_raws], \
        [gr.update(visible=False) for _ in output_raws], \
        [gr.update(value='Thanks for sharing your feedback! You can now enter another statement.') for _ in output_raws]
def record_feedback_agree(output_raws, do_saves):
    return record_feedback(output_raws, 'agree', do_saves)
def record_feedback_disagree(output_raws, do_saves):
    return record_feedback(output_raws, 'disagree', do_saves)
def record_feedback_uncertain(output_raws, do_saves):
    return record_feedback(output_raws, 'uncertain', do_saves)
def record_feedback_outofscope(output_raws, do_saves):
    return record_feedback(output_raws, 'outofscope', do_saves)

def push():
    global repo, lock
    lock.acquire()
    if repo.is_repo_clean():
        print('No new data recorded, skipping git push ...')
        print()
    else:
        try:
            commit_url = repo.push_to_hub()
        except Exception as e:
            print('Failed to push to git:', e)
            shutil.rmtree(DATA_DIR)
            repo = huggingface_hub.Repository(
                local_dir=DATA_DIR,
                clone_from=DATASET_REPO_URL,
                token=HF_TOKEN_UPLOAD,
                repo_type='dataset',
            )
            repo.git_pull()
    lock.release()

examples = [
    # # openbookqa
    # 'If a person walks in the opposite direction of a compass arrow they are walking south.',
    # 'If a person walks in the opposite direction of a compass arrow they are walking north.',
    # arc_easy
    'A pond is different from a lake because ponds are smaller and shallower.',
    'A pond is different from a lake because ponds have moving water.',
    # arc_hard
    'Hunting strategies are more likely to be learned rather than inherited.',
    'A spotted coat is more likely to be learned rather than inherited.',
    # ai2_science_elementary
    'Photosynthesis uses carbon from the air to make food for plants.',
    'Respiration uses carbon from the air to make food for plants.',
    # ai2_science_middle
    'The barometer measures atmospheric pressure.',
    'The thermometer measures atmospheric pressure.',
    # commonsenseqa
    'People aim to complete a job at work.',
    'People aim to kill animals at work.',
    # qasc
    'Climate is generally described in terms of local weather conditions.',
    'Climate is generally described in terms of forests.',
    # physical_iqa
    'ice box will turn into a cooler if you add water to it.',
    'ice box will turn into a cooler if you add soda to it.',
    # social_iqa
    'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very aggressive and talkative person.',
    'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very quiet person.',
    # winogrande_xl
    'Sarah was a much better surgeon than Maria so Maria always got the easier cases.',
    'Sarah was a much better surgeon than Maria so Sarah always got the easier cases.',
    # com2sense_paired
    'If you want a quick snack, getting one banana would be a good choice generally.',
    'If you want a snack, getting twenty bananas would be a good choice generally.',
    # sciq
    'Each specific polypeptide has a unique linear sequence of amino acids.',
    'Each specific polypeptide has a unique linear sequence of fatty acids.',
    # quarel
    'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because wet floor has more resistance.',
    'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because marble floor has more resistance.',
    # quartz
    'If less waters falls on an area of land it will cause less plants to grow in that area.',
    'If less waters falls on an area of land it will cause more plants to grow in that area.',
    # cycic_mc
    'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the park on January 20.',
    'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the financial district on January 20.',
    # comve_a
    'Summer in North America is great for swimming,  boating, and fishing.',
    'Summer in North America is great for skiing,  snowshoeing,  and making a snowman.',
    # csqa2
    'Gas is always capable of turning into liquid under high pressure.',
    'Cotton candy is sometimes made out of cotton.',
    # symkd_anno
    'James visits a famous landmark. As a result, James learns about the world.',
    'Cliff and Andrew enter the castle. But before, Cliff needed to have been a student at the school.',
    # gengen_anno
    'Generally, bar patrons are capable of taking care of their own drinks.',
    'Generally, ocean currents have little influence over storm intensity.',

    # 'If A sits next to B and B sits next to C, then A must sit next to C.',
    # 'If A sits next to B and B sits next to C, then A might not sit next to C.',
]

# input_statement = gr.Dropdown(choices=examples, label='Statement:')
# input_model = gr.Textbox(label='Commonsense statement verification model:', value=MODEL_NAME, interactive=False)
# output = gr.outputs.Label(num_top_classes=2)

# description = '''This is a demo for Vera, a commonsense statement verification model. Under development.
# ⚠️ Data Collection: by default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app!'''

# gr.Interface(
#     fn=predict,
#     inputs=[input_statement, input_model],
#     outputs=output,
#     title="Vera",
#     description=description,
# ).launch()

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown(
            '''# Vera

            Vera is a commonsense statement verification model. Under development.

            Type a commonsense statement in the box below and click the submit button to see Vera's prediction on its correctness. You can try both correct and incorrect statements. If you are looking for inspiration, try the examples at the bottom of the page!

            We'd love your feedback! Please indicate whether you agree or disagree with Vera's prediction (and don't mind the percentage numbers).

            ⚠️ **Intended Use**: Vera is a research prototype and may make mistakes. Do not use for making critical decisions. It is intended to predict the correctness of commonsense statements, and may be unreliable when taking input out of this scope.

            ⚠️ **Data Collection**: by default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app! You can opt out of this data collection by removing the checkbox below:
            '''
        )
        with gr.Row():
            with gr.Column(scale=2):
                do_save = gr.Checkbox(
                    value=True,
                    label="Store data",
                    info="You agree to the storage of your input for research and development purposes:")
                statement = gr.Textbox(placeholder='Enter a commonsense statement here, or select an example from below', label='Statement', interactive=True)
                submit = gr.Button(value='Submit', variant='primary', visible=True)
            with gr.Column(scale=1):
                output = gr.Label(num_top_classes=2, interactive=False)
                output_raw = gr.JSON(visible=False)
                with gr.Row():
                    feedback_agree = gr.Button(value='πŸ‘ Agree', variant='secondary', visible=False)
                    feedback_uncertain = gr.Button(value='πŸ€” Uncertain', variant='secondary', visible=False)
                    feedback_disagree = gr.Button(value='πŸ‘Ž Disagree', variant='secondary', visible=False)
                feedback_outofscope = gr.Button(value='🚫 This is not a statement about commonsense', variant='stop', visible=False)
                feedback_ack = gr.Markdown(value='', visible=True, interactive=False)
        gr.Markdown('\n---\n')
        with gr.Row():
            gr.Examples(
                examples=examples,
                fn=predict,
                inputs=[statement],
                outputs=[output, output_raw, statement, submit, feedback_agree, feedback_disagree, feedback_ack],
                examples_per_page=100,
                cache_examples=False,
                run_on_click=False, # If we want this to be True, I suspect we need to enable the statement.submit()
            )
    submit.click(predict, inputs=[statement, do_save], outputs=[output, output_raw, submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
    # statement.submit(predict, inputs=[statement], outputs=[output, output_raw])
    feedback_agree.click(record_feedback_agree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
    feedback_uncertain.click(record_feedback_uncertain, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
    feedback_disagree.click(record_feedback_disagree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
    feedback_outofscope.click(record_feedback_outofscope, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)

    demo.load(push, inputs=None, outputs=None, every=60) # Push to git every 60 seconds

demo.queue(concurrency_count=1).launch(debug=True)