File size: 4,659 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
081073f
7f7285f
 
 
a9a78de
7f7285f
 
 
 
 
 
 
8c0f1a8
773328a
 
0d63b55
a9a78de
 
 
7f7285f
 
773328a
 
 
 
 
7f7285f
773328a
7f7285f
 
 
 
773328a
 
 
 
1717bf9
7f7285f
 
 
 
 
 
1717bf9
f62e242
1717bf9
 
 
f62e242
 
 
 
7f7285f
1717bf9
 
 
 
 
 
 
7f7285f
f62e242
7f7285f
f62e242
 
 
 
 
 
 
 
 
7f7285f
1717bf9
 
f62e242
 
 
7f7285f
 
 
 
8c0f1a8
 
7f7285f
 
 
8c0f1a8
7f7285f
1717bf9
 
7f7285f
 
 
 
fc4b043
 
02cd36f
 
 
9d6c6a0
02cd36f
 
7f7285f
 
02cd36f
081073f
 
7f7285f
 
 
 
1717bf9
7f7285f
 
 
 
66eb2db
 
 
7f7285f
9d6c6a0
 
35e858b
02cd36f
d3bdb9d
0d63b55
7f7285f
d3bdb9d
 
7f7285f
 
081073f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# -*- coding: utf-8 -*-

"""
@Author     : Jiangjie Chen
@Time       : 2021/12/13 17:17
@Contact    : [email protected]
@Description:
"""

import os
import gradio as gr
from huggingface_hub import snapshot_download
from prettytable import PrettyTable
import pandas as pd
import torch

config = {
    "model_type": "roberta",
    "model_name_or_path": "roberta-large",
    "logic_lambda": 0.5,
    "prior": "random",
    "mask_rate": 0.0,
    "cand_k": 1,
    "max_seq1_length": 256,
    "max_seq2_length": 128,
    "max_num_questions": 8,
    "do_lower_case": False,
    "seed": 42,
    "n_gpu": torch.cuda.device_count(),
}

os.system('git clone https://github.com/jiangjiechen/LOREN/')
os.system('rm -r LOREN/data/')
os.system('rm -r LOREN/results/')
os.system('rm -r LOREN/models/')
os.system('mv LOREN/* ./')

model_dir = snapshot_download('Jiangjie/loren')
config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')


from src.loren import Loren


loren = Loren(config, verbose=False)
try:
    js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
except Exception as e:
    raise ValueError(e)


def highlight_phrase(text, phrase):
    text = loren.fc_client.tokenizer.clean_up_tokenization(text)
    return text.replace('<mask>', f'<i><b>{phrase}</b></i>')


def highlight_entity(text, entity):
    return text.replace(entity, f'<i><b>{entity}</b></i>')


def gradio_formatter(js, output_type):
    zebra_css = '''
    tr:nth-child(even) {
        background: #f1f1f1;
    }
    thead{
        background: #f1f1f1;
    }'''
    if output_type == 'e':
        data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]}
    elif output_type == 'z':
        p_sup, p_ref, p_nei = [], [], []
        for x in js['phrase_veracity']:
            max_idx = torch.argmax(torch.tensor(x)).tolist()
            x = ['%.4f' % xx for xx in x]
            x[max_idx] = f'<i><b>{x[max_idx]}</b></i>'
            p_sup.append(x[2])
            p_ref.append(x[0])
            p_nei.append(x[1])

        data = {
            'Claim Phrase': js['claim_phrases'],
            'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])],
            'p_SUP': p_sup,
            'p_REF': p_ref,
            'p_NEI': p_nei,
        }
    else:
        raise NotImplementedError
    data = pd.DataFrame(data)
    pt = PrettyTable(field_names=list(data.columns), 
        align='l', border=True, hrules=1, vrules=1)
    for v in data.values:
        pt.add_row(v)
    html = pt.get_html_string(attributes={
        'style': 'border-width: 2px; bordercolor: black'
    }, format=True)
    html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html
    html = html.replace('&lt;', '<').replace('&gt;', '>')
    return html


def run(claim):
    try:
        js = loren.check(claim)
    except Exception as e:
        loren.logger.error(str(e))
        loren.logger.error(claim)
        return 'Oops, something went wrong.', '', ''
    label = js['claim_veracity']
    loren.logger.warning(label + str(js))
    ev_html = gradio_formatter(js, 'e')
    z_html = gradio_formatter(js, 'z')
    return label, z_html, ev_html


iface = gr.Interface(
    fn=run,
    inputs="text",
    outputs=[
        'label',
        'html',
        'html',
    ],
    examples=['Donald Trump won the U.S. 2020 presidential election.',
              'The first inauguration of Bill Clinton was in the United States.',
              'The Cry of the Owl is based on a book by an American.',
              'Smriti Mandhana is an Indian woman.'],
    title="LOREN",
    layout='horizontal',
    description="LOREN is an interpretable Fact Verification model using Wikipedia as its knowledge source. "
                "This is a demo system for the AAAI 2022 paper: \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\"(https://arxiv.org/abs/2012.13577). "
                "See the paper for more details. You can add a *FLAG* on the bottom to record interesting or bad cases! "
                "(Note that the demo system directly retrieves evidence from an up-to-date Wikipedia, which is different from the evidence used in the paper.)",
    flagging_dir='results/flagged/',
    allow_flagging=True,
    flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise',
                      'Error: Require Commonsense', 'Error: Evidence Retrieval'],
    enable_queue=True
)
iface.launch()