MaryamKarimi080 commited on
Commit
3310b03
·
verified ·
1 Parent(s): 83380dd

Update scripts/router_chain.py

Browse files
Files changed (1) hide show
  1. scripts/router_chain.py +39 -10
scripts/router_chain.py CHANGED
@@ -44,23 +44,52 @@ User request: {input}
44
  # chain = prompt | llm | StrOutputParser()
45
  # return {"result": chain.invoke({"input": input_dict["input"]})}
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  elif category == "summarize":
48
- # 1. Use RAG to retrieve relevant docs
49
  rag_result = general_qa({"query": input_dict["input"]})
50
 
51
- # 2. Extract docs and prepare text
52
- source_docs = rag_result.get("source_documents", [])
53
- combined_text = "\n\n".join([doc.page_content for doc in source_docs])
54
 
55
- # 3. Run the summarizer chain on the retrieved text
 
56
  from scripts.summarizer import get_summarizer
 
57
  summarizer_chain = get_summarizer()
58
- summary = summarizer_chain.run(combined_text)
59
 
60
- # 4. Add sources if any
61
- sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs})
62
- if sources:
63
- summary += f"\n\n📚 Sources: {', '.join(sources)}"
 
 
 
 
 
 
 
 
64
 
65
  return {"result": summary}
66
 
 
44
  # chain = prompt | llm | StrOutputParser()
45
  # return {"result": chain.invoke({"input": input_dict["input"]})}
46
 
47
+ #elif category == "summarize":
48
+ # # 1. Use RAG to retrieve relevant docs
49
+ # rag_result = general_qa({"query": input_dict["input"]})
50
+
51
+ # # 2. Extract docs and prepare text
52
+ # source_docs = rag_result.get("source_documents", [])
53
+ # combined_text = "\n\n".join([doc.page_content for doc in source_docs])
54
+
55
+ # # 3. Run the summarizer chain on the retrieved text
56
+ # from scripts.summarizer import get_summarizer
57
+ # summarizer_chain = get_summarizer()
58
+ # summary = summarizer_chain.run(combined_text)
59
+
60
+ # # 4. Add sources if any
61
+ # sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs})
62
+ # if sources:
63
+ # summary += f"\n\n📚 Sources: {', '.join(sources)}"
64
+
65
+ # return {"result": summary}
66
+
67
+
68
  elif category == "summarize":
69
+ # 1) Retrieve relevant documents via your existing RAG chain
70
  rag_result = general_qa({"query": input_dict["input"]})
71
 
72
+ # 2) Get the retrieved docs (already LangChain Document objects)
73
+ source_docs = rag_result.get("source_documents", []) or []
 
74
 
75
+ # 3) Build the summarizer and prepare the docs list
76
+ from langchain.docstore.document import Document
77
  from scripts.summarizer import get_summarizer
78
+
79
  summarizer_chain = get_summarizer()
 
80
 
81
+ # If retrieval returned nothing, fall back to summarizing the user’s text
82
+ docs = source_docs if source_docs else [Document(page_content=input_dict["input"])]
83
+
84
+ # 4) Summarize load_summarize_chain returns {"output_text": "..."}
85
+ out = summarizer_chain.invoke(docs)
86
+ summary = out["output_text"] if isinstance(out, dict) and "output_text" in out else str(out)
87
+
88
+ # 5) Append sources (only if we actually had retrieved docs)
89
+ if source_docs:
90
+ sources = sorted({str(d.metadata.get("source", "unknown")) for d in source_docs})
91
+ if sources:
92
+ summary += f"\n\n📚 Sources: {', '.join(sources)}"
93
 
94
  return {"result": summary}
95