asataura commited on
Commit
a5f2527
·
1 Parent(s): dd07314

Adding Streamlit Pyplot

Browse files
Files changed (3) hide show
  1. app.py +12 -9
  2. tester.py +26 -17
  3. trainer.py +25 -16
app.py CHANGED
@@ -25,17 +25,20 @@ def main():
25
  st.write(f"Channel Switching Cost: {channel_switching_cost}")
26
 
27
  if st.button('Train'):
28
- st.write("==================================================")
29
- st.write('Training Starting')
30
- trainer.train(jammer_type, channel_switching_cost)
31
- st.write("Training completed")
32
- st.write("==================================================")
33
-
 
 
 
34
  if st.button('Test'):
35
- st.write("==================================================")
36
- st.write('Testing Starting')
37
- agentName = f'savedAgents/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
38
  if os.path.exists(agentName):
 
 
39
  tester.test(jammer_type, channel_switching_cost)
40
  st.write("Testing completed")
41
  st.write("==================================================")
 
25
  st.write(f"Channel Switching Cost: {channel_switching_cost}")
26
 
27
  if st.button('Train'):
28
+ agentName = f'DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
29
+ if os.path.exists(agentName):
30
+ st.write("Agent has been trained already!!!")
31
+ else:
32
+ st.write("==================================================")
33
+ st.write('Training Starting')
34
+ trainer.train(jammer_type, channel_switching_cost)
35
+ st.write("Training completed")
36
+ st.write("==================================================")
37
  if st.button('Test'):
38
+ agentName = f'DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
 
 
39
  if os.path.exists(agentName):
40
+ st.write("==================================================")
41
+ st.write('Testing Starting')
42
  tester.test(jammer_type, channel_switching_cost)
43
  st.write("Testing completed")
44
  st.write("==================================================")
tester.py CHANGED
@@ -4,6 +4,7 @@
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import json
 
7
  from DDQN import DoubleDeepQNetwork
8
  from antiJamEnv import AntiJamEnv
9
 
@@ -12,12 +13,11 @@ def test(jammer_type, channel_switching_cost):
12
  env = AntiJamEnv(jammer_type, channel_switching_cost)
13
  ob_space = env.observation_space
14
  ac_space = env.action_space
15
- print("Observation space: ", ob_space, ob_space.dtype)
16
- print("Action space: ", ac_space, ac_space.n)
17
 
18
  s_size = ob_space.shape[0]
19
  a_size = ac_space.n
20
- total_episodes = 200
21
  max_env_steps = 100
22
  TEST_Episodes = 100
23
  env._max_episode_steps = max_env_steps
@@ -28,7 +28,7 @@ def test(jammer_type, channel_switching_cost):
28
  discount_rate = 0.95
29
  lr = 0.001
30
 
31
- agentName = f'savedAgents/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
32
  DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
33
  DDQN_agent.model = DDQN_agent.load_saved_model(agentName)
34
  rewards = [] # Store rewards for graphing
@@ -45,8 +45,7 @@ def test(jammer_type, channel_switching_cost):
45
  if done or t_test == max_env_steps - 1:
46
  rewards.append(tot_rewards)
47
  epsilons.append(0) # We are doing full exploit
48
- print("episode: {}/{}, score: {}, e: {}"
49
- .format(e_test, TEST_Episodes, tot_rewards, 0))
50
  break
51
  next_state = np.reshape(next_state, [1, s_size])
52
  tot_rewards += reward
@@ -54,21 +53,31 @@ def test(jammer_type, channel_switching_cost):
54
  state = next_state
55
 
56
  # Plotting
57
- plotName = f'results/test/rewards_{jammer_type}_csc_{channel_switching_cost}.png'
58
- rolling_average = np.convolve(rewards, np.ones(10) / 10)
59
- plt.plot(rewards)
60
- plt.plot(rolling_average, color='black')
61
- plt.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-') # Solved Line
62
- # Scale Epsilon (0.001 - 1.0) to match reward (0 - 200) range
63
- eps_graph = [200 * x for x in epsilons]
64
- plt.plot(eps_graph, color='g', linestyle='-')
 
65
  plt.xlabel('Episodes')
66
  plt.ylabel('Rewards')
