pvanand commited on
Commit
fcbf0a3
·
verified ·
1 Parent(s): 038065a

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +186 -0
main.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import StreamingResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from langchain_core.messages import BaseMessage, HumanMessage, trim_messages
6
+ from langchain_core.tools import tool
7
+ from langchain_openai import ChatOpenAI
8
+ from langgraph.checkpoint.memory import MemorySaver
9
+ from langgraph.prebuilt import create_react_agent
10
+ from pydantic import BaseModel
11
+ from typing import Optional
12
+ import json
13
+ from sse_starlette.sse import EventSourceResponse
14
+ import io
15
+ import sys
16
+ from contextlib import redirect_stdout, redirect_stderr
17
+ from langchain_core.runnables import RunnableConfig
18
+ import requests
19
+ import uvicorn
20
+ import re
21
+ app = FastAPI()
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ class CodeExecutionResult:
32
+ def __init__(self, output: str, error: str = None):
33
+ self.output = output
34
+ self.error = error
35
+
36
+ API_URL = "https://pvanand-code-execution-files-v4.hf.space"
37
+
38
+ @tool
39
+ def execute_python(code: str) -> str:
40
+ """Execute Python code in an IPython interactiveshell and return the output.
41
+
42
+ Args:
43
+ code: The Python code to execute
44
+
45
+ Available Libraries:
46
+ # Use plotly as the default charting library
47
+ # While using yfinance to pull stock data, Always clean the multiindex columns as this might cause issues in plotting plotly charts
48
+ # Remove the ticker level from columns if it exists
49
+ yf_data = yf.download(symbol, start=start_date, end=end_date)
50
+ if isinstance(yf_data.columns, pd.MultiIndex):
51
+ yf_data.columns = yf_data.columns.get_level_values(0)
52
+
53
+ matplotlib
54
+ pandas
55
+ plotly
56
+ groq
57
+ yfinance
58
+ numpy
59
+ seaborn
60
+ numpy
61
+ scikit-learn
62
+ statsmodels
63
+ geopandas
64
+ folium
65
+ fpdf
66
+ kaleido
67
+ scipy
68
+ geopy
69
+ mapbox
70
+
71
+ Artifacts are automatically rendered in the UI hence no need to provide links to them.
72
+ """
73
+
74
+ #print(config)
75
+
76
+ headers = {
77
+ 'accept': 'application/json',
78
+ 'Content-Type': 'application/json'
79
+ }
80
+ data = {
81
+ "session_token": "test12345", #config.configurable.get("thread_id", "test"),
82
+ "code": code
83
+ }
84
+ response = requests.post(
85
+ f'{API_URL}/v0/execute',
86
+ headers=headers,
87
+ data=json.dumps(data)
88
+ )
89
+
90
+ if response.status_code != 200:
91
+ return f"Error: Request failed with status code {response.status_code}. Response: {response.text}"
92
+ else:
93
+ response_json = response.json()
94
+ return f"data: {json.dumps(response_json)} \ndata:"
95
+
96
+ # Configure the memory and model"
97
+ memory = MemorySaver()
98
+ model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
99
+
100
+ def state_modifier(state) -> list[BaseMessage]:
101
+ return trim_messages(
102
+ state["messages"],
103
+ token_counter=len,
104
+ max_tokens=16000,
105
+ strategy="last",
106
+ start_on="human",
107
+ include_system=True,
108
+ allow_partial=False,
109
+ )
110
+
111
+ # Create the agent with the Python execution tool
112
+ agent = create_react_agent(
113
+ model,
114
+ tools=[execute_python],
115
+ checkpointer=memory,
116
+ state_modifier=state_modifier,
117
+ )
118
+
119
+ class ChatInput(BaseModel):
120
+ message: str
121
+ thread_id: Optional[str] = None
122
+
123
+ @app.post("/chat")
124
+ async def chat(input_data: ChatInput):
125
+ thread_id = input_data.thread_id or str(uuid.uuid4())
126
+
127
+ config = {
128
+ "configurable": {
129
+ "thread_id": thread_id
130
+ }
131
+ }
132
+
133
+ input_message = HumanMessage(content=input_data.message)
134
+
135
+ async def generate():
136
+ async for event in agent.astream_events(
137
+ {"messages": [input_message]},
138
+ config,
139
+ version="v2"
140
+ ):
141
+ kind = event["event"]
142
+
143
+ if kind == "on_chat_model_stream":
144
+ content = event["data"]["chunk"].content
145
+ if content:
146
+ yield f"{json.dumps({'type': 'token', 'content': content})}\n"
147
+
148
+ elif kind == "on_tool_start":
149
+ tool_input = event['data'].get('input', '')
150
+ yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}\n"
151
+
152
+ elif kind == "on_tool_end":
153
+ tool_output = event['data'].get('output', '').content
154
+ #print(type(tool_output))
155
+ #print(dir(tool_output))
156
+ #print the keys
157
+ pattern = r'data: (.*?)\ndata:'
158
+ match = re.search(pattern, tool_output)
159
+ print(tool_output)
160
+
161
+ if match:
162
+ tool_output_json = match.group(1).strip()
163
+ try:
164
+ tool_output = json.loads(tool_output_json)
165
+ if "artifacts" in tool_output:
166
+ for artifact in tool_output["artifacts"]:
167
+ artifact_content = requests.get(f"{API_URL}/artifact/{artifact['artifact_id']}").content
168
+ print(artifact_content)
169
+ tool_output["artifacts"][artifact["artifact_id"]] = artifact_content
170
+ except Exception as e:
171
+ print(e)
172
+ print("Error parsing tool output as json: ", tool_output)
173
+ else:
174
+ print("No match found in tool output")
175
+ yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}\n"
176
+ return EventSourceResponse(
177
+ generate(),
178
+ media_type="text/event-stream"
179
+ )
180
+
181
+ @app.get("/health")
182
+ async def health_check():
183
+ return {"status": "healthy"}
184
+
185
+ if __name__ == "__main__":
186
+ uvicorn.run(app, host="0.0.0.0", port=9000)