Spaces:
Sleeping
Sleeping
from langchain_core.prompts.chat import ChatPromptTemplate | |
from langchain_ollama import ChatOllama | |
from langchain_core.output_parsers import StrOutputParser | |
def create_query_rewriter(llm): | |
""" | |
Create a query rewriter to optimize retrieval. | |
Returns: | |
Callable: Query rewriter function | |
""" | |
# Prompt for query rewriting | |
system = """You are a question re-writer that converts an input question to a better version that is optimized | |
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.""" | |
re_write_prompt = ChatPromptTemplate.from_messages([ | |
("system", system), | |
("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."), | |
]) | |
# Create query rewriter chain | |
return re_write_prompt | llm | StrOutputParser() | |
def rewrite_query(question: str, llm): | |
""" | |
Rewrite a given query to optimize retrieval. | |
Args: | |
question (str): Original user question | |
Returns: | |
str: Rewritten query | |
""" | |
query_rewriter = create_query_rewriter(llm) | |
try: | |
rewritten_query = query_rewriter.invoke({"question": question}) | |
return rewritten_query | |
except Exception as e: | |
print(f"Query rewriting error: {e}") | |
return question | |
if __name__ == "__main__": | |
# Example usage | |
test_queries = [ | |
"Tell me about AI agents", | |
"What do we know about memory in AI systems?", | |
"Bears draft strategy" | |
] | |
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5) | |
for query in test_queries: | |
rewritten = rewrite_query(query, llm) | |
print(f"Original: {query}") | |
print(f"Rewritten: {rewritten}\n") |