4kasha commited on
Commit
bcdda34
·
1 Parent(s): eba5fa8

initial commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.dat filter=lfs diff=lfs merge=lfs -text
_arxiv.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Optional
5
+ import string
6
+ import nltk
7
+ import arxiv
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def extract_title_abst(arxiv_id: str):
12
+ try:
13
+ paper = next(arxiv.Search(id_list=[arxiv_id]).results())
14
+ doc = paper.title + ' ' + paper.summary
15
+ except:
16
+ doc = None
17
+ return doc
18
+
19
+ def doc_to_ids(
20
+ doc: Optional[str],
21
+ word_to_id_: dict[str, int],
22
+ stemming: bool,
23
+ lower: bool = True,
24
+ ):
25
+ from nltk.stem.porter import PorterStemmer
26
+
27
+ if not doc:
28
+ y = []
29
+ else:
30
+ if lower:
31
+ doc = doc.lower()
32
+ doc = "".join([char for char in doc if char not in string.punctuation])
33
+ words = nltk.word_tokenize(doc)
34
+ if stemming:
35
+ porter = PorterStemmer()
36
+ words = [porter.stem(word) for word in words]
37
+
38
+ # Consider out-of-vocabulary cases, if y == []: no matched results
39
+ y = [word_to_id_[word] for word in words if word in word_to_id_]
40
+ # pick up keywords only once
41
+ #y = list(set(y))
42
+ return y
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import nltk
4
+ import pandas as pd
5
+ import numpy as np
6
+ import gradio as gr
7
+
8
+ fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s"
9
+ logging.basicConfig(level=logging.WARNING, format=fmt)
10
+ logger = logging.getLogger()
11
+ logger.setLevel(logging.INFO)
12
+
13
+ from utils import load_matrix, query_to_ids, search
14
+ from _arxiv import extract_title_abst, doc_to_ids
15
+
16
+ nltk.download('punkt')
17
+
18
+ def get_args():
19
+ return {
20
+ 'acl_data_file': 'data/acl-pub-info-2019-title-abst.parquet',
21
+ 'docs_rep_file': 'data/Docs-rep-2019-h500.npy',
22
+ 'r_matrix_file': 'data/Rmatrix-2019-h500.dat',
23
+ 'vocab_file': 'data/vocab_2019.npy',
24
+ 'topk': 20,
25
+ 'metric': 'INNER_PRODUCT', # choices=['COSINE', 'INNER_PRODUCT']
26
+ }
27
+
28
+ class ObjectView(object):
29
+ def __init__(self, d): self.__dict__ = d
30
+
31
+
32
+ def _format(s: float, year: str, authors: str, title: str, url: str):
33
+ authors = ', '.join(authors.replace(',','').replace('\\', '').split(' and\n'))
34
+ authors = re.sub('[{}]', '', authors)
35
+ title = re.sub('[{\}]', '', title)
36
+ title_with_url_markdown = f'[{title}]({url})'
37
+ url = url.rstrip('/')
38
+ pdf_url = f'[click]({url}.pdf)'
39
+ return [round(s,2), year, title_with_url_markdown, authors, pdf_url]
40
+
41
+ def main(args: ObjectView):
42
+ df = pd.read_parquet(args.acl_data_file)
43
+ #logger.info(f'document size: {len(df)}')
44
+ word_to_id_ = np.load(args.vocab_file, allow_pickle=True).item()
45
+ D, R = load_matrix(args.docs_rep_file, args.r_matrix_file, word_to_id_)
46
+
47
+ def _search(query: str):
48
+ results = []
49
+ y = query_to_ids(query, word_to_id_, stemming=True)
50
+ if y==[]:
51
+ return [[None,'N/A', 'N/A', 'N/A', 'N/A']]
52
+ else:
53
+ scores, docids = search(args, df, args.topk, y, R, D)
54
+ for s, year, authors, title, url in zip(scores[docids], df.iloc[docids]["year"], df.iloc[docids]["author"], df.iloc[docids]["title"], df.iloc[docids]["url"]):
55
+ results.append(_format(s, year, authors, title, url))
56
+ return results
57
+
58
+ def _search_arxiv(arxiv_id: str):
59
+ results = []
60
+ doc = extract_title_abst(arxiv_id)
61
+ y = doc_to_ids(doc, word_to_id_, stemming=True)
62
+ if y==[]:
63
+ return [[None,'N/A', 'N/A', 'N/A', 'N/A']]
64
+ else:
65
+ scores, docids = search(args, df, args.topk, y, R, D)
66
+ for s, year, authors, title, url in zip(scores[docids], df.iloc[docids]["year"], df.iloc[docids]["author"], df.iloc[docids]["title"], df.iloc[docids]["url"]):
67
+ results.append(_format(s, year, authors, title, url))
68
+ return results
69
+
70
+ with gr.Blocks() as demo:
71
+ gr.HTML(
72
+ """
73
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
74
+ <div
75
+ style="
76
+ display: inline-flex;
77
+ align-items: center;
78
+ gap: 1rem;
79
+ font-size: 1.75rem;
80
+ "
81
+ >
82
+ <svg width="68" height="46" xmlns="http://www.w3.org/2000/svg">
83
+ <path
84
+ d="M 41.977553,-2.8421709e-014 C 41.977553,1.76178 41.977553,1.44211 41.977553,3.0158 L 7.4869054,3.0158 L 0,3.0158 L 0,10.50079 L 0,38.47867 L 0,46 L 7.4869054,46 L 49.500802,46 L 56.987708,46 L 68,46 L 68,30.99368 L 56.987708,30.99368 L 56.987708,10.50079 L 56.987708,3.0158 C 56.987708,1.44211 56.987708,1.76178 56.987708,-2.8421709e-014 L 41.977553,-2.8421709e-014 z M 15.010155,17.98578 L 41.977553,17.98578 L 41.977553,30.99368 L 15.010155,30.99368 L 15.010155,17.98578 z "
85
+ style="fill:#ed1c24;fill-opacity:1;fill-rule:evenodd;stroke:none;stroke-width:12.89541149;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
86
+ />
87
+ </svg>
88
+ <h1 style="font-weight: 900; margin-bottom: 0">
89
+ ACL2Vec
90
+ </h1>
91
+ </div>
92
+ <p style="margin: 15px 0 5px; font-size: 100%; text-align: justify">
93
+ This is a light-weighted version of <a href=http://clml.ism.ac.jp/ACL2Vec/>ACL2Vec keyword search</a>, implemented in a totally statistical manner.
94
+ Start typing below to search papers limited to 2019 onwards and up to September 2022.
95
+ </p>
96
+ </div>
97
+ """)
98
+ with gr.Row():
99
+ inputs = gr.Textbox(placeholder="Input keywords separated by spaces.", show_label=False)
100
+ inputs_arxiv = gr.Textbox(placeholder="Input arxiv number and press Enter to find similar papers.", show_label=False)
101
+
102
+ outputs = gr.Dataframe(
103
+ headers=['score', 'year', 'title', 'authors', 'PDF'],
104
+ datatype=["number", "str", "markdown", "str", "markdown"],
105
+ col_count=(5, "fixed"),
106
+ wrap=True,
107
+ label=f"top-{args.topk} results"
108
+ )
109
+ inputs.change(_search, inputs, outputs)
110
+ inputs_arxiv.submit(_search_arxiv, inputs_arxiv, outputs)
111
+
112
+ demo.launch(
113
+ #share=True,
114
+ debug=True
115
+ )
116
+
117
+ if __name__ == '__main__':
118
+ args = ObjectView(get_args())
119
+ main(args)
data/Docs-rep-2019-h500.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea0d111544401419d19ced34fa254e6f1f4a9a1f21056c45857a0290bc0735f3
3
+ size 39740128
data/Rmatrix-2019-h500.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbcf1fdb7633825254a2de82008be03371ccabc600704382334161baa6c0cfb0
3
+ size 5610000
data/acl-pub-info-2019-title-abst.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc0e1f45980d44ebbfb265a4e56b4cd694d791ad695c8156f8ca09ca71d57146
3
+ size 12547039
data/vocab_2019.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dcf906ee75dbc2c8f2c09a2de25c8b90b104e48cf7bd1581953210af6ebabd6
3
+ size 53098
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ arxiv==1.4.3
2
+ gradio==3.18.0
3
+ nltk==3.7
4
+ numpy==1.21.6
5
+ pandas==1.3.5
6
+ scikit-learn==1.0.2
utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import argparse
5
+ import re
6
+ import string
7
+
8
+ import nltk
9
+ import pandas
10
+ import pandas as pd
11
+ import numpy as np
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def load_matrix(
17
+ d_file: str,
18
+ r_file: str,
19
+ word_to_id_: dict[str, int]
20
+ ):
21
+ D = np.load(d_file)
22
+ R = np.memmap(r_file, dtype='float32', mode='r', shape=(D.shape[-1],len(word_to_id_)))
23
+ logger.info(f'D size: {D.shape}, R size: {R.shape}')
24
+ return D, R
25
+
26
+ def query_to_ids(
27
+ query: str,
28
+ word_to_id_: dict[str, int],
29
+ stemming: bool,
30
+ lower: bool = True,
31
+ ):
32
+ from nltk.stem.porter import PorterStemmer
33
+
34
+ if lower:
35
+ query = query.lower()
36
+ # TODO: weight "*" process
37
+ query = "".join([char for char in query if char not in string.punctuation])
38
+ words = nltk.word_tokenize(query)
39
+ if stemming:
40
+ porter = PorterStemmer()
41
+ words = [porter.stem(word) for word in words]
42
+
43
+ # Consider out-of-vocabulary cases, if y == []: no matched results
44
+ y = [word_to_id_[word] for word in words if word in word_to_id_]
45
+
46
+ return y
47
+
48
+ def query_to_vec(
49
+ R: np.ndarray,
50
+ y: list[int]
51
+ ):
52
+ qvec = np.zeros((R.shape[0], ))
53
+ for ind in y:
54
+ qvec += R[:,ind]
55
+ return qvec
56
+
57
+
58
+ def search(
59
+ args: argparse.Namespace,
60
+ df: pandas.DataFrame,
61
+ k: int,
62
+ y: list[int],
63
+ R: np.ndarray,
64
+ D: np.ndarray
65
+ ):
66
+ qvec = query_to_vec(R, y)
67
+ if args.metric=='COSINE':
68
+ scores = cosine_similarity([qvec], D)[0]
69
+ elif args.metric=='INNER_PRODUCT':
70
+ scores = D @ qvec
71
+ docids = np.argsort(scores)[::-1][:k]
72
+
73
+ return scores, docids