Spaces:
Runtime error
Runtime error
File size: 8,671 Bytes
a5e5bde 33d0c27 ee56cf8 a5e5bde d0d0416 0d1bfaa ceeea5f 0d1bfaa a5e5bde ee56cf8 c3e8fb1 53b9021 ceeea5f 0d1bfaa ceeea5f 53b9021 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ceeea5f 0d1bfaa ee56cf8 0d1bfaa ee56cf8 53b9021 ceeea5f 53b9021 ceeea5f 53b9021 ceeea5f 53b9021 ceeea5f 53b9021 ceeea5f 53b9021 ceeea5f ee56cf8 ceeea5f ee56cf8 ceeea5f 53b9021 ee56cf8 33d0c27 a5e5bde 53b9021 ee56cf8 e2f5cf3 c3e8fb1 ee56cf8 a5e5bde ee56cf8 a5e5bde 53b9021 a5e5bde 53b9021 a5e5bde 53b9021 a5e5bde 33d0c27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from rl_agent.env import Environment
from rl_agent.policy import Policy
from rl_agent.utils import myOptimizer
import torch
from collections import OrderedDict
from tqdm import tqdm
import statistics
import datetime
def get_time():
return datetime.datetime.now().time()
def get_profit():
global profit
return profit
# def update_table():
# global
def pretrain_rl_agent():
global equity
observations = env_train.reset()
for _ in tqdm(range(state_size, len(train))):
observations = torch.as_tensor(observations).float()
action = agent(observations)
observations, reward, _ = env_train.step(action.data.to("cpu").numpy())
# reward *= 1e3
action.backward()
for name, param in agent.named_parameters():
grad_n = param.grad
param = param + optimizer.step(grad_n, reward, observations[-1], model_gradients_history[name])
checkpoint[name] = param
model_gradients_history.update({name: grad_n})
# equity += env_train.profit
optimizer.after_step(reward)
agent.load_state_dict(checkpoint)
def make_prediction(observations):
# observations: 0-14
action = agent(torch.as_tensor(observations).float())
# returned observation: 1- 15
observations, reward, _ = env_test.step(action.data.to("cpu").numpy())
action.backward()
for name, param in agent.named_parameters():
grad_n = param.grad
param = param + optimizer.step(grad_n, reward, observations[-1], model_gradients_history[name])
checkpoint[name] = param
model_gradients_history.update({name: grad_n})
# equity += env_test.profit
optimizer.after_step(reward)
agent.load_state_dict(checkpoint)
return action, observations # [-1.0, 1.0] * leverage
# ----------------------------------------------------------------------------------------------------------------------
# For visualization
# ----------------------------------------------------------------------------------------------------------------------
profit = 0.0
counter = 0
start_year, test_year = 2021, 2023
datetime_column = "Date"
df_data = pd.read_csv(f"./data/EURUSD_Candlestick_1_M_BID_01.01.{start_year}-04.02.2023_processed.csv")
df_data[datetime_column] = pd.to_datetime(df_data[datetime_column], format="%Y-%m-%d") # %d.%m.%Y %H:%M:%S.000 GMT%z
# Removing all empty dates
# Build complete timeline from start date to end date
dt_all = pd.date_range(start=df_data[datetime_column].tolist()[0], end=df_data[datetime_column].tolist()[-1])
# Retrieve the dates that ARE in the original dataset
dt_obs = set([d.strftime("%Y-%m-%d") for d in pd.to_datetime(df_data[datetime_column])])
# Define dates with missing values
dt_breaks = [d for d in dt_all.strftime("%Y-%m-%d").tolist() if not d in list(dt_obs)]
df_data_test = df_data[df_data['Date'].dt.year == test_year]
df_data_train = df_data[df_data['Date'].dt.year != test_year]
df_data_train_viz = pd.DataFrame(columns=["Action", "Amount", "Profit"])
# ----------------------------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------------------------
# For RL Agent
# ----------------------------------------------------------------------------------------------------------------------
data = pd.read_csv(f'./data/EURUSD_Candlestick_1_M_BID_01.01.{start_year}-04.02.2023.csv')
data = data.head(600000)
data = data.set_index('Local time')
date_split = '31.01.2022 03:29:00.000 GMT-0600'
learning_rate = 0.001
first_momentum = 0.0
second_momentum = 0.0001
transaction_cost = 0.0001
adaptation_rate = 0.01
state_size = 15
equity = 1.0
train = data[:date_split]
test = pd.concat([train.tail(state_size), data[date_split:]])
# Initialize agent and optimizer
agent = Policy(input_channels=state_size)
optimizer = myOptimizer(learning_rate, first_momentum, second_momentum, adaptation_rate, transaction_cost)
history = []
for i in range(1, state_size):
c = train.iloc[i, :]['Close'] - train.iloc[i - 1, :]['Close']
history.append(c)
# Initialize train and test environments
env_train = Environment(train, history=history, state_size=state_size)
history = []
for i in range(1, state_size):
c = test.iloc[i, :]['Close'] - test.iloc[i - 1, :]['Close']
history.append(c)
env_test = Environment(test, history=history, state_size=state_size)
model_gradients_history = dict()
checkpoint = OrderedDict()
for name, param in agent.named_parameters():
model_gradients_history.update({name: torch.zeros_like(param)})
pretrain_rl_agent()
observations = env_test.reset()
# ----------------------------------------------------------------------------------------------------------------------
def trading_plot():
global counter
global df_data_train
global observations
global profit
actions = []
if counter < len(df_data_test):
df_data_train = df_data_train.append(df_data_test.iloc[counter])
counter += 1
last_observation = observations[-1]
for i in range(1440):
action, observations = make_prediction(observations)
actions.append(action.item())
position = statistics.mean(actions)
# profit += -1.0 * (last_observation - observations[-1]) * position
profit = env_test.profits
else:
df_data_train = df_data
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.02, row_heights=[0.7, 0.3],
subplot_titles=['OHLC chart', ''])
# Plot OHLC on 1st subplot
fig.add_trace(go.Candlestick(x=df_data_train[datetime_column].tolist(),
open=df_data_train["Open"].tolist(), close=df_data_train["Close"].tolist(),
high=df_data_train["High"].tolist(), low=df_data_train["Low"].tolist(),
name=""), row=1, col=1)
# Plot volume trace on 2nd row
colors = ['red' if row['Open'] - row['Close'] >= 0 else 'green' for index, row in df_data_train.iterrows()]
fig.add_trace(go.Bar(x=df_data_train[datetime_column], y=df_data_train['Volume'], name="", marker_color=colors,
hovertemplate="%{x}<br>Volume: %{y}"), row=2, col=1)
# Add chart title and Hide dates with no values and remove rangeslider
fig.update_layout(title="", height=600, showlegend=False,
xaxis_rangeslider_visible=False,
xaxis_rangebreaks=[dict(values=dt_breaks)])
# Update y-axis label
fig.update_yaxes(title_text="Price", row=1, col=1)
fig.update_yaxes(title_text="Volume", row=2, col=1)
fig.update_xaxes(showspikes=True, spikecolor="green", spikesnap="cursor", spikemode="across")
fig.update_yaxes(showspikes=True, spikecolor="orange", spikethickness=2)
fig.update_layout(spikedistance=1000, hoverdistance=100)
fig.layout.xaxis.range = ("2022-12-01", "2023-03-01")
return fig
# The UI of the demo defines here.
with gr.Blocks() as demo:
gr.Markdown("Auto AI Trading Bot")
gr.Markdown(f"Investment: $100,000")
dt = gr.Textbox(label="Total profit (Amount of profit in PIPS that the agent makes in EUR/USD)")
demo.queue().load(get_profit, inputs=None, outputs=dt, every=1)
# for plotly it should follow this: https://gradio.app/plot-component-for-maps/
candlestick_plot = gr.Plot().style()
demo.queue().load(trading_plot, [], candlestick_plot, every=1)
with gr.Row():
with gr.Column():
gr.Markdown("User Interactive panel")
amount = gr.components.Textbox(value="", label="Amount", interactive=True)
with gr.Row():
buy_btn = gr.components.Button("Buy", label="Buy", interactive=True, inputs=[amount])
sell_btn = gr.components.Button("Sell", label="Sell", interactive=True, inputs=[amount])
hold_btn = gr.components.Button("Hold", label="Hold", interactive=True, inputs=[amount])
with gr.Column():
gr.Markdown("Trade bot history")
# trade_bot_table = gr.Dataframe(df_data_train_viz)
# demo.queue().load(update_table, inputs=None, outputs=trade_bot_table, every=1)
# Show trade box history in a table or something
# gr.components.Textbox(value="Some history? Need to decide how to show bot history", label="History", interactive=True)
demo.launch()
|