Lazyhope commited on
Commit
18247d7
·
1 Parent(s): 71a5415

Use docarray.index.InMemoryExactNNIndex instead of DocList as index

Browse files
Files changed (1) hide show
  1. app.py +62 -50
app.py CHANGED
@@ -4,25 +4,34 @@ from typing import List, Optional
4
 
5
  import pandas as pd
6
  import streamlit as st
7
- from docarray import BaseDoc, DocList
8
  from docarray.typing import TorchTensor
9
- from docarray.utils.find import find
10
  from transformers import pipeline
11
 
12
- DATASET_PATH = Path(__file__).parent.joinpath("data/index.bin")
13
 
14
 
15
  @st.cache_resource(show_spinner="Loading dataset...")
16
  def load_index():
17
  class RepoDoc(BaseDoc):
18
  name: str
19
- topics: list # TODO: List[str]
20
  stars: int
21
  license: str
22
  code_embedding: Optional[TorchTensor[768]]
23
  doc_embedding: Optional[TorchTensor[768]]
24
 
25
- return DocList[RepoDoc].load_binary(DATASET_PATH)
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  @st.cache_resource(show_spinner="Loading RepoSim pipeline...")
@@ -31,7 +40,6 @@ def load_model():
31
  model="Lazyhope/RepoSim",
32
  trust_remote_code=True,
33
  device_map="auto",
34
- use_auth_token=st.secrets.hf_token, # TODO: delete this line when the pipeline is public
35
  )
36
 
37
 
@@ -52,8 +60,8 @@ def run_model(_model, repo_name, github_token):
52
 
53
 
54
  def run_search(index, query, search_field, limit):
55
- top_matches, scores = find(
56
- index=index, query=query, search_field=search_field, limit=limit
57
  )
58
 
59
  search_results = top_matches.to_dataframe()
@@ -62,7 +70,7 @@ def run_search(index, query, search_field, limit):
62
  return search_results
63
 
64
 
65
- index = load_index()
66
  model = load_model()
67
 
68
  with st.sidebar:
@@ -81,12 +89,14 @@ with st.sidebar:
81
  value=10,
82
  step=1,
83
  key="search_results_limit",
 
84
  )
85
 
86
  st.multiselect(
87
  label="Display columns",
88
  options=["scores", "name", "topics", "stars", "license"],
89
- default=["scores", "name", "topics"],
 
90
  key="display_columns",
91
  )
92
 
@@ -107,7 +117,7 @@ st.checkbox(
107
  label="Add/Update this repo to the index",
108
  value=False,
109
  key="update_index",
110
- help="Update index by generating embeddings for the latest version of this repo",
111
  )
112
 
113
 
@@ -117,55 +127,57 @@ if search:
117
  if match_res is not None:
118
  repo_name = f"{match_res.group('owner')}/{match_res.group('repo')}"
119
 
120
- doc_index = -1
121
- update_index = st.session_state.update_index
122
- try:
123
- doc_index = index.name.index(repo_name)
124
- assert update_index is False
125
-
126
- repo_doc = index[doc_index]
127
- except (ValueError, AssertionError):
128
  repo_info = run_model(model, repo_name, st.session_state.github_token)
129
  if repo_info is None:
130
  st.error("Repo not found or invalid GitHub token!")
131
  st.stop()
132
 
133
- repo_doc = index.doc_type(
134
- name=repo_info["name"],
135
- topics=repo_info["topics"],
136
- stars=repo_info["stars"],
137
- license=repo_info["license"],
138
- code_embedding=repo_info["mean_code_embedding"],
139
- doc_embedding=repo_info["mean_doc_embedding"],
140
- )
141
-
142
- if update_index:
143
- if not repo_doc.license:
144
- st.warning("License is missing in this repo!")
145
-
146
- if doc_index == -1:
147
- index.append(repo_doc)
148
- st.success("Repo added to the index!")
 
 
 
 
 
 
 
149
  else:
150
- index[doc_index] = repo_doc
151
  st.success("Repo updated in the index!")
152
 
153
- st.session_state["query"] = repo_doc
 
 
154
  else:
155
  st.error("Invalid input!")
156
 
157
- if "query" in st.session_state:
158
- query = st.session_state.query
159
-
160
  limit = st.session_state.search_results_limit
161
  st.dataframe(
162
  pd.DataFrame(
163
  [
164
  {
165
- "name": query.name,
166
- "topics": query.topics,
167
- "stars": query.stars,
168
- "license": query.license,
169
  }
170
  ],
171
  )
@@ -174,14 +186,14 @@ if "query" in st.session_state:
174
  display_columns = st.session_state.display_columns
175
  code_sim_tab, doc_sim_tab = st.tabs(["Code Similarity", "Docstring Similarity"])
176
 
177
- if query.code_embedding is not None:
178
- code_sim_res = run_search(index, query, "code_embedding", limit)
179
  code_sim_tab.dataframe(code_sim_res[display_columns])
180
  else:
181
- code_sim_tab.error("No code was extracted for this repo!")
182
 
183
- if query.doc_embedding is not None:
184
- doc_sim_res = run_search(index, query, "doc_embedding", limit)
185
  doc_sim_tab.dataframe(doc_sim_res[display_columns])
186
  else:
187
- doc_sim_tab.error("No docstring was extracted for this repo!")
 
4
 
5
  import pandas as pd
6
  import streamlit as st
7
+ from docarray import BaseDoc
8
  from docarray.typing import TorchTensor
9
+ from docarray.index import InMemoryExactNNIndex
10
  from transformers import pipeline
11
 
12
+ INDEX_PATH = Path(__file__).parent.joinpath("data/index.bin")
13
 
14
 
15
  @st.cache_resource(show_spinner="Loading dataset...")
16
  def load_index():
17
  class RepoDoc(BaseDoc):
18
  name: str
19
+ topics: list # List[str]
20
  stars: int
21
  license: str
22
  code_embedding: Optional[TorchTensor[768]]
23
  doc_embedding: Optional[TorchTensor[768]]
24
 
25
+ default_doc = RepoDoc(
26
+ name="",
27
+ topics=[],
28
+ stars=0,
29
+ license="",
30
+ code_embedding=None,
31
+ doc_embedding=None,
32
+ )
33
+
34
+ return InMemoryExactNNIndex[RepoDoc](index_file_path=INDEX_PATH), default_doc
35
 
36
 
37
  @st.cache_resource(show_spinner="Loading RepoSim pipeline...")
 
40
  model="Lazyhope/RepoSim",
41
  trust_remote_code=True,
42
  device_map="auto",
 
43
  )
44
 
45
 
 
60
 
61
 
62
  def run_search(index, query, search_field, limit):
63
+ top_matches, scores = index.find(
64
+ query=query, search_field=search_field, limit=limit
65
  )
66
 
67
  search_results = top_matches.to_dataframe()
 
70
  return search_results
71
 
72
 
73
+ index, default_doc = load_index()
74
  model = load_model()
75
 
76
  with st.sidebar:
 
89
  value=10,
90
  step=1,
91
  key="search_results_limit",
92
+ help="Limit the number of search results",
93
  )
94
 
95
  st.multiselect(
96
  label="Display columns",
97
  options=["scores", "name", "topics", "stars", "license"],
98
+ default=["scores", "name", "topics", "stars", "license"],
99
+ help="Select columns to display in the search results",
100
  key="display_columns",
101
  )
102
 
 
117
  label="Add/Update this repo to the index",
118
  value=False,
119
  key="update_index",
120
+ help="Encode the latest version of this repo and add/update it to the index",
121
  )
122
 
123
 
 
127
  if match_res is not None:
128
  repo_name = f"{match_res.group('owner')}/{match_res.group('repo')}"
129
 
130
+ records = index.filter({"name": {"$eq": repo_name}})
131
+ query_doc = default_doc.copy() if not records else records[0]
132
+ if st.session_state.update_index or not records:
 
 
 
 
 
133
  repo_info = run_model(model, repo_name, st.session_state.github_token)
134
  if repo_info is None:
135
  st.error("Repo not found or invalid GitHub token!")
136
  st.stop()
137
 
138
+ # Update document inplace
139
+ query_doc.name = repo_info["name"]
140
+ query_doc.topics = repo_info["topics"]
141
+ query_doc.stars = repo_info["stars"]
142
+ query_doc.license = repo_info["license"]
143
+ query_doc.code_embedding = repo_info["mean_code_embedding"]
144
+ query_doc.doc_embedding = repo_info["mean_doc_embedding"]
145
+
146
+ if st.session_state.update_index:
147
+ if not records:
148
+ if not query_doc.license:
149
+ st.warning(
150
+ "License is missing in this repo and will not be persisted!"
151
+ )
152
+ elif (
153
+ query_doc.code_embedding is None and query_doc.doc_embedding is None
154
+ ):
155
+ st.warning(
156
+ "This repo has no function code or docstring extracted and will not be persisted!"
157
+ )
158
+ else:
159
+ index.index(query_doc)
160
+ st.success("Repo added to the index!")
161
  else:
 
162
  st.success("Repo updated in the index!")
163
 
164
+ index.persist(file=INDEX_PATH)
165
+
166
+ st.session_state["query_doc"] = query_doc
167
  else:
168
  st.error("Invalid input!")
169
 
170
+ if "query_doc" in st.session_state:
171
+ query_doc = st.session_state.query_doc
 
172
  limit = st.session_state.search_results_limit
173
  st.dataframe(
174
  pd.DataFrame(
175
  [
176
  {
177
+ "name": query_doc.name,
178
+ "topics": query_doc.topics,
179
+ "stars": query_doc.stars,
180
+ "license": query_doc.license,
181
  }
182
  ],
183
  )
 
186
  display_columns = st.session_state.display_columns
187
  code_sim_tab, doc_sim_tab = st.tabs(["Code Similarity", "Docstring Similarity"])
188
 
189
+ if query_doc.code_embedding is not None:
190
+ code_sim_res = run_search(index, query_doc, "code_embedding", limit)
191
  code_sim_tab.dataframe(code_sim_res[display_columns])
192
  else:
193
+ code_sim_tab.error("No function code was extracted for this repo!")
194
 
195
+ if query_doc.doc_embedding is not None:
196
+ doc_sim_res = run_search(index, query_doc, "doc_embedding", limit)
197
  doc_sim_tab.dataframe(doc_sim_res[display_columns])
198
  else:
199
+ doc_sim_tab.error("No function docstring was extracted for this repo!")