4kasha
commited on
Commit
·
bcdda34
1
Parent(s):
eba5fa8
initial commit
Browse files- .gitattributes +1 -0
- _arxiv.py +42 -0
- app.py +119 -0
- data/Docs-rep-2019-h500.npy +3 -0
- data/Rmatrix-2019-h500.dat +3 -0
- data/acl-pub-info-2019-title-abst.parquet +3 -0
- data/vocab_2019.npy +3 -0
- requirements.txt +6 -0
- utils.py +73 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.dat filter=lfs diff=lfs merge=lfs -text
|
_arxiv.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Optional
|
5 |
+
import string
|
6 |
+
import nltk
|
7 |
+
import arxiv
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
def extract_title_abst(arxiv_id: str):
|
12 |
+
try:
|
13 |
+
paper = next(arxiv.Search(id_list=[arxiv_id]).results())
|
14 |
+
doc = paper.title + ' ' + paper.summary
|
15 |
+
except:
|
16 |
+
doc = None
|
17 |
+
return doc
|
18 |
+
|
19 |
+
def doc_to_ids(
|
20 |
+
doc: Optional[str],
|
21 |
+
word_to_id_: dict[str, int],
|
22 |
+
stemming: bool,
|
23 |
+
lower: bool = True,
|
24 |
+
):
|
25 |
+
from nltk.stem.porter import PorterStemmer
|
26 |
+
|
27 |
+
if not doc:
|
28 |
+
y = []
|
29 |
+
else:
|
30 |
+
if lower:
|
31 |
+
doc = doc.lower()
|
32 |
+
doc = "".join([char for char in doc if char not in string.punctuation])
|
33 |
+
words = nltk.word_tokenize(doc)
|
34 |
+
if stemming:
|
35 |
+
porter = PorterStemmer()
|
36 |
+
words = [porter.stem(word) for word in words]
|
37 |
+
|
38 |
+
# Consider out-of-vocabulary cases, if y == []: no matched results
|
39 |
+
y = [word_to_id_[word] for word in words if word in word_to_id_]
|
40 |
+
# pick up keywords only once
|
41 |
+
#y = list(set(y))
|
42 |
+
return y
|
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
import nltk
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
|
9 |
+
logging.basicConfig(level=logging.WARNING, format=fmt)
|
10 |
+
logger = logging.getLogger()
|
11 |
+
logger.setLevel(logging.INFO)
|
12 |
+
|
13 |
+
from utils import load_matrix, query_to_ids, search
|
14 |
+
from _arxiv import extract_title_abst, doc_to_ids
|
15 |
+
|
16 |
+
nltk.download('punkt')
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
return {
|
20 |
+
'acl_data_file': 'data/acl-pub-info-2019-title-abst.parquet',
|
21 |
+
'docs_rep_file': 'data/Docs-rep-2019-h500.npy',
|
22 |
+
'r_matrix_file': 'data/Rmatrix-2019-h500.dat',
|
23 |
+
'vocab_file': 'data/vocab_2019.npy',
|
24 |
+
'topk': 20,
|
25 |
+
'metric': 'INNER_PRODUCT', # choices=['COSINE', 'INNER_PRODUCT']
|
26 |
+
}
|
27 |
+
|
28 |
+
class ObjectView(object):
|
29 |
+
def __init__(self, d): self.__dict__ = d
|
30 |
+
|
31 |
+
|
32 |
+
def _format(s: float, year: str, authors: str, title: str, url: str):
|
33 |
+
authors = ', '.join(authors.replace(',','').replace('\\', '').split(' and\n'))
|
34 |
+
authors = re.sub('[{}]', '', authors)
|
35 |
+
title = re.sub('[{\}]', '', title)
|
36 |
+
title_with_url_markdown = f'[{title}]({url})'
|
37 |
+
url = url.rstrip('/')
|
38 |
+
pdf_url = f'[click]({url}.pdf)'
|
39 |
+
return [round(s,2), year, title_with_url_markdown, authors, pdf_url]
|
40 |
+
|
41 |
+
def main(args: ObjectView):
|
42 |
+
df = pd.read_parquet(args.acl_data_file)
|
43 |
+
#logger.info(f'document size: {len(df)}')
|
44 |
+
word_to_id_ = np.load(args.vocab_file, allow_pickle=True).item()
|
45 |
+
D, R = load_matrix(args.docs_rep_file, args.r_matrix_file, word_to_id_)
|
46 |
+
|
47 |
+
def _search(query: str):
|
48 |
+
results = []
|
49 |
+
y = query_to_ids(query, word_to_id_, stemming=True)
|
50 |
+
if y==[]:
|
51 |
+
return [[None,'N/A', 'N/A', 'N/A', 'N/A']]
|
52 |
+
else:
|
53 |
+
scores, docids = search(args, df, args.topk, y, R, D)
|
54 |
+
for s, year, authors, title, url in zip(scores[docids], df.iloc[docids]["year"], df.iloc[docids]["author"], df.iloc[docids]["title"], df.iloc[docids]["url"]):
|
55 |
+
results.append(_format(s, year, authors, title, url))
|
56 |
+
return results
|
57 |
+
|
58 |
+
def _search_arxiv(arxiv_id: str):
|
59 |
+
results = []
|
60 |
+
doc = extract_title_abst(arxiv_id)
|
61 |
+
y = doc_to_ids(doc, word_to_id_, stemming=True)
|
62 |
+
if y==[]:
|
63 |
+
return [[None,'N/A', 'N/A', 'N/A', 'N/A']]
|
64 |
+
else:
|
65 |
+
scores, docids = search(args, df, args.topk, y, R, D)
|
66 |
+
for s, year, authors, title, url in zip(scores[docids], df.iloc[docids]["year"], df.iloc[docids]["author"], df.iloc[docids]["title"], df.iloc[docids]["url"]):
|
67 |
+
results.append(_format(s, year, authors, title, url))
|
68 |
+
return results
|
69 |
+
|
70 |
+
with gr.Blocks() as demo:
|
71 |
+
gr.HTML(
|
72 |
+
"""
|
73 |
+
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
74 |
+
<div
|
75 |
+
style="
|
76 |
+
display: inline-flex;
|
77 |
+
align-items: center;
|
78 |
+
gap: 1rem;
|
79 |
+
font-size: 1.75rem;
|
80 |
+
"
|
81 |
+
>
|
82 |
+
<svg width="68" height="46" xmlns="http://www.w3.org/2000/svg">
|
83 |
+
<path
|
84 |
+
d="M 41.977553,-2.8421709e-014 C 41.977553,1.76178 41.977553,1.44211 41.977553,3.0158 L 7.4869054,3.0158 L 0,3.0158 L 0,10.50079 L 0,38.47867 L 0,46 L 7.4869054,46 L 49.500802,46 L 56.987708,46 L 68,46 L 68,30.99368 L 56.987708,30.99368 L 56.987708,10.50079 L 56.987708,3.0158 C 56.987708,1.44211 56.987708,1.76178 56.987708,-2.8421709e-014 L 41.977553,-2.8421709e-014 z M 15.010155,17.98578 L 41.977553,17.98578 L 41.977553,30.99368 L 15.010155,30.99368 L 15.010155,17.98578 z "
|
85 |
+
style="fill:#ed1c24;fill-opacity:1;fill-rule:evenodd;stroke:none;stroke-width:12.89541149;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
|
86 |
+
/>
|
87 |
+
</svg>
|
88 |
+
<h1 style="font-weight: 900; margin-bottom: 0">
|
89 |
+
ACL2Vec
|
90 |
+
</h1>
|
91 |
+
</div>
|
92 |
+
<p style="margin: 15px 0 5px; font-size: 100%; text-align: justify">
|
93 |
+
This is a light-weighted version of <a href=http://clml.ism.ac.jp/ACL2Vec/>ACL2Vec keyword search</a>, implemented in a totally statistical manner.
|
94 |
+
Start typing below to search papers limited to 2019 onwards and up to September 2022.
|
95 |
+
</p>
|
96 |
+
</div>
|
97 |
+
""")
|
98 |
+
with gr.Row():
|
99 |
+
inputs = gr.Textbox(placeholder="Input keywords separated by spaces.", show_label=False)
|
100 |
+
inputs_arxiv = gr.Textbox(placeholder="Input arxiv number and press Enter to find similar papers.", show_label=False)
|
101 |
+
|
102 |
+
outputs = gr.Dataframe(
|
103 |
+
headers=['score', 'year', 'title', 'authors', 'PDF'],
|
104 |
+
datatype=["number", "str", "markdown", "str", "markdown"],
|
105 |
+
col_count=(5, "fixed"),
|
106 |
+
wrap=True,
|
107 |
+
label=f"top-{args.topk} results"
|
108 |
+
)
|
109 |
+
inputs.change(_search, inputs, outputs)
|
110 |
+
inputs_arxiv.submit(_search_arxiv, inputs_arxiv, outputs)
|
111 |
+
|
112 |
+
demo.launch(
|
113 |
+
#share=True,
|
114 |
+
debug=True
|
115 |
+
)
|
116 |
+
|
117 |
+
if __name__ == '__main__':
|
118 |
+
args = ObjectView(get_args())
|
119 |
+
main(args)
|
data/Docs-rep-2019-h500.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea0d111544401419d19ced34fa254e6f1f4a9a1f21056c45857a0290bc0735f3
|
3 |
+
size 39740128
|
data/Rmatrix-2019-h500.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbcf1fdb7633825254a2de82008be03371ccabc600704382334161baa6c0cfb0
|
3 |
+
size 5610000
|
data/acl-pub-info-2019-title-abst.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc0e1f45980d44ebbfb265a4e56b4cd694d791ad695c8156f8ca09ca71d57146
|
3 |
+
size 12547039
|
data/vocab_2019.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2dcf906ee75dbc2c8f2c09a2de25c8b90b104e48cf7bd1581953210af6ebabd6
|
3 |
+
size 53098
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
arxiv==1.4.3
|
2 |
+
gradio==3.18.0
|
3 |
+
nltk==3.7
|
4 |
+
numpy==1.21.6
|
5 |
+
pandas==1.3.5
|
6 |
+
scikit-learn==1.0.2
|
utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
import re
|
6 |
+
import string
|
7 |
+
|
8 |
+
import nltk
|
9 |
+
import pandas
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
def load_matrix(
|
17 |
+
d_file: str,
|
18 |
+
r_file: str,
|
19 |
+
word_to_id_: dict[str, int]
|
20 |
+
):
|
21 |
+
D = np.load(d_file)
|
22 |
+
R = np.memmap(r_file, dtype='float32', mode='r', shape=(D.shape[-1],len(word_to_id_)))
|
23 |
+
logger.info(f'D size: {D.shape}, R size: {R.shape}')
|
24 |
+
return D, R
|
25 |
+
|
26 |
+
def query_to_ids(
|
27 |
+
query: str,
|
28 |
+
word_to_id_: dict[str, int],
|
29 |
+
stemming: bool,
|
30 |
+
lower: bool = True,
|
31 |
+
):
|
32 |
+
from nltk.stem.porter import PorterStemmer
|
33 |
+
|
34 |
+
if lower:
|
35 |
+
query = query.lower()
|
36 |
+
# TODO: weight "*" process
|
37 |
+
query = "".join([char for char in query if char not in string.punctuation])
|
38 |
+
words = nltk.word_tokenize(query)
|
39 |
+
if stemming:
|
40 |
+
porter = PorterStemmer()
|
41 |
+
words = [porter.stem(word) for word in words]
|
42 |
+
|
43 |
+
# Consider out-of-vocabulary cases, if y == []: no matched results
|
44 |
+
y = [word_to_id_[word] for word in words if word in word_to_id_]
|
45 |
+
|
46 |
+
return y
|
47 |
+
|
48 |
+
def query_to_vec(
|
49 |
+
R: np.ndarray,
|
50 |
+
y: list[int]
|
51 |
+
):
|
52 |
+
qvec = np.zeros((R.shape[0], ))
|
53 |
+
for ind in y:
|
54 |
+
qvec += R[:,ind]
|
55 |
+
return qvec
|
56 |
+
|
57 |
+
|
58 |
+
def search(
|
59 |
+
args: argparse.Namespace,
|
60 |
+
df: pandas.DataFrame,
|
61 |
+
k: int,
|
62 |
+
y: list[int],
|
63 |
+
R: np.ndarray,
|
64 |
+
D: np.ndarray
|
65 |
+
):
|
66 |
+
qvec = query_to_vec(R, y)
|
67 |
+
if args.metric=='COSINE':
|
68 |
+
scores = cosine_similarity([qvec], D)[0]
|
69 |
+
elif args.metric=='INNER_PRODUCT':
|
70 |
+
scores = D @ qvec
|
71 |
+
docids = np.argsort(scores)[::-1][:k]
|
72 |
+
|
73 |
+
return scores, docids
|