67
- plt.savefig(plotName, bbox_inches='tight')
68
- plt.show()
 
 
 
 
 
 
 
 
 
69
 
70
  # Save Results
71
  # Rewards
72
- fileName = f'results/test/rewards_{jammer_type}_csc_{channel_switching_cost}.json'
73
  with open(fileName, 'w') as f:
74
  json.dump(rewards, f)
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import json
7
+ import streamlit as st
8
  from DDQN import DoubleDeepQNetwork
9
  from antiJamEnv import AntiJamEnv
10
 
 
13
  env = AntiJamEnv(jammer_type, channel_switching_cost)
14
  ob_space = env.observation_space
15
  ac_space = env.action_space
16
+ st.write(f"Observation space: , {ob_space}")
17
+ st.write(f"Action space: {ac_space}")
18
 
19
  s_size = ob_space.shape[0]
20
  a_size = ac_space.n
 
21
  max_env_steps = 100
22
  TEST_Episodes = 100
23
  env._max_episode_steps = max_env_steps
 
28
  discount_rate = 0.95
29
  lr = 0.001
30
 
31
+ agentName = f'DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
32
  DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
33
  DDQN_agent.model = DDQN_agent.load_saved_model(agentName)
34
  rewards = [] # Store rewards for graphing
 
45
  if done or t_test == max_env_steps - 1:
46
  rewards.append(tot_rewards)
47
  epsilons.append(0) # We are doing full exploit
48
+ st.write(f"episode: {e_test}/{TEST_Episodes}, score: {tot_rewards}, e: {DDQN_agent.epsilon}")
 
49
  break
50
  next_state = np.reshape(next_state, [1, s_size])
51
  tot_rewards += reward
 
53
  state = next_state
54
 
55
  # Plotting
56
+ rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
57
+
58
+ # Create a new Streamlit figure
59
+ fig = plt.figure()
60
+ plt.plot(rewards, label='Rewards')
61
+ plt.plot(rolling_average, color='black', label='Rolling Average')
62
+ plt.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
63
+ eps_graph = [100 * x for x in epsilons]
64
+ plt.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
65
  plt.xlabel('Episodes')
66
  plt.ylabel('Rewards')
67
+ plt.title(f'Testing Rewards - {jammer_type}, CSC: {channel_switching_cost}')
68
+ plt.legend()
69
+
70
+ # Display the Streamlit figure using streamlit.pyplot
71
+ st.set_option('deprecation.showPyplotGlobalUse', False)
72
+ st.pyplot(fig)
73
+
74
+ # Save the figure
75
+ plot_name = f'test_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
76
+ plt.savefig(plot_name, bbox_inches='tight')
77
+ plt.close(fig) # Close the figure to release resources
78
 
79
  # Save Results
80
  # Rewards
81
+ fileName = f'test_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
82
  with open(fileName, 'w') as f:
83
  json.dump(rewards, f)
trainer.py CHANGED
@@ -4,6 +4,7 @@
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import json
 
7
  from DDQN import DoubleDeepQNetwork
8
  from antiJamEnv import AntiJamEnv
9
 
@@ -12,8 +13,8 @@ def train(jammer_type, channel_switching_cost):
12
  env = AntiJamEnv(jammer_type, channel_switching_cost)
13
  ob_space = env.observation_space
14
  ac_space = env.action_space
15
- print("Observation space: ", ob_space, ob_space.dtype)
16
- print("Action space: ", ac_space, ac_space.n)
17
 
18
  s_size = ob_space.shape[0]
19
  a_size = ac_space.n
@@ -38,7 +39,6 @@ def train(jammer_type, channel_switching_cost):
38
  # print(f"Initial state is: {state}")
39
  state = np.reshape(state, [1, s_size]) # Resize to store in memory to pass to .predict
40
  tot_rewards = 0
41
- previous_action = 0
42
  for time in range(max_env_steps): # 200 is when you "solve" the game. This can continue forever as far as I know
43
  action = DDQN_agent.action(state)
44
  next_state, reward, done, _ = env.step(action)
@@ -48,8 +48,7 @@ def train(jammer_type, channel_switching_cost):
48
  if done or time == max_env_steps - 1:
