File size: 716 Bytes
9c21048
 
74c90c3
9c21048
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from langchain.chains import LLMChain
from langchain_groq import ChatGroq
import os
from prompts import contextualize_prompt  # A new prompt for contextualizing the query

class ContextualizeChain(LLMChain):
    def format_query(self, user_query: str, chat_history: list) -> str:
        # Use invoke() to format the query based on chat history
        return self.invoke({"input": user_query, "chat_history": chat_history})

def get_contextualize_chain():
    chat_groq_model = ChatGroq(model="Gemma2-9b-It", groq_api_key=os.environ["GROQ_API_KEY"])
    chain = ContextualizeChain(
        llm=chat_groq_model,
        prompt=contextualize_prompt  # A prompt for contextualizing the question
    )
    return chain