Chatbot2 / classification_chain.py
Phoenix21's picture
Update classification_chain.py
47b1df8 verified
# classification_chain.py
import os
from langchain.chains import LLMChain
from langchain_groq import ChatGroq
from prompts import classification_prompt
# classification_chain.py
def get_classification_chain() -> LLMChain:
"""
Builds the classification chain (LLMChain) using ChatGroq and the classification prompt.
"""
# Initialize the ChatGroq model (Gemma2-9b-It) with your GROQ_API_KEY
chat_groq_model = ChatGroq(
model="Gemma2-9b-It",
groq_api_key=os.environ["GROQ_API_KEY"] # must be set in environment
)
# Build an LLMChain
classification_chain = LLMChain(
llm=chat_groq_model,
prompt=classification_prompt
)
return classification_chain
def classify_with_history(query: str, chat_history: list) -> str:
"""
Classifies a user query based on the context of previous conversation (chat_history).
"""
# Add the history into the query context if needed (depending on the type of model)
context = "\n".join([f"User: {msg['content']}" for msg in chat_history]) + "\nUser: " + query
# Update the prompt with both the context and the query
classification_result = get_classification_chain().run({"query": context})
return classification_result