Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
import gradio as gr
|
2 |
from langgraph.graph import StateGraph, MessagesState, START, END
|
3 |
from langgraph.types import Command
|
4 |
-
from langchain_core.messages import
|
5 |
from langgraph.prebuilt import create_react_agent
|
6 |
from langchain_anthropic import ChatAnthropic
|
7 |
import os
|
8 |
-
|
|
|
|
|
9 |
|
|
|
|
|
10 |
|
11 |
-
#
|
12 |
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
|
13 |
|
14 |
def make_system_prompt(suffix: str) -> str:
|
@@ -22,35 +26,28 @@ def make_system_prompt(suffix: str) -> str:
|
|
22 |
f"\n{suffix}"
|
23 |
)
|
24 |
|
25 |
-
# Research agent and node
|
26 |
def research_node(state: MessagesState) -> Command[str]:
|
27 |
agent = create_react_agent(
|
28 |
llm,
|
29 |
-
tools=[],
|
30 |
state_modifier=make_system_prompt("You can only do research.")
|
31 |
)
|
32 |
result = agent.invoke(state)
|
33 |
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "chart_generator"
|
34 |
-
result["messages"][-1] = HumanMessage(
|
35 |
-
content=result["messages"][-1].content, name="researcher"
|
36 |
-
)
|
37 |
return Command(update={"messages": result["messages"]}, goto=goto)
|
38 |
|
39 |
-
# Chart generator agent and node
|
40 |
def chart_node(state: MessagesState) -> Command[str]:
|
41 |
agent = create_react_agent(
|
42 |
llm,
|
43 |
-
tools=[],
|
44 |
state_modifier=make_system_prompt("You can only generate charts.")
|
45 |
)
|
46 |
result = agent.invoke(state)
|
47 |
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "researcher"
|
48 |
-
result["messages"][-1] = HumanMessage(
|
49 |
-
content=result["messages"][-1].content, name="chart_generator"
|
50 |
-
)
|
51 |
return Command(update={"messages": result["messages"]}, goto=goto)
|
52 |
|
53 |
-
# Initialize the LangGraph workflow
|
54 |
workflow = StateGraph(MessagesState)
|
55 |
workflow.add_node("researcher", research_node)
|
56 |
workflow.add_node("chart_generator", chart_node)
|
@@ -59,7 +56,32 @@ workflow.add_edge("researcher", "chart_generator")
|
|
59 |
workflow.add_edge("chart_generator", END)
|
60 |
graph = workflow.compile()
|
61 |
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def run_langgraph(user_input):
|
64 |
events = graph.stream(
|
65 |
{"messages": [("user", user_input)]},
|
@@ -73,20 +95,27 @@ def run_langgraph(user_input):
|
|
73 |
|
74 |
return final_message or "No output generated"
|
75 |
|
76 |
-
|
77 |
-
# Create Gradio interface
|
78 |
def process_input(user_input):
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
interface = gr.Interface(
|
83 |
fn=process_input,
|
84 |
inputs="text",
|
85 |
-
outputs=
|
|
|
|
|
|
|
86 |
title="LangGraph Research Automation",
|
87 |
-
description="Enter your research task (e.g., 'Get GDP data for the USA over the past 5 years and create a chart.')
|
88 |
)
|
89 |
|
90 |
-
# Launch the Gradio interface
|
91 |
if __name__ == "__main__":
|
92 |
interface.launch()
|
|
|
|
1 |
import gradio as gr
|
2 |
from langgraph.graph import StateGraph, MessagesState, START, END
|
3 |
from langgraph.types import Command
|
4 |
+
from langchain_core.messages import HumanMessage
|
5 |
from langgraph.prebuilt import create_react_agent
|
6 |
from langchain_anthropic import ChatAnthropic
|
7 |
import os
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from io import BytesIO
|
10 |
+
import base64
|
11 |
|
12 |
+
# Load API Key
|
13 |
+
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
|
14 |
|
15 |
+
# LangGraph setup
|
16 |
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
|
17 |
|
18 |
def make_system_prompt(suffix: str) -> str:
|
|
|
26 |
f"\n{suffix}"
|
27 |
)
|
28 |
|
|
|
29 |
def research_node(state: MessagesState) -> Command[str]:
|
30 |
agent = create_react_agent(
|
31 |
llm,
|
32 |
+
tools=[],
|
33 |
state_modifier=make_system_prompt("You can only do research.")
|
34 |
)
|
35 |
result = agent.invoke(state)
|
36 |
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "chart_generator"
|
37 |
+
result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="researcher")
|
|
|
|
|
38 |
return Command(update={"messages": result["messages"]}, goto=goto)
|
39 |
|
|
|
40 |
def chart_node(state: MessagesState) -> Command[str]:
|
41 |
agent = create_react_agent(
|
42 |
llm,
|
43 |
+
tools=[],
|
44 |
state_modifier=make_system_prompt("You can only generate charts.")
|
45 |
)
|
46 |
result = agent.invoke(state)
|
47 |
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "researcher"
|
48 |
+
result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="chart_generator")
|
|
|
|
|
49 |
return Command(update={"messages": result["messages"]}, goto=goto)
|
50 |
|
|
|
51 |
workflow = StateGraph(MessagesState)
|
52 |
workflow.add_node("researcher", research_node)
|
53 |
workflow.add_node("chart_generator", chart_node)
|
|
|
56 |
workflow.add_edge("chart_generator", END)
|
57 |
graph = workflow.compile()
|
58 |
|
59 |
+
def extract_chart_data(text):
|
60 |
+
"""
|
61 |
+
Try to extract something like:
|
62 |
+
2018: 20
|
63 |
+
2019: 21.5
|
64 |
+
2020: 18
|
65 |
+
"""
|
66 |
+
import re
|
67 |
+
matches = re.findall(r'(\d{4})\s*[:\-]?\s*\$?([\d\.]+)', text)
|
68 |
+
if not matches:
|
69 |
+
return None, None
|
70 |
+
years = [m[0] for m in matches]
|
71 |
+
values = [float(m[1]) for m in matches]
|
72 |
+
return years, values
|
73 |
+
|
74 |
+
def generate_plot(years, values):
|
75 |
+
fig, ax = plt.subplots()
|
76 |
+
ax.bar(years, values)
|
77 |
+
ax.set_title("Generated Chart")
|
78 |
+
ax.set_xlabel("Year")
|
79 |
+
ax.set_ylabel("Value")
|
80 |
+
buf = BytesIO()
|
81 |
+
plt.savefig(buf, format="png")
|
82 |
+
buf.seek(0)
|
83 |
+
return buf
|
84 |
+
|
85 |
def run_langgraph(user_input):
|
86 |
events = graph.stream(
|
87 |
{"messages": [("user", user_input)]},
|
|
|
95 |
|
96 |
return final_message or "No output generated"
|
97 |
|
|
|
|
|
98 |
def process_input(user_input):
|
99 |
+
result_text = run_langgraph(user_input)
|
100 |
+
years, values = extract_chart_data(result_text)
|
101 |
+
|
102 |
+
if years and values:
|
103 |
+
chart = generate_plot(years, values)
|
104 |
+
return result_text, chart
|
105 |
+
else:
|
106 |
+
return result_text, None
|
107 |
|
108 |
interface = gr.Interface(
|
109 |
fn=process_input,
|
110 |
inputs="text",
|
111 |
+
outputs=[
|
112 |
+
gr.Textbox(label="Generated Response"),
|
113 |
+
gr.Image(type="pil", label="Generated Chart")
|
114 |
+
],
|
115 |
title="LangGraph Research Automation",
|
116 |
+
description="Enter your research task (e.g., 'Get GDP data for the USA over the past 5 years and create a chart.')"
|
117 |
)
|
118 |
|
|
|
119 |
if __name__ == "__main__":
|
120 |
interface.launch()
|
121 |
+
|