Spaces:
Runtime error
Runtime error
File size: 18,580 Bytes
5d92357 c5c065e 5d92357 8185fe8 295e94f 62bce9d 5d92357 11a61b4 8185fe8 2600399 8185fe8 b8a7cf9 8185fe8 b7fb5d6 8185fe8 5d92357 8185fe8 11a61b4 ea1fca9 c752f9e b921db9 5d92357 11a61b4 05b7e51 11a61b4 481afa0 b921db9 5d92357 c752f9e b921db9 5d92357 c752f9e e150a4a c752f9e 24bf24b c752f9e 5d92357 6a4226d 8224785 6a4226d c752f9e 06cefb6 b206d70 c752f9e 295d753 62bce9d c752f9e b57bf75 c752f9e 62bce9d c752f9e 3cd31d8 15a601f c752f9e 8224785 295d753 62bce9d c752f9e 7916def c752f9e 62bce9d c752f9e 3cd31d8 c752f9e 15a601f c752f9e 3cd31d8 5d92357 62bce9d 8224785 62bce9d 295d753 62bce9d 6a4226d 62bce9d 5d92357 7916def 90d2ad7 5d92357 7128b84 5d92357 7916def 71f7d47 7916def b7fb5d6 7128b84 295d753 7128b84 8b35b55 7916def b206d70 f8a5c23 295d753 f8a5c23 6c09b42 7128b84 f8a5c23 3cd31d8 f8a5c23 71f7d47 0ef49e6 3cd31d8 7128b84 0ef49e6 f8a5c23 0ef49e6 7128b84 3cd31d8 8b35b55 3cd31d8 5d92357 62bce9d c752f9e 62bce9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
import gradio as gr
import os
import torch
import transformers
import huggingface_hub
import datetime
import json
import shutil
import threading
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# To suppress the following warning:
# huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
os.environ["TOKENIZERS_PARALLELISM"] = "false"
HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD']
HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD']
MODE = os.environ['MODE'] # 'debug' or 'prod'
MODEL_NAME = 'liujch1998/vera'
DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset"
DATA_DIR = 'data'
DATA_FILENAME = 'data.jsonl' if MODE != 'debug' else 'data_debug.jsonl'
DATA_PATH = os.path.join(DATA_DIR, DATA_FILENAME)
class Interactive:
def __init__(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
if MODE == 'debug':
return
self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload')
self.model.D = self.model.shared.embedding_dim
self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device)
self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
self.model.eval()
self.t = self.model.shared.weight[32097, 0].item()
def run(self, statement):
if MODE == 'debug':
return {
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
'statement': statement,
'logit': 0.0,
'logit_calibrated': 0.0,
'score': 0.5,
'score_calibrated': 0.5,
}
input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest', truncation='longest_first', max_length=128).input_ids.to(device)
with torch.no_grad():
output = self.model(input_ids)
last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
hidden = last_hidden_state[0, -1, :] # (D)
logit = self.linear(hidden).squeeze(-1) # ()
logit_calibrated = logit / self.t
score = logit.sigmoid()
score_calibrated = logit_calibrated.sigmoid()
return {
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
'statement': statement,
'logit': logit.item(),
'logit_calibrated': logit_calibrated.item(),
'score': score.item(),
'score_calibrated': score_calibrated.item(),
}
def runs(self, statements):
if MODE == 'debug':
return [{
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
'statement': statement,
'logit': 0.0,
'logit_calibrated': 0.0,
'score': 0.5,
'score_calibrated': 0.5,
} for statement in statements]
tok = self.tokenizer.batch_encode_plus(statements, return_tensors='pt', padding='longest')
input_ids = tok.input_ids.to(device)
attention_mask = tok.attention_mask.to(device)
with torch.no_grad():
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_indices = attention_mask.sum(dim=1, keepdim=True) - 1 # (B, 1)
last_indices = last_indices.unsqueeze(-1).expand(-1, -1, self.model.D) # (B, 1, D)
last_hidden_state = output.last_hidden_state.to(device) # (B, L, D)
hidden = last_hidden_state.gather(dim=1, index=last_indices).squeeze(1) # (B, D)
logits = self.linear(hidden).squeeze(-1) # (B)
logits_calibrated = logits / self.t
scores = logits.sigmoid()
scores_calibrated = logits_calibrated.sigmoid()
return [{
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
'statement': statement,
'logit': logit.item(),
'logit_calibrated': logit_calibrated.item(),
'score': score.item(),
'score_calibrated': score_calibrated.item(),
} for statement, logit, logit_calibrated, score, score_calibrated in zip(statements, logits, logits_calibrated, scores, scores_calibrated)]
interactive = Interactive()
try:
shutil.rmtree(DATA_DIR)
except:
pass
global repo, lock
repo = huggingface_hub.Repository(
local_dir=DATA_DIR,
clone_from=DATASET_REPO_URL,
token=HF_TOKEN_UPLOAD,
repo_type='dataset',
)
repo.git_pull()
lock = threading.Lock()
# def predict(statement, do_save=True):
# output_raw = interactive.run(statement)
# output = {
# 'True': output_raw['score_calibrated'],
# 'False': 1 - output_raw['score_calibrated'],
# }
# if do_save:
# with open(DATA_PATH, 'a') as f:
# json.dump(output_raw, f, ensure_ascii=False)
# f.write('\n')
# commit_url = repo.push_to_hub()
# print('Logged statement to dataset:')
# print('Commit URL:', commit_url)
# print(output_raw)
# print()
# return output, output_raw, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value='Please provide your feedback before trying out another statement.')
# def record_feedback(output_raw, feedback, do_save=True):
# if do_save:
# output_raw.update({ 'feedback': feedback })
# with open(DATA_PATH, 'a') as f:
# json.dump(output_raw, f, ensure_ascii=False)
# f.write('\n')
# commit_url = repo.push_to_hub()
# print('Logged feedback to dataset:')
# print('Commit URL:', commit_url)
# print(output_raw)
# print()
# return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value='Thanks for your feedback! Now you can enter another statement.')
# def record_feedback_agree(output_raw, do_save=True):
# return record_feedback(output_raw, 'agree', do_save)
# def record_feedback_disagree(output_raw, do_save=True):
# return record_feedback(output_raw, 'disagree', do_save)
def predict(statements, do_saves):
global lock, interactive
output_raws = interactive.runs(list(statements)) # statements is a tuple, but tokenizer takes a list
outputs = [{
'True': output_raw['score_calibrated'],
'False': 1 - output_raw['score_calibrated'],
} for output_raw in output_raws]
print(f'Logging statements to {DATA_FILENAME}:')
lock.acquire()
for output_raw, do_save in zip(output_raws, do_saves):
if do_save:
print(output_raw)
with open(DATA_PATH, 'a') as f:
json.dump(output_raw, f, ensure_ascii=False)
f.write('\n')
print()
lock.release()
return outputs, output_raws, \
[gr.update(visible=False) for _ in statements], \
[gr.update(visible=True) for _ in statements], \
[gr.update(visible=True) for _ in statements], \
[gr.update(visible=True) for _ in statements], \
[gr.update(visible=True) for _ in statements], \
[gr.update(value='Please share your feedback before trying out another statement.') for _ in statements]
def record_feedback(output_raws, feedback, do_saves):
global lock
print(f'Logging feedbacks to {DATA_FILENAME}:')
lock.acquire()
for output_raw, do_save in zip(output_raws, do_saves):
if do_save:
output_raw.update({ 'feedback': feedback })
print(output_raw)
with open(DATA_PATH, 'a') as f:
json.dump(output_raw, f, ensure_ascii=False)
f.write('\n')
print()
lock.release()
return [gr.update(visible=True) for _ in output_raws], \
[gr.update(visible=False) for _ in output_raws], \
[gr.update(visible=False) for _ in output_raws], \
[gr.update(visible=False) for _ in output_raws], \
[gr.update(visible=False) for _ in output_raws], \
[gr.update(value='Thanks for sharing your feedback! You can now enter another statement.') for _ in output_raws]
def record_feedback_agree(output_raws, do_saves):
return record_feedback(output_raws, 'agree', do_saves)
def record_feedback_disagree(output_raws, do_saves):
return record_feedback(output_raws, 'disagree', do_saves)
def record_feedback_uncertain(output_raws, do_saves):
return record_feedback(output_raws, 'uncertain', do_saves)
def record_feedback_outofscope(output_raws, do_saves):
return record_feedback(output_raws, 'outofscope', do_saves)
def push():
global repo, lock
lock.acquire()
if repo.is_repo_clean():
# print('No new data recorded, skipping git push ...')
# print()
pass
else:
try:
commit_url = repo.push_to_hub()
except Exception as e:
print('Failed to push to git:', e)
shutil.rmtree(DATA_DIR)
repo = huggingface_hub.Repository(
local_dir=DATA_DIR,
clone_from=DATASET_REPO_URL,
token=HF_TOKEN_UPLOAD,
repo_type='dataset',
)
repo.git_pull()
lock.release()
examples = [
# # openbookqa
# 'If a person walks in the opposite direction of a compass arrow they are walking south.',
# 'If a person walks in the opposite direction of a compass arrow they are walking north.',
# arc_easy
'A pond is different from a lake because ponds are smaller and shallower.',
'A pond is different from a lake because ponds have moving water.',
# arc_hard
'Hunting strategies are more likely to be learned rather than inherited.',
'A spotted coat is more likely to be learned rather than inherited.',
# ai2_science_elementary
'Photosynthesis uses carbon from the air to make food for plants.',
'Respiration uses carbon from the air to make food for plants.',
# ai2_science_middle
'The barometer measures atmospheric pressure.',
'The thermometer measures atmospheric pressure.',
# commonsenseqa
'People aim to complete a job at work.',
'People aim to kill animals at work.',
# qasc
'Climate is generally described in terms of local weather conditions.',
'Climate is generally described in terms of forests.',
# physical_iqa
'ice box will turn into a cooler if you add water to it.',
'ice box will turn into a cooler if you add soda to it.',
# social_iqa
'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very aggressive and talkative person.',
'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very quiet person.',
# winogrande_xl
'Sarah was a much better surgeon than Maria so Maria always got the easier cases.',
'Sarah was a much better surgeon than Maria so Sarah always got the easier cases.',
# com2sense_paired
'If you want a quick snack, getting one banana would be a good choice generally.',
'If you want a snack, getting twenty bananas would be a good choice generally.',
# sciq
'Each specific polypeptide has a unique linear sequence of amino acids.',
'Each specific polypeptide has a unique linear sequence of fatty acids.',
# quarel
'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because wet floor has more resistance.',
'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because marble floor has more resistance.',
# quartz
'If less waters falls on an area of land it will cause less plants to grow in that area.',
'If less waters falls on an area of land it will cause more plants to grow in that area.',
# cycic_mc
'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the park on January 20.',
'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the financial district on January 20.',
# comve_a
'Summer in North America is great for swimming, boating, and fishing.',
'Summer in North America is great for skiing, snowshoeing, and making a snowman.',
# csqa2
'Gas is always capable of turning into liquid under high pressure.',
'Cotton candy is sometimes made out of cotton.',
# symkd_anno
'James visits a famous landmark. As a result, James learns about the world.',
'Cliff and Andrew enter the castle. But before, Cliff needed to have been a student at the school.',
# gengen_anno
'Generally, bar patrons are capable of taking care of their own drinks.',
'Generally, ocean currents have little influence over storm intensity.',
# 'If A sits next to B and B sits next to C, then A must sit next to C.',
# 'If A sits next to B and B sits next to C, then A might not sit next to C.',
]
# input_statement = gr.Dropdown(choices=examples, label='Statement:')
# input_model = gr.Textbox(label='Commonsense statement verification model:', value=MODEL_NAME, interactive=False)
# output = gr.outputs.Label(num_top_classes=2)
# description = '''This is a demo for Vera, a commonsense statement verification model. Under development.
# β οΈ Data Collection: by default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app!'''
# gr.Interface(
# fn=predict,
# inputs=[input_statement, input_model],
# outputs=output,
# title="Vera",
# description=description,
# ).launch()
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown(
'''# Vera
Vera is a commonsense statement verification model. Under development.
Type a commonsense statement in the box below and click the submit button to see Vera's prediction on its correctness. You can try both correct and incorrect statements. If you are looking for inspiration, try the examples at the bottom of the page!
We'd love your feedback! Please indicate whether you agree or disagree with Vera's prediction (and don't mind the percentage numbers). If you're unsure or the statement doesn't have a certain correctness label, please select "Uncertain". If your input is actually not a statement about commonsense, please select "I don't think this is a statement about commonsense".
β οΈ **Intended Use**: Vera is a research prototype and may make mistakes. Do not use for making critical decisions. It is intended to predict the correctness of commonsense statements, and may be unreliable when taking input out of this scope.
β οΈ **Data Collection**: By default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app! You can opt out of this data collection by removing the checkbox below:
'''
)
with gr.Row():
with gr.Column(scale=3):
do_save = gr.Checkbox(
value=True,
label="Store data",
info="You agree to the storage of your input for research and development purposes:")
statement = gr.Textbox(placeholder='Enter a commonsense statement here, or select an example from below', label='Statement', interactive=True)
submit = gr.Button(value='Submit', variant='primary', visible=True)
with gr.Column(scale=2):
output = gr.Label(num_top_classes=2, interactive=False)
output_raw = gr.JSON(visible=False)
with gr.Row():
feedback_agree = gr.Button(value='π Agree', variant='secondary', visible=False)
feedback_uncertain = gr.Button(value='π€ Uncertain', variant='secondary', visible=False)
feedback_disagree = gr.Button(value='π Disagree', variant='secondary', visible=False)
feedback_outofscope = gr.Button(value='π« I don\'t think this a statement about commonsense', variant='secondary', visible=False)
feedback_ack = gr.Markdown(value='', visible=True, interactive=False)
gr.Markdown('\n---\n')
with gr.Row():
gr.Examples(
examples=examples,
fn=predict,
inputs=[statement],
outputs=[output, output_raw, statement, submit, feedback_agree, feedback_disagree, feedback_ack],
examples_per_page=100,
cache_examples=False,
run_on_click=False, # If we want this to be True, I suspect we need to enable the statement.submit()
)
submit.click(predict, inputs=[statement, do_save], outputs=[output, output_raw, submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
# statement.submit(predict, inputs=[statement], outputs=[output, output_raw])
feedback_agree.click(record_feedback_agree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
feedback_uncertain.click(record_feedback_uncertain, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
feedback_disagree.click(record_feedback_disagree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
feedback_outofscope.click(record_feedback_outofscope, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_uncertain, feedback_disagree, feedback_outofscope, feedback_ack], batch=True, max_batch_size=16)
demo.load(push, inputs=None, outputs=None, every=60) # Push to git every 60 seconds
demo.queue(concurrency_count=1).launch(debug=True)
|