tools / routers /tool_find_related.py
Germano Cavalcante
Find Related: Update Cache
125e8d3
raw
history blame
11 kB
# routers/find_related.py
import os
import pickle
import torch
import re
from typing import List
from datetime import datetime, timedelta
from enum import Enum
from sentence_transformers import util
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
try:
from .rag import EMBEDDING_CTX
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
except:
from rag import EMBEDDING_CTX
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
router = APIRouter()
issue_attr_filter = {'number', 'title', 'body',
'state', 'updated_at', 'created_at'}
class State(str, Enum):
opened = "opened"
closed = "closed"
all = "all"
class _Data(dict):
cache_path = "routers/rag/embeddings_issues.pkl"
@staticmethod
def _create_issue_string(title, body):
cleaned_body = body.replace('\r', '')
cleaned_body = cleaned_body.replace('**System Information**\n', '')
cleaned_body = cleaned_body.replace('**Blender Version**\n', '')
cleaned_body = cleaned_body.replace(
'Worked: (newest version of Blender that worked as expected)\n', '')
cleaned_body = cleaned_body.replace(
'**Short description of error**\n', '')
cleaned_body = cleaned_body.replace('**Addon Information**\n', '')
cleaned_body = cleaned_body.replace(
'**Exact steps for others to reproduce the error**\n', '')
cleaned_body = cleaned_body.replace(
'[Please describe the exact steps needed to reproduce the issue]\n', '')
cleaned_body = cleaned_body.replace(
'[Please fill out a short description of the error here]\n', '')
cleaned_body = cleaned_body.replace(
'[Based on the default startup or an attached .blend file (as simple as possible)]\n', '')
cleaned_body = re.sub(
r', branch: .+?, commit date: \d{4}-\d{2}-\d{2} \d{2}:\d{2}, hash: `.+?`', '', cleaned_body)
cleaned_body = re.sub(
r'\/?attachments\/[a-zA-Z0-9\-]+', 'attachment', cleaned_body)
cleaned_body = re.sub(
r'https?:\/\/[^\s/]+(?:\/[^\s/]+)*\/([^\s/]+)', lambda match: match.group(1), cleaned_body)
return title + '\n' + cleaned_body
@staticmethod
def _find_latest_date(issues, default_str=None):
# Handle the case where 'issues' is empty
if not issues:
return default_str
return max((issue['updated_at'] for issue in issues), default=default_str)
@classmethod
def _create_strings_to_embbed(cls, issues):
texts_to_embed = [cls._create_issue_string(
issue['title'], issue['body']) for issue in issues]
return texts_to_embed
def _data_ensure_size(self, repo, size_new):
ARRAY_CHUNK_SIZE = 4096
updated_at_old = None
arrays_size_old = 0
titles_old = []
try:
arrays_size_old = self[repo]['arrays_size']
if size_new <= arrays_size_old:
return
updated_at_old = self[repo]['updated_at']
titles_old = self[repo]['titles']
except:
pass
arrays_size_new = ARRAY_CHUNK_SIZE * \
(int(size_new / ARRAY_CHUNK_SIZE) + 1)
data_new = {
'updated_at': updated_at_old,
'arrays_size': arrays_size_new,
'titles': titles_old + [None] * (arrays_size_new - arrays_size_old),
'embeddings': torch.empty((arrays_size_new, *EMBEDDING_CTX.embedding_shape),
dtype=EMBEDDING_CTX.embedding_dtype,
device=EMBEDDING_CTX.embedding_device),
'opened': torch.zeros(arrays_size_new, dtype=torch.bool),
'closed': torch.zeros(arrays_size_new, dtype=torch.bool),
}
try:
data_new['embeddings'][:arrays_size_old] = self[repo]['embeddings']
data_new['opened'][:arrays_size_old] = self[repo]['opened']
data_new['closed'][:arrays_size_old] = self[repo]['closed']
except:
pass
self[repo] = data_new
def _embeddings_generate(self, repo):
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as file:
data = pickle.load(file)
self.update(data)
if repo in self:
return
issues = gitea_fetch_issues('blender', repo, state='all', since=None,
issue_attr_filter=issue_attr_filter)
# issues = sorted(issues, key=lambda issue: int(issue['number']))
print("Embedding Issues...")
texts_to_embed = self._create_strings_to_embbed(issues)
embeddings = EMBEDDING_CTX.encode(texts_to_embed)
self._data_ensure_size(repo, int(issues[0]['number']))
self[repo]['updated_at'] = self._find_latest_date(issues)
titles = self[repo]['titles']
embeddings_new = self[repo]['embeddings']
opened = self[repo]['opened']
closed = self[repo]['closed']
for i, issue in enumerate(issues):
number = int(issue['number'])
titles[number] = issue['title']
embeddings_new[number] = embeddings[i]
if issue['state'] == 'open':
opened[number] = True
if issue['state'] == 'closed':
closed[number] = True
def _embeddings_updated_get(self, repo):
with EMBEDDING_CTX.lock:
if not repo in self:
self._embeddings_generate(repo)
date_old = self[repo]['updated_at']
issues = gitea_fetch_issues(
'blender', repo, since=date_old, issue_attr_filter=issue_attr_filter)
# Get the most recent date
date_new = self._find_latest_date(issues, date_old)
if date_new == date_old:
# Nothing changed
return self[repo]
self[repo]['updated_at'] = date_new
# autopep8: off
# Consider that if the time hasn't changed, it's the same issue.
issues = [issue for issue in issues if issue['updated_at'] != date_old]
self._data_ensure_size(repo, int(issues[0]['number']))
updated_at = gitea_issues_body_updated_at_get(issues)
issues_to_embed = []
for i, issue in enumerate(issues):
number = int(issue['number'])
self[repo]['opened'][number] = issue['state'] == 'open'
self[repo]['closed'][number] = issue['state'] == 'closed'
title_old = self[repo]['titles'][number]
if title_old != issue['title']:
self[repo]['titles'][number] = issue['title']
issues_to_embed.append(issue)
elif not updated_at or updated_at[i] >= date_old:
issues_to_embed.append(issue)
if issues_to_embed:
print(f"Embedding {len(issues_to_embed)} issue{'s' if len(issues_to_embed) > 1 else ''}")
texts_to_embed = self._create_strings_to_embbed(issues_to_embed)
embeddings = EMBEDDING_CTX.encode(texts_to_embed)
for i, issue in enumerate(issues_to_embed):
number = int(issue['number'])
self[repo]['embeddings'][number] = embeddings[i]
# autopep8: on
return self[repo]
def _sort_similarity(self,
repo: str,
query_emb: List[torch.Tensor],
limit: int,
state: State = State.opened) -> list:
duplicates = []
data = self[repo]
embeddings = data['embeddings']
mask_opened = data["opened"]
if state == State.all:
mask = mask_opened | data["closed"]
else:
mask = data[state.value]
embeddings = embeddings[mask]
true_indices = mask.nonzero(as_tuple=True)[0]
ret = util.semantic_search(
query_emb, embeddings, top_k=limit, score_function=util.dot_score)
for score in ret[0]:
corpus_id = score['corpus_id']
number = true_indices[corpus_id].item()
closed_char = "" if mask_opened[number] else "~~"
text = f"{closed_char}#{number}{closed_char}: {data['titles'][number]}"
duplicates.append(text)
return duplicates
def find_relatedness(self, repo: str, number: int, limit: int = 20, state: State = State.opened):
data = self._embeddings_updated_get(repo)
# Check if the embedding already exists.
if data['titles'][number] is not None:
new_embedding = data['embeddings'][number]
else:
gitea_issue = gitea_json_issue_get('blender', repo, number)
text_to_embed = self._create_issue_string(
gitea_issue['title'], gitea_issue['body'])
new_embedding = EMBEDDING_CTX.encode([text_to_embed])
duplicates = self._sort_similarity(
repo, new_embedding, limit=limit, state=state)
if not duplicates:
return ''
if match := re.search(r'(~~)?#(\d+)(~~)?:', duplicates[0]):
number_cached = int(match.group(2))
if number_cached == number:
return '\n'.join(duplicates[1:])
return '\n'.join(duplicates)
G_data = _Data()
@router.get("/find_related/{repo}/{number}", response_class=PlainTextResponse)
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, state: State = State.opened) -> str:
related = G_data.find_relatedness(repo, number, limit=limit, state=state)
return related
if __name__ == "__main__":
update_cache = True
if update_cache:
G_data._embeddings_updated_get('blender')
G_data._embeddings_updated_get('blender-addons')
with open(G_data.cache_path, "wb") as file:
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
for val in G_data.values():
val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
pickle.dump(dict(G_data), file, protocol=pickle.HIGHEST_PROTOCOL)
# Converting the embeddings to be GPU.
for val in G_data.values():
val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
# 'blender/blender/111434' must print #96153, #83604 and #79762
related1 = G_data.find_relatedness(
'blender', 111434, limit=20, state=State.all)
related2 = G_data.find_relatedness('blender-addons', 104399, limit=20)
print("These are the 20 most related issues:")
print(related1)
print()
print("These are the 20 most related issues:")
print(related2)