File size: 1,124 Bytes
38d8e22
 
eb56186
e69e8e4
 
 
eb56186
 
38d8e22
eb56186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e69e8e4
 
 
 
 
eb56186
 
 
 
 
 
51ec97a
eb56186
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-

"""
@Author     : Jiangjie Chen
@Time       : 2021/12/13 17:17
@Contact    : [email protected]
@Description:
"""

import os
import gradio as gr
from huggingface_hub import snapshot_download
from prettytable import PrettyTable
import pandas as pd
import torch
import traceback

config = {
    "model_type": "roberta",
    "model_name_or_path": "roberta-large",
    "logic_lambda": 0.5,
    "prior": "random",
    "mask_rate": 0.0,
    "cand_k": 1,
    "max_seq1_length": 256,
    "max_seq2_length": 128,
    "max_num_questions": 8,
    "do_lower_case": False,
    "seed": 42,
    "n_gpu": torch.cuda.device_count(),
}

os.system('git clone https://github.com/jiangjiechen/LOREN/')
os.system('rm -r LOREN/data/')
os.system('rm -r LOREN/results/')
os.system('rm -r LOREN/models/')
os.system('mv LOREN/* ./')

model_dir = snapshot_download('Jiangjie/loren')
config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/')
config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/')
config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/')

from src.loren import Loren