HenryStephen commited on
Commit
43515a8
1 Parent(s): c831d35

Deploying RepoSnipy

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # Streamlit configs
163
+ .streamlit/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 RepoSnipy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,96 @@
1
- ---
2
- title: RepoSnipy
3
- emoji: 📈
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: streamlit
7
- sdk_version: 1.31.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RepoSnipy 🐉
2
+ Neural search engine for discovering semantically similar Python repositories on GitHub.
3
+
4
+ ## Demo
5
+ **TODO --- Update the gif file!!!**
6
+
7
+ Searching an indexed repository:
8
+
9
+ ![Search Indexed Repo Demo](assets/search.gif)
10
+
11
+ ## About
12
+
13
+ RepoSnipy is a neural search engine built with [streamlit](https://github.com/streamlit/streamlit) and [docarray](https://github.com/docarray/docarray). You can query a public Python repository hosted on GitHub and find popular repositories that are semantically similar to it.
14
+
15
+ Compared to the previous generation of [RepoSnipy](https://github.com/RepoAnalysis/RepoSnipy), the latest version has such new features below:
16
+ * It uses the [RepoSim4Py](https://github.com/RepoMining/RepoSim4Py), which is based on [RepoSim4Py pipeline](https://huggingface.co/Henry65/RepoSim4Py), to create multi-level embeddings for Python repositories.
17
+ * Multi-level embeddings --- code, docstring, readme, requirement, and repository.
18
+ * It uses the [SciBERT](https://arxiv.org/abs/1903.10676) model to analyse repository topics and to generate embeddings for topics.
19
+ * Transfer multiple topics into one cluster --- it uses a [KMeans](data/kmeans_model_scibert.pkl) model to analyse topic embeddings and to cluster repositories based on topics.
20
+ * **SimilarityCal --- TODO update!!!**
21
+
22
+ We have created a [vector dataset](data/index.bin) (stored as docarray index) of approximate 9700 GitHub Python repositories that has license and over 300 stars by the time of February 2024. The accordingly generated clusters were putted in a [json dataset](data/repo_clusters.json) (stored repo-cluster as key-values).
23
+
24
+ ## Installation
25
+
26
+ ### Prerequisites
27
+ * Python 3.11
28
+ * pip
29
+
30
+ ### Installation with code
31
+ We recommend to install first a [conda](https://conda.io/projects/conda/en/latest/index.html) environment with `python 3.11`. Then, you can download the repository. See below:
32
+ ```bash
33
+ conda create --name py311 python=3.11
34
+ conda activate py311
35
+ git clone https://github.com/RepoMining/RepoSnipy
36
+ ```
37
+ After downloading the repository, you need install the required package. **Make sure the python and pip you used are both from conda environment!**
38
+ For the following:
39
+ ```bash
40
+ cd RepoSnipy
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ ### Usage
45
+ Then run the app on your local machine using:
46
+ ```bash
47
+ streamlit run app.py
48
+ ```
49
+ or
50
+ ```bash
51
+ python -m streamlit run app.py
52
+ ```
53
+ Importantly, to avoid unnecessary conflict (like version conflict, or package location conflict), you should ensure that **streamlit you used is from conda environment**!
54
+
55
+ ### Dataset
56
+ As mentioned above, RepoSnipy needs [vector](data/index.bin), [json](data/repo_clusters.json) dataset and [KMeans](data/kmeans_model_scibert.pkl) model when you start up it. For your convenience, we have uploaded them in the folder [data](data) of this repository.
57
+
58
+ To provide research-oriented meaning, we have provided the following scripts for you to recreate them:
59
+ ```bash
60
+ cd data
61
+ python create_index.py # For creating vector dataset (binary files)
62
+ python generate_cluster.py # For creating useful cluster model and information (KMeans model and json files representing repo-clusters)
63
+ ```
64
+
65
+ More details can refer to these two scripts above. When you run scripts above, you will get the following files:
66
+ 1. Generated by [create_index.py](data/create_index.py):
67
+ ```bash
68
+ repositories.txt # the original repositories file
69
+ invalid_repositories.txt # the invalid repositories file, including invalid repositories
70
+ filtered_repositories.txt # the final repositories file, removing duplicated and invalid repositories
71
+ index{i}_{i * target_sub_length}.bin # the sub-index files, where i means number of sub-repositories and target_sub_length means sub-repositories length
72
+ index.bin # the index file merged by sub-index files and removed numpy zero arrays
73
+ ```
74
+ 2. Generated by [generate_cluster.py](data/generate_cluster.py):
75
+ ```
76
+ repo_clusters.json # a json file representing repo-cluster dictionary
77
+ kmeans_model_scibert.pkl # a pickle file for storing kmeans model based on topic embeddings generated by scibert model
78
+ ```
79
+
80
+
81
+ ## Evaluation
82
+ **TODO ---- update!!!**
83
+
84
+ The [evaluation script](evaluate.py) finds all combinations of repository pairs in the dataset and calculates the cosine similarity between their embeddings. It also checks if they share at least one topic (except for `python` and `python3`). Then we compare them and use ROC AUC score to evaluate the embeddings performance. The resultant dataframe containing all pairs of cosine similarity and topics similarity can be downloaded from [here](https://huggingface.co/datasets/Lazyhope/RepoSnipy_eval/tree/main), including both code embeddings and docstring embeddings evaluations. The resultant ROC AUC score of code embeddings is around 0.84, and the docstring embeddings is around 0.81.
85
+
86
+ ## License
87
+
88
+ Distributed under the MIT License. See [LICENSE](LICENSE) for more information.
89
+
90
+ ## Acknowledgments
91
+
92
+ The model and the fine-tuning dataset used:
93
+
94
+ * [UniXCoder](https://arxiv.org/abs/2203.03850)
95
+ * [AdvTest](https://arxiv.org/abs/1909.09436)
96
+ * [SciBERT](https://arxiv.org/abs/1903.10676)
app.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import nltk
4
+ import joblib
5
+ import torch
6
+ import pandas as pd
7
+ import numpy as np
8
+ import streamlit as st
9
+ from pathlib import Path
10
+ from torch import nn
11
+ from docarray import DocList
12
+ from docarray.index import InMemoryExactNNIndex
13
+ from transformers import pipeline
14
+ from transformers import AutoTokenizer, AutoModel
15
+ from data.repo_doc import RepoDoc
16
+ from data.pair_classifier import PairClassifier
17
+ from nltk.stem import WordNetLemmatizer
18
+
19
+ nltk.download("wordnet")
20
+ KMEANS_MODEL_PATH = Path(__file__).parent.joinpath("data/kmeans_model_scibert.pkl")
21
+ SIMILARITY_CAL_MODEL_PATH = Path(__file__).parent.joinpath("data/SimilarityCal_model_NO1.pt")
22
+ device = (
23
+ "cuda"
24
+ if torch.cuda.is_available()
25
+ else "mps"
26
+ if torch.backends.mps.is_available()
27
+ else "cpu"
28
+ )
29
+
30
+ # 1. Product environment
31
+ # INDEX_PATH = Path(__file__).parent.joinpath("data/index.bin")
32
+ # CLUSTER_PATH = Path(__file__).parent.joinpath("data/repo_clusters.json")
33
+ SCIBERT_MODEL_PATH = "allenai/scibert_scivocab_uncased"
34
+
35
+
36
+ # 2. Developing environment
37
+ INDEX_PATH = Path(__file__).parent.joinpath("data/index_test.bin")
38
+ CLUSTER_PATH = Path(__file__).parent.joinpath("data/repo_clusters_test.json")
39
+ # SCIBERT_MODEL_PATH = Path(__file__).parent.joinpath("data/scibert_scivocab_uncased") # Download locally
40
+
41
+
42
+ @st.cache_resource(show_spinner="Loading repositories basic information...")
43
+ def load_index():
44
+ """
45
+ The function to load the index file and return a RepoDoc object with default value
46
+ :return: index and a RepoDoc object with default value
47
+ """
48
+ default_doc = RepoDoc(
49
+ name="",
50
+ topics=[],
51
+ stars=0,
52
+ license="",
53
+ code_embedding=None,
54
+ doc_embedding=None,
55
+ readme_embedding=None,
56
+ requirement_embedding=None,
57
+ repository_embedding=None
58
+ )
59
+
60
+ return InMemoryExactNNIndex[RepoDoc](index_file_path=INDEX_PATH), default_doc
61
+
62
+
63
+ @st.cache_resource(show_spinner="Loading repositories clusters...")
64
+ def load_repo_clusters():
65
+ """
66
+ The function to load the repo-clusters file
67
+ :return: a dictionary with the repo-clusters
68
+ """
69
+ with open(CLUSTER_PATH, "r") as file:
70
+ repo_clusters = json.load(file)
71
+
72
+ return repo_clusters
73
+
74
+
75
+ @st.cache_resource(show_spinner="Loading RepoSim4Py pipeline model...")
76
+ def load_pipeline_model():
77
+ """
78
+ The function to load RepoSim4Py pipeline model
79
+ :return: a HuggingFace pipeline
80
+ """
81
+ # Option 1 --- Download model by HuggingFace username/model_name
82
+ model_path = "Henry65/RepoSim4Py"
83
+
84
+ # Option 2 --- Download model locally
85
+ # model_path = Path(__file__).parent.joinpath("data/RepoSim4Py")
86
+
87
+ return pipeline(
88
+ model=model_path,
89
+ trust_remote_code=True,
90
+ device_map="auto"
91
+ )
92
+
93
+
94
+ @st.cache_resource(show_spinner="Loading SciBERT model...")
95
+ def load_scibert_model():
96
+ """
97
+ The function to load SciBERT model
98
+ :return: tokenizer and model
99
+ """
100
+ tokenizer = AutoTokenizer.from_pretrained(SCIBERT_MODEL_PATH)
101
+ scibert_model = AutoModel.from_pretrained(SCIBERT_MODEL_PATH).to(device)
102
+ return tokenizer, scibert_model
103
+
104
+
105
+ @st.cache_resource(show_spinner="Loading KMeans model...")
106
+ def load_kmeans_model():
107
+ """
108
+ The function to load KMeans model
109
+ :return: a KMeans model
110
+ """
111
+ return joblib.load(KMEANS_MODEL_PATH)
112
+
113
+
114
+ @st.cache_resource(show_spinner="Loading SimilarityCal model...")
115
+ def load_similaritycal_model():
116
+ sim_cal_model = PairClassifier()
117
+ sim_cal_model.load_state_dict(torch.load(SIMILARITY_CAL_MODEL_PATH))
118
+ sim_cal_model = sim_cal_model.to(device)
119
+ sim_cal_model = sim_cal_model.eval()
120
+ return sim_cal_model
121
+
122
+
123
+ def generate_scibert_embedding(tokenizer, scibert_model, text):
124
+ """
125
+ The function for generating SciBERT embeddings based on topic text
126
+ :param text: the topic text
127
+ :return: topic embeddings
128
+ """
129
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
130
+ outputs = scibert_model(**inputs)
131
+ # Use mean pooling for sentence representation
132
+ embeddings = outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy()
133
+ return embeddings
134
+
135
+
136
+ @st.cache_data(show_spinner=False)
137
+ def run_pipeline_model(_model, repo_name, github_token):
138
+ """
139
+ The function to generate repo_info by using pipeline model
140
+ :param _model: pipeline
141
+ :param repo_name: the name of repository
142
+ :param github_token: GitHub token
143
+ :return: the information generated by the pipeline
144
+ """
145
+ with st.spinner(
146
+ f"Downloading and extracting the {repo_name}, this may take a while..."
147
+ ):
148
+ extracted_infos = _model.preprocess(repo_name, github_token=github_token)
149
+
150
+ if not extracted_infos:
151
+ return None
152
+
153
+ with st.spinner(f"Generating embeddings for {repo_name}..."):
154
+ repo_info = _model.forward(extracted_infos)[0]
155
+
156
+ return repo_info
157
+
158
+
159
+ def run_index_search(index, query, search_field, limit):
160
+ """
161
+ The function to search at index file based on query and limit
162
+ :param index: the index
163
+ :param query: query
164
+ :param search_field: which field to search for
165
+ :param limit: page limit
166
+ :return: a dataframe with search results
167
+ """
168
+ top_matches, scores = index.find(
169
+ query=query, search_field=search_field, limit=limit
170
+ )
171
+
172
+ search_results = top_matches.to_dataframe()
173
+ search_results["scores"] = scores
174
+
175
+ return search_results
176
+
177
+
178
+ def run_cluster_search(repo_clusters, repo_name_list):
179
+ """
180
+ The function to search cluster number for such repositories.
181
+ :param repo_clusters: dictionary with repo-clusters
182
+ :param repo_name_list: list or array represent repository names
183
+ :return: cluster number list
184
+ """
185
+ clusters = []
186
+ for repo_name in repo_name_list:
187
+ clusters.append(repo_clusters[repo_name])
188
+ return clusters
189
+
190
+
191
+ def run_similaritycal_search(index, repo_clusters, model, query_doc, query_cluster_number, limit, same_cluster=True):
192
+ """
193
+ The function to run SimilarityCal model.
194
+ :param index: index file
195
+ :param repo_clusters: repo-clusters json file
196
+ :param model: SimilarityCal model
197
+ :param query_doc: query repo doc
198
+ :param query_cluster_number: query repo cluster number
199
+ :param limit: limit
200
+ :param same_cluster: whether searching for same cluster
201
+ :return: result dataframe
202
+ """
203
+ docs = index._docs
204
+ input_embeddings_list = []
205
+ result_dl = DocList[RepoDoc]()
206
+ for doc in docs:
207
+ if same_cluster and query_cluster_number != repo_clusters[doc.name]:
208
+ continue
209
+ if doc.name != query_doc.name:
210
+ e1, e2 = (torch.Tensor(query_doc.repository_embedding),
211
+ torch.Tensor(doc.repository_embedding))
212
+ input_embeddings = torch.cat([e1, e2])
213
+ input_embeddings_list.append(input_embeddings)
214
+ result_dl.append(doc)
215
+
216
+ input_embeddings_list = torch.stack(input_embeddings_list).to(device)
217
+ softmax = nn.Softmax(dim=1).to(device)
218
+ model_output = model(input_embeddings_list)
219
+ similarity_scores = softmax(model_output)[:, 1].cpu().detach().numpy()
220
+ df = result_dl.to_dataframe()
221
+ df["scores"] = similarity_scores
222
+ return df.sort_values(by='scores', ascending=False).reset_index(drop=True).head(limit)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ # Loading dataset and models
227
+ index, default_doc = load_index()
228
+ repo_clusters = load_repo_clusters()
229
+ pipeline_model = load_pipeline_model()
230
+ lemmatizer = WordNetLemmatizer()
231
+ tokenizer, scibert_model = load_scibert_model()
232
+ kmeans = load_kmeans_model()
233
+ sim_cal_model = load_similaritycal_model()
234
+
235
+ # Setting the sidebar
236
+ with st.sidebar:
237
+ st.text_input(
238
+ label="GitHub Token",
239
+ key="github_token",
240
+ type="password",
241
+ placeholder="Paste your GitHub token here",
242
+ help="Consider setting GitHub token to avoid hitting rate limits: https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token",
243
+ )
244
+
245
+ st.slider(
246
+ label="Search results limit",
247
+ min_value=1,
248
+ max_value=100,
249
+ value=10,
250
+ step=1,
251
+ key="search_results_limit",
252
+ help="Limit the number of search results",
253
+ )
254
+
255
+ st.multiselect(
256
+ label="Display columns",
257
+ options=["scores", "name", "topics", "cluster number", "stars", "license"],
258
+ default=["scores", "name", "topics", "cluster number", "stars", "license"],
259
+ help="Select columns to display in the search results",
260
+ key="display_columns",
261
+ )
262
+
263
+ # Setting the main content
264
+ st.title("RepoSnipy")
265
+
266
+ st.text_input(
267
+ "Enter a GitHub repository URL or owner/repository (case-sensitive):",
268
+ value="",
269
+ max_chars=200,
270
+ placeholder="numpy/numpy",
271
+ key="repo_input",
272
+ )
273
+
274
+ st.checkbox(
275
+ label="Add/Update this repository to the index",
276
+ value=False,
277
+ key="update_index",
278
+ help="Encode the latest version of this repository and add/update it to the index",
279
+ )
280
+
281
+ # Setting the search button
282
+ search = st.button("Search")
283
+ # The regular expression for repository
284
+ repo_regex = r"^((git@|http(s)?://)?(github\.com)(/|:))?(?P<owner>[\w.-]+)(/)(?P<repo>[\w.-]+?)(\.git)?(/)?$"
285
+
286
+ if search:
287
+ match_res = re.match(repo_regex, st.session_state.repo_input)
288
+ # 1. Repository can be matched
289
+ if match_res is not None:
290
+ repo_name = f"{match_res.group('owner')}/{match_res.group('repo')}"
291
+ records = index.filter({"name": {"$eq": repo_name}})
292
+ # 1) Building the query information
293
+ query_doc = default_doc.copy() if not records else records[0]
294
+ # 2) Recording the cluster number
295
+ cluster_number = -1 if not records else repo_clusters[repo_name]
296
+
297
+ # Importance 1 ---- situation need to update repository information and cluster number
298
+ if st.session_state.update_index or not records:
299
+ # 1) Updating repository information by using RepoSim4Py pipeline
300
+ repo_info = run_pipeline_model(pipeline_model, repo_name, st.session_state.github_token)
301
+ if repo_info is None:
302
+ st.error("Repository not found or invalid GitHub token!")
303
+ st.stop()
304
+
305
+ query_doc.name = repo_info["name"]
306
+ query_doc.topics = repo_info["topics"]
307
+ query_doc.stars = repo_info["stars"]
308
+ query_doc.license = repo_info["license"]
309
+ query_doc.code_embedding = None if np.all(repo_info["mean_code_embedding"] == 0) else repo_info[
310
+ "mean_code_embedding"].reshape(-1)
311
+ query_doc.doc_embedding = None if np.all(repo_info["mean_doc_embedding"] == 0) else repo_info[
312
+ "mean_doc_embedding"].reshape(-1)
313
+ query_doc.readme_embedding = None if np.all(repo_info["mean_readme_embedding"] == 0) else repo_info[
314
+ "mean_readme_embedding"].reshape(-1)
315
+ query_doc.requirement_embedding = None if np.all(repo_info["mean_requirement_embedding"] == 0) else \
316
+ repo_info["mean_requirement_embedding"].reshape(-1)
317
+ query_doc.repository_embedding = None if np.all(repo_info["mean_repo_embedding"] == 0) else repo_info[
318
+ "mean_repo_embedding"].reshape(-1)
319
+
320
+ # 2) Updating cluster number
321
+ topics_text = ' '.join(
322
+ [lemmatizer.lemmatize(topic.lower().replace('-', ' ')) for topic in query_doc.topics])
323
+ topic_embeddings = generate_scibert_embedding(tokenizer, scibert_model, topics_text)
324
+ cluster_number = int(kmeans.predict(topic_embeddings)[0])
325
+
326
+ # Importance 2 ---- update index file and repository clusters file
327
+ if st.session_state.update_index:
328
+ if not query_doc.license:
329
+ st.warning(
330
+ "License is missing in this repository and will not be persisted!"
331
+ )
332
+ elif (query_doc.code_embedding is None) and (query_doc.doc_embedding is None) and (
333
+ query_doc.requirement_embedding is None) and (query_doc.readme_embedding is None) and (
334
+ query_doc.repository_embedding is None):
335
+ st.warning(
336
+ "This repository has no such useful information (code, docstring, readme and requirement) extracted and will not be persisted!"
337
+ )
338
+ else:
339
+ index.index(query_doc)
340
+ repo_clusters[query_doc.name] = cluster_number
341
+
342
+ with st.spinner("Persisting the index and repository clusters..."):
343
+ index.persist(str(INDEX_PATH))
344
+ with open(CLUSTER_PATH, "w") as file:
345
+ json.dump(repo_clusters, file, indent=4)
346
+ st.success("Repository updated to the index!")
347
+
348
+ load_index.clear()
349
+ load_repo_clusters.clear()
350
+
351
+ st.session_state["query_doc"] = query_doc
352
+ st.session_state["cluster_number"] = cluster_number
353
+
354
+ # 2. Repository cannot be matched
355
+ else:
356
+ st.error("Invalid input!")
357
+
358
+ # Starting to query
359
+ if "query_doc" in st.session_state:
360
+ query_doc = st.session_state.query_doc
361
+ cluster_number = st.session_state.cluster_number
362
+ limit = st.session_state.search_results_limit
363
+
364
+ # Showing the query repository information
365
+ st.dataframe(
366
+ pd.DataFrame(
367
+ [
368
+ {
369
+ "name": query_doc.name,
370
+ "topics": query_doc.topics,
371
+ "cluster number": cluster_number,
372
+ "stars": query_doc.stars,
373
+ "license": query_doc.license,
374
+ }
375
+ ],
376
+ )
377
+ )
378
+
379
+ display_columns = st.session_state.display_columns
380
+ code_sim_tab, doc_sim_tab, readme_sim_tab, requirement_sim_tab, repo_sim_tab, same_cluster_tab, diff_cluster_tab = st.tabs(
381
+ ["Code_sim", "Docstring_sim", "Readme_sim", "Requirement_sim",
382
+ "Repository_sim", "Same_cluster", "Different_cluster"])
383
+
384
+ if query_doc.code_embedding is not None:
385
+ code_sim_res = run_index_search(index, query_doc, "code_embedding", limit)
386
+ cluster_numbers = run_cluster_search(repo_clusters, code_sim_res["name"])
387
+ code_sim_res["cluster number"] = cluster_numbers
388
+ code_sim_tab.dataframe(code_sim_res[display_columns])
389
+ else:
390
+ code_sim_tab.error("No function code was extracted for this repository!")
391
+
392
+ if query_doc.doc_embedding is not None:
393
+ doc_sim_res = run_index_search(index, query_doc, "doc_embedding", limit)
394
+ cluster_numbers = run_cluster_search(repo_clusters, doc_sim_res["name"])
395
+ doc_sim_res["cluster number"] = cluster_numbers
396
+ doc_sim_tab.dataframe(doc_sim_res[display_columns])
397
+ else:
398
+ doc_sim_tab.error("No function docstring was extracted for this repository!")
399
+
400
+ if query_doc.readme_embedding is not None:
401
+ readme_sim_res = run_index_search(index, query_doc, "readme_embedding", limit)
402
+ cluster_numbers = run_cluster_search(repo_clusters, readme_sim_res["name"])
403
+ readme_sim_res["cluster number"] = cluster_numbers
404
+ readme_sim_tab.dataframe(readme_sim_res[display_columns])
405
+ else:
406
+ readme_sim_tab.error("No readme file was extracted for this repository!")
407
+
408
+ if query_doc.requirement_embedding is not None:
409
+ requirement_sim_res = run_index_search(index, query_doc, "requirement_embedding", limit)
410
+ cluster_numbers = run_cluster_search(repo_clusters, requirement_sim_res["name"])
411
+ requirement_sim_res["cluster number"] = cluster_numbers
412
+ requirement_sim_tab.dataframe(requirement_sim_res[display_columns])
413
+ else:
414
+ requirement_sim_tab.error("No requirement file was extracted for this repository!")
415
+
416
+ if query_doc.repository_embedding is not None:
417
+ repo_sim_res = run_index_search(index, query_doc, "repository_embedding", limit)
418
+ cluster_numbers = run_cluster_search(repo_clusters, repo_sim_res["name"])
419
+ repo_sim_res["cluster number"] = cluster_numbers
420
+ repo_sim_tab.dataframe(repo_sim_res[display_columns])
421
+ else:
422
+ repo_sim_tab.error("No such useful information was extracted for this repository!")
423
+
424
+ if cluster_number is not None and query_doc.repository_embedding is not None:
425
+ same_cluster_df = run_similaritycal_search(index, repo_clusters, sim_cal_model,
426
+ query_doc, cluster_number, limit,
427
+ same_cluster=True)
428
+ diff_cluster_df = run_similaritycal_search(index, repo_clusters, sim_cal_model,
429
+ query_doc, cluster_number, limit,
430
+ same_cluster=False)
431
+ same_cluster_numbers = run_cluster_search(repo_clusters, same_cluster_df["name"])
432
+ same_cluster_df["cluster number"] = same_cluster_numbers
433
+
434
+ diff_cluster_numbers = run_cluster_search(repo_clusters, diff_cluster_df["name"])
435
+ diff_cluster_df["cluster number"] = diff_cluster_numbers
436
+
437
+ same_cluster_tab.dataframe(same_cluster_df[display_columns])
438
+ diff_cluster_tab.dataframe(diff_cluster_df[display_columns])
439
+
440
+ else:
441
+ same_cluster_tab.error("No such useful information was extracted for this repository!")
442
+ diff_cluster_tab.error("No such useful information was extracted for this repository!")
assets/search.gif ADDED

Git LFS Details

  • SHA256: 98ca3ea97923fb15842bef8278d55e9255b36750b03f234c649f93ea06ea7842
  • Pointer size: 132 Bytes
  • Size of remote file: 6.07 MB
data/SimilarityCal_model_NO1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9146d0736261db38bb6fe6d4d6dd17797c01980be23b114af4b86a18589af632
3
+ size 102423158
data/index.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3837b4cb3f10cd0ff035201ef44ab655608b2877e5c89efc5cc63a69b666c415
3
+ size 226172318
data/index_test.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3837b4cb3f10cd0ff035201ef44ab655608b2877e5c89efc5cc63a69b666c415
3
+ size 226172318
data/kmeans_model_scibert.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b561ee3342b0b8646533e6b7ffd451234d76ce3695862fd17fad18787a3b47c
3
+ size 967215
data/pair_classifier.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class EmbeddingMLP(nn.Module):
6
+ def __init__(self, size=4):
7
+ super().__init__()
8
+ self.net = nn.Sequential(
9
+ nn.Linear(768 * size, 900 * size),
10
+ nn.BatchNorm1d(900 * size),
11
+ nn.ReLU(),
12
+ nn.Linear(900 * size, 300 * size)
13
+ )
14
+
15
+ def forward(self, data):
16
+ res = self.net(data)
17
+ return res
18
+
19
+
20
+ class PairClassifier(nn.Module):
21
+ def __init__(self, size=4):
22
+ super().__init__()
23
+ self.encoder = EmbeddingMLP(size)
24
+ self.net = nn.Sequential(
25
+ nn.Linear(300 * size * 2, 3000),
26
+ nn.ReLU(),
27
+ nn.Linear(3000, 1000),
28
+ nn.ReLU(),
29
+ nn.Linear(1000, 2),
30
+ )
31
+
32
+ def forward(self, data):
33
+ e1 = self.encoder(data[:, :768 * 4])
34
+ e2 = self.encoder(data[:, 768 * 4:])
35
+ twins = torch.cat([e1, e2], dim=1)
36
+ res = self.net(twins)
37
+ return res
data/repo_clusters.json ADDED
The diff for this file is too large to render. See raw diff
 
data/repo_clusters_test.json ADDED
The diff for this file is too large to render. See raw diff
 
data/repo_doc.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from docarray import BaseDoc
3
+ from docarray.typing import NdArray
4
+
5
+
6
+ class RepoDoc(BaseDoc):
7
+ """
8
+ The class for representing basic data structures.
9
+ """
10
+ name: str
11
+ topics: List[str]
12
+ stars: int
13
+ license: str
14
+ code_embedding: Optional[NdArray[768]]
15
+ doc_embedding: Optional[NdArray[768]]
16
+ readme_embedding: Optional[NdArray[768]]
17
+ requirement_embedding: Optional[NdArray[768]]
18
+ repository_embedding: Optional[NdArray[3072]]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ docarray
3
+ pandas
4
+ numpy
5
+ streamlit
6
+ torch
7
+ transformers
8
+ tqdm
9
+ scikit-learn
10
+ nltk
11
+ plotly
12
+ joblib