Spaces:
Running
Running
# routers/find_related.py | |
import os | |
import pickle | |
import torch | |
import re | |
from typing import List | |
from datetime import datetime, timedelta | |
from enum import Enum | |
from sentence_transformers import util | |
from fastapi import APIRouter | |
from fastapi.responses import PlainTextResponse | |
try: | |
from .rag import EMBEDDING_CTX | |
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get | |
except: | |
from rag import EMBEDDING_CTX | |
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get | |
router = APIRouter() | |
issue_attr_filter = {'number', 'title', 'body', | |
'state', 'updated_at', 'created_at'} | |
class State(str, Enum): | |
opened = "opened" | |
closed = "closed" | |
all = "all" | |
class _Data(dict): | |
cache_path = "routers/rag/embeddings_issues.pkl" | |
def _create_issue_string(title, body): | |
cleaned_body = body.replace('\r', '') | |
cleaned_body = cleaned_body.replace('**System Information**\n', '') | |
cleaned_body = cleaned_body.replace('**Blender Version**\n', '') | |
cleaned_body = cleaned_body.replace( | |
'Worked: (newest version of Blender that worked as expected)\n', '') | |
cleaned_body = cleaned_body.replace( | |
'**Short description of error**\n', '') | |
cleaned_body = cleaned_body.replace('**Addon Information**\n', '') | |
cleaned_body = cleaned_body.replace( | |
'**Exact steps for others to reproduce the error**\n', '') | |
cleaned_body = cleaned_body.replace( | |
'[Please describe the exact steps needed to reproduce the issue]\n', '') | |
cleaned_body = cleaned_body.replace( | |
'[Please fill out a short description of the error here]\n', '') | |
cleaned_body = cleaned_body.replace( | |
'[Based on the default startup or an attached .blend file (as simple as possible)]\n', '') | |
cleaned_body = re.sub( | |
r', branch: .+?, commit date: \d{4}-\d{2}-\d{2} \d{2}:\d{2}, hash: `.+?`', '', cleaned_body) | |
cleaned_body = re.sub( | |
r'\/?attachments\/[a-zA-Z0-9\-]+', 'attachment', cleaned_body) | |
cleaned_body = re.sub( | |
r'https?:\/\/[^\s/]+(?:\/[^\s/]+)*\/([^\s/]+)', lambda match: match.group(1), cleaned_body) | |
return title + '\n' + cleaned_body | |
def _find_latest_date(issues, default_str=None): | |
# Handle the case where 'issues' is empty | |
if not issues: | |
return default_str | |
return max((issue['updated_at'] for issue in issues), default=default_str) | |
def _create_strings_to_embbed(cls, issues): | |
texts_to_embed = [cls._create_issue_string( | |
issue['title'], issue['body']) for issue in issues] | |
return texts_to_embed | |
def _data_ensure_size(self, repo, size_new): | |
ARRAY_CHUNK_SIZE = 4096 | |
updated_at_old = None | |
arrays_size_old = 0 | |
titles_old = [] | |
try: | |
arrays_size_old = self[repo]['arrays_size'] | |
if size_new <= arrays_size_old: | |
return | |
updated_at_old = self[repo]['updated_at'] | |
titles_old = self[repo]['titles'] | |
except: | |
pass | |
arrays_size_new = ARRAY_CHUNK_SIZE * \ | |
(int(size_new / ARRAY_CHUNK_SIZE) + 1) | |
data_new = { | |
'updated_at': updated_at_old, | |
'arrays_size': arrays_size_new, | |
'titles': titles_old + [None] * (arrays_size_new - arrays_size_old), | |
'embeddings': torch.empty((arrays_size_new, *EMBEDDING_CTX.embedding_shape), | |
dtype=EMBEDDING_CTX.embedding_dtype, | |
device=EMBEDDING_CTX.embedding_device), | |
'opened': torch.zeros(arrays_size_new, dtype=torch.bool), | |
'closed': torch.zeros(arrays_size_new, dtype=torch.bool), | |
} | |
try: | |
data_new['embeddings'][:arrays_size_old] = self[repo]['embeddings'] | |
data_new['opened'][:arrays_size_old] = self[repo]['opened'] | |
data_new['closed'][:arrays_size_old] = self[repo]['closed'] | |
except: | |
pass | |
self[repo] = data_new | |
def _embeddings_generate(self, repo): | |
if os.path.exists(self.cache_path): | |
with open(self.cache_path, 'rb') as file: | |
data = pickle.load(file) | |
self.update(data) | |
if repo in self: | |
return | |
issues = gitea_fetch_issues('blender', repo, state='all', since=None, | |
issue_attr_filter=issue_attr_filter) | |
# issues = sorted(issues, key=lambda issue: int(issue['number'])) | |
print("Embedding Issues...") | |
texts_to_embed = self._create_strings_to_embbed(issues) | |
embeddings = EMBEDDING_CTX.encode(texts_to_embed) | |
self._data_ensure_size(repo, int(issues[0]['number'])) | |
self[repo]['updated_at'] = self._find_latest_date(issues) | |
titles = self[repo]['titles'] | |
embeddings_new = self[repo]['embeddings'] | |
opened = self[repo]['opened'] | |
closed = self[repo]['closed'] | |
for i, issue in enumerate(issues): | |
number = int(issue['number']) | |
titles[number] = issue['title'] | |
embeddings_new[number] = embeddings[i] | |
if issue['state'] == 'open': | |
opened[number] = True | |
if issue['state'] == 'closed': | |
closed[number] = True | |
def _embeddings_updated_get(self, repo): | |
with EMBEDDING_CTX.lock: | |
if not repo in self: | |
self._embeddings_generate(repo) | |
date_old = self[repo]['updated_at'] | |
issues = gitea_fetch_issues( | |
'blender', repo, since=date_old, issue_attr_filter=issue_attr_filter) | |
# Get the most recent date | |
date_new = self._find_latest_date(issues, date_old) | |
if date_new == date_old: | |
# Nothing changed | |
return self[repo] | |
self[repo]['updated_at'] = date_new | |
# autopep8: off | |
# Consider that if the time hasn't changed, it's the same issue. | |
issues = [issue for issue in issues if issue['updated_at'] != date_old] | |
self._data_ensure_size(repo, int(issues[0]['number'])) | |
updated_at = gitea_issues_body_updated_at_get(issues) | |
issues_to_embed = [] | |
for i, issue in enumerate(issues): | |
number = int(issue['number']) | |
self[repo]['opened'][number] = issue['state'] == 'open' | |
self[repo]['closed'][number] = issue['state'] == 'closed' | |
title_old = self[repo]['titles'][number] | |
if title_old != issue['title']: | |
self[repo]['titles'][number] = issue['title'] | |
issues_to_embed.append(issue) | |
elif not updated_at or updated_at[i] >= date_old: | |
issues_to_embed.append(issue) | |
if issues_to_embed: | |
print(f"Embedding {len(issues_to_embed)} issue{'s' if len(issues_to_embed) > 1 else ''}") | |
texts_to_embed = self._create_strings_to_embbed(issues_to_embed) | |
embeddings = EMBEDDING_CTX.encode(texts_to_embed) | |
for i, issue in enumerate(issues_to_embed): | |
number = int(issue['number']) | |
self[repo]['embeddings'][number] = embeddings[i] | |
# autopep8: on | |
return self[repo] | |
def _sort_similarity(self, | |
repo: str, | |
query_emb: List[torch.Tensor], | |
limit: int, | |
state: State = State.opened) -> list: | |
duplicates = [] | |
data = self[repo] | |
embeddings = data['embeddings'] | |
mask_opened = data["opened"] | |
if state == State.all: | |
mask = mask_opened | data["closed"] | |
else: | |
mask = data[state.value] | |
embeddings = embeddings[mask] | |
true_indices = mask.nonzero(as_tuple=True)[0] | |
ret = util.semantic_search( | |
query_emb, embeddings, top_k=limit, score_function=util.dot_score) | |
for score in ret[0]: | |
corpus_id = score['corpus_id'] | |
number = true_indices[corpus_id].item() | |
closed_char = "" if mask_opened[number] else "~~" | |
text = f"{closed_char}#{number}{closed_char}: {data['titles'][number]}" | |
duplicates.append(text) | |
return duplicates | |
def find_relatedness(self, repo: str, number: int, limit: int = 20, state: State = State.opened): | |
data = self._embeddings_updated_get(repo) | |
# Check if the embedding already exists. | |
if data['titles'][number] is not None: | |
new_embedding = data['embeddings'][number] | |
else: | |
gitea_issue = gitea_json_issue_get('blender', repo, number) | |
text_to_embed = self._create_issue_string( | |
gitea_issue['title'], gitea_issue['body']) | |
new_embedding = EMBEDDING_CTX.encode([text_to_embed]) | |
duplicates = self._sort_similarity( | |
repo, new_embedding, limit=limit, state=state) | |
if not duplicates: | |
return '' | |
if match := re.search(r'(~~)?#(\d+)(~~)?:', duplicates[0]): | |
number_cached = int(match.group(2)) | |
if number_cached == number: | |
return '\n'.join(duplicates[1:]) | |
return '\n'.join(duplicates) | |
G_data = _Data() | |
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, state: State = State.opened) -> str: | |
related = G_data.find_relatedness(repo, number, limit=limit, state=state) | |
return related | |
if __name__ == "__main__": | |
update_cache = True | |
if update_cache: | |
G_data._embeddings_updated_get('blender') | |
G_data._embeddings_updated_get('blender-addons') | |
with open(G_data.cache_path, "wb") as file: | |
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU. | |
for val in G_data.values(): | |
val['embeddings'] = val['embeddings'].to(torch.device('cpu')) | |
pickle.dump(dict(G_data), file, protocol=pickle.HIGHEST_PROTOCOL) | |
# Converting the embeddings to be GPU. | |
for val in G_data.values(): | |
val['embeddings'] = val['embeddings'].to(torch.device('cuda')) | |
# 'blender/blender/111434' must print #96153, #83604 and #79762 | |
related1 = G_data.find_relatedness( | |
'blender', 111434, limit=20, state=State.all) | |
related2 = G_data.find_relatedness('blender-addons', 104399, limit=20) | |
print("These are the 20 most related issues:") | |
print(related1) | |
print() | |
print("These are the 20 most related issues:") | |
print(related2) | |