vera / app.py
liujch1998's picture
WIP
0ef49e6
raw
history blame
11.4 kB
import gradio as gr
import os
import torch
import transformers
import huggingface_hub
import datetime
import json
import shutil
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# To suppress the following warning:
# huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
os.environ["TOKENIZERS_PARALLELISM"] = "false"
HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD']
HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD']
MODE = os.environ['MODE'] # 'debug' or 'prod'
MODEL_NAME = 'liujch1998/vera'
DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset"
DATA_DIR = 'data'
DATA_PATH = os.path.join(DATA_DIR, 'data.jsonl')
try:
shutil.rmtree(DATA_DIR)
except:
pass
repo = huggingface_hub.Repository(
local_dir=DATA_DIR,
clone_from=DATASET_REPO_URL,
token=HF_TOKEN_UPLOAD,
repo_type='dataset',
)
repo.git_pull()
class Interactive:
def __init__(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
if MODE == 'debug':
return
self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1, dtype=self.model.dtype).to(device)
self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
self.model.eval()
self.t = self.model.shared.weight[32097, 0].item()
def run(self, statement):
if MODE == 'debug':
return {
'logit': 0.0,
'logit_calibrated': 0.0,
'score': 0.5,
'score_calibrated': 0.5,
}
input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
with torch.no_grad():
output = self.model(input_ids)
last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
hidden = last_hidden_state[0, -1, :] # (D)
logit = self.linear(hidden).squeeze(-1) # ()
logit_calibrated = logit / self.t
score = logit.sigmoid()
score_calibrated = logit_calibrated.sigmoid()
return {
'logit': logit.item(),
'logit_calibrated': logit_calibrated.item(),
'score': score.item(),
'score_calibrated': score_calibrated.item(),
}
interactive = Interactive()
def predict(statement, do_save=True):
result = interactive.run(statement)
output = {
'True': result['score_calibrated'],
'False': 1 - result['score_calibrated'],
}
output_raw = {
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
'statement': statement,
}
output_raw.update(result)
if do_save:
with open(DATA_PATH, 'a') as f:
json.dump(output_raw, f, ensure_ascii=False)
f.write('\n')
commit_url = repo.push_to_hub()
print('Logged statement to dataset:')
print('Commit URL:', commit_url)
print(output_raw)
print()
return output, output_raw, gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value='Please provide your feedback before trying out another statement.')
def record_feedback(output_raw, feedback, do_save=True):
if do_save:
output_raw.update({ 'feedback': feedback })
with open(DATA_PATH, 'a') as f:
json.dump(output_raw, f, ensure_ascii=False)
f.write('\n')
commit_url = repo.push_to_hub()
print('Logged feedback to dataset:')
print('Commit URL:', commit_url)
print(output_raw)
print()
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value='Thanks for your feedback! Now you can enter another statement.')
def record_feedback_agree(output_raw, do_save=True):
return record_feedback(output_raw, 'agree', do_save)
def record_feedback_disagree(output_raw, do_save=True):
return record_feedback(output_raw, 'disagree', do_save)
examples = [
# openbookqa
'If a person walks in the opposite direction of a compass arrow they are walking south.',
'If a person walks in the opposite direction of a compass arrow they are walking north.',
# arc_easy
'A pond is different from a lake because ponds are smaller and shallower.',
'A pond is different from a lake because ponds have moving water.',
# arc_hard
'Hunting strategies are more likely to be learned rather than inherited.',
'A spotted coat is more likely to be learned rather than inherited.',
# ai2_science_elementary
'Photosynthesis uses carbon from the air to make food for plants.',
'Respiration uses carbon from the air to make food for plants.',
# ai2_science_middle
'The barometer measures atmospheric pressure.',
'The thermometer measures atmospheric pressure.',
# commonsenseqa
'People aim to complete a job at work.',
'People aim to kill animals at work.',
# qasc
'Climate is generally described in terms of local weather conditions.',
'Climate is generally described in terms of forests.',
# physical_iqa
'ice box will turn into a cooler if you add water to it.',
'ice box will turn into a cooler if you add soda to it.',
# social_iqa
'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very aggressive and talkative person.',
'Kendall opened their mouth to speak and what came out shocked everyone. Kendall is a very quiet person.',
# winogrande_xl
'Sarah was a much better surgeon than Maria so Maria always got the easier cases.',
'Sarah was a much better surgeon than Maria so Sarah always got the easier cases.',
# com2sense_paired
'If you want a quick snack, getting one banana would be a good choice generally.',
'If you want a snack, getting twenty bananas would be a good choice generally.',
# sciq
'Each specific polypeptide has a unique linear sequence of amino acids.',
'Each specific polypeptide has a unique linear sequence of fatty acids.',
# quarel
'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because wet floor has more resistance.',
'Tommy glided across the marble floor with ease, but slipped and fell on the wet floor because marble floor has more resistance.',
# quartz
'If less waters falls on an area of land it will cause less plants to grow in that area.',
'If less waters falls on an area of land it will cause more plants to grow in that area.',
# cycic_mc
'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the park on January 20.',
'In U.S. spring, Rob visits the financial district every day. In U.S. winter, Rob visits the park every day. Rob will go to the financial district on January 20.',
# comve_a
'Summer in North America is great for swimming, boating, and fishing.',
'Summer in North America is great for skiing, snowshoeing, and making a snowman.',
# csqa2
'Gas is always capable of turning into liquid under high pressure.',
'Cotton candy is sometimes made out of cotton.',
# symkd_anno
'James visits a famous landmark. As a result, James learns about the world.',
'Cliff and Andrew enter the castle. But before, Cliff needed to have been a student at the school.',
# gengen_anno
'Generally, bar patrons are capable of taking care of their own drinks.',
'Generally, ocean currents have little influence over storm intensity.',
# 'If A sits next to B and B sits next to C, then A must sit next to C.',
# 'If A sits next to B and B sits next to C, then A might not sit next to C.',
]
# input_statement = gr.Dropdown(choices=examples, label='Statement:')
# input_model = gr.Textbox(label='Commonsense statement verification model:', value=MODEL_NAME, interactive=False)
# output = gr.outputs.Label(num_top_classes=2)
# description = '''This is a demo for Vera, a commonsense statement verification model. Under development.
# ⚠️ Data Collection: by default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app!'''
# gr.Interface(
# fn=predict,
# inputs=[input_statement, input_model],
# outputs=output,
# title="Vera",
# description=description,
# ).launch()
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown(
'''# Vera
This is a demo for Vera, a commonsense statement verification model. Under development.
⚠️ Data Collection: by default, we are collecting the inputs entered in this app to further improve and evaluate the model. Do not share any personal or sensitive information while using the app!
'''
)
with gr.Row():
with gr.Column(scale=2):
do_save = gr.Checkbox(
value=True,
label="Store data",
info="You agree to the storage of your input data for research and development purposes:")
statement = gr.Textbox(placeholder='Enter a commonsense statement here', label='Statement', interactive=True)
submit = gr.Button(value='Submit', variant='primary', visible=True)
with gr.Column(scale=1):
output = gr.Label(num_top_classes=2, interactive=False)
output_raw = gr.JSON(visible=False)
with gr.Row():
feedback_agree = gr.Button(value='πŸ‘ Agree', variant='secondary', visible=False)
feedback_disagree = gr.Button(value='πŸ‘Ž Disagree', variant='secondary', visible=False)
feedback_ack = gr.Markdown(value='', visible=True, interactive=False)
with gr.Row():
gr.Examples(
examples=examples,
fn=predict,
inputs=[statement],
outputs=[output, output_raw, statement, submit, feedback_agree, feedback_disagree, feedback_ack],
examples_per_page=100,
cache_examples=False,
run_on_click=False, # If we want this to be True, I suspect we need to enable the statement.submit()
)
submit.click(predict, inputs=[statement, do_save], outputs=[output, output_raw, submit, feedback_agree, feedback_disagree, feedback_ack])
# statement.submit(predict, inputs=[statement], outputs=[output, output_raw])
feedback_agree.click(record_feedback_agree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack])
feedback_disagree.click(record_feedback_disagree, inputs=[output_raw, do_save], outputs=[submit, feedback_agree, feedback_disagree, feedback_ack])
demo.queue(concurrency_count=16).launch(debug=True)