File size: 3,085 Bytes
4d17f84 1ef298a 0f046d0 1ef298a 0f046d0 c1a68be 0f046d0 c1a68be 291d559 c3c7abe 0f046d0 c3c7abe 0f046d0 c3c7abe 0f046d0 c3c7abe 0f046d0 291d559 c3c7abe 4d17f84 999f24c 291d559 999f24c 4d17f84 999f24c 4d17f84 999f24c 4d17f84 1ef298a 0f046d0 999f24c 0f046d0 999f24c 0f046d0 1ef298a 0f046d0 4d17f84 999f24c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from operator import itemgetter
from typing import List
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from backend.app.vectorstore import get_vector_db
MODEL = "gpt-3.5-turbo"
SYSTEM_ROLE_PROMPT = """
You are a helpful assistant that generates questions based on a given context.
"""
USER_ROLE_PROMPT = """
Based on the following context about {query}, generate 5 relevant and specific questions.
Make sure the questions can be answered using only the provided context.
Context: {context}
Generate 5 questions that test understanding of the material in the context.
Return only a json object with the following format:
{{
"questions": ["question1", "question2", "question3", "question4", "question5"]
}}
"""
class ProblemGenerationPipeline:
def __init__(self, return_context: bool = False, embedding_model_id: str = None):
self.chat_prompt = ChatPromptTemplate.from_messages(
[("system", SYSTEM_ROLE_PROMPT), ("user", USER_ROLE_PROMPT)]
)
self.llm = ChatOpenAI(model=MODEL, temperature=0.7)
self.retriever = get_vector_db(embedding_model_id).as_retriever(
search_kwargs={"k": 2}
)
# TODO: This is a hack to get the context for the questions. Very messy interface.
self.return_context = return_context
if not return_context:
self.rag_chain = (
{"context": self.retriever, "query": RunnablePassthrough()}
| self.chat_prompt
| self.llm
| StrOutputParser()
)
else:
# response looks like: {response: str, context: List[Document]}
self.rag_chain = (
{
"context": itemgetter("query") | self.retriever,
"query": itemgetter("query"),
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {
"response": self.chat_prompt | self.llm | StrOutputParser(),
"context": itemgetter("context"),
}
)
def generate_problems(self, query: str, debug: bool = False) -> List[str]:
"""
Generate problems based on the user's query using RAG.
Args:
query (str): The topic to generate questions about
Returns:
List[str]: A list of generated questions
"""
raw_result = self.rag_chain.invoke(query)
if debug:
print(raw_result)
# raw_result is a dict with keys "response" and "context" when return_context is True
if self.return_context:
return raw_result
# raw_result is a string when return_context is False
else:
return json.loads(raw_result)["questions"]
|