Spaces:
Running
Running
Germano Cavalcante
commited on
Commit
·
f92bafd
1
Parent(s):
17b9710
Find Related: Support to closed issues
Browse files- routers/__init__.py +1 -0
- routers/tool_find_related.py +113 -151
- routers/tool_find_related_cache.pkl +2 -2
routers/__init__.py
CHANGED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# __init__.py
|
routers/tool_find_related.py
CHANGED
@@ -5,7 +5,9 @@ import pickle
|
|
5 |
import re
|
6 |
import torch
|
7 |
import threading
|
|
|
8 |
from datetime import datetime, timedelta
|
|
|
9 |
from sentence_transformers import SentenceTransformer, util
|
10 |
from fastapi import APIRouter
|
11 |
|
@@ -53,6 +55,7 @@ class EmbeddingContext:
|
|
53 |
# These don't change
|
54 |
TOKEN_LEN_MAX_FOR_EMBEDDING = 512
|
55 |
TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
|
|
|
56 |
issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'}
|
57 |
cache_path = "routers/tool_find_related_cache.pkl"
|
58 |
|
@@ -62,6 +65,9 @@ class EmbeddingContext:
|
|
62 |
openai_client = None
|
63 |
model_name = ''
|
64 |
config_type = ''
|
|
|
|
|
|
|
65 |
|
66 |
# Updates constantly
|
67 |
data = {}
|
@@ -102,6 +108,11 @@ class EmbeddingContext:
|
|
102 |
self.model_name = model_name
|
103 |
self.config_type = config_type
|
104 |
|
|
|
|
|
|
|
|
|
|
|
105 |
def encode(self, texts_to_embed):
|
106 |
pass
|
107 |
|
@@ -171,6 +182,40 @@ class EmbeddingContext:
|
|
171 |
|
172 |
return texts_to_embed
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
def embeddings_generate(self, repo):
|
175 |
if os.path.exists(self.cache_path):
|
176 |
with open(self.cache_path, 'rb') as file:
|
@@ -183,24 +228,31 @@ class EmbeddingContext:
|
|
183 |
|
184 |
black_list = self.black_list[repo]
|
185 |
|
186 |
-
issues = gitea_fetch_issues('blender', repo, state='
|
187 |
issue_attr_filter=self.issue_attr_filter, exclude=black_list)
|
188 |
|
189 |
-
issues = sorted(issues, key=lambda issue: int(issue['number']))
|
190 |
|
191 |
print("Embedding Issues...")
|
192 |
texts_to_embed = self.create_strings_to_embbed(issues, black_list)
|
193 |
embeddings = self.encode(texts_to_embed)
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
def embeddings_updated_get(self, repo):
|
206 |
with self.lock:
|
@@ -230,97 +282,19 @@ class EmbeddingContext:
|
|
230 |
# Consider that if the time hasn't changed, it's the same issue.
|
231 |
issues = [issue for issue in issues if issue['updated_at'] != date_old]
|
232 |
|
233 |
-
|
234 |
-
titles_old = data['titles']
|
235 |
-
embeddings_old = data['embeddings']
|
236 |
-
|
237 |
-
last_index = len(numbers_old) - 1
|
238 |
-
|
239 |
-
issues = sorted(issues, key=lambda issue: int(issue['number']))
|
240 |
-
issues_clos = [issue for issue in issues if issue['state'] == 'closed']
|
241 |
-
issues_open = [issue for issue in issues if issue['state'] == 'open']
|
242 |
-
|
243 |
-
numbers_clos = [int(issue['number']) for issue in issues_clos]
|
244 |
-
numbers_open = [int(issue['number']) for issue in issues_open]
|
245 |
-
|
246 |
-
old_closed = []
|
247 |
-
for number_clos in numbers_clos:
|
248 |
-
for i_old in range(last_index, -1, -1):
|
249 |
-
number_old = numbers_old[i_old]
|
250 |
-
if number_old < number_clos:
|
251 |
-
break
|
252 |
-
if number_old == number_clos:
|
253 |
-
old_closed.append(i_old)
|
254 |
-
break
|
255 |
-
|
256 |
-
if not old_closed and not issues_open:
|
257 |
-
return data
|
258 |
-
|
259 |
-
mask_open = torch.ones(len(numbers_open), dtype=torch.bool)
|
260 |
-
need_sort = False
|
261 |
-
change_map = []
|
262 |
-
for i_open, number_open in enumerate(numbers_open):
|
263 |
-
for i_old in range(last_index, -1, -1):
|
264 |
-
number_old = numbers_old[i_old]
|
265 |
-
if number_old < number_open:
|
266 |
-
need_sort = need_sort or (i_old != last_index)
|
267 |
-
break
|
268 |
-
if number_old == number_open:
|
269 |
-
change_map.append((i_old, i_open))
|
270 |
-
mask_open[i_open] = False
|
271 |
-
break
|
272 |
-
|
273 |
-
if issues_open:
|
274 |
-
texts_to_embed = self.create_strings_to_embbed(issues_open, black_list)
|
275 |
-
embeddings = self.encode(texts_to_embed)
|
276 |
-
|
277 |
-
for i_old, i_open in change_map:
|
278 |
-
titles_old[i_old] = issues_open[i_open]['title']
|
279 |
-
embeddings_old[i_old] = embeddings[i_open]
|
280 |
-
|
281 |
-
if old_closed:
|
282 |
-
total = (len(numbers_old) - len(old_closed)) + (len(numbers_open) - len(change_map))
|
283 |
-
numbers_new = [None] * total
|
284 |
-
titles_new = [None] * total
|
285 |
-
embeddings_new = torch.empty((total, *embeddings_old.shape[1:]), dtype=embeddings_old.dtype, device=embeddings_old.device)
|
286 |
-
|
287 |
-
i_new = 0
|
288 |
-
i_old = 0
|
289 |
-
for i_closed in old_closed + [len(numbers_old)]:
|
290 |
-
while i_old < i_closed:
|
291 |
-
numbers_new[i_new] = numbers_old[i_old]
|
292 |
-
titles_new[i_new] = titles_old[i_old]
|
293 |
-
embeddings_new[i_new] = embeddings_old[i_old]
|
294 |
-
i_new += 1
|
295 |
-
i_old += 1
|
296 |
-
i_old += 1
|
297 |
-
|
298 |
-
for i_open in range(len(numbers_open)):
|
299 |
-
if not mask_open[i_open]:
|
300 |
-
continue
|
301 |
-
titles_new[i_new] = issues_open[i_open]['title']
|
302 |
-
numbers_new[i_new] = numbers_open[i_open]
|
303 |
-
embeddings_new[i_new] = embeddings[i_open]
|
304 |
-
i_new += 1
|
305 |
-
|
306 |
-
assert i_new == total
|
307 |
-
elif mask_open.any():
|
308 |
-
titles_new = titles_old + [issue['title'] for i, issue in enumerate(issues_open) if mask_open[i]]
|
309 |
-
numbers_new = numbers_old + [number for i, number in enumerate(numbers_open) if mask_open[i]]
|
310 |
-
embeddings_new = torch.cat([embeddings_old, embeddings[mask_open]])
|
311 |
-
else:
|
312 |
-
# Only Updated Data changed
|
313 |
-
return data
|
314 |
|
315 |
-
|
316 |
-
|
317 |
-
titles_new = [titles_new[i] for i in sorted_indices]
|
318 |
-
numbers_new = [numbers_new[i] for i in sorted_indices]
|
319 |
-
embeddings_new = embeddings_new[sorted_indices]
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
# autopep8: on
|
326 |
return data
|
@@ -332,64 +306,55 @@ EMBEDDING_CTX = EmbeddingContext()
|
|
332 |
# EMBEDDING_CTX.embeddings_generate('blender', 'blender-addons')
|
333 |
|
334 |
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
duplicates = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
ret = util.semantic_search(
|
338 |
-
query_emb,
|
|
|
339 |
for score in ret[0]:
|
340 |
corpus_id = score['corpus_id']
|
341 |
-
|
|
|
|
|
342 |
duplicates.append(text)
|
343 |
|
344 |
return duplicates
|
345 |
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
def text_search(owner, repo, text_to_embed, limit=None):
|
351 |
-
global cached_search
|
352 |
-
global EMBEDDING_CTX
|
353 |
-
if not text_to_embed:
|
354 |
-
return []
|
355 |
-
|
356 |
-
if text_to_embed == cached_search['text'] and repo == cached_search['repo']:
|
357 |
-
return cached_search['issues'][:limit]
|
358 |
-
|
359 |
-
data = EMBEDDING_CTX.embeddings_updated_get(owner, repo)
|
360 |
-
|
361 |
-
new_embedding = EMBEDDING_CTX.encode([text_to_embed])
|
362 |
-
result = _sort_similarity(data, new_embedding, 500)
|
363 |
-
|
364 |
-
cached_search = {'text': text_to_embed, 'repo': repo, 'issues': result}
|
365 |
-
return result[:limit]
|
366 |
-
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
|
376 |
-
data = EMBEDDING_CTX.embeddings_updated_get(repo)
|
377 |
-
new_embedding = None
|
378 |
-
|
379 |
-
# Check if the embedding already exist.
|
380 |
-
for i in range(len(data['numbers']) - 1, -1, -1):
|
381 |
-
number_cached = data['numbers'][i]
|
382 |
-
if number_cached < number:
|
383 |
-
break
|
384 |
-
if number_cached == number:
|
385 |
-
new_embedding = data['embeddings'][i]
|
386 |
-
break
|
387 |
-
|
388 |
-
if new_embedding is None:
|
389 |
-
text_to_embed = _create_issue_string(title, body)
|
390 |
new_embedding = EMBEDDING_CTX.encode([text_to_embed])
|
391 |
|
392 |
-
duplicates = _sort_similarity(
|
|
|
|
|
393 |
if not duplicates:
|
394 |
return ''
|
395 |
|
@@ -401,8 +366,8 @@ def find_relatedness(repo, number, limit=20):
|
|
401 |
|
402 |
|
403 |
@router.get("/find_related/{repo}/{number}")
|
404 |
-
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15):
|
405 |
-
related = find_relatedness(repo, number, limit=limit)
|
406 |
return related
|
407 |
|
408 |
|
@@ -425,11 +390,8 @@ if __name__ == "__main__":
|
|
425 |
val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
|
426 |
|
427 |
# 'blender/blender/111434' must print #96153, #83604 and #79762
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
related1 = find_relatedness(issue1, limit=20)
|
432 |
-
related2 = find_relatedness(issue2, limit=20)
|
433 |
|
434 |
print("These are the 20 most related issues:")
|
435 |
print(related1)
|
|
|
5 |
import re
|
6 |
import torch
|
7 |
import threading
|
8 |
+
|
9 |
from datetime import datetime, timedelta
|
10 |
+
from enum import Enum
|
11 |
from sentence_transformers import SentenceTransformer, util
|
12 |
from fastapi import APIRouter
|
13 |
|
|
|
55 |
# These don't change
|
56 |
TOKEN_LEN_MAX_FOR_EMBEDDING = 512
|
57 |
TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
|
58 |
+
ARRAY_CHUNK_SIZE = 4096
|
59 |
issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'}
|
60 |
cache_path = "routers/tool_find_related_cache.pkl"
|
61 |
|
|
|
65 |
openai_client = None
|
66 |
model_name = ''
|
67 |
config_type = ''
|
68 |
+
embedding_shape = None
|
69 |
+
embedding_dtype = None
|
70 |
+
embedding_device = None
|
71 |
|
72 |
# Updates constantly
|
73 |
data = {}
|
|
|
108 |
self.model_name = model_name
|
109 |
self.config_type = config_type
|
110 |
|
111 |
+
tmp = self.encode(['tmp'])
|
112 |
+
self.embedding_shape = tmp.shape[1:]
|
113 |
+
self.embedding_dtype = tmp.dtype
|
114 |
+
self.embedding_device = tmp.device
|
115 |
+
|
116 |
def encode(self, texts_to_embed):
|
117 |
pass
|
118 |
|
|
|
182 |
|
183 |
return texts_to_embed
|
184 |
|
185 |
+
def data_ensure_size(self, repo, size_new):
|
186 |
+
updated_at_old = None
|
187 |
+
arrays_size_old = 0
|
188 |
+
titles_old = []
|
189 |
+
try:
|
190 |
+
arrays_size_old = self.data[repo]['arrays_size']
|
191 |
+
if size_new <= arrays_size_old:
|
192 |
+
return
|
193 |
+
except:
|
194 |
+
pass
|
195 |
+
|
196 |
+
arrays_size_new = self.ARRAY_CHUNK_SIZE * \
|
197 |
+
(int(size_new / self.ARRAY_CHUNK_SIZE) + 1)
|
198 |
+
|
199 |
+
data_new = {
|
200 |
+
'updated_at': updated_at_old,
|
201 |
+
'arrays_size': arrays_size_new,
|
202 |
+
'titles': titles_old + [None] * (arrays_size_new - arrays_size_old),
|
203 |
+
'embeddings': torch.empty((arrays_size_new, *self.embedding_shape),
|
204 |
+
dtype=self.embedding_dtype,
|
205 |
+
device=self.embedding_device),
|
206 |
+
'opened': torch.zeros(arrays_size_new, dtype=torch.bool),
|
207 |
+
'closed': torch.zeros(arrays_size_new, dtype=torch.bool),
|
208 |
+
}
|
209 |
+
|
210 |
+
try:
|
211 |
+
data_new['embeddings'][:arrays_size_old] = self.data[repo]['embeddings']
|
212 |
+
data_new['opened'][:arrays_size_old] = self.data[repo]['opened']
|
213 |
+
data_new['closed'][:arrays_size_old] = self.data[repo]['closed']
|
214 |
+
except:
|
215 |
+
pass
|
216 |
+
|
217 |
+
self.data[repo] = data_new
|
218 |
+
|
219 |
def embeddings_generate(self, repo):
|
220 |
if os.path.exists(self.cache_path):
|
221 |
with open(self.cache_path, 'rb') as file:
|
|
|
228 |
|
229 |
black_list = self.black_list[repo]
|
230 |
|
231 |
+
issues = gitea_fetch_issues('blender', repo, state='all', since=None,
|
232 |
issue_attr_filter=self.issue_attr_filter, exclude=black_list)
|
233 |
|
234 |
+
# issues = sorted(issues, key=lambda issue: int(issue['number']))
|
235 |
|
236 |
print("Embedding Issues...")
|
237 |
texts_to_embed = self.create_strings_to_embbed(issues, black_list)
|
238 |
embeddings = self.encode(texts_to_embed)
|
239 |
|
240 |
+
self.data_ensure_size(repo, int(issues[0]['number']))
|
241 |
+
self.data[repo]['updated_at'] = _find_latest_date(issues)
|
242 |
+
|
243 |
+
titles = self.data[repo]['titles']
|
244 |
+
embeddings_new = self.data[repo]['embeddings']
|
245 |
+
opened = self.data[repo]['opened']
|
246 |
+
closed = self.data[repo]['closed']
|
247 |
|
248 |
+
for i, issue in enumerate(issues):
|
249 |
+
number = int(issue['number'])
|
250 |
+
titles[number] = issue['title']
|
251 |
+
embeddings_new[number] = embeddings[i]
|
252 |
+
if issue['state'] == 'open':
|
253 |
+
opened[number] = True
|
254 |
+
if issue['state'] == 'closed':
|
255 |
+
closed[number] = True
|
256 |
|
257 |
def embeddings_updated_get(self, repo):
|
258 |
with self.lock:
|
|
|
282 |
# Consider that if the time hasn't changed, it's the same issue.
|
283 |
issues = [issue for issue in issues if issue['updated_at'] != date_old]
|
284 |
|
285 |
+
self.data_ensure_size(repo, int(issues[0]['number']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
+
texts_to_embed = self.create_strings_to_embbed(issues, black_list)
|
288 |
+
embeddings = self.encode(texts_to_embed)
|
|
|
|
|
|
|
289 |
|
290 |
+
for i, issue in enumerate(issues):
|
291 |
+
number = int(issue['number'])
|
292 |
+
data['titles'][number] = issue['title']
|
293 |
+
data['embeddings'][number] = embeddings[i]
|
294 |
+
if issue['state'] == 'open':
|
295 |
+
data['opened'][number] = True
|
296 |
+
if issue['state'] == 'closed':
|
297 |
+
data['closed'][number] = True
|
298 |
|
299 |
# autopep8: on
|
300 |
return data
|
|
|
306 |
# EMBEDDING_CTX.embeddings_generate('blender', 'blender-addons')
|
307 |
|
308 |
|
309 |
+
# Define your Enum class
|
310 |
+
class State(str, Enum):
|
311 |
+
opened = "opened"
|
312 |
+
closed = "closed"
|
313 |
+
all = "all"
|
314 |
+
|
315 |
+
|
316 |
+
def _sort_similarity(data: dict,
|
317 |
+
query_emb: torch.Tensor,
|
318 |
+
limit: int,
|
319 |
+
state: State = State.opened) -> list:
|
320 |
duplicates = []
|
321 |
+
embeddings = data['embeddings']
|
322 |
+
true_indices = None
|
323 |
+
|
324 |
+
if state != State.all:
|
325 |
+
mask = data[state.value]
|
326 |
+
embeddings = embeddings[mask]
|
327 |
+
true_indices = mask.nonzero(as_tuple=True)[0]
|
328 |
+
|
329 |
ret = util.semantic_search(
|
330 |
+
query_emb, embeddings, top_k=limit, score_function=util.dot_score)
|
331 |
+
|
332 |
for score in ret[0]:
|
333 |
corpus_id = score['corpus_id']
|
334 |
+
number = true_indices[corpus_id].item(
|
335 |
+
) if true_indices is not None else corpus_id
|
336 |
+
text = f"#{number}: {data['titles'][number]}"
|
337 |
duplicates.append(text)
|
338 |
|
339 |
return duplicates
|
340 |
|
341 |
|
342 |
+
def find_relatedness(repo: str, number: int, limit: int = 20, state: State = State.opened):
|
343 |
+
data = EMBEDDING_CTX.embeddings_updated_get(repo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
+
# Check if the embedding already exists.
|
346 |
+
if data['titles'][number] is not None:
|
347 |
+
new_embedding = data['embeddings'][number]
|
348 |
+
else:
|
349 |
+
gitea_issue = gitea_json_issue_get('blender', repo, number)
|
350 |
+
text_to_embed = _create_issue_string(
|
351 |
+
gitea_issue['title'], gitea_issue['body'])
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
new_embedding = EMBEDDING_CTX.encode([text_to_embed])
|
354 |
|
355 |
+
duplicates = _sort_similarity(
|
356 |
+
data, new_embedding, limit=limit, state=state)
|
357 |
+
|
358 |
if not duplicates:
|
359 |
return ''
|
360 |
|
|
|
366 |
|
367 |
|
368 |
@router.get("/find_related/{repo}/{number}")
|
369 |
+
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, state: State = State.opened):
|
370 |
+
related = find_relatedness(repo, number, limit=limit, state=state)
|
371 |
return related
|
372 |
|
373 |
|
|
|
390 |
val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
|
391 |
|
392 |
# 'blender/blender/111434' must print #96153, #83604 and #79762
|
393 |
+
related1 = find_relatedness('blender', 111434, limit=20)
|
394 |
+
related2 = find_relatedness('blender-addons', 104399, limit=20)
|
|
|
|
|
|
|
395 |
|
396 |
print("These are the 20 most related issues:")
|
397 |
print(related1)
|
routers/tool_find_related_cache.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46c42973a8caaa2f0d4a76ebc6ff16c0b8df927c9b16ba645c3f7155cce84f6a
|
3 |
+
size 723382452
|