49
  rewards.append(tot_rewards)
50
  epsilons.append(DDQN_agent.epsilon)
51
- print("episode: {}/{}, score: {}, e: {}"
52
- .format(e, TRAIN_Episodes, tot_rewards, DDQN_agent.epsilon))
53
  break
54
  # Applying channel switching cost
55
  next_state = np.reshape(next_state, [1, s_size])
@@ -68,25 +67,35 @@ def train(jammer_type, channel_switching_cost):
68
  break
69
 
70
  # Plotting
71
- plotName = f'results/train/rewards_{jammer_type}_csc_{channel_switching_cost}.png'
72
- rolling_average = np.convolve(rewards, np.ones(10) / 10)
73
- plt.plot(rewards)
74
- plt.plot(rolling_average, color='black')
75
- plt.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-') # Solved Line
76
- # Scale Epsilon (0.001 - 1.0) to match reward (0 - 100) range
 
77
  eps_graph = [100 * x for x in epsilons]
78
- plt.plot(eps_graph, color='g', linestyle='-')
79
  plt.xlabel('Episodes')
80
  plt.ylabel('Rewards')
81
- plt.savefig(plotName, bbox_inches='tight')
82
- plt.show()
 
 
 
 
 
 
 
 
 
83
 
84
  # Save Results
85
  # Rewards
86
- fileName = f'results/train/rewards_{jammer_type}_csc_{channel_switching_cost}.json'
87
  with open(fileName, 'w') as f:
88
  json.dump(rewards, f)
89
 
90
  # Save the agent as a SavedAgent.
91
- agentName = f'savedAgents/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
92
  DDQN_agent.save_model(agentName)
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import json
7
+ import streamlit as st
8
  from DDQN import DoubleDeepQNetwork
9
  from antiJamEnv import AntiJamEnv
10
 
 
13
  env = AntiJamEnv(jammer_type, channel_switching_cost)
14
  ob_space = env.observation_space
15
  ac_space = env.action_space
16
+ st.write(f"Observation space: , {ob_space}")
17
+ st.write(f"Action space: {ac_space}")
18
 
19
  s_size = ob_space.shape[0]
20
  a_size = ac_space.n
 
39
  # print(f"Initial state is: {state}")
40
  state = np.reshape(state, [1, s_size]) # Resize to store in memory to pass to .predict
41
  tot_rewards = 0
 
42
  for time in range(max_env_steps): # 200 is when you "solve" the game. This can continue forever as far as I know
43
  action = DDQN_agent.action(state)
44
  next_state, reward, done, _ = env.step(action)
 
48
  if done or time == max_env_steps - 1:
49
  rewards.append(tot_rewards)
50
  epsilons.append(DDQN_agent.epsilon)
51
+ st.write(f"episode: {e}/{TRAIN_Episodes}, score: {tot_rewards}, e: {DDQN_agent.epsilon}")
 
52
  break
53
  # Applying channel switching cost
54
  next_state = np.reshape(next_state, [1, s_size])
 
67
  break
68
 
69
  # Plotting
70
+ rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
71
+
72
+ # Create a new Streamlit figure
73
+ fig = plt.figure()
74
+ plt.plot(rewards, label='Rewards')
75
+ plt.plot(rolling_average, color='black', label='Rolling Average')
76
+ plt.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
77
  eps_graph = [100 * x for x in epsilons]
78
+ plt.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
79
  plt.xlabel('Episodes')
80
  plt.ylabel('Rewards')
81
+ plt.title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
82
+ plt.legend()
83
+
84
+ # Display the Streamlit figure using streamlit.pyplot
85
+ st.set_option('deprecation.showPyplotGlobalUse', False)
86
+ st.pyplot(fig)
87
+
88
+ # Save the figure
89
+ plot_name = f'train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
90
+ plt.savefig(plot_name, bbox_inches='tight')
91
+ plt.close(fig) # Close the figure to release resources
92
 
93
  # Save Results
94
  # Rewards
95
+ fileName = f'train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
96
  with open(fileName, 'w') as f:
97
  json.dump(rewards, f)
98
 
99
  # Save the agent as a SavedAgent.
100
+ agentName = f'DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
101
  DDQN_agent.save_model(agentName)