coderpotter's picture
Upload folder using huggingface_hub
7b2e5db verified
from langchain_core.messages import HumanMessage
from langgraph.graph import END, START, StateGraph
from research_assistant.app_logging import app_logger
from research_assistant.components.agent import Agent
from research_assistant.components.agent_tools import get_arxiv_tool, get_qa_tool
from research_assistant.components.pdfParser import pdf_parser
from research_assistant.components.planner import get_planner
from research_assistant.components.solver import get_solver
from research_assistant.components.state import ResearchSummary
from research_assistant.config.configuration import ConfigurationManager
from research_assistant.utils.state_utils import SummaryStateUtils
class ArticleSummarization:
def __init__(self, file_path):
self.article_path = file_path
self.config = ConfigurationManager()
self.summary_utils = SummaryStateUtils()
# This function gives us the model name being requested for any component in the workflow.
def get_model(self, component: str):
if component == "planner":
config = self.config.get_planner_config()
elif component == "qa_tool":
config = self.config.get_qa_tool_config()
elif component == "solver":
config = self.config.get_solver_config()
else:
raise ValueError("Invalid component name for getting the Model")
agent = Agent(config.model_name)
return agent.get_model()
# This function generates the plan for the given task using planner tool. This is attached to the planner node.
def get_plan(self, state: ResearchSummary):
response = get_planner(llm=self.get_model("planner")).invoke(
{"article_text": state["article_text"]}
)
if len(response.tools) != len(response.arguments):
raise ValueError("The Plan string is not parsed properly")
app_logger.info(f"The plan produced is: {response.plan_str}")
return {
"plan_string": response.plan_str,
"dependencies": response.dependencies,
"arguments": response.arguments,
"tools": response.tools,
}
# This function executes the tools of the plan. This is attached to the tool execution node.
def tool_execution(self, state: ResearchSummary):
"""Worker node that executes the tools of a given plan."""
current_step = self.summary_utils.get_current_task(state)
arg, tools = state["arguments"], state["tools"]
results_dict = (state["results"] or {}) if "results" in state else {}
# Tool calling for each step.
if tools[current_step - 1] == "Arxiv":
result = get_arxiv_tool().run(arg[current_step - 1])
elif tools[current_step - 1] == "LLM":
result = get_qa_tool(llm=self.get_model("qa_tool")).invoke(
{
"question": arg[current_step - 1],
"context": self.summary_utils.get_current_dependencies(
state, current_step
),
}
)
else:
raise ValueError
# Store the result in the results dictionary with the step number as key.
results_dict[current_step] = str(result)
return {"results": results_dict}
# This function generates the final answer using the results obtained from tool executions. This is attached to the solve node.
def solve(self, state: ResearchSummary):
return {
"result": get_solver(llm=self.get_model("solver"))
.invoke(self.summary_utils.get_plan_results(state))
.answer
}
# This function builds the execution graph for the article summarization workflow.
def get_graph(self):
graph = StateGraph(ResearchSummary)
graph.add_node("plan", self.get_plan)
graph.add_node("tool", self.tool_execution)
graph.add_node("solve", self.solve)
graph.add_edge("plan", "tool")
graph.add_edge("solve", END)
graph.add_conditional_edges("tool", self.summary_utils.route)
graph.add_edge(START, "plan")
return graph.compile()
# This function builds the execution graph for the summarization task workflow.
def get_summary(self):
app = self.get_graph()
for s in app.stream({"article_text": pdf_parser(self.article_path)}):
final_output = s
return final_output["solve"]["result"]