pmthangk09 commited on
Commit
f3ec673
·
2 Parent(s): 33d0c27 2daf62b

Resolve conflicts

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.cvs filter=lfs diff=lfs merge=lfs -text
36
+ data/EURUSD_Candlestick_1_M_BID_01.01.2021-04.02.2023.csv filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -7,8 +7,11 @@ data = pd.read_csv("./data/EURUSD_Candlestick_1_M_BID_01.01.2021-04.02.2023.csv"
7
  # The UI of the demo defines here.
8
  with gr.Blocks() as demo:
9
  gr.Markdown("Auto trade bot.")
10
-
11
- gr.Markdown("Candle plots may goes here.")
 
 
 
12
  gr.components.Image()
13
  # for plotly it should follow this: https://gradio.app/plot-component-for-maps/
14
  # trade_plot = gr.Plot()
 
7
  # The UI of the demo defines here.
8
  with gr.Blocks() as demo:
9
  gr.Markdown("Auto trade bot.")
10
+
11
+ with gr.Row():
12
+ gr.Markdown("Trade data settings.")
13
+ time_slider = gr.components.Slider(minimum=1, maximum=20, value=5, label="Time interval", interactive=True)
14
+ gr.Markdown("Candle polts may goes here.")
15
  gr.components.Image()
16
  # for plotly it should follow this: https://gradio.app/plot-component-for-maps/
17
  # trade_plot = gr.Plot()
data/EURUSD_Candlestick_1_M_BID_01.01.2021-04.02.2023.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:571cbd2e9950b74879480c9d5ee721570fb769ef221acf7cad5247e48824f5ec
3
+ size 56272445
data_management.py ADDED
File without changes
rl_agent/env.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ class Environment:
5
+
6
+ def __init__(self, data, history_t=90):
7
+ self.data = data
8
+ self.history_t = history_t
9
+ self.reset()
10
+
11
+ def reset(self):
12
+ self.t = 0
13
+ self.done = False
14
+ self.profits = 0
15
+ self.positions = []
16
+ self.position_value = 0
17
+ self.history = [0 for _ in range(self.history_t)]
18
+ return [self.position_value] + self.history # obs
19
+
20
+ def step(self, act):
21
+ reward = 0
22
+
23
+ # act = 0: stay, 1: buy, -1: sell
24
+ if act == 1:
25
+ self.positions.append(self.data.iloc[self.t, :]['Close'])
26
+ elif act == 2: # sell
27
+ if len(self.positions) == 0:
28
+ reward = -1
29
+ else:
30
+ profits = 0
31
+ for p in self.positions:
32
+ profits += (self.data.iloc[self.t, :]['Close'] - p)
33
+ reward += profits
34
+ self.profits += profits
35
+ self.positions = []
36
+
37
+ # set next time
38
+ self.t += 1
39
+ self.position_value = 0
40
+ for p in self.positions:
41
+ self.position_value += (self.data.iloc[self.t, :]['Close'] - p)
42
+ self.history.pop(0)
43
+ self.history.append(self.data.iloc[self.t, :]['Close'] - self.data.iloc[(self.t-1), :]['Close'])
44
+
45
+ # clipping reward
46
+ if reward > 0:
47
+ reward = 1
48
+ elif reward < 0:
49
+ reward = -1
50
+
51
+ return [self.position_value] + self.history, reward, self.done # obs, reward, done
52
+
53
+
54
+
55
+
56
+ if __name__ == "__main__":
57
+ data = pd.read_csv('./data/EURUSD_Candlestick_1_M_BID_01.01.2021-04.02.2023.csv')
58
+ # data['Local time'] = pd.to_datetime(data['Local time'])
59
+ data = data.set_index('Local time')
60
+ print(data.index.min(), data.index.max())
61
+
62
+ date_split = '19.09.2022 17:55:00.000 GMT-0500'
63
+ train = data[:date_split]
64
+ test = data[date_split:]
65
+ print(train.head(10))
66
+
67
+ env = Environment(train)
68
+ print(env.reset())
69
+ for _ in range(3):
70
+ pact = np.random.randint(3)
71
+ print(env.step(pact))
72
+
rl_agent/policy.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Policy(nn.Module):
7
+ def __init__(self, input_channels=8):
8
+
9
+ super(Policy, self).__init__()
10
+
11
+ self.layer1 = nn.Linear(input_channels, 2 * input_channels)
12
+ self.tanh1 = nn.Tanh()
13
+ self.layer2 = nn.linear(2 * input_channels, 1)
14
+ self.tanh2 = nn.Tanh()
15
+
16
+ def forward(self, state):
17
+
18
+ hidden = self.layer1(state)
19
+ hidden = self.tanh1(hidden)
20
+ hidden = self.layer2(hidden)
21
+ action = self.tanh2(hidden)
22
+
23
+ return action
24
+
25
+
26
+
27
+