File size: 1,342 Bytes
7fdb8e9
 
cc3f1e1
7fdb8e9
 
 
 
 
 
 
 
cc3f1e1
7fdb8e9
 
cc3f1e1
 
 
7fdb8e9
 
 
 
 
 
 
 
cc3f1e1
 
7fdb8e9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
from typing import Any
from openai import OpenAI

from rag_demo.rag.base.query import Query
from rag_demo.rag.base.template_factory import RAGStep
from rag_demo.rag.prompt_templates import QueryExpansionTemplate


class QueryExpansion(RAGStep):
    def generate(self, query: Query, expand_to_n: int) -> Any:
        api = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        query_expansion_template = QueryExpansionTemplate()
        prompt = query_expansion_template.create_template(expand_to_n - 1)
        response = api.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {
                    "role": "user",
                    "content": prompt.template.format(
                        question=query.content,
                        expand_to_n=expand_to_n,
                        separator=query_expansion_template.separator,
                    ),
                }
            ],
            max_tokens=8192,
        )
        result = response.choices[0].message.content
        queries_content = result.split(query_expansion_template.separator)
        queries = [query]
        queries += [
            query.replace_content(stripped_content)
            for content in queries_content
            if (stripped_content := content.strip())
        ]
        return queries