File size: 3,291 Bytes
7f7285f
 
 
 
 
 
 
 
 
 
081073f
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
081073f
 
7f7285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
081073f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-

"""
@Author     : Jiangjie Chen
@Time       : 2021/12/13 17:17
@Contact    : [email protected]
@Description:
"""

import os
import gradio as gr
from src.loren import Loren
from huggingface_hub import snapshot_download
from prettytable import PrettyTable
import pandas as pd

config = {
    "input": "demo",
    "model_type": "roberta",
    "model_name_or_path": "roberta-large",
    "logic_lambda": 0.5,
    "prior": "random",
    "mask_rate": 0.0,
    "cand_k": 3,
    "max_seq2_length": 256,
    "max_seq1_length": 128,
    "max_num_questions": 8
}

model_dir = snapshot_download('Jiangjie/loren')

config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')

loren = Loren(config)
try:
    # js = {
    #     'id': 0,
    #     'evidence': ['EVIDENCE1', 'EVIDENCE2'],
    #     'question': ['QUESTION1', 'QUESTION2'],
    #     'claim_phrase': ['CLAIMPHRASE1', 'CLAIMPHRASE2'],
    #     'local_premise': [['E1 ' * 100, 'E1' * 100, 'E1' * 10], ['E2', 'E2', 'E2']],
    #     'phrase_veracity': [[0.1, 0.5, 0.4], [0.1, 0.7, 0.2]],
    #     'claim_veracity': 'SUPPORT'
    # }
    js = loren.check('Donald Trump won the 2020 U.S. presidential election.')
except Exception as e:
    raise ValueError(e)


def gradio_formatter(js, output_type):
    if output_type == 'e':
        data = {'Evidence': js['evidence']}
    elif output_type == 'z':
        data = {
            'Claim Phrase': js['claim_phrase'],
            'Local Premise': [x[0] for x in js['local_premise']],
            'p_SUP': [round(x[2], 4) for x in js['phrase_veracity']],
            'p_REF': [round(x[0], 4) for x in js['phrase_veracity']],
            'p_NEI': [round(x[1], 4) for x in js['phrase_veracity']],
        }
    else:
        raise NotImplementedError
    data = pd.DataFrame(data)
    pt = PrettyTable(field_names=list(data.columns))
    for v in data.values:
        pt.add_row(v)

    html = pt.get_html_string(attributes={
        'style': 'border-width: 1px; border-collapse: collapse',
    }, format=True)
    return html


def run(claim):
    js = loren.check(claim)
    ev_html = gradio_formatter(js, 'e')
    z_html = gradio_formatter(js, 'z')
    return ev_html, z_html, js['claim_veracity'], js


iface = gr.Interface(
    fn=run,
    inputs="text",
    outputs=[
        'html',
        'html',
        'label',
        'json'
    ],
    examples=['Donald Trump won the U.S. 2020 presidential election.',
              'The first inauguration of Bill Clinton was in the United States.'],
    title="LOREN",
    layout='vertical',
    description="LOREN is an interpretable Fact Verification model against Wikipedia. "
                "This is a demo system for \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\". "
                "See the paper for technical details. You can add FLAG on the bottom to record interesting or bad cases!",
    flagging_dir='results/flagged/',
    allow_flagging=True,
    flagging_options=['Good Case!', 'Error: MRC', 'Error: Parsing',
                      'Error: Commonsense', 'Error: Evidence', 'Error: Other'],
    enable_queue=True
)
iface.launch()