Lazyhope commited on
Commit
71a5415
·
1 Parent(s): a0e0d2b

Upload the first complete version of RepoSnipy

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +187 -0
  3. requirements.txt +4 -0
.gitignore CHANGED
@@ -158,3 +158,6 @@ cython_debug/
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
 
 
 
 
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
+
162
+ # Streamlit configs
163
+ .streamlit/
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from docarray import BaseDoc, DocList
8
+ from docarray.typing import TorchTensor
9
+ from docarray.utils.find import find
10
+ from transformers import pipeline
11
+
12
+ DATASET_PATH = Path(__file__).parent.joinpath("data/index.bin")
13
+
14
+
15
+ @st.cache_resource(show_spinner="Loading dataset...")
16
+ def load_index():
17
+ class RepoDoc(BaseDoc):
18
+ name: str
19
+ topics: list # TODO: List[str]
20
+ stars: int
21
+ license: str
22
+ code_embedding: Optional[TorchTensor[768]]
23
+ doc_embedding: Optional[TorchTensor[768]]
24
+
25
+ return DocList[RepoDoc].load_binary(DATASET_PATH)
26
+
27
+
28
+ @st.cache_resource(show_spinner="Loading RepoSim pipeline...")
29
+ def load_model():
30
+ return pipeline(
31
+ model="Lazyhope/RepoSim",
32
+ trust_remote_code=True,
33
+ device_map="auto",
34
+ use_auth_token=st.secrets.hf_token, # TODO: delete this line when the pipeline is public
35
+ )
36
+
37
+
38
+ @st.cache_data(show_spinner=False)
39
+ def run_model(_model, repo_name, github_token):
40
+ with st.spinner(
41
+ f"Downloading and extracting the {repo_name}, this may take a while..."
42
+ ):
43
+ extracted_infos = _model.preprocess(repo_name, github_token=github_token)
44
+
45
+ if not extracted_infos:
46
+ return None
47
+
48
+ with st.spinner(f"Generating embeddings for {repo_name}..."):
49
+ repo_info = _model.forward(extracted_infos, st_progress=st.progress(0.0))[0]
50
+
51
+ return repo_info
52
+
53
+
54
+ def run_search(index, query, search_field, limit):
55
+ top_matches, scores = find(
56
+ index=index, query=query, search_field=search_field, limit=limit
57
+ )
58
+
59
+ search_results = top_matches.to_dataframe()
60
+ search_results["scores"] = scores
61
+
62
+ return search_results
63
+
64
+
65
+ index = load_index()
66
+ model = load_model()
67
+
68
+ with st.sidebar:
69
+ st.text_input(
70
+ label="GitHub Token",
71
+ key="github_token",
72
+ type="password",
73
+ placeholder="Paste your GitHub token here",
74
+ help="Consider setting GitHub token to avoid hitting rate limits: https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token",
75
+ )
76
+
77
+ st.slider(
78
+ label="Search results limit",
79
+ min_value=1,
80
+ max_value=100,
81
+ value=10,
82
+ step=1,
83
+ key="search_results_limit",
84
+ )
85
+
86
+ st.multiselect(
87
+ label="Display columns",
88
+ options=["scores", "name", "topics", "stars", "license"],
89
+ default=["scores", "name", "topics"],
90
+ key="display_columns",
91
+ )
92
+
93
+
94
+ repo_regex = r"^((git@|http(s)?://)?(github\.com)(/|:))?(?P<owner>[\w.-]+)(/)(?P<repo>[\w.-]+?)(\.git)?(/)?$"
95
+
96
+ st.title("RepoSnipy")
97
+
98
+ st.text_input(
99
+ "Enter a GitHub repo URL or owner/repo (case-sensitive):",
100
+ value="",
101
+ max_chars=200,
102
+ placeholder="huggingface/transformers",
103
+ key="repo_input",
104
+ )
105
+
106
+ st.checkbox(
107
+ label="Add/Update this repo to the index",
108
+ value=False,
109
+ key="update_index",
110
+ help="Update index by generating embeddings for the latest version of this repo",
111
+ )
112
+
113
+
114
+ search = st.button("Search")
115
+ if search:
116
+ match_res = re.match(repo_regex, st.session_state.repo_input)
117
+ if match_res is not None:
118
+ repo_name = f"{match_res.group('owner')}/{match_res.group('repo')}"
119
+
120
+ doc_index = -1
121
+ update_index = st.session_state.update_index
122
+ try:
123
+ doc_index = index.name.index(repo_name)
124
+ assert update_index is False
125
+
126
+ repo_doc = index[doc_index]
127
+ except (ValueError, AssertionError):
128
+ repo_info = run_model(model, repo_name, st.session_state.github_token)
129
+ if repo_info is None:
130
+ st.error("Repo not found or invalid GitHub token!")
131
+ st.stop()
132
+
133
+ repo_doc = index.doc_type(
134
+ name=repo_info["name"],
135
+ topics=repo_info["topics"],
136
+ stars=repo_info["stars"],
137
+ license=repo_info["license"],
138
+ code_embedding=repo_info["mean_code_embedding"],
139
+ doc_embedding=repo_info["mean_doc_embedding"],
140
+ )
141
+
142
+ if update_index:
143
+ if not repo_doc.license:
144
+ st.warning("License is missing in this repo!")
145
+
146
+ if doc_index == -1:
147
+ index.append(repo_doc)
148
+ st.success("Repo added to the index!")
149
+ else:
150
+ index[doc_index] = repo_doc
151
+ st.success("Repo updated in the index!")
152
+
153
+ st.session_state["query"] = repo_doc
154
+ else:
155
+ st.error("Invalid input!")
156
+
157
+ if "query" in st.session_state:
158
+ query = st.session_state.query
159
+
160
+ limit = st.session_state.search_results_limit
161
+ st.dataframe(
162
+ pd.DataFrame(
163
+ [
164
+ {
165
+ "name": query.name,
166
+ "topics": query.topics,
167
+ "stars": query.stars,
168
+ "license": query.license,
169
+ }
170
+ ],
171
+ )
172
+ )
173
+
174
+ display_columns = st.session_state.display_columns
175
+ code_sim_tab, doc_sim_tab = st.tabs(["Code Similarity", "Docstring Similarity"])
176
+
177
+ if query.code_embedding is not None:
178
+ code_sim_res = run_search(index, query, "code_embedding", limit)
179
+ code_sim_tab.dataframe(code_sim_res[display_columns])
180
+ else:
181
+ code_sim_tab.error("No code was extracted for this repo!")
182
+
183
+ if query.doc_embedding is not None:
184
+ doc_sim_res = run_search(index, query, "doc_embedding", limit)
185
+ doc_sim_tab.dataframe(doc_sim_res[display_columns])
186
+ else:
187
+ doc_sim_tab.error("No docstring was extracted for this repo!")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate
2
+ docarray
3
+ torch
4
+ transformers