gradiobox / app.py
pwilczewski
added missing argument
a35f64b
raw
history blame
4.88 kB
import gradio as gr
# cell 1
from typing import Annotated
from langchain_experimental.tools import PythonREPLTool, PythonAstREPLTool
import pandas as pd
import statsmodels as sm
import os
# df = pd.read_csv("HOUST.csv")
df = pd.read_csv("USSTHPI.csv")
python_repl_tool = PythonAstREPLTool(locals={"df": df})
import matplotlib.pyplot as nplt
def gen_plot(name):
nplt.figure(figsize=(10, 5))
nplt.plot(df['DATE'], df['USSTHPI'], label='USSTHPI')
nplt.title('Time Series of USSTHPI')
nplt.xlabel('Date')
nplt.ylabel('USSTHPI')
nplt.legend()
nplt.xticks(rotation=45)
nplt.tight_layout()
nplt.savefig('plots/plot.png')
return "plots/plot.png"
# cell 2
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
import functools
import operator
from typing import Sequence, TypedDict
system_prompt = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
It is important to understand the attributes of the dataframe before working with it. This is the result of running `df.head().to_markdown()`
<df>
{dhead}
</df>
You are not meant to use only these rows to answer questions - they are meant as a way of telling you about the shape and schema of the dataframe.
You also do not have use only the information here to answer questions - you can run intermediate queries to do exporatory data analysis to give you more information as needed. """
system_prompt = system_prompt.format(dhead=df.head().to_markdown())
# The agent state is the input to each node in the graph
class AgentState(TypedDict):
# The annotation tells the graph that new messages will always be added to the current states
messages: Annotated[Sequence[BaseMessage], operator.add]
# The 'next' field indicates where to route to next
next: str
# part of the problem might be that I'm passing a PromptTemplate object for the system_prompt here
# not everything needs to be an openai tools agent
def create_agent(llm: ChatOpenAI, tools: list, task: str):
# Each worker node will be given a name and some tools.
prompt = ChatPromptTemplate.from_messages(
[
( "system", system_prompt, ), # using a global system_prompt
HumanMessage(content=task),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
agent = create_openai_tools_agent(llm, tools, prompt)
# for debugging
executor = AgentExecutor(agent=agent, tools=tools, verbose=True, return_intermediate_steps=True)
return executor
# AIMessage will have all kinds of metadata, so treat it all as HumanMessage I suppose?
def agent_node(state: AgentState, agent, name):
result = agent.invoke(state)
# return {"messages": [HumanMessage(content=result["output"], name=name)]}
return {"messages": [result]}
# I need to write the message to state here? or is that handled automatically?
def chain_node(state: AgentState, chain, name):
result = chain.invoke(input={"detail": "medium", "messages": state["messages"]})
return {"messages": [HumanMessage(content=result.content, name=name)]}
# cell 3
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0, api_key=OPENAI_API_KEY)
llm_big = ChatOpenAI(model="gpt-4o", temperature=0, api_key=OPENAI_API_KEY)
eda_task = """Using the data in the dataframe `df` and the package statsmodels, first run an augmented dickey fuller test on the data.
Using matplotlib plot the time series, display it and save it to 'plots/plot.png'.
Next use the statsmodel package to generate an ACF plot with zero flag set to False, display it and save it to 'plots/acf.png'.
Then use the statsmodel package to generate a PACF plot with zero flag set to False, display it and save it to 'plots/pacf.png'"""
eda_agent = create_agent(llm, [python_repl_tool], task=eda_task,)
eda_node = functools.partial(agent_node, agent=eda_agent, name="EDA")
from langgraph.graph import END, StateGraph, START
# add a chain to the node to analyze the ACF plot?
workflow = StateGraph(AgentState)
workflow.add_node("EDA", eda_node)
# conditional_edge to refit and the loop refit with resid?
workflow.add_edge(START, "EDA")
workflow.add_edge("EDA", END)
graph = workflow.compile()
def greet(name):
resp = graph.invoke({"messages": [HumanMessage(content="Run the analysis")]}, debug=True)
return resp
demo = gr.Interface(fn=gen_plot, inputs="text", outputs="text")
demo.launch()