Spaces:
Runtime error
Runtime error
from langchain_core.messages import HumanMessage | |
from langgraph.graph import END, START, StateGraph | |
from research_assistant.app_logging import app_logger | |
from research_assistant.components.agent import Agent | |
from research_assistant.components.agent_tools import get_arxiv_tool, get_qa_tool | |
from research_assistant.components.pdfParser import pdf_parser | |
from research_assistant.components.planner import get_planner | |
from research_assistant.components.solver import get_solver | |
from research_assistant.components.state import ResearchSummary | |
from research_assistant.config.configuration import ConfigurationManager | |
from research_assistant.utils.state_utils import SummaryStateUtils | |
class ArticleSummarization: | |
def __init__(self, file_path): | |
self.article_path = file_path | |
self.config = ConfigurationManager() | |
self.summary_utils = SummaryStateUtils() | |
# This function gives us the model name being requested for any component in the workflow. | |
def get_model(self, component: str): | |
if component == "planner": | |
config = self.config.get_planner_config() | |
elif component == "qa_tool": | |
config = self.config.get_qa_tool_config() | |
elif component == "solver": | |
config = self.config.get_solver_config() | |
else: | |
raise ValueError("Invalid component name for getting the Model") | |
agent = Agent(config.model_name) | |
return agent.get_model() | |
# This function generates the plan for the given task using planner tool. This is attached to the planner node. | |
def get_plan(self, state: ResearchSummary): | |
response = get_planner(llm=self.get_model("planner")).invoke( | |
{"article_text": state["article_text"]} | |
) | |
if len(response.tools) != len(response.arguments): | |
raise ValueError("The Plan string is not parsed properly") | |
app_logger.info(f"The plan produced is: {response.plan_str}") | |
return { | |
"plan_string": response.plan_str, | |
"dependencies": response.dependencies, | |
"arguments": response.arguments, | |
"tools": response.tools, | |
} | |
# This function executes the tools of the plan. This is attached to the tool execution node. | |
def tool_execution(self, state: ResearchSummary): | |
"""Worker node that executes the tools of a given plan.""" | |
current_step = self.summary_utils.get_current_task(state) | |
arg, tools = state["arguments"], state["tools"] | |
results_dict = (state["results"] or {}) if "results" in state else {} | |
# Tool calling for each step. | |
if tools[current_step - 1] == "Arxiv": | |
result = get_arxiv_tool().run(arg[current_step - 1]) | |
elif tools[current_step - 1] == "LLM": | |
result = get_qa_tool(llm=self.get_model("qa_tool")).invoke( | |
{ | |
"question": arg[current_step - 1], | |
"context": self.summary_utils.get_current_dependencies( | |
state, current_step | |
), | |
} | |
) | |
else: | |
raise ValueError | |
# Store the result in the results dictionary with the step number as key. | |
results_dict[current_step] = str(result) | |
return {"results": results_dict} | |
# This function generates the final answer using the results obtained from tool executions. This is attached to the solve node. | |
def solve(self, state: ResearchSummary): | |
return { | |
"result": get_solver(llm=self.get_model("solver")) | |
.invoke(self.summary_utils.get_plan_results(state)) | |
.answer | |
} | |
# This function builds the execution graph for the article summarization workflow. | |
def get_graph(self): | |
graph = StateGraph(ResearchSummary) | |
graph.add_node("plan", self.get_plan) | |
graph.add_node("tool", self.tool_execution) | |
graph.add_node("solve", self.solve) | |
graph.add_edge("plan", "tool") | |
graph.add_edge("solve", END) | |
graph.add_conditional_edges("tool", self.summary_utils.route) | |
graph.add_edge(START, "plan") | |
return graph.compile() | |
# This function builds the execution graph for the summarization task workflow. | |
def get_summary(self): | |
app = self.get_graph() | |
for s in app.stream({"article_text": pdf_parser(self.article_path)}): | |
final_output = s | |
return final_output["solve"]["result"] | |