File size: 6,428 Bytes
40e4418
7bed91f
2db50b5
d34229c
 
2db50b5
46247a1
40e4418
 
7bed91f
 
40e4418
2db50b5
7bed91f
e43f483
7bed91f
 
 
46247a1
701a70b
2db50b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46247a1
 
2db50b5
 
 
7bed91f
 
 
 
 
 
 
 
9b95f10
8d6b3c3
2db50b5
40e4418
 
 
 
 
891f199
117c8b1
40e4418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937de7f
 
40e4418
 
 
 
 
 
 
 
937de7f
40e4418
 
 
 
 
937de7f
 
40e4418
 
 
 
117c8b1
2db50b5
40e4418
 
 
2db50b5
40e4418
 
 
 
 
caad63d
2db50b5
caad63d
 
2db50b5
614af9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f16285
 
2db50b5
 
 
 
 
 
 
 
9b95f10
2db50b5
46247a1
2db50b5
9b95f10
2db50b5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
import os
import csv
import json
import uuid
import random
import pickle
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from googleapiclient.discovery import build
from google.oauth2 import service_account

USER_ID = uuid.uuid4()
SERVICE_ACCOUNT_JSON = os.environ.get('GOOGLE_SHEETS_CREDENTIALS')
creds = service_account.Credentials.from_service_account_info(json.loads(SERVICE_ACCOUNT_JSON))
SPREADSHEET_ID = '1o0iKPxWYKYKEPjqB2YwrTgrLzvGyb9ULj9tnw_cfJb0'
service = build('sheets', 'v4', credentials=creds)

with open("article_list.pkl","rb") as articles:
    article_list = tuple(pickle.load(articles))
INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"]
MODELS = [
    "biodatlab/MIReAD-Neuro-Large",
    "biodatlab/MIReAD-Neuro-Contrastive",
    "biodatlab/SciBERT-Neuro-Contrastive",
]
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
faiss_embedders = [HuggingFaceEmbeddings(
    model_name=name,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs) for name in MODELS]

vecdbs = [FAISS.load_local(index_name, faiss_embedder)
          for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)]

def get_matchup():
    choices = INDEXES
    left, right = random.sample(choices,2)
    return left, right

def get_comp(prompt):
    left, right = get_matchup()
    left_output = inference(prompt,left)
    right_output = inference(prompt,right)
    return left_output, right_output

def get_article():
    return random.choice(article_list)


def send_result(l_output, r_output, prompt, pick):
    # with open('results.csv','a') as res_file:
    #   writer = csv.writer(res_file)
    #   writer.writerow(row)
    row = [USER_ID,l_output,r_output,prompt,pick]
    row = [str(x) for x in row]
    body = {'values': [row]}
    result = service.spreadsheets().values().append(spreadsheetId=SPREADSHEET_ID, range='A1:E1',majorDimension="ROWS", valueInputOption='RAW', body=body).execute()
    print(f"Appended {result['updates']['updatedCells']} cells.")
    new_prompt = get_article()
    return new_prompt,gr.State.update(value=new_prompt)


def get_matches(query, db_name="miread_contrastive"):
    """
    Wrapper to call the similarity search on the required index
    """
    matches = vecdbs[INDEXES.index(
        db_name)].similarity_search_with_score(query, k=30)
    return matches


def inference(query, model="miread_contrastive"):
    """
    This function processes information retrieved by the get_matches() function
    Returns - Gradio update commands for the authors, abstracts and journals tablular output
    """
    matches = get_matches(query, model)
    auth_counts = {}
    n_table = []
    scores = [round(match[1].item(), 3) for match in matches]
    min_score = min(scores)
    max_score = max(scores)
    def normaliser(x): return round(1 - (x-min_score)/max_score, 3)
    i = 1
    for match in matches:
        doc = match[0]
        score = round(normaliser(round(match[1].item(), 3)), 3)
        title = doc.metadata['title']
        author = doc.metadata['authors'][0].title()
        date = doc.metadata.get('date', 'None')
        link = doc.metadata.get('link', 'None')

        # For authors
        record = [score,
                  author,
                  title,
                  link,
                  date]
        if auth_counts.get(author, 0) < 2:
            n_table.append([i,]+record)
            i += 1
            if auth_counts.get(author, 0) == 0:
                auth_counts[author] = 1
            else:
                auth_counts[author] += 1
    n_output = gr.Dataframe.update(value=n_table[:10], visible=True)
    return n_output


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# NBDT Recommendation Engine Arena")
    gr.Markdown("NBDT Recommendation Engine for Editors is a tool for neuroscience authors/abstracts/journalsrecommendation built for NBDT journal editors. \
    It aims to help an editor to find similar reviewers, abstracts, and journals to a given submitted abstract.\
    To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click on the appropriate \"Find Matches\" button.\
    Then, you can hover to authors/abstracts/journals tab to find a suggested list.\
    The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.")
    article = get_article()
    models = gr.State(value=get_matchup())
    prompt = gr.State(value=article)
    abst = gr.Textbox(value = article, label="Abstract", lines=10)
    action_btn = gr.Button(value="Get comparison")
    with gr.Group():
        with gr.Row().style(equal_height=True):
          with gr.Column(scale=1):
            l_output = gr.Dataframe(
                headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
                datatype=['number', 'number', 'str', 'str', 'str', 'str'],
                col_count=(6, "fixed"),
                wrap=True,
                visible=True,
                label='Model A',
                show_label = True,
                overflow_row_behaviour='paginate',
                scale=1
                )
          with gr.Column(scale=1):
            r_output = gr.Dataframe(
                headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'],
                datatype=['number', 'number', 'str', 'str', 'str', 'str'],
                col_count=(6, "fixed"),
                wrap=True,
                visible=True,
                label='Model B',
                show_label = True,
                overflow_row_behaviour='paginate',
                scale=1
                )
    with gr.Row().style(equal_height=True):
        l_btn = gr.Button(value="Model A is better",scale=1)
        r_btn = gr.Button(value="Model B is better",scale=1)

    action_btn.click(fn=get_comp,
        inputs=[prompt,],
        outputs=[l_output, r_output],
        api_name="arena")
    l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'),
                inputs=[l_output,r_output,prompt],
                outputs=[abst,],
                api_name="feedleft")
    r_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'),
                inputs=[l_output,r_output,prompt],
                outputs=[abst,prompt],
                api_name="feedright")

demo.launch(debug=True)