Spaces:
Configuration error
Configuration error
File size: 5,108 Bytes
7bd11ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import string
from typing import List, Optional, Tuple
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from loguru import logger
from app.chroma import ChromaDenseVectorDB
from app.config.models.configs import (
ResponseModel,
Config, SemanticSearchConfig,
)
from app.ranking import BCEReranker, rerank
from app.splade import SpladeSparseVectorDB
class LLMBundle:
def __init__(
self,
chain: Chain,
dense_db: ChromaDenseVectorDB,
reranker: BCEReranker,
sparse_db: SpladeSparseVectorDB,
chunk_sizes: List[int],
hyde_chain: Optional[LLMChain] = None
) -> None:
self.chain = chain
self.dense_db = dense_db
self.reranker = reranker
self.sparse_db = sparse_db
self.chunk_sizes = chunk_sizes
self.hyde_chain = hyde_chain
def get_relevant_documents(
self,
original_query: str,
query: str,
config: SemanticSearchConfig,
label: str,
) -> Tuple[List[str], float]:
most_relevant_docs = []
docs = []
current_reranker_score, reranker_score = -1e5, -1e5
for chunk_size in self.chunk_sizes:
all_relevant_docs = []
all_relevant_doc_ids = set()
logger.debug("Evaluating query: {}", query)
if config.query_prefix:
logger.info(f"Adding query prefix for retrieval: {config.query_prefix}")
query = config.query_prefix + query
sparse_search_docs_ids, sparse_scores = self.sparse_db.query(
search=query, n=config.max_k, label=label, chunk_size=chunk_size
)
logger.info(f"Stage 1: Got {len(sparse_search_docs_ids)} documents.")
filter = (
{"chunk_size": chunk_size}
if len(self.chunk_sizes) > 1
else dict()
)
if label:
filter.update({"label": label})
if (
not filter
):
filter = None
logger.info(f"Dense embeddings filter: {filter}")
res = self.dense_db.similarity_search_with_relevance_scores(
query, filter=filter
)
dense_search_doc_ids = [r[0].metadata["document_id"] for r in res]
all_doc_ids = (
set(sparse_search_docs_ids).union(set(dense_search_doc_ids))
).difference(all_relevant_doc_ids)
if all_doc_ids:
relevant_docs = self.dense_db.get_documents_by_id(
document_ids=list(all_doc_ids)
)
all_relevant_docs += relevant_docs
# Re-rank embeddings
reranker_score, relevant_docs = rerank(
rerank_model=self.reranker,
query=original_query,
docs=all_relevant_docs,
)
if reranker_score > current_reranker_score:
docs = relevant_docs
current_reranker_score = reranker_score
len_ = 0
for doc in docs:
doc_length = len(doc.page_content)
if len_ + doc_length < config.max_char_size:
most_relevant_docs.append(doc)
len_ += doc_length
return most_relevant_docs, current_reranker_score
def get_and_parse_response(
self,
query: str,
config: Config,
label: str = "",
) -> ResponseModel:
original_query = query
# Add HyDE queries
hyde_response = self.hyde_chain.run(query)
query += hyde_response
logger.info(f"query: {query}")
semantic_search_config = config.semantic_search
most_relevant_docs, score = self.get_relevant_documents(
original_query, query, semantic_search_config, label
)
res = self.chain(
{"input_documents": most_relevant_docs, "question": original_query},
)
out = ResponseModel(
response=res["output_text"],
question=query,
average_score=score,
hyde_response="",
)
for doc in res["input_documents"]:
out.semantic_search.append(doc.page_content)
return out
class PartialFormatter(string.Formatter):
def __init__(self, missing="~~", bad_fmt="!!"):
self.missing, self.bad_fmt = missing, bad_fmt
def get_field(self, field_name, args, kwargs):
try:
val = super(PartialFormatter, self).get_field(field_name, args, kwargs)
except (KeyError, AttributeError):
val = None, field_name
return val
def format_field(self, value, spec):
if value is None:
return self.missing
try:
return super(PartialFormatter, self).format_field(value, spec)
except ValueError:
if self.bad_fmt is not None:
return self.bad_fmt
else:
raise
|