Phoenix21 commited on
Commit
5067009
·
verified ·
1 Parent(s): d25ef9b

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +9 -17
pipeline.py CHANGED
@@ -5,8 +5,14 @@ import getpass
5
  import pandas as pd
6
  from typing import Optional, Dict, Any
7
 
8
- # Correct import for Runnable
9
- from langchain.schema import Runnable
 
 
 
 
 
 
10
 
11
  from langchain.docstore.document import Document
12
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -16,7 +22,6 @@ from langchain.chains import RetrievalQA
16
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
17
  import litellm
18
 
19
- # Classification/Refusal/Tailor/Cleaner
20
  from classification_chain import get_classification_chain
21
  from refusal_chain import get_refusal_chain
22
  from tailor_chain import get_tailor_chain
@@ -129,15 +134,10 @@ def do_web_search(query: str) -> str:
129
  # 6) Orchestrator function: returns a dict => {"answer": "..."}
130
  ###############################################################################
131
  def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
132
- """
133
- Called by the Runnable.
134
- inputs: { "input": <user_query>, "chat_history": <list of messages> (optional) }
135
- Output: { "answer": <final string> }
136
- """
137
  user_query = inputs["input"]
138
  chat_history = inputs.get("chat_history", [])
139
 
140
- # 1) Classification
141
  class_result = classification_chain.invoke({"query": user_query})
142
  classification = class_result.get("text", "").strip()
143
 
@@ -157,7 +157,6 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
157
  web_answer = do_web_search(user_query)
158
  else:
159
  web_answer = ""
160
-
161
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
162
  final_answer = tailor_chain.run({"response": final_merged}).strip()
163
  return {"answer": final_answer}
@@ -169,7 +168,6 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
169
  final_answer = tailor_chain.run({"response": final_merged}).strip()
170
  return {"answer": final_answer}
171
 
172
- # fallback
173
  refusal_text = refusal_chain.run({})
174
  final_refusal = tailor_chain.run({"response": refusal_text}).strip()
175
  return {"answer": final_refusal}
@@ -177,14 +175,8 @@ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
177
  ###############################################################################
178
  # 7) Build a "Runnable" wrapper so .with_listeners() works
179
  ###############################################################################
180
-
181
  class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
182
- """
183
- Wraps run_with_chain_context(...) in a Runnable
184
- so that RunnableWithMessageHistory can attach listeners.
185
- """
186
  def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
187
  return run_with_chain_context(input)
188
 
189
- # Export an instance of PipelineRunnable for use in my_memory_logic.py
190
  pipeline_runnable = PipelineRunnable()
 
5
  import pandas as pd
6
  from typing import Optional, Dict, Any
7
 
8
+ # Conditional import for Runnable from available locations
9
+ try:
10
+ from langchain_core.runnables.base import Runnable
11
+ except ImportError:
12
+ try:
13
+ from langchain.runnables.base import Runnable
14
+ except ImportError:
15
+ raise ImportError("Cannot find Runnable class. Please upgrade LangChain or check your installation.")
16
 
17
  from langchain.docstore.document import Document
18
  from langchain.embeddings import HuggingFaceEmbeddings
 
22
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
23
  import litellm
24
 
 
25
  from classification_chain import get_classification_chain
26
  from refusal_chain import get_refusal_chain
27
  from tailor_chain import get_tailor_chain
 
134
  # 6) Orchestrator function: returns a dict => {"answer": "..."}
135
  ###############################################################################
136
  def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
 
 
 
 
 
137
  user_query = inputs["input"]
138
  chat_history = inputs.get("chat_history", [])
139
 
140
+ # Classification step
141
  class_result = classification_chain.invoke({"query": user_query})
142
  classification = class_result.get("text", "").strip()
143
 
 
157
  web_answer = do_web_search(user_query)
158
  else:
159
  web_answer = ""
 
160
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
161
  final_answer = tailor_chain.run({"response": final_merged}).strip()
162
  return {"answer": final_answer}
 
168
  final_answer = tailor_chain.run({"response": final_merged}).strip()
169
  return {"answer": final_answer}
170
 
 
171
  refusal_text = refusal_chain.run({})
172
  final_refusal = tailor_chain.run({"response": refusal_text}).strip()
173
  return {"answer": final_refusal}
 
175
  ###############################################################################
176
  # 7) Build a "Runnable" wrapper so .with_listeners() works
177
  ###############################################################################
 
178
  class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
 
 
 
 
179
  def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
180
  return run_with_chain_context(input)
181
 
 
182
  pipeline_runnable = PipelineRunnable()