asataura commited on
Commit
a86d2bc
·
1 Parent(s): d779fa9

Commenting out the testing part

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. tester.py +0 -3
  3. trainer.py +5 -17
app.py CHANGED
@@ -27,7 +27,7 @@ def main():
27
 
28
  if start_button:
29
  agent = perform_training(jammer_type, channel_switching_cost)
30
- test(agent, jammer_type, channel_switching_cost)
31
 
32
 
33
  def perform_training(jammer_type, channel_switching_cost):
 
27
 
28
  if start_button:
29
  agent = perform_training(jammer_type, channel_switching_cost)
30
+ # test(agent, jammer_type, channel_switching_cost)
31
 
32
 
33
  def perform_training(jammer_type, channel_switching_cost):
tester.py CHANGED
@@ -2,10 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  import numpy as np
5
- import matplotlib.pyplot as plt
6
- import json
7
  import streamlit as st
8
- from DDQN import DoubleDeepQNetwork
9
  from antiJamEnv import AntiJamEnv
10
 
11
 
 
2
  # -*- coding: utf-8 -*-
3
 
4
  import numpy as np
 
 
5
  import streamlit as st
 
6
  from antiJamEnv import AntiJamEnv
7
 
8
 
trainer.py CHANGED
@@ -3,7 +3,6 @@
3
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
- import json
7
  import streamlit as st
8
  from DDQN import DoubleDeepQNetwork
9
  from antiJamEnv import AntiJamEnv
@@ -108,30 +107,18 @@ def train(jammer_type, channel_switching_cost):
108
  st.subheader("Graph Explanation")
109
  st.write(insights)
110
 
111
- # Save the figure
112
- # plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
113
- # plt.savefig(plot_name, bbox_inches='tight')
114
  plt.close(fig) # Close the figure to release resources
115
 
116
- # Save Results
117
- # Rewards
118
- # fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
119
- # with open(fileName, 'w') as f:
120
- # json.dump(rewards, f)
121
- #
122
- # # Save the agent as a SavedAgent.
123
- # agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
124
- # DDQN_agent.save_model(agentName)
125
  return DDQN_agent
126
 
127
 
128
  def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
129
  data_description = (
130
  f"The graph represents training rewards over episodes. "
131
- f"The actual rewards range from {min(rewards)} to {max(rewards)} with an average of {np.mean(rewards):.2f}. "
132
- f"The rolling average values range from {min(rolling_average)} to {max(rolling_average)} with an average of {np.mean(rolling_average):.2f}. "
133
- f"The epsilon values range from {min(epsilons)} to {max(epsilons)} with an average exploration rate of {np.mean(epsilons):.2f}. "
134
- f"The solved threshold is set at {solved_threshold}."
135
  )
136
 
137
  result = llm_chain.predict(data=data_description)
@@ -139,3 +126,4 @@ def generate_insights_langchain(rewards, rolling_average, epsilons, solved_thres
139
 
140
 
141
 
 
 
3
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
 
6
  import streamlit as st
7
  from DDQN import DoubleDeepQNetwork
8
  from antiJamEnv import AntiJamEnv
 
107
  st.subheader("Graph Explanation")
108
  st.write(insights)
109
 
 
 
 
110
  plt.close(fig) # Close the figure to release resources
111
 
 
 
 
 
 
 
 
 
 
112
  return DDQN_agent
113
 
114
 
115
  def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
116
  data_description = (
117
  f"The graph represents training rewards over episodes. "
118
+ f"The actual rewards range from {min(rewards):.2f} to {max(rewards):.2f} with an average of {np.mean(rewards):.2f}. "
119
+ f"The rolling average values range from {min(rolling_average):.2f} to {max(rolling_average):.2f} with an average of {np.mean(rolling_average):.2f}. "
120
+ f"The epsilon values range from {min(epsilons):.2f} to {max(epsilons):.2f} with an average exploration rate of {np.mean(epsilons):.2f}. "
121
+ f"The solved threshold is set at {solved_threshold:.2f}."
122
  )
123
 
124
  result = llm_chain.predict(data=data_description)
 
126
 
127
 
128
 
129
+