ngthanhtinqn commited on
Commit
c5f80c4
·
1 Parent(s): 7844376
Files changed (1) hide show
  1. rl_agent/env.py +72 -0
rl_agent/env.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ class Environment:
5
+
6
+ def __init__(self, data, history_t=90):
7
+ self.data = data
8
+ self.history_t = history_t
9
+ self.reset()
10
+
11
+ def reset(self):
12
+ self.t = 0
13
+ self.done = False
14
+ self.profits = 0
15
+ self.positions = []
16
+ self.position_value = 0
17
+ self.history = [0 for _ in range(self.history_t)]
18
+ return [self.position_value] + self.history # obs
19
+
20
+ def step(self, act):
21
+ reward = 0
22
+
23
+ # act = 0: stay, 1: buy, -1: sell
24
+ if act == 1:
25
+ self.positions.append(self.data.iloc[self.t, :]['Close'])
26
+ elif act == 2: # sell
27
+ if len(self.positions) == 0:
28
+ reward = -1
29
+ else:
30
+ profits = 0
31
+ for p in self.positions:
32
+ profits += (self.data.iloc[self.t, :]['Close'] - p)
33
+ reward += profits
34
+ self.profits += profits
35
+ self.positions = []
36
+
37
+ # set next time
38
+ self.t += 1
39
+ self.position_value = 0
40
+ for p in self.positions:
41
+ self.position_value += (self.data.iloc[self.t, :]['Close'] - p)
42
+ self.history.pop(0)
43
+ self.history.append(self.data.iloc[self.t, :]['Close'] - self.data.iloc[(self.t-1), :]['Close'])
44
+
45
+ # clipping reward
46
+ if reward > 0:
47
+ reward = 1
48
+ elif reward < 0:
49
+ reward = -1
50
+
51
+ return [self.position_value] + self.history, reward, self.done # obs, reward, done
52
+
53
+
54
+
55
+
56
+ if __name__ == "__main__":
57
+ data = pd.read_csv('./data/EURUSD_Candlestick_1_M_BID_01.01.2021-04.02.2023.csv')
58
+ # data['Local time'] = pd.to_datetime(data['Local time'])
59
+ data = data.set_index('Local time')
60
+ print(data.index.min(), data.index.max())
61
+
62
+ date_split = '19.09.2022 17:55:00.000 GMT-0500'
63
+ train = data[:date_split]
64
+ test = data[date_split:]
65
+ print(train.head(10))
66
+
67
+ env = Environment(train)
68
+ print(env.reset())
69
+ for _ in range(3):
70
+ pact = np.random.randint(3)
71
+ print(env.step(pact))
72
+