File size: 20,631 Bytes
ba1724a
 
 
7285400
 
 
 
ba1724a
7db94f3
ba1724a
7285400
6c71bbc
 
fbba6d9
4bc7b36
 
7285400
 
7db94f3
7285400
 
fbba6d9
7285400
 
 
 
6c71bbc
dfc89a9
7285400
fbba6d9
7285400
7db94f3
 
7285400
 
fbba6d9
6c71bbc
 
fbba6d9
6c71bbc
 
 
d374bc3
 
 
6c71bbc
 
 
 
23585ec
 
6c71bbc
23585ec
20ac67a
6c71bbc
 
 
 
 
a7b2b6d
 
 
 
 
6c71bbc
a7b2b6d
 
6c71bbc
 
 
 
fbba6d9
6c71bbc
 
23585ec
 
 
 
 
6c71bbc
7285400
a7b2b6d
 
 
 
dc760b4
7db94f3
 
7285400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bc7b36
 
 
 
 
 
 
 
 
fbba6d9
6c71bbc
 
 
7db94f3
fbba6d9
b4a0b98
 
 
 
 
fbba6d9
4bc7b36
fbba6d9
4bc7b36
7db94f3
4bc7b36
7db94f3
 
a7b2b6d
 
6c71bbc
fbba6d9
7db94f3
 
 
2f53a52
7db94f3
6c71bbc
 
 
 
 
7db94f3
 
 
dc760b4
 
 
6c71bbc
2f53a52
 
 
6c71bbc
 
 
2f53a52
 
 
 
 
 
6c71bbc
 
 
66113e1
 
2f53a52
 
 
 
 
 
66113e1
6c71bbc
2f53a52
6c71bbc
2f53a52
 
 
 
 
6c71bbc
 
 
 
 
2f53a52
 
 
 
 
 
 
 
5b66ffa
2f53a52
 
dc760b4
66113e1
dc760b4
dfc89a9
2f53a52
 
 
fbba6d9
ca1f4b1
 
 
 
 
 
dc760b4
 
 
ca1f4b1
2f53a52
 
 
 
6c71bbc
dfc89a9
20ac67a
dfc89a9
 
 
 
 
 
 
 
 
 
 
 
 
20ac67a
7285400
 
 
 
dc760b4
dfc89a9
7285400
 
 
 
20ac67a
7285400
4bc7b36
7285400
4bc7b36
fbba6d9
dfc89a9
5355a96
 
4bc7b36
5355a96
 
a7b2b6d
 
 
 
 
fbba6d9
a7b2b6d
 
 
 
 
fbba6d9
4bc7b36
 
5355a96
4bc7b36
 
 
 
 
 
 
 
7db94f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc760b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7285400
 
fbba6d9
 
7285400
 
 
fbba6d9
7285400
fbba6d9
7285400
fbba6d9
 
7285400
 
 
a7b2b6d
 
 
 
 
7db94f3
4bc7b36
 
7db94f3
 
 
 
 
 
7285400
fbba6d9
7285400
 
 
 
 
4bc7b36
7285400
4bc7b36
da3c141
4bc7b36
 
ba1724a
dc760b4
4bc7b36
dc760b4
fbba6d9
7db94f3
 
 
 
 
 
 
 
5e54614
7db94f3
ba1724a
7285400
6c71bbc
7285400
6c71bbc
 
 
ba1724a
7285400
 
 
 
4bc7b36
 
 
 
 
 
 
 
7285400
ba1724a
7285400
ba1724a
 
7285400
 
fbba6d9
7285400
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
import re
import gradio as gr
from scipy.sparse import load_npz
import torch
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
from transformers import BertTokenizer, BertModel
import numpy as np
import pandas as pd
from datasets import load_dataset
from gensim.models import KeyedVectors
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer

class ArxivSearch:
    def __init__(self, dataset, embedding="sbert"):
        self.dataset = dataset
        self.embedding = embedding
        self.query = None
        self.documents = []
        self.titles = []
        self.raw_texts = []
        self.arxiv_ids = []
        self.last_results = []
        self.query_encoding = None

        # model selection
        self.embedding_dropdown = gr.Dropdown(
            choices=["tfidf", "word2vec", "bert", "sbert", "clustered sbert"],
            value="sbert",
            label="Model"
            )
    
        self.plot_button = gr.Button("Show 3D Plot")

        # Gradio blocks for UI elements
        with gr.Blocks() as self.iface:
            gr.Markdown("# arXiv Search Engine")
            gr.Markdown("Search arXiv papers by keyword and embedding model.")

            self.plot_output = gr.Plot()
            
            with gr.Row():
                self.query_box = gr.Textbox(lines=1, placeholder="Enter your search query", label="Query")
                self.embedding_dropdown.render()
                self.plot_button.render()
                with gr.Column():
                    self.search_button = gr.Button("Search")

            self.output_md = gr.Markdown()

            self.query_box.submit(
                self.search_function,
                inputs=[self.query_box, self.embedding_dropdown],
                outputs=self.output_md
            )
            # self.embedding_dropdown.change(
            #     self.model_switch,
            #     inputs=[self.embedding_dropdown],
            #     outputs=self.output_md
            # )
            self.embedding_dropdown.change(
                self.search_function,
                inputs=[self.query_box, self.embedding_dropdown],
                outputs=self.output_md
            )
            self.plot_button.click(
                self.plot_3d_embeddings,
                inputs=[],
                outputs=self.plot_output
            )
            self.search_button.click(
                self.search_function,
                inputs=[self.query_box, self.embedding_dropdown],
                outputs=self.output_md
            )

        self.load_data(dataset)
        # self.load_model(embedding)
        self.load_model('tfidf')
        self.load_model('word2vec')
        self.load_model('bert')
        # self.load_model('scibert')
        # self.load_model('sbert')
        self.load_model('clustered sbert')

        self.iface.launch()

    def load_data(self, dataset):
        train_data = dataset["train"]
        for item in train_data.select(range(len(train_data))):
            text = item["text"]
            if not text or len(text.strip()) < 10:
                continue

            lines = text.splitlines()
            title_lines = []
            found_arxiv = False
            arxiv_id = None

            for line in lines:
                line_strip = line.strip()
                if not found_arxiv and line_strip.lower().startswith("arxiv:"):
                    found_arxiv = True
                    match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE)
                    if match:
                        arxiv_id = match.group(0).lower()
                elif not found_arxiv:
                    title_lines.append(line_strip)
                else:
                    if line_strip.lower().startswith("abstract"):
                        break

            title = " ".join(title_lines).strip()

            self.raw_texts.append(text.strip())
            self.titles.append(title)
            self.documents.append(text.strip())
            self.arxiv_ids.append(arxiv_id)

    def plot_dense(self, embedding, pca, results_indices):
        all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
        all_data = embedding[all_indices]
        pca.fit(all_data)
        reduced_data = pca.transform(embedding[:5000])
        reduced_results_points = pca.transform(embedding[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
        query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
        return reduced_data, reduced_results_points, query_point

    def plot_3d_embeddings(self):
        # Example: plot random points, replace with your embeddings
        pca = PCA(n_components=3)
        results_indices = [i[0] for i in self.last_results]
        
        if self.embedding == "tfidf":
            all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0]))))
            all_data = self.tfidf_matrix[all_indices].toarray()
            pca.fit(all_data)
            reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
            reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
        elif self.embedding == "word2vec":
            reduced_data, reduced_results_points, query_point = self.plot_dense(self.word2vec_embeddings, pca, results_indices)
        elif self.embedding == "bert":
            reduced_data, reduced_results_points, query_point = self.plot_dense(self.bert_embeddings, pca, results_indices)
        elif self.embedding == "sbert" or self.embedding == "clustered sbert":
            reduced_data, reduced_results_points, query_point = self.plot_dense(self.sbert_embedding, pca, results_indices)
            if self.embedding == "clustered sbert":
                cluster_colors = ["#00b7ff" if i in np.where(self.clusters == self.top_cluster_index)[0] else "#ffffff" for i in range(len(self.documents))]
        # elif self.embedding == "scibert":
        #     reduced_data, reduced_results_points, query_point = self.plot_dense(self.scibert_embeddings, pca, results_indices)
        else:
            raise ValueError(f"Unsupported embedding type: {self.embedding}")
        
        results_scores = [i[1] for i in self.last_results]

        traces = []

        trace = go.Scatter3d(
            x=reduced_data[:, 0],
            y=reduced_data[:, 1],
            z=reduced_data[:, 2],
            mode='markers',
            marker=dict(size=3.5, 
                        color="#ffffff" if self.embedding != "clustered sbert" else cluster_colors, 
                        opacity=0.2),
            name='All Documents',   
            text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
            hoverinfo='text'
        )

        traces.append(trace)

        layout = go.Layout(
            margin=dict(l=0, r=0, b=0, t=0),
            scene=dict(
            xaxis_title='PCA 1',
            yaxis_title='PCA 2',
            zaxis_title='PCA 3',
            xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
            yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
            zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
            ),
            paper_bgcolor='black',   # Outside the plotting area
            plot_bgcolor='black',    # Plotting area
            font=dict(color='white'), # Axis and legend text
            legend=dict(
            bgcolor='rgba(0,0,0,0)',   # Transparent legend background
            bordercolor='rgba(0,0,0,0)', # No border
            x=0.01,  # Place legend inside plot area (adjust as needed)
            y=0.99,
            xanchor='left',
            yanchor='top'
            )
        )

        if len(reduced_results_points) > 0:
            custom_colorscale = [
                [0.0, "#00ffea"],  # Start color (e.g., bright cyan)
                [1.0, "#ffea00"],  # End color (e.g., bright yellow)
            ]

            results_trace = go.Scatter3d(
                x=reduced_results_points[:, 0],
                y=reduced_results_points[:, 1],
                z=reduced_results_points[:, 2],
                mode='markers',
                marker=dict(size=4.25, 
                            color=results_scores, 
                            colorscale=custom_colorscale, 
                            opacity=0.99, 
                            colorbar=dict(
                                    title="Score",
                                    bgcolor='rgba(0,0,0,0)',  # <-- Transparent background for colorbar
                                    bordercolor='rgba(0,0,0,0)' # No border

                            )
                        ),
                name='Results',
                text=[f"<br>{self.documents[i][:100]}" for i in results_indices],
                hoverinfo='text'
            )

            traces.append(results_trace)

            if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
                query_trace = go.Scatter3d(
                    x=query_point[:, 0],
                    y=query_point[:, 1],
                    z=query_point[:, 2],
                    mode='markers',
                    marker=dict(size=5, color='red', opacity=0.8),
                    name='Query',
                    text=[f"<br>Query: {self.query}"],
                    hoverinfo='text'
                )
                traces.append(query_trace)

        fig = go.Figure(data=traces, layout=layout)

        return fig
    
    def keyword_match_ranking(self, query, top_n=10):
        query_terms = query.lower().split()
        query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms]
        if not query_indices:
            return []
        scores = []
        for doc_idx in range(self.tfidf_matrix.shape[0]):
            doc_vector = self.tfidf_matrix[doc_idx]
            doc_score = sum(doc_vector[0, i] for i in query_indices)
            if doc_score > 0:
                scores.append((doc_idx, doc_score))
        scores.sort(key=lambda x: x[1], reverse=True)
        return scores[:top_n]
    
    def word2vec_search(self, query, top_n=10):
        tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
        if not tokens:
            return []
        vectors = np.array([self.wv_model[word] for word in tokens])
        query_vec = np.mean(vectors, axis=0).reshape(1, -1)
        self.query_encoding = query_vec
        sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
        top_indices = sims.argsort()[::-1][:top_n]
        return [(i, sims[i]) for i in top_indices]

    def bert_search(self, query, top_n=10):
        with torch.no_grad():
            inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
            outputs = self.model(**inputs)
            query_vec = outputs.last_hidden_state[:, 0, :].numpy()

        self.query_encoding = query_vec
        sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
        top_indices = sims.argsort()[::-1][:top_n]
        print(f"sim, top_indices: {sims}, {top_indices}")
        return [(i, sims[i]) for i in top_indices]

    # def scibert_search(self, query, top_n=10):
    #     with torch.no_grad():
    #         inputs = self.sci_tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=512)
    #         outputs = self.sci_model(**inputs)
    #         query_vec = outputs.last_hidden_state[:, 0, :].numpy()

    #     self.query_encoding = query_vec
    #     sims = cosine_similarity(query_vec, self.scibert_embeddings).flatten()
    #     top_indices = sims.argsort()[::-1][:top_n]
    #     print(f"sim, top_indices: {sims}, {top_indices}")
    #     return [(i, sims[i]) for i in top_indices]

    def sbert_search(self, query, top_n=10):
        query_vec = self.sbert_model.encode([query])
        self.query_encoding = query_vec
        cos_scores = cosine_similarity(query_vec, self.sbert_embedding)[0]
        top_k_indices = np.argsort(cos_scores)[-50:][::-1]
        candidates = [dataset['train'][int(i)]['text'] for i in top_k_indices]
        scores = self.cross_encoder.predict([(query, doc) for doc in candidates])
        final_scores = 0.7 * scores + 0.3 * cos_scores[top_k_indices]
        top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
        print(f"sim, top_indices: {final_scores}, {top_indices}")
        return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]

    def clustered_sbert_search(self, query, top_n=10):
        query_vec = self.sbert_model.encode([query])
        self.query_encoding = query_vec # Store the query encoding for plotting
        cos_cluster_scores = cosine_similarity(query_vec, self.cluster_centers)[0]  # Get cosine similarity with cluster centers
        self.top_cluster_index = np.argmax(cos_cluster_scores) # Get the index of the top cluster
        cos_scores = cosine_similarity(query_vec, self.clustered_embeddings[self.top_cluster_index])[0] # Get cosine similarity within the top cluster
        top_k_indices = np.argsort(cos_scores)[-50:][::-1]  # Get top 50 indices within the top cluster (cluster internal indices)
        top_full_dataset_indices = np.where(self.clusters == self.top_cluster_index)[0][top_k_indices]  # Get the 50 indices that correspond to the full dataset
        candidates = [self.dataset['train'][int(i)]['text'] for i in top_full_dataset_indices]   # Get the 50 candidate documents
        scores = self.cross_encoder.predict([(query, doc) for doc in candidates])   # Get the 50 cross-encoder scores for the candidates
        final_scores = 0.7 * scores + 0.3 * cos_scores[top_k_indices]   # Combine the 50 cross-encoder scores with the cosine similarity scores
        top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]   # Get the top N cluster internal indices based on the final scores
        top_indices_full = np.where(self.clusters == self.top_cluster_index)[0][top_indices]    # Get the top N full dataset indices based on the final scores
        print(f"sim, top_indices: {final_scores}, {top_indices}")
        return [(i, final_scores[j]) for j, i in enumerate(top_indices_full)]

    def model_switch(self, embedding, progress=gr.Progress()):
        if self.embedding != embedding:
            old_embedding = self.embedding
            print(f"Switching model to {embedding}")
            self.load_model(embedding)
            print(f"Loaded {embedding} model")
            self.embedding = embedding
            if old_embedding == "tfidf":
                del self.tfidf_matrix
                del self.feature_names
            if old_embedding == "word2vec":
                del self.word2vec_embeddings
                del self.wv_model
            if old_embedding == "bert":
                del self.bert_embeddings
                del self.tokenizer
                del self.model
            if old_embedding == "scibert":
                del self.scibert_embeddings
                del self.sci_tokenizer
                del self.sci_model
            if old_embedding == "sbert":
                del self.sbert_model
                del self.sbert_embedding
                del self.cross_encoder
            print(f"old embedding removed")
            if hasattr(self, "query") and self.query:
                return self.search_function(self.query, self.embedding)
            else:
                return ""  # Or a message like "Model switched. Please enter a query."
        return gr.update()  # No change if embedding is the same

    def load_model(self, embedding):
        self.embedding = embedding
        if self.embedding == "tfidf":
            self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
            with open("TF-IDF embeddings/feature_names.txt", "r") as f:
                self.feature_names = [line.strip() for line in f.readlines()]
        elif self.embedding == "word2vec":
            # Use trimmed model here
            self.word2vec_embeddings = np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"]
            self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model")
        elif self.embedding == "bert":
            self.bert_embeddings = np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"]
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.model = BertModel.from_pretrained('bert-base-uncased')
            self.model.eval()
        # elif self.embedding == "scibert":
        #     self.scibert_embeddings = np.load("SciBERT_embeddings/scibert_embedding.npz")["bert_embedding"]
        #     self.sci_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
        #     self.sci_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
        #     self.sci_model.eval()
        elif self.embedding == "sbert" or self.embedding == "clustered sbert":
            self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
            self.sbert_embedding = np.load("BERT embeddings/sbert_embedding.npz")["sbert_embedding"]
            # self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
            self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-2-v2")
            if self.embedding == "clustered sbert":
                self.clusters = pd.read_csv(f'raf_clusters/cluster_labels_sbert.csv')['cluster_label'].values
                self.cluster_centers = pd.read_csv(f'BERT embeddings/sbert_cluster_centers.csv').values   
                self.clustered_embeddings = [self.sbert_embedding[self.clusters == i] for i in np.unique(self.clusters)]
        else:
            raise ValueError(f"Unsupported embedding type: {self.embedding}")
        
    def snippet_before_abstract(self, text):
        pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
        match = pattern.search(text)
        if match:
            return text[:match.start()].strip() if match.start() < 1000 else text[:100].strip()
        else:
            return text[:300].strip()

    def set_embedding(self, embedding):
        self.embedding = embedding

    def search_function(self, query, embedding, progress=gr.Progress()):
        self.set_embedding(embedding)
        self.query = query
        query = query.encode().decode('unicode_escape')  # Interpret escape sequences
        search_methods = {
            "tfidf": self.keyword_match_ranking,
            "word2vec": self.word2vec_search,
            "bert": self.bert_search,
            # "scibert": self.scibert_search,  # Uncomment if implemented
            "sbert": self.sbert_search,
            "clustered sbert": self.clustered_sbert_search,
        }

        results = search_methods.get(self.embedding, lambda q: [])(query)

        if not results:
            self.last_results = []
            return "No results found."
        
        if results:
            self.last_results = results

        output = ""
        display_rank = 1
        for idx, score in results:
            if not self.arxiv_ids[idx]:
                output += f"### Document {display_rank}\n"
                output += f"<pre>{self.documents[idx][:200]}</pre>\n\n"
            else:
                link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}"
                snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>')
                output += f"### Document {display_rank}\n"
                output += f"[arXiv Link]({link})\n\n"
                output += f"<pre>{snippet}</pre>\n\n---\n"
            display_rank += 1

        return output


if __name__ == "__main__":
    dataset = load_dataset("ccdv/arxiv-classification", "no_ref")  # replace with your dataset
    search_engine = ArxivSearch(dataset)
    search_engine.iface.launch()