File size: 4,793 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
081073f
7f7285f
 
 
a9a78de
7f7285f
 
 
 
 
 
 
8c0f1a8
773328a
 
0d63b55
a9a78de
 
 
7f7285f
 
773328a
 
 
 
 
7f7285f
773328a
7f7285f
 
 
 
773328a
 
 
 
1717bf9
7f7285f
 
 
 
 
 
1717bf9
f62e242
1717bf9
 
 
f62e242
 
 
 
7f7285f
1717bf9
 
 
 
 
 
 
7f7285f
f62e242
7f7285f
f62e242
 
 
 
 
 
 
 
 
7f7285f
1717bf9
 
f62e242
 
 
7f7285f
 
 
 
8c0f1a8
 
7f7285f
 
 
8c0f1a8
7f7285f
1717bf9
 
7f7285f
 
 
 
8c0f1a8
 
 
 
 
 
 
 
 
7f7285f
1151a3a
7f7285f
 
69b544d
081073f
 
7f7285f
 
 
 
1717bf9
7f7285f
 
 
 
66eb2db
 
 
7f7285f
5c533e1
7f7285f
 
d3bdb9d
 
0d63b55
7f7285f
d3bdb9d
 
7f7285f
 
081073f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# -*- coding: utf-8 -*-

"""
@Author     : Jiangjie Chen
@Time       : 2021/12/13 17:17
@Contact    : [email protected]
@Description:
"""

import os
import gradio as gr
from huggingface_hub import snapshot_download
from prettytable import PrettyTable
import pandas as pd
import torch

config = {
    "model_type": "roberta",
    "model_name_or_path": "roberta-large",
    "logic_lambda": 0.5,
    "prior": "random",
    "mask_rate": 0.0,
    "cand_k": 1,
    "max_seq1_length": 256,
    "max_seq2_length": 128,
    "max_num_questions": 8,
    "do_lower_case": False,
    "seed": 42,
    "n_gpu": torch.cuda.device_count(),
}

os.system('git clone https://github.com/jiangjiechen/LOREN/')
os.system('rm -r LOREN/data/')
os.system('rm -r LOREN/results/')
os.system('rm -r LOREN/models/')
os.system('mv LOREN/* ./')

model_dir = snapshot_download('Jiangjie/loren')
config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')


from src.loren import Loren


loren = Loren(config, verbose=False)
try:
    js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
except Exception as e:
    raise ValueError(e)


def highlight_phrase(text, phrase):
    text = loren.fc_client.tokenizer.clean_up_tokenization(text)
    return text.replace('<mask>', f'<i><b>{phrase}</b></i>')


def highlight_entity(text, entity):
    return text.replace(entity, f'<i><b>{entity}</b></i>')


def gradio_formatter(js, output_type):
    zebra_css = '''
    tr:nth-child(even) {
        background: #f1f1f1;
    }
    thead{
        background: #f1f1f1;
    }'''
    if output_type == 'e':
        data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]}
    elif output_type == 'z':
        p_sup, p_ref, p_nei = [], [], []
        for x in js['phrase_veracity']:
            max_idx = torch.argmax(torch.tensor(x)).tolist()
            x = ['%.4f' % xx for xx in x]
            x[max_idx] = f'<i><b>{x[max_idx]}</b></i>'
            p_sup.append(x[2])
            p_ref.append(x[0])
            p_nei.append(x[1])

        data = {
            'Claim Phrase': js['claim_phrases'],
            'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])],
            'p_SUP': p_sup,
            'p_REF': p_ref,
            'p_NEI': p_nei,
        }
    else:
        raise NotImplementedError
    data = pd.DataFrame(data)
    pt = PrettyTable(field_names=list(data.columns), 
        align='l', border=True, hrules=1, vrules=1)
    for v in data.values:
        pt.add_row(v)
    html = pt.get_html_string(attributes={
        'style': 'border-width: 2px; bordercolor: black'
    }, format=True)
    html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html
    html = html.replace('&lt;', '<').replace('&gt;', '>')
    return html


def run(claim):
    # js = {
    #     'id': 0,
    #     'evidence': ['EVIDENCE1', 'EVIDENCE2'],
    #     'question': ['QUESTION1', 'QUESTION2'],
    #     'claim_phrases': ['CLAIMPHRASE1', 'CLAIMPHRASE2'],
    #     'local_premises': [['E1 ' * 100, 'E1 ' * 100, 'E1 ' * 10], ['E2', 'E2', 'E2']],
    #     'phrase_veracity': [[0.1, 0.5, 0.4], [0.1, 0.7, 0.2]],
    #     'claim_veracity': 'SUPPORT'
    # }
    js = loren.check(claim)
    loren.logger.warning(str(js))
    ev_html = gradio_formatter(js, 'e')
    z_html = gradio_formatter(js, 'z')
    return js['claim_veracity'], z_html, ev_html


iface = gr.Interface(
    fn=run,
    inputs="text",
    outputs=[
        'label',
        'html',
        'html',
    ],
    examples=['Donald Trump won the U.S. 2020 presidential election.',
              'The first inauguration of Bill Clinton was in the United States.',
              'The Cry of the Owl is based on a book by an American.',
              'Smriti Mandhana is an Indian woman.'],
    title="LOREN",
    layout='horizontal',
    description="LOREN is an interpretable Fact Verification model against Wikipedia. "
                "This is a demo system for \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\". "
                "See the paper for technical details. You can add a *FLAG* on the bottom to record interesting or bad cases! "
                "(Note that the demo system directly retrieves evidence from an up-to-date Wikipedia, which is different from the evidence used in the paper.)",
    flagging_dir='results/flagged/',
    allow_flagging=True,
    flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise',
                      'Error: Require Commonsense', 'Error: Evidence Retrieval'],
    enable_queue=True
)
iface.launch()