File size: 1,971 Bytes
1ef298a
0f046d0
1ef298a
0f046d0
c1a68be
0f046d0
 
 
 
c1a68be
c3c7abe
 
 
0f046d0
c3c7abe
 
 
0f046d0
c3c7abe
0f046d0
c3c7abe
 
 
 
 
 
 
0f046d0
c3c7abe
 
0f046d0
c3c7abe
 
0f046d0
 
 
 
 
 
 
 
 
 
 
 
1ef298a
 
0f046d0
 
 
 
 
 
 
1ef298a
0f046d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import List
import json

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from backend.app.vectorstore import get_vector_db

SYSTEM_ROLE_PROMPT = """
    You are a helpful assistant that generates questions based on a given context.
"""

USER_ROLE_PROMPT = """
    Based on the following context about {query}, generate 5 relevant and specific questions.
    Make sure the questions can be answered using only the provided context.

    Context: {context}

    Generate 5 questions that test understanding of the material in the context.
    
    Return only a json object with the following format:
    {{
        "questions": ["question1", "question2", "question3", "question4", "question5"]
    }}
"""

class ProblemGenerationPipeline:
    def __init__(self):
        self.chat_prompt = ChatPromptTemplate.from_messages([
            ("system", SYSTEM_ROLE_PROMPT),
            ("user", USER_ROLE_PROMPT)
        ])
        
        self.llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
        self.retriever = get_vector_db().as_retriever(search_kwargs={"k": 2})
        
        self.rag_chain = (
            {"context": self.retriever, "query": RunnablePassthrough()}
            | self.chat_prompt
            | self.llm
            | StrOutputParser()
        )

    def generate_problems(self, query: str) -> List[str]:
        """
        Generate problems based on the user's query using RAG.
        
        Args:
            query (str): The topic to generate questions about
            
        Returns:
            List[str]: A list of generated questions
        """
        raw_result = self.rag_chain.invoke(query)
        result = json.loads(raw_result)
        return result["questions"]