# classification_chain.py | |
import os | |
from langchain.chains import LLMChain | |
from langchain_groq import ChatGroq | |
from prompts import classification_prompt | |
# classification_chain.py | |
def get_classification_chain() -> LLMChain: | |
""" | |
Builds the classification chain (LLMChain) using ChatGroq and the classification prompt. | |
""" | |
# Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY | |
chat_groq_model = ChatGroq( | |
model="Gemma2-9b-It", | |
groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment | |
) | |
# Build an LLMChain | |
classification_chain = LLMChain( | |
llm=chat_groq_model, | |
prompt=classification_prompt | |
) | |
return classification_chain | |
def classify_with_history(query: str, chat_history: list) -> str: | |
""" | |
Classifies a user query based on the context of previous conversation (chat_history). | |
""" | |
# Add the history into the query context if needed (depending on the type of model) | |
context = "\n".join([f"User: {msg['content']}" for msg in chat_history]) + "\nUser: " + query | |
# Update the prompt with both the context and the query | |
classification_result = get_classification_chain().run({"query": context}) | |
return classification_result | |