File size: 4,023 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
081073f
7f7285f
 
 
a9a78de
7f7285f
 
 
 
 
 
 
 
773328a
 
0d63b55
a9a78de
 
 
7f7285f
 
773328a
 
 
 
 
7f7285f
773328a
 
 
 
7f7285f
 
 
 
773328a
 
 
 
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dc8f22
3e34cb7
7f7285f
 
 
 
 
 
 
 
 
 
 
 
a9a78de
 
 
7f7285f
 
 
 
 
 
 
 
 
081073f
 
7f7285f
 
 
 
 
 
 
 
 
 
66eb2db
 
 
7f7285f
3e34cb7
7f7285f
 
66eb2db
 
0d63b55
7f7285f
 
 
 
 
081073f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-

"""
@Author     : Jiangjie Chen
@Time       : 2021/12/13 17:17
@Contact    : [email protected]
@Description:
"""

import os
import gradio as gr
from huggingface_hub import snapshot_download
from prettytable import PrettyTable
import pandas as pd
import torch

config = {
    "model_type": "roberta",
    "model_name_or_path": "roberta-large",
    "logic_lambda": 0.5,
    "prior": "random",
    "mask_rate": 0.0,
    "cand_k": 3,
    "max_seq1_length": 256,
    "max_seq2_length": 128,
    "max_num_questions": 8,
    "do_lower_case": False,
    "seed": 42,
    "n_gpu": torch.cuda.device_count(),
}

os.system('git clone https://github.com/jiangjiechen/LOREN/')
os.system('rm -r LOREN/data/')
os.system('rm -r LOREN/results/')
os.system('rm -r LOREN/models/')
os.system('mv LOREN/* ./')

# os.makedirs('data/', exist_ok=True)
# os.system('wget -O data/fever.db https://s3-eu-west-1.amazonaws.com/fever.public/wiki_index/fever.db')

model_dir = snapshot_download('Jiangjie/loren')
config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')


from src.loren import Loren


loren = Loren(config)
try:
    # js = {
    #     'id': 0,
    #     'evidence': ['EVIDENCE1', 'EVIDENCE2'],
    #     'question': ['QUESTION1', 'QUESTION2'],
    #     'claim_phrase': ['CLAIMPHRASE1', 'CLAIMPHRASE2'],
    #     'local_premise': [['E1 ' * 100, 'E1' * 100, 'E1' * 10], ['E2', 'E2', 'E2']],
    #     'phrase_veracity': [[0.1, 0.5, 0.4], [0.1, 0.7, 0.2]],
    #     'claim_veracity': 'SUPPORT'
    # }
    js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
except Exception as e:
    raise ValueError(e)


def gradio_formatter(js, output_type):
    if output_type == 'e':
        data = {'Evidence': js['evidence']}
    elif output_type == 'z':
        data = {
            'Claim Phrase': js['claim_phrases'],
            'Local Premise': js['local_premises'],
            'p_SUP': [round(x[2], 4) for x in js['phrase_veracity']],
            'p_REF': [round(x[0], 4) for x in js['phrase_veracity']],
            'p_NEI': [round(x[1], 4) for x in js['phrase_veracity']],
        }
    else:
        raise NotImplementedError
    data = pd.DataFrame(data)
    pt = PrettyTable(field_names=list(data.columns))
    for v in data.values:
        pt.add_row(v)

    html = pt.get_html_string(attributes={
        'style': 'border-width: 1px; border-collapse: collapse',
        'align': 'left',
        'border': '1'
    }, format=True)
    return html


def run(claim):
    js = loren.check(claim)
    ev_html = gradio_formatter(js, 'e')
    z_html = gradio_formatter(js, 'z')
    return ev_html, z_html, js['claim_veracity'], js


iface = gr.Interface(
    fn=run,
    inputs="text",
    outputs=[
        'html',
        'html',
        'label',
        'json'
    ],
    examples=['Donald Trump won the U.S. 2020 presidential election.',
              'The first inauguration of Bill Clinton was in the United States.',
              'The Cry of the Owl is based on a book by an American.',
              'Smriti Mandhana is an Indian woman.'],
    title="LOREN",
    layout='horizontal',
    description="LOREN is an interpretable Fact Verification model against Wikipedia. "
                "This is a demo system for \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\". "
                "See the paper for technical details. You can add FLAG on the bottom to record interesting or bad cases! \n"
                "*Note that the demo system directly retrieves evidence from an up-to-date Wikipedia, which is different from the evidence used in the paper.",
    flagging_dir='results/flagged/',
    allow_flagging=True,
    flagging_options=['Good Case!', 'Error: MRC', 'Error: Parsing',
                      'Error: Commonsense', 'Error: Evidence', 'Error: Other'],
    enable_queue=True
)
iface.launch()