Spaces:
Build error
Build error
File size: 4,185 Bytes
d660b02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import concurrent.futures
import opik
from loguru import logger
from qdrant_client.models import FieldCondition, Filter, MatchValue
from llm_engineering.application import utils
from llm_engineering.application.preprocessing.dispatchers import EmbeddingDispatcher
from llm_engineering.domain.embedded_chunks import (
EmbeddedArticleChunk,
EmbeddedChunk,
EmbeddedPostChunk,
EmbeddedRepositoryChunk,
)
from llm_engineering.domain.queries import EmbeddedQuery, Query
from .query_expanison import QueryExpansion
from .reranking import Reranker
from .self_query import SelfQuery
class ContextRetriever:
def __init__(self, mock: bool = False) -> None:
self._query_expander = QueryExpansion(mock=mock)
self._metadata_extractor = SelfQuery(mock=mock)
self._reranker = Reranker(mock=mock)
@opik.track(name="ContextRetriever.search")
def search(
self,
query: str,
k: int = 3,
expand_to_n_queries: int = 3,
) -> list:
query_model = Query.from_str(query)
query_model = self._metadata_extractor.generate(query_model)
logger.info(
f"Successfully extracted the author_full_name = {query_model.author_full_name} from the query.",
)
n_generated_queries = self._query_expander.generate(query_model, expand_to_n=expand_to_n_queries)
logger.info(
f"Successfully generated {len(n_generated_queries)} search queries.",
)
logger.info(f"The generated queries are \n {n_generated_queries}")
with concurrent.futures.ThreadPoolExecutor() as executor:
search_tasks = [executor.submit(self._search, _query_model, k) for _query_model in n_generated_queries]
n_k_documents = [task.result() for task in concurrent.futures.as_completed(search_tasks)]
n_k_documents = utils.misc.flatten(n_k_documents)
n_k_documents = list(set(n_k_documents))
logger.info(f"{len(n_k_documents)} documents retrieved successfully")
if len(n_k_documents) > 0:
k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
else:
k_documents = []
return k_documents
def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
assert k >= 3, "k should be >= 3"
def _search_data_category(
data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
) -> list[EmbeddedChunk]:
#if embedded_query.author_id:
# query_filter = Filter(
# must=[
# FieldCondition(
# key="author_id",
# match=MatchValue(
# value=str(embedded_query.author_id),
# ),
# )
# ]
# )
#else:
query_filter = None
return data_category_odm.search(
query_vector=embedded_query.embedding,
limit=k // 3,
query_filter=query_filter,
)
embedded_query: EmbeddedQuery = EmbeddingDispatcher.dispatch(query)
#post_chunks = _search_data_category(EmbeddedPostChunk, embedded_query)
#articles_chunks = _search_data_category(EmbeddedArticleChunk, embedded_query)
repositories_chunks = _search_data_category(EmbeddedRepositoryChunk, embedded_query)
retrieved_chunks = repositories_chunks #post_chunks + articles_chunks +
logger.info(f"Retrieved {len(retrieved_chunks)} chunks")
return retrieved_chunks
def rerank(self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int) -> list[EmbeddedChunk]:
if isinstance(query, str):
query = Query.from_str(query)
reranked_documents = self._reranker.generate(query=query, chunks=chunks, keep_top_k=keep_top_k)
logger.info(f"{len(reranked_documents)} documents reranked successfully.")
return reranked_documents
|