Update scripts/router_chain.py
Browse files- scripts/router_chain.py +39 -10
scripts/router_chain.py
CHANGED
@@ -44,23 +44,52 @@ User request: {input}
|
|
44 |
# chain = prompt | llm | StrOutputParser()
|
45 |
# return {"result": chain.invoke({"input": input_dict["input"]})}
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
elif category == "summarize":
|
48 |
-
# 1
|
49 |
rag_result = general_qa({"query": input_dict["input"]})
|
50 |
|
51 |
-
# 2
|
52 |
-
source_docs = rag_result.get("source_documents", [])
|
53 |
-
combined_text = "\n\n".join([doc.page_content for doc in source_docs])
|
54 |
|
55 |
-
# 3
|
|
|
56 |
from scripts.summarizer import get_summarizer
|
|
|
57 |
summarizer_chain = get_summarizer()
|
58 |
-
summary = summarizer_chain.run(combined_text)
|
59 |
|
60 |
-
#
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
return {"result": summary}
|
66 |
|
|
|
44 |
# chain = prompt | llm | StrOutputParser()
|
45 |
# return {"result": chain.invoke({"input": input_dict["input"]})}
|
46 |
|
47 |
+
#elif category == "summarize":
|
48 |
+
# # 1. Use RAG to retrieve relevant docs
|
49 |
+
# rag_result = general_qa({"query": input_dict["input"]})
|
50 |
+
|
51 |
+
# # 2. Extract docs and prepare text
|
52 |
+
# source_docs = rag_result.get("source_documents", [])
|
53 |
+
# combined_text = "\n\n".join([doc.page_content for doc in source_docs])
|
54 |
+
|
55 |
+
# # 3. Run the summarizer chain on the retrieved text
|
56 |
+
# from scripts.summarizer import get_summarizer
|
57 |
+
# summarizer_chain = get_summarizer()
|
58 |
+
# summary = summarizer_chain.run(combined_text)
|
59 |
+
|
60 |
+
# # 4. Add sources if any
|
61 |
+
# sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs})
|
62 |
+
# if sources:
|
63 |
+
# summary += f"\n\n📚 Sources: {', '.join(sources)}"
|
64 |
+
|
65 |
+
# return {"result": summary}
|
66 |
+
|
67 |
+
|
68 |
elif category == "summarize":
|
69 |
+
# 1) Retrieve relevant documents via your existing RAG chain
|
70 |
rag_result = general_qa({"query": input_dict["input"]})
|
71 |
|
72 |
+
# 2) Get the retrieved docs (already LangChain Document objects)
|
73 |
+
source_docs = rag_result.get("source_documents", []) or []
|
|
|
74 |
|
75 |
+
# 3) Build the summarizer and prepare the docs list
|
76 |
+
from langchain.docstore.document import Document
|
77 |
from scripts.summarizer import get_summarizer
|
78 |
+
|
79 |
summarizer_chain = get_summarizer()
|
|
|
80 |
|
81 |
+
# If retrieval returned nothing, fall back to summarizing the user’s text
|
82 |
+
docs = source_docs if source_docs else [Document(page_content=input_dict["input"])]
|
83 |
+
|
84 |
+
# 4) Summarize — load_summarize_chain returns {"output_text": "..."}
|
85 |
+
out = summarizer_chain.invoke(docs)
|
86 |
+
summary = out["output_text"] if isinstance(out, dict) and "output_text" in out else str(out)
|
87 |
+
|
88 |
+
# 5) Append sources (only if we actually had retrieved docs)
|
89 |
+
if source_docs:
|
90 |
+
sources = sorted({str(d.metadata.get("source", "unknown")) for d in source_docs})
|
91 |
+
if sources:
|
92 |
+
summary += f"\n\n📚 Sources: {', '.join(sources)}"
|
93 |
|
94 |
return {"result": summary}
|
95 |
|