Update pipeline.py
Browse files- pipeline.py +9 -17
pipeline.py
CHANGED
@@ -5,8 +5,14 @@ import getpass
|
|
5 |
import pandas as pd
|
6 |
from typing import Optional, Dict, Any
|
7 |
|
8 |
-
#
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
from langchain.docstore.document import Document
|
12 |
from langchain.embeddings import HuggingFaceEmbeddings
|
@@ -16,7 +22,6 @@ from langchain.chains import RetrievalQA
|
|
16 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
17 |
import litellm
|
18 |
|
19 |
-
# Classification/Refusal/Tailor/Cleaner
|
20 |
from classification_chain import get_classification_chain
|
21 |
from refusal_chain import get_refusal_chain
|
22 |
from tailor_chain import get_tailor_chain
|
@@ -129,15 +134,10 @@ def do_web_search(query: str) -> str:
|
|
129 |
# 6) Orchestrator function: returns a dict => {"answer": "..."}
|
130 |
###############################################################################
|
131 |
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
132 |
-
"""
|
133 |
-
Called by the Runnable.
|
134 |
-
inputs: { "input": <user_query>, "chat_history": <list of messages> (optional) }
|
135 |
-
Output: { "answer": <final string> }
|
136 |
-
"""
|
137 |
user_query = inputs["input"]
|
138 |
chat_history = inputs.get("chat_history", [])
|
139 |
|
140 |
-
#
|
141 |
class_result = classification_chain.invoke({"query": user_query})
|
142 |
classification = class_result.get("text", "").strip()
|
143 |
|
@@ -157,7 +157,6 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
157 |
web_answer = do_web_search(user_query)
|
158 |
else:
|
159 |
web_answer = ""
|
160 |
-
|
161 |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
|
162 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
163 |
return {"answer": final_answer}
|
@@ -169,7 +168,6 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
169 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
170 |
return {"answer": final_answer}
|
171 |
|
172 |
-
# fallback
|
173 |
refusal_text = refusal_chain.run({})
|
174 |
final_refusal = tailor_chain.run({"response": refusal_text}).strip()
|
175 |
return {"answer": final_refusal}
|
@@ -177,14 +175,8 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
177 |
###############################################################################
|
178 |
# 7) Build a "Runnable" wrapper so .with_listeners() works
|
179 |
###############################################################################
|
180 |
-
|
181 |
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
|
182 |
-
"""
|
183 |
-
Wraps run_with_chain_context(...) in a Runnable
|
184 |
-
so that RunnableWithMessageHistory can attach listeners.
|
185 |
-
"""
|
186 |
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
|
187 |
return run_with_chain_context(input)
|
188 |
|
189 |
-
# Export an instance of PipelineRunnable for use in my_memory_logic.py
|
190 |
pipeline_runnable = PipelineRunnable()
|
|
|
5 |
import pandas as pd
|
6 |
from typing import Optional, Dict, Any
|
7 |
|
8 |
+
# Conditional import for Runnable from available locations
|
9 |
+
try:
|
10 |
+
from langchain_core.runnables.base import Runnable
|
11 |
+
except ImportError:
|
12 |
+
try:
|
13 |
+
from langchain.runnables.base import Runnable
|
14 |
+
except ImportError:
|
15 |
+
raise ImportError("Cannot find Runnable class. Please upgrade LangChain or check your installation.")
|
16 |
|
17 |
from langchain.docstore.document import Document
|
18 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
22 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
23 |
import litellm
|
24 |
|
|
|
25 |
from classification_chain import get_classification_chain
|
26 |
from refusal_chain import get_refusal_chain
|
27 |
from tailor_chain import get_tailor_chain
|
|
|
134 |
# 6) Orchestrator function: returns a dict => {"answer": "..."}
|
135 |
###############################################################################
|
136 |
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
137 |
user_query = inputs["input"]
|
138 |
chat_history = inputs.get("chat_history", [])
|
139 |
|
140 |
+
# Classification step
|
141 |
class_result = classification_chain.invoke({"query": user_query})
|
142 |
classification = class_result.get("text", "").strip()
|
143 |
|
|
|
157 |
web_answer = do_web_search(user_query)
|
158 |
else:
|
159 |
web_answer = ""
|
|
|
160 |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
|
161 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
162 |
return {"answer": final_answer}
|
|
|
168 |
final_answer = tailor_chain.run({"response": final_merged}).strip()
|
169 |
return {"answer": final_answer}
|
170 |
|
|
|
171 |
refusal_text = refusal_chain.run({})
|
172 |
final_refusal = tailor_chain.run({"response": refusal_text}).strip()
|
173 |
return {"answer": final_refusal}
|
|
|
175 |
###############################################################################
|
176 |
# 7) Build a "Runnable" wrapper so .with_listeners() works
|
177 |
###############################################################################
|
|
|
178 |
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
|
|
|
|
|
|
|
|
|
179 |
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
|
180 |
return run_with_chain_context(input)
|
181 |
|
|
|
182 |
pipeline_runnable = PipelineRunnable()
|