pwilczewski commited on
Commit
6589e60
·
1 Parent(s): 0169c8b
Files changed (2) hide show
  1. app.py +88 -1
  2. requirements.txt +184 -1
app.py CHANGED
@@ -1,7 +1,94 @@
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ # cell 1
4
+ from typing import Annotated
5
+ from langchain_experimental.tools import PythonREPLTool, PythonAstREPLTool
6
+ import pandas as pd
7
+ import statsmodels as sm
8
+
9
+ # df = pd.read_csv("HOUST.csv")
10
+ df = pd.read_csv("USSTHPI.csv")
11
+ python_repl_tool = PythonAstREPLTool(locals={"df": df})
12
+
13
+ # cell 2
14
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
15
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
16
+ from langchain_openai import ChatOpenAI
17
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
18
+ import functools
19
+ import operator
20
+ from typing import Sequence, TypedDict
21
+
22
+ system_prompt = """You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
23
+ It is important to understand the attributes of the dataframe before working with it. This is the result of running `df.head().to_markdown()`
24
+
25
+ <df>
26
+ {dhead}
27
+ </df>
28
+
29
+ You are not meant to use only these rows to answer questions - they are meant as a way of telling you about the shape and schema of the dataframe.
30
+ You also do not have use only the information here to answer questions - you can run intermediate queries to do exporatory data analysis to give you more information as needed. """
31
+ system_prompt = system_prompt.format(dhead=df.head().to_markdown())
32
+
33
+ # The agent state is the input to each node in the graph
34
+ class AgentState(TypedDict):
35
+ # The annotation tells the graph that new messages will always be added to the current states
36
+ messages: Annotated[Sequence[BaseMessage], operator.add]
37
+ # The 'next' field indicates where to route to next
38
+ next: str
39
+
40
+ # part of the problem might be that I'm passing a PromptTemplate object for the system_prompt here
41
+ # not everything needs to be an openai tools agent
42
+ def create_agent(llm: ChatOpenAI, tools: list, task: str):
43
+ # Each worker node will be given a name and some tools.
44
+ prompt = ChatPromptTemplate.from_messages(
45
+ [
46
+ ( "system", system_prompt, ), # using a global system_prompt
47
+ HumanMessage(content=task),
48
+ MessagesPlaceholder(variable_name="messages"),
49
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
50
+ ]
51
+ )
52
+ agent = create_openai_tools_agent(llm, tools, prompt)
53
+ executor = AgentExecutor(agent=agent, tools=tools)
54
+ return executor
55
+
56
+ # AIMessage will have all kinds of metadata, so treat it all as HumanMessage I suppose?
57
+ def agent_node(state: AgentState, agent, name):
58
+ result = agent.invoke(state)
59
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
60
+
61
+ # I need to write the message to state here? or is that handled automatically?
62
+ def chain_node(state: AgentState, chain, name):
63
+
64
+ result = chain.invoke(input={"detail": "medium", "messages": state["messages"]})
65
+ return {"messages": [HumanMessage(content=result.content, name=name)]}
66
+
67
+ # cell 3
68
+ llm = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0)
69
+ llm_big = ChatOpenAI(model="gpt-4o", temperature=0)
70
+
71
+ eda_task = """Using the data in the dataframe `df` and the package statsmodels, first run an augmented dickey fuller test on the data.
72
+ Using matplotlib plot the time series, display it and save it to 'plot.png'.
73
+ Next use the statsmodel package to generate an ACF plot with zero flag set to False, display it and save it to 'acf.png'.
74
+ Then use the statsmodel package to generate a PACF plot with zero flag set to False, display it and save it to 'pacf.png'"""
75
+ eda_agent = create_agent(llm, [python_repl_tool], task=eda_task,)
76
+ eda_node = functools.partial(agent_node, agent=eda_agent, name="EDA")
77
+
78
+ from langgraph.graph import END, StateGraph, START
79
+
80
+ # add a chain to the node to analyze the ACF plot?
81
+ workflow = StateGraph(AgentState)
82
+ workflow.add_node("EDA", eda_node)
83
+
84
+ # conditional_edge to refit and the loop refit with resid?
85
+ workflow.add_edge(START, "EDA")
86
+ workflow.add_edge("EDA", END)
87
+
88
+ graph = workflow.compile()
89
+
90
  def greet(name):
91
+ return graph.invoke({"messages": [HumanMessage(content="Run the analysis")]})
92
 
93
  demo = gr.Interface(fn=greet, inputs="text", outputs="text")
94
  demo.launch()
requirements.txt CHANGED
@@ -1 +1,184 @@
1
- pandas
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: win-64
4
+ aiofiles=23.2.1=pypi_0
5
+ aiohappyeyeballs=2.4.0=pypi_0
6
+ aiohttp=3.10.5=pypi_0
7
+ aiosignal=1.3.1=pypi_0
8
+ annotated-types=0.7.0=pypi_0
9
+ anyio=4.4.0=pypi_0
10
+ asttokens=2.4.1=pyhd8ed1ab_0
11
+ attrs=24.2.0=pypi_0
12
+ blas=1.0=mkl
13
+ bottleneck=1.3.7=py312he558020_0
14
+ brotli=1.0.9=h2bbff1b_8
15
+ brotli-bin=1.0.9=h2bbff1b_8
16
+ bzip2=1.0.8=h2bbff1b_6
17
+ ca-certificates=2024.7.2=haa95532_0
18
+ certifi=2024.8.30=pypi_0
19
+ charset-normalizer=3.3.2=pypi_0
20
+ click=8.1.7=pypi_0
21
+ colorama=0.4.6=pyhd8ed1ab_0
22
+ comm=0.2.2=pyhd8ed1ab_0
23
+ contourpy=1.2.0=py312h59b6b97_0
24
+ cycler=0.11.0=pyhd3eb1b0_0
25
+ dataclasses-json=0.6.7=pypi_0
26
+ debugpy=1.6.7=py312hd77b12b_0
27
+ decorator=5.1.1=pyhd8ed1ab_0
28
+ distro=1.9.0=pypi_0
29
+ exceptiongroup=1.2.2=pyhd8ed1ab_0
30
+ executing=2.1.0=pyhd8ed1ab_0
31
+ expat=2.6.2=hd77b12b_0
32
+ fastapi=0.115.0=pypi_0
33
+ ffmpy=0.4.0=pypi_0
34
+ filelock=3.16.1=pypi_0
35
+ fonttools=4.51.0=py312h2bbff1b_0
36
+ freetype=2.12.1=ha860e81_0
37
+ frozenlist=1.4.1=pypi_0
38
+ fsspec=2024.9.0=pypi_0
39
+ gradio=4.44.0=pypi_0
40
+ gradio-client=1.3.0=pypi_0
41
+ greenlet=3.0.3=pypi_0
42
+ h11=0.14.0=pypi_0
43
+ httpcore=1.0.5=pypi_0
44
+ httpx=0.27.2=pypi_0
45
+ huggingface-hub=0.25.1=pypi_0
46
+ icc_rt=2022.1.0=h6049295_2
47
+ icu=73.1=h6c2663c_0
48
+ idna=3.8=pypi_0
49
+ importlib-metadata=8.5.0=pyha770c72_0
50
+ importlib-resources=6.4.5=pypi_0
51
+ importlib_metadata=8.5.0=hd8ed1ab_0
52
+ intel-openmp=2023.1.0=h59b6b97_46320
53
+ ipykernel=6.29.5=pyh4bbf305_0
54
+ ipython=8.27.0=pyh7428d3b_0
55
+ jedi=0.19.1=pyhd8ed1ab_0
56
+ jinja2=3.1.4=pypi_0
57
+ jiter=0.5.0=pypi_0
58
+ jpeg=9e=h827c3e9_3
59
+ jsonpatch=1.33=pypi_0
60
+ jsonpointer=3.0.0=pypi_0
61
+ jupyter_client=8.6.2=pyhd8ed1ab_0
62
+ jupyter_core=5.7.2=py312haa95532_0
63
+ kiwisolver=1.4.4=py312hd77b12b_0
64
+ krb5=1.20.1=h5b6d351_0
65
+ langchain=0.3.0=pypi_0
66
+ langchain-community=0.3.0=pypi_0
67
+ langchain-core=0.3.1=pypi_0
68
+ langchain-experimental=0.3.0=pypi_0
69
+ langchain-openai=0.2.0=pypi_0
70
+ langchain-text-splitters=0.3.0=pypi_0
71
+ langgraph=0.2.22=pypi_0
72
+ langgraph-checkpoint=1.0.10=pypi_0
73
+ langsmith=0.1.121=pypi_0
74
+ lcms2=2.12=h83e58a3_0
75
+ lerc=3.0=hd77b12b_0
76
+ libbrotlicommon=1.0.9=h2bbff1b_8
77
+ libbrotlidec=1.0.9=h2bbff1b_8
78
+ libbrotlienc=1.0.9=h2bbff1b_8
79
+ libclang=14.0.6=default_hb5a9fac_1
80
+ libclang13=14.0.6=default_h8e68704_1
81
+ libdeflate=1.17=h2bbff1b_1
82
+ libffi=3.4.4=hd77b12b_1
83
+ libpng=1.6.39=h8cc25b3_0
84
+ libpq=12.17=h906ac69_0
85
+ libsodium=1.0.18=h8d14728_1
86
+ libtiff=4.5.1=hd77b12b_0
87
+ libwebp-base=1.3.2=h2bbff1b_0
88
+ lz4-c=1.9.4=h2bbff1b_1
89
+ markdown-it-py=3.0.0=pypi_0
90
+ markupsafe=2.1.5=pypi_0
91
+ marshmallow=3.22.0=pypi_0
92
+ matplotlib=3.9.2=py312haa95532_0
93
+ matplotlib-base=3.9.2=py312hbdc63d0_0
94
+ matplotlib-inline=0.1.7=pyhd8ed1ab_0
95
+ mdurl=0.1.2=pypi_0
96
+ mkl=2023.1.0=h6b88ed4_46358
97
+ mkl-service=2.4.0=py312h2bbff1b_1
98
+ mkl_fft=1.3.10=py312h827c3e9_0
99
+ mkl_random=1.2.7=py312h0158946_0
100
+ msgpack=1.1.0=pypi_0
101
+ multidict=6.0.5=pypi_0
102
+ mypy-extensions=1.0.0=pypi_0
103
+ nest-asyncio=1.6.0=pyhd8ed1ab_0
104
+ numexpr=2.8.7=py312h96b7d27_0
105
+ numpy=1.26.4=py312hfd52020_0
106
+ numpy-base=1.26.4=py312h4dde369_0
107
+ openai=1.46.0=pypi_0
108
+ openjpeg=2.5.2=hae555c5_0
109
+ openssl=3.0.15=h827c3e9_0
110
+ orjson=3.10.7=pypi_0
111
+ packaging=24.1=pyhd8ed1ab_0
112
+ pandas=2.2.2=py312h0158946_0
113
+ parso=0.8.4=pyhd8ed1ab_0
114
+ patsy=0.5.6=py312haa95532_0
115
+ pickleshare=0.7.5=py_1003
116
+ pillow=10.4.0=py312h827c3e9_0
117
+ pip=24.2=py312haa95532_0
118
+ platformdirs=4.3.3=pyhd8ed1ab_0
119
+ ply=3.11=py312haa95532_1
120
+ prompt-toolkit=3.0.47=pyha770c72_0
121
+ psutil=5.9.0=py312h2bbff1b_0
122
+ pure_eval=0.2.3=pyhd8ed1ab_0
123
+ pybind11-abi=5=hd3eb1b0_0
124
+ pydantic=2.8.2=pypi_0
125
+ pydantic-core=2.20.1=pypi_0
126
+ pydantic-settings=2.5.2=pypi_0
127
+ pydub=0.25.1=pypi_0
128
+ pygments=2.18.0=pyhd8ed1ab_0
129
+ pyparsing=3.1.2=py312haa95532_0
130
+ pyqt=5.15.10=py312hd77b12b_0
131
+ pyqt5-sip=12.13.0=py312h2bbff1b_0
132
+ python=3.12.4=h14ffc60_1
133
+ python-dateutil=2.9.0post0=py312haa95532_2
134
+ python-dotenv=1.0.1=pypi_0
135
+ python-multipart=0.0.10=pypi_0
136
+ python-tzdata=2023.3=pyhd3eb1b0_0
137
+ pytz=2024.1=py312haa95532_0
138
+ pywin32=305=py312h2bbff1b_0
139
+ pyyaml=6.0.2=pypi_0
140
+ pyzmq=25.1.2=py312hd77b12b_0
141
+ qt-main=5.15.2=h19c9488_10
142
+ regex=2024.9.11=pypi_0
143
+ requests=2.32.3=pypi_0
144
+ rich=13.8.1=pypi_0
145
+ ruff=0.6.7=pypi_0
146
+ scipy=1.13.1=py312hbb039d4_0
147
+ semantic-version=2.10.0=pypi_0
148
+ setuptools=72.1.0=py312haa95532_0
149
+ shellingham=1.5.4=pypi_0
150
+ sip=6.7.12=py312hd77b12b_0
151
+ six=1.16.0=pyhd3eb1b0_1
152
+ sniffio=1.3.1=pypi_0
153
+ sqlalchemy=2.0.33=pypi_0
154
+ sqlite=3.45.3=h2bbff1b_0
155
+ stack_data=0.6.2=pyhd8ed1ab_0
156
+ starlette=0.38.6=pypi_0
157
+ statsmodels=0.14.2=py312h4b0e54e_0
158
+ tabulate=0.9.0=py312haa95532_0
159
+ tbb=2021.8.0=h59b6b97_0
160
+ tenacity=8.5.0=pypi_0
161
+ tiktoken=0.7.0=pypi_0
162
+ tk=8.6.14=h0416ee5_0
163
+ tomlkit=0.12.0=pypi_0
164
+ tornado=6.4.1=py312h827c3e9_0
165
+ tqdm=4.66.5=pypi_0
166
+ traitlets=5.14.3=pyhd8ed1ab_0
167
+ typer=0.12.5=pypi_0
168
+ typing-inspect=0.9.0=pypi_0
169
+ typing_extensions=4.12.2=pyha770c72_0
170
+ tzdata=2024a=h04d1e81_0
171
+ unicodedata2=15.1.0=py312h2bbff1b_0
172
+ urllib3=2.2.2=pypi_0
173
+ uvicorn=0.30.6=pypi_0
174
+ vc=14.40=h2eaa2aa_0
175
+ vs2015_runtime=14.40.33807=h98bb1dd_0
176
+ wcwidth=0.2.13=pyhd8ed1ab_0
177
+ websockets=12.0=pypi_0
178
+ wheel=0.43.0=py312haa95532_0
179
+ xz=5.4.6=h8cc25b3_1
180
+ yarl=1.9.7=pypi_0
181
+ zeromq=4.3.5=hd77b12b_0
182
+ zipp=3.20.2=pyhd8ed1ab_0
183
+ zlib=1.2.13=h8cc25b3_1
184
+ zstd=1.5.5=hd43e919_2