MaryamKarimi080 commited on
Commit
faf34fe
·
verified ·
1 Parent(s): 8a603a3

Update scripts/router_chain.py

Browse files
Files changed (1) hide show
  1. scripts/router_chain.py +38 -13
scripts/router_chain.py CHANGED
@@ -1,32 +1,57 @@
1
  from typing import Dict, Any
2
  from langchain_openai import ChatOpenAI
3
  from langchain.prompts import ChatPromptTemplate
 
4
  from scripts.rag_chat import build_general_qa_chain
5
 
6
  def build_router_chain(model_name=None):
7
  general_qa = build_general_qa_chain(model_name=model_name)
8
  llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class Router:
11
  def invoke(self, input_dict: Dict[str, Any]):
12
- text = input_dict.get("input", "").lower()
13
- if "code" in text or "program" in text or "debug" in text:
 
 
 
14
  prompt = ChatPromptTemplate.from_template(
15
  "As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:"
16
  )
17
- chain = prompt | llm
18
- return {"result": chain.invoke({"input": input_dict["input"]}).content}
19
- elif "summarize" in text or "summary" in text:
 
20
  prompt = ChatPromptTemplate.from_template(
21
  "Provide a concise summary about: {input}\nSummary:"
22
  )
23
- chain = prompt | llm
24
- return {"result": chain.invoke({"input": input_dict["input"]}).content}
25
- elif "calculate" in text or any(char.isdigit() for char in text):
26
- return {"result": "For calculations, please ask a specific calculation or provide more context."}
27
- else:
28
- # Use RAG chain
29
- result = general_qa({"query": input_dict["input"]})
30
- return result
 
 
 
 
31
 
32
  return Router()
 
1
  from typing import Dict, Any
2
  from langchain_openai import ChatOpenAI
3
  from langchain.prompts import ChatPromptTemplate
4
+ from langchain.schema import StrOutputParser
5
  from scripts.rag_chat import build_general_qa_chain
6
 
7
  def build_router_chain(model_name=None):
8
  general_qa = build_general_qa_chain(model_name=model_name)
9
  llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0)
10
 
11
+ # This prompt asks the LLM to choose which "mode" to use
12
+ router_prompt = ChatPromptTemplate.from_template("""
13
+ You are a routing assistant for a chatbot.
14
+ Classify the following user request into one of these categories:
15
+ - "code" for programming or debugging
16
+ - "summarize" for summary requests
17
+ - "calculate" for math or numeric calculations
18
+ - "general" for general Q&A using course files
19
+
20
+ Return ONLY the category word.
21
+
22
+ User request: {input}
23
+ """)
24
+
25
+ router_chain = router_prompt | llm | StrOutputParser()
26
+
27
  class Router:
28
  def invoke(self, input_dict: Dict[str, Any]):
29
+ category = router_chain.invoke({"input": input_dict["input"]}).strip().lower()
30
+
31
+ print(f"[ROUTER] User query routed to category: {category}")
32
+
33
+ if category == "code":
34
  prompt = ChatPromptTemplate.from_template(
35
  "As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:"
36
  )
37
+ chain = prompt | llm | StrOutputParser()
38
+ return {"result": chain.invoke({"input": input_dict["input"]})}
39
+
40
+ elif category == "summarize":
41
  prompt = ChatPromptTemplate.from_template(
42
  "Provide a concise summary about: {input}\nSummary:"
43
  )
44
+ chain = prompt | llm | StrOutputParser()
45
+ return {"result": chain.invoke({"input": input_dict["input"]})}
46
+
47
+ elif category == "calculate":
48
+ prompt = ChatPromptTemplate.from_template(
49
+ "Solve the following calculation step-by-step:\n{input}"
50
+ )
51
+ chain = prompt | llm | StrOutputParser()
52
+ return {"result": chain.invoke({"input": input_dict["input"]})}
53
+
54
+ else: # "general"
55
+ return general_qa({"query": input_dict["input"]})
56
 
57
  return Router()