SkazuHD's picture
init space
d660b02
import opik
from langchain_ollama import ChatOllama
from loguru import logger
from llm_engineering.domain.queries import Query
from llm_engineering.settings import settings
from .base import RAGStep
from .prompt_templates import QueryExpansionTemplate
class QueryExpansion(RAGStep):
@opik.track(name="QueryExpansion.generate")
def generate(self, query: Query, expand_to_n: int) -> list[Query]:
assert expand_to_n > 0, f"'expand_to_n' should be greater than 0. Got {expand_to_n}."
if self._mock:
return [query for _ in range(expand_to_n)]
query_expansion_template = QueryExpansionTemplate()
prompt = query_expansion_template.create_template(expand_to_n - 1)
model = ChatOllama(model=settings.LLAMA_MODEL_ID, temperature=0)
chain = prompt | model
response = chain.invoke({"question": query})
result = response.content
queries_content = result.strip().split(query_expansion_template.separator)
queries = [query]
queries += [
query.replace_content(stripped_content)
for content in queries_content
if (stripped_content := content.strip())
]
return queries
if __name__ == "__main__":
query = Query.from_str("Write an article about the best types of advanced RAG methods.")
query_expander = QueryExpansion()
expanded_queries = query_expander.generate(query, expand_to_n=3)
for expanded_query in expanded_queries:
logger.info(expanded_query.content)