File size: 4,882 Bytes
0169c8b
 
6589e60
 
 
 
 
b31f6f5
6589e60
 
 
 
 
2944273
a35f64b
d70f60e
 
 
 
 
 
 
 
 
 
2944273
6589e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27bb768
bcc9a7a
27bb768
6589e60
 
 
 
 
5b69170
 
6589e60
 
 
 
 
 
 
 
b31f6f5
 
 
6589e60
 
7d3e0da
 
 
6589e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0169c8b
f229ac4
89418a5
0169c8b
d70f60e
0169c8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr

# cell 1
from typing import Annotated
from langchain_experimental.tools import PythonREPLTool, PythonAstREPLTool
import pandas as pd
import statsmodels as sm
import os

# df = pd.read_csv("HOUST.csv")
df = pd.read_csv("USSTHPI.csv")
python_repl_tool = PythonAstREPLTool(locals={"df": df})

import matplotlib.pyplot as nplt
def gen_plot(name):
    nplt.figure(figsize=(10, 5))
    nplt.plot(df['DATE'], df['USSTHPI'], label='USSTHPI')
    nplt.title('Time Series of USSTHPI')
    nplt.xlabel('Date')
    nplt.ylabel('USSTHPI')
    nplt.legend()
    nplt.xticks(rotation=45)
    nplt.tight_layout()
    nplt.savefig('plots/plot.png')
    return "plots/plot.png"

# cell 2
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
import functools
import operator
from typing import Sequence, TypedDict

system_prompt = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
                It is important to understand the attributes of the dataframe before working with it. This is the result of running `df.head().to_markdown()`

                <df>
                {dhead}
                </df>

                You are not meant to use only these rows to answer questions - they are meant as a way of telling you about the shape and schema of the dataframe. 
                You also do not have use only the information here to answer questions - you can run intermediate queries to do exporatory data analysis to give you more information as needed. """
system_prompt = system_prompt.format(dhead=df.head().to_markdown())

# The agent state is the input to each node in the graph
class AgentState(TypedDict):
    # The annotation tells the graph that new messages will always be added to the current states
    messages: Annotated[Sequence[BaseMessage], operator.add]
    # The 'next' field indicates where to route to next
    next: str

# part of the problem might be that I'm passing a PromptTemplate object for the system_prompt here
# not everything needs to be an openai tools agent
def create_agent(llm: ChatOpenAI, tools: list, task: str):
    # Each worker node will be given a name and some tools.
    prompt = ChatPromptTemplate.from_messages(
        [
            ( "system", system_prompt, ), # using a global system_prompt
            HumanMessage(content=task),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    agent = create_openai_tools_agent(llm, tools, prompt)
    # for debugging
    executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
    return executor

# AIMessage will have all kinds of metadata, so treat it all as HumanMessage I suppose?
def agent_node(state: AgentState, agent, name):
    result = agent.invoke(state)
    # return {"messages": [HumanMessage(content=result["output"], name=name)]}
    return {"messages": [result]}

# I need to write the message to state here? or is that handled automatically?
def chain_node(state: AgentState, chain, name):

    result = chain.invoke(input={"detail": "medium", "messages": state["messages"]})
    return {"messages": [HumanMessage(content=result.content, name=name)]}

# cell 3
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0, api_key=OPENAI_API_KEY)
llm_big = ChatOpenAI(model="gpt-4o", temperature=0, api_key=OPENAI_API_KEY)

eda_task = """Using the data in the dataframe `df` and the package statsmodels, first run an augmented dickey fuller test on the data.
            Using matplotlib plot the time series, display it and save it to 'plots/plot.png'.
            Next use the statsmodel package to generate an ACF plot with zero flag set to False, display it and save it to 'plots/acf.png'.
            Then use the statsmodel package to generate a PACF plot with zero flag set to False, display it and save it to 'plots/pacf.png'"""
eda_agent = create_agent(llm, [python_repl_tool], task=eda_task,)
eda_node = functools.partial(agent_node, agent=eda_agent, name="EDA")

from langgraph.graph import END, StateGraph, START

# add a chain to the node to analyze the ACF plot?
workflow = StateGraph(AgentState)
workflow.add_node("EDA", eda_node)

# conditional_edge to refit and the loop refit with resid?
workflow.add_edge(START, "EDA")
workflow.add_edge("EDA", END)

graph = workflow.compile()

def greet(name):
    resp = graph.invoke({"messages": [HumanMessage(content="Run the analysis")]}, debug=True)
    return resp

demo = gr.Interface(fn=gen_plot, inputs="text", outputs="text")
demo.launch()