import os | |
from typing import Any | |
from openai import OpenAI | |
from rag_demo.rag.base.query import Query | |
from rag_demo.rag.base.template_factory import RAGStep | |
from rag_demo.rag.prompt_templates import QueryExpansionTemplate | |
class QueryExpansion(RAGStep): | |
def generate(self, query: Query, expand_to_n: int) -> Any: | |
api = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
query_expansion_template = QueryExpansionTemplate() | |
prompt = query_expansion_template.create_template(expand_to_n - 1) | |
response = api.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt.template.format( | |
question=query.content, | |
expand_to_n=expand_to_n, | |
separator=query_expansion_template.separator, | |
), | |
} | |
], | |
max_tokens=8192, | |
) | |
result = response.choices[0].message.content | |
queries_content = result.split(query_expansion_template.separator) | |
queries = [query] | |
queries += [ | |
query.replace_content(stripped_content) | |
for content in queries_content | |
if (stripped_content := content.strip()) | |
] | |
return queries | |