File size: 1,342 Bytes
7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import os
from typing import Any
from openai import OpenAI
from rag_demo.rag.base.query import Query
from rag_demo.rag.base.template_factory import RAGStep
from rag_demo.rag.prompt_templates import QueryExpansionTemplate
class QueryExpansion(RAGStep):
def generate(self, query: Query, expand_to_n: int) -> Any:
api = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
query_expansion_template = QueryExpansionTemplate()
prompt = query_expansion_template.create_template(expand_to_n - 1)
response = api.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": prompt.template.format(
question=query.content,
expand_to_n=expand_to_n,
separator=query_expansion_template.separator,
),
}
],
max_tokens=8192,
)
result = response.choices[0].message.content
queries_content = result.split(query_expansion_template.separator)
queries = [query]
queries += [
query.replace_content(stripped_content)
for content in queries_content
if (stripped_content := content.strip())
]
return queries
|