|
import re |
|
import json |
|
import nltk |
|
import joblib |
|
import torch |
|
import pandas as pd |
|
import numpy as np |
|
import streamlit as st |
|
from pathlib import Path |
|
from torch import nn |
|
from docarray import DocList |
|
from docarray.index import InMemoryExactNNIndex |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModel |
|
from data.repo_doc import RepoDoc |
|
from data.pair_classifier import PairClassifier |
|
from nltk.stem import WordNetLemmatizer |
|
|
|
nltk.download("wordnet") |
|
KMEANS_MODEL_PATH = Path(__file__).parent.joinpath("data/kmeans_model_scibert.pkl") |
|
SIMILARITY_CAL_MODEL_PATH = Path(__file__).parent.joinpath("data/SimilarityCal_model_NO1.pt") |
|
device = ( |
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "mps" |
|
if torch.backends.mps.is_available() |
|
else "cpu" |
|
) |
|
|
|
|
|
|
|
|
|
SCIBERT_MODEL_PATH = "allenai/scibert_scivocab_uncased" |
|
|
|
|
|
|
|
INDEX_PATH = Path(__file__).parent.joinpath("data/index_test.bin") |
|
CLUSTER_PATH = Path(__file__).parent.joinpath("data/repo_clusters_test.json") |
|
|
|
|
|
|
|
@st.cache_resource(show_spinner="Loading repositories basic information...") |
|
def load_index(): |
|
""" |
|
The function to load the index file and return a RepoDoc object with default value |
|
:return: index and a RepoDoc object with default value |
|
""" |
|
default_doc = RepoDoc( |
|
name="", |
|
topics=[], |
|
stars=0, |
|
license="", |
|
code_embedding=None, |
|
doc_embedding=None, |
|
readme_embedding=None, |
|
requirement_embedding=None, |
|
repository_embedding=None |
|
) |
|
|
|
return InMemoryExactNNIndex[RepoDoc](index_file_path=INDEX_PATH), default_doc |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading repositories clusters...") |
|
def load_repo_clusters(): |
|
""" |
|
The function to load the repo-clusters file |
|
:return: a dictionary with the repo-clusters |
|
""" |
|
with open(CLUSTER_PATH, "r") as file: |
|
repo_clusters = json.load(file) |
|
|
|
return repo_clusters |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading RepoSim4Py pipeline model...") |
|
def load_pipeline_model(): |
|
""" |
|
The function to load RepoSim4Py pipeline model |
|
:return: a HuggingFace pipeline |
|
""" |
|
|
|
model_path = "Henry65/RepoSim4Py" |
|
|
|
|
|
|
|
|
|
return pipeline( |
|
model=model_path, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading SciBERT model...") |
|
def load_scibert_model(): |
|
""" |
|
The function to load SciBERT model |
|
:return: tokenizer and model |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained(SCIBERT_MODEL_PATH) |
|
scibert_model = AutoModel.from_pretrained(SCIBERT_MODEL_PATH).to(device) |
|
return tokenizer, scibert_model |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading KMeans model...") |
|
def load_kmeans_model(): |
|
""" |
|
The function to load KMeans model |
|
:return: a KMeans model |
|
""" |
|
return joblib.load(KMEANS_MODEL_PATH) |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading SimilarityCal model...") |
|
def load_similaritycal_model(): |
|
sim_cal_model = PairClassifier() |
|
sim_cal_model.load_state_dict(torch.load(SIMILARITY_CAL_MODEL_PATH)) |
|
sim_cal_model = sim_cal_model.to(device) |
|
sim_cal_model = sim_cal_model.eval() |
|
return sim_cal_model |
|
|
|
|
|
def generate_scibert_embedding(tokenizer, scibert_model, text): |
|
""" |
|
The function for generating SciBERT embeddings based on topic text |
|
:param text: the topic text |
|
:return: topic embeddings |
|
""" |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) |
|
outputs = scibert_model(**inputs) |
|
|
|
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy() |
|
return embeddings |
|
|
|
|
|
@st.cache_data(show_spinner=False) |
|
def run_pipeline_model(_model, repo_name, github_token): |
|
""" |
|
The function to generate repo_info by using pipeline model |
|
:param _model: pipeline |
|
:param repo_name: the name of repository |
|
:param github_token: GitHub token |
|
:return: the information generated by the pipeline |
|
""" |
|
with st.spinner( |
|
f"Downloading and extracting the {repo_name}, this may take a while..." |
|
): |
|
extracted_infos = _model.preprocess(repo_name, github_token=github_token) |
|
|
|
if not extracted_infos: |
|
return None |
|
|
|
with st.spinner(f"Generating embeddings for {repo_name}..."): |
|
repo_info = _model.forward(extracted_infos)[0] |
|
|
|
return repo_info |
|
|
|
|
|
def run_index_search(index, query, search_field, limit): |
|
""" |
|
The function to search at index file based on query and limit |
|
:param index: the index |
|
:param query: query |
|
:param search_field: which field to search for |
|
:param limit: page limit |
|
:return: a dataframe with search results |
|
""" |
|
top_matches, scores = index.find( |
|
query=query, search_field=search_field, limit=limit |
|
) |
|
|
|
search_results = top_matches.to_dataframe() |
|
search_results["scores"] = scores |
|
|
|
return search_results |
|
|
|
|
|
def run_cluster_search(repo_clusters, repo_name_list): |
|
""" |
|
The function to search cluster number for such repositories. |
|
:param repo_clusters: dictionary with repo-clusters |
|
:param repo_name_list: list or array represent repository names |
|
:return: cluster number list |
|
""" |
|
clusters = [] |
|
for repo_name in repo_name_list: |
|
clusters.append(repo_clusters[repo_name]) |
|
return clusters |
|
|
|
|
|
def run_similaritycal_search(index, repo_clusters, model, query_doc, query_cluster_number, limit, same_cluster=True): |
|
""" |
|
The function to run SimilarityCal model. |
|
:param index: index file |
|
:param repo_clusters: repo-clusters json file |
|
:param model: SimilarityCal model |
|
:param query_doc: query repo doc |
|
:param query_cluster_number: query repo cluster number |
|
:param limit: limit |
|
:param same_cluster: whether searching for same cluster |
|
:return: result dataframe |
|
""" |
|
docs = index._docs |
|
input_embeddings_list = [] |
|
result_dl = DocList[RepoDoc]() |
|
for doc in docs: |
|
if same_cluster and query_cluster_number != repo_clusters[doc.name]: |
|
continue |
|
if doc.name != query_doc.name: |
|
e1, e2 = (torch.Tensor(query_doc.repository_embedding), |
|
torch.Tensor(doc.repository_embedding)) |
|
input_embeddings = torch.cat([e1, e2]) |
|
input_embeddings_list.append(input_embeddings) |
|
result_dl.append(doc) |
|
|
|
input_embeddings_list = torch.stack(input_embeddings_list).to(device) |
|
softmax = nn.Softmax(dim=1).to(device) |
|
model_output = model(input_embeddings_list) |
|
similarity_scores = softmax(model_output)[:, 1].cpu().detach().numpy() |
|
df = result_dl.to_dataframe() |
|
df["scores"] = similarity_scores |
|
return df.sort_values(by='scores', ascending=False).reset_index(drop=True).head(limit) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
index, default_doc = load_index() |
|
repo_clusters = load_repo_clusters() |
|
pipeline_model = load_pipeline_model() |
|
lemmatizer = WordNetLemmatizer() |
|
tokenizer, scibert_model = load_scibert_model() |
|
kmeans = load_kmeans_model() |
|
sim_cal_model = load_similaritycal_model() |
|
|
|
|
|
with st.sidebar: |
|
st.text_input( |
|
label="GitHub Token", |
|
key="github_token", |
|
type="password", |
|
placeholder="Paste your GitHub token here", |
|
help="Consider setting GitHub token to avoid hitting rate limits: https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token", |
|
) |
|
|
|
st.slider( |
|
label="Search results limit", |
|
min_value=1, |
|
max_value=100, |
|
value=10, |
|
step=1, |
|
key="search_results_limit", |
|
help="Limit the number of search results", |
|
) |
|
|
|
st.multiselect( |
|
label="Display columns", |
|
options=["scores", "name", "topics", "cluster number", "stars", "license"], |
|
default=["scores", "name", "topics", "cluster number", "stars", "license"], |
|
help="Select columns to display in the search results", |
|
key="display_columns", |
|
) |
|
|
|
|
|
st.title("RepoSnipy") |
|
|
|
st.text_input( |
|
"Enter a GitHub repository URL or owner/repository (case-sensitive):", |
|
value="", |
|
max_chars=200, |
|
placeholder="numpy/numpy", |
|
key="repo_input", |
|
) |
|
|
|
st.checkbox( |
|
label="Add/Update this repository to the index", |
|
value=False, |
|
key="update_index", |
|
help="Encode the latest version of this repository and add/update it to the index", |
|
) |
|
|
|
|
|
search = st.button("Search") |
|
|
|
repo_regex = r"^((git@|http(s)?://)?(github\.com)(/|:))?(?P<owner>[\w.-]+)(/)(?P<repo>[\w.-]+?)(\.git)?(/)?$" |
|
|
|
if search: |
|
match_res = re.match(repo_regex, st.session_state.repo_input) |
|
|
|
if match_res is not None: |
|
repo_name = f"{match_res.group('owner')}/{match_res.group('repo')}" |
|
records = index.filter({"name": {"$eq": repo_name}}) |
|
|
|
query_doc = default_doc.copy() if not records else records[0] |
|
|
|
cluster_number = -1 if not records else repo_clusters[repo_name] |
|
|
|
|
|
if st.session_state.update_index or not records: |
|
|
|
repo_info = run_pipeline_model(pipeline_model, repo_name, st.session_state.github_token) |
|
if repo_info is None: |
|
st.error("Repository not found or invalid GitHub token!") |
|
st.stop() |
|
|
|
query_doc.name = repo_info["name"] |
|
query_doc.topics = repo_info["topics"] |
|
query_doc.stars = repo_info["stars"] |
|
query_doc.license = repo_info["license"] |
|
query_doc.code_embedding = None if np.all(repo_info["mean_code_embedding"] == 0) else repo_info[ |
|
"mean_code_embedding"].reshape(-1) |
|
query_doc.doc_embedding = None if np.all(repo_info["mean_doc_embedding"] == 0) else repo_info[ |
|
"mean_doc_embedding"].reshape(-1) |
|
query_doc.readme_embedding = None if np.all(repo_info["mean_readme_embedding"] == 0) else repo_info[ |
|
"mean_readme_embedding"].reshape(-1) |
|
query_doc.requirement_embedding = None if np.all(repo_info["mean_requirement_embedding"] == 0) else \ |
|
repo_info["mean_requirement_embedding"].reshape(-1) |
|
query_doc.repository_embedding = None if np.all(repo_info["mean_repo_embedding"] == 0) else repo_info[ |
|
"mean_repo_embedding"].reshape(-1) |
|
|
|
|
|
topics_text = ' '.join( |
|
[lemmatizer.lemmatize(topic.lower().replace('-', ' ')) for topic in query_doc.topics]) |
|
topic_embeddings = generate_scibert_embedding(tokenizer, scibert_model, topics_text) |
|
cluster_number = int(kmeans.predict(topic_embeddings)[0]) |
|
|
|
|
|
if st.session_state.update_index: |
|
if not query_doc.license: |
|
st.warning( |
|
"License is missing in this repository and will not be persisted!" |
|
) |
|
elif (query_doc.code_embedding is None) and (query_doc.doc_embedding is None) and ( |
|
query_doc.requirement_embedding is None) and (query_doc.readme_embedding is None) and ( |
|
query_doc.repository_embedding is None): |
|
st.warning( |
|
"This repository has no such useful information (code, docstring, readme and requirement) extracted and will not be persisted!" |
|
) |
|
else: |
|
index.index(query_doc) |
|
repo_clusters[query_doc.name] = cluster_number |
|
|
|
with st.spinner("Persisting the index and repository clusters..."): |
|
index.persist(str(INDEX_PATH)) |
|
with open(CLUSTER_PATH, "w") as file: |
|
json.dump(repo_clusters, file, indent=4) |
|
st.success("Repository updated to the index!") |
|
|
|
load_index.clear() |
|
load_repo_clusters.clear() |
|
|
|
st.session_state["query_doc"] = query_doc |
|
st.session_state["cluster_number"] = cluster_number |
|
|
|
|
|
else: |
|
st.error("Invalid input!") |
|
|
|
|
|
if "query_doc" in st.session_state: |
|
query_doc = st.session_state.query_doc |
|
cluster_number = st.session_state.cluster_number |
|
limit = st.session_state.search_results_limit |
|
|
|
|
|
st.dataframe( |
|
pd.DataFrame( |
|
[ |
|
{ |
|
"name": query_doc.name, |
|
"topics": query_doc.topics, |
|
"cluster number": cluster_number, |
|
"stars": query_doc.stars, |
|
"license": query_doc.license, |
|
} |
|
], |
|
) |
|
) |
|
|
|
display_columns = st.session_state.display_columns |
|
code_sim_tab, doc_sim_tab, readme_sim_tab, requirement_sim_tab, repo_sim_tab, same_cluster_tab, diff_cluster_tab = st.tabs( |
|
["Code_sim", "Docstring_sim", "Readme_sim", "Requirement_sim", |
|
"Repository_sim", "Same_cluster", "Different_cluster"]) |
|
|
|
if query_doc.code_embedding is not None: |
|
code_sim_res = run_index_search(index, query_doc, "code_embedding", limit) |
|
cluster_numbers = run_cluster_search(repo_clusters, code_sim_res["name"]) |
|
code_sim_res["cluster number"] = cluster_numbers |
|
code_sim_tab.dataframe(code_sim_res[display_columns]) |
|
else: |
|
code_sim_tab.error("No function code was extracted for this repository!") |
|
|
|
if query_doc.doc_embedding is not None: |
|
doc_sim_res = run_index_search(index, query_doc, "doc_embedding", limit) |
|
cluster_numbers = run_cluster_search(repo_clusters, doc_sim_res["name"]) |
|
doc_sim_res["cluster number"] = cluster_numbers |
|
doc_sim_tab.dataframe(doc_sim_res[display_columns]) |
|
else: |
|
doc_sim_tab.error("No function docstring was extracted for this repository!") |
|
|
|
if query_doc.readme_embedding is not None: |
|
readme_sim_res = run_index_search(index, query_doc, "readme_embedding", limit) |
|
cluster_numbers = run_cluster_search(repo_clusters, readme_sim_res["name"]) |
|
readme_sim_res["cluster number"] = cluster_numbers |
|
readme_sim_tab.dataframe(readme_sim_res[display_columns]) |
|
else: |
|
readme_sim_tab.error("No readme file was extracted for this repository!") |
|
|
|
if query_doc.requirement_embedding is not None: |
|
requirement_sim_res = run_index_search(index, query_doc, "requirement_embedding", limit) |
|
cluster_numbers = run_cluster_search(repo_clusters, requirement_sim_res["name"]) |
|
requirement_sim_res["cluster number"] = cluster_numbers |
|
requirement_sim_tab.dataframe(requirement_sim_res[display_columns]) |
|
else: |
|
requirement_sim_tab.error("No requirement file was extracted for this repository!") |
|
|
|
if query_doc.repository_embedding is not None: |
|
repo_sim_res = run_index_search(index, query_doc, "repository_embedding", limit) |
|
cluster_numbers = run_cluster_search(repo_clusters, repo_sim_res["name"]) |
|
repo_sim_res["cluster number"] = cluster_numbers |
|
repo_sim_tab.dataframe(repo_sim_res[display_columns]) |
|
else: |
|
repo_sim_tab.error("No such useful information was extracted for this repository!") |
|
|
|
if cluster_number is not None and query_doc.repository_embedding is not None: |
|
same_cluster_df = run_similaritycal_search(index, repo_clusters, sim_cal_model, |
|
query_doc, cluster_number, limit, |
|
same_cluster=True) |
|
diff_cluster_df = run_similaritycal_search(index, repo_clusters, sim_cal_model, |
|
query_doc, cluster_number, limit, |
|
same_cluster=False) |
|
same_cluster_numbers = run_cluster_search(repo_clusters, same_cluster_df["name"]) |
|
same_cluster_df["cluster number"] = same_cluster_numbers |
|
|
|
diff_cluster_numbers = run_cluster_search(repo_clusters, diff_cluster_df["name"]) |
|
diff_cluster_df["cluster number"] = diff_cluster_numbers |
|
|
|
same_cluster_tab.dataframe(same_cluster_df[display_columns]) |
|
diff_cluster_tab.dataframe(diff_cluster_df[display_columns]) |
|
|
|
else: |
|
same_cluster_tab.error("No such useful information was extracted for this repository!") |
|
diff_cluster_tab.error("No such useful information was extracted for this repository!") |
|
|