Spaces:
Sleeping
Sleeping
Commenting out the testing part
Browse files- app.py +1 -1
- tester.py +0 -3
- trainer.py +5 -17
app.py
CHANGED
@@ -27,7 +27,7 @@ def main():
|
|
27 |
|
28 |
if start_button:
|
29 |
agent = perform_training(jammer_type, channel_switching_cost)
|
30 |
-
test(agent, jammer_type, channel_switching_cost)
|
31 |
|
32 |
|
33 |
def perform_training(jammer_type, channel_switching_cost):
|
|
|
27 |
|
28 |
if start_button:
|
29 |
agent = perform_training(jammer_type, channel_switching_cost)
|
30 |
+
# test(agent, jammer_type, channel_switching_cost)
|
31 |
|
32 |
|
33 |
def perform_training(jammer_type, channel_switching_cost):
|
tester.py
CHANGED
@@ -2,10 +2,7 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
import numpy as np
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
import json
|
7 |
import streamlit as st
|
8 |
-
from DDQN import DoubleDeepQNetwork
|
9 |
from antiJamEnv import AntiJamEnv
|
10 |
|
11 |
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
import numpy as np
|
|
|
|
|
5 |
import streamlit as st
|
|
|
6 |
from antiJamEnv import AntiJamEnv
|
7 |
|
8 |
|
trainer.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
|
4 |
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
6 |
-
import json
|
7 |
import streamlit as st
|
8 |
from DDQN import DoubleDeepQNetwork
|
9 |
from antiJamEnv import AntiJamEnv
|
@@ -108,30 +107,18 @@ def train(jammer_type, channel_switching_cost):
|
|
108 |
st.subheader("Graph Explanation")
|
109 |
st.write(insights)
|
110 |
|
111 |
-
# Save the figure
|
112 |
-
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
113 |
-
# plt.savefig(plot_name, bbox_inches='tight')
|
114 |
plt.close(fig) # Close the figure to release resources
|
115 |
|
116 |
-
# Save Results
|
117 |
-
# Rewards
|
118 |
-
# fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
119 |
-
# with open(fileName, 'w') as f:
|
120 |
-
# json.dump(rewards, f)
|
121 |
-
#
|
122 |
-
# # Save the agent as a SavedAgent.
|
123 |
-
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
124 |
-
# DDQN_agent.save_model(agentName)
|
125 |
return DDQN_agent
|
126 |
|
127 |
|
128 |
def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
|
129 |
data_description = (
|
130 |
f"The graph represents training rewards over episodes. "
|
131 |
-
f"The actual rewards range from {min(rewards)} to {max(rewards)} with an average of {np.mean(rewards):.2f}. "
|
132 |
-
f"The rolling average values range from {min(rolling_average)} to {max(rolling_average)} with an average of {np.mean(rolling_average):.2f}. "
|
133 |
-
f"The epsilon values range from {min(epsilons)} to {max(epsilons)} with an average exploration rate of {np.mean(epsilons):.2f}. "
|
134 |
-
f"The solved threshold is set at {solved_threshold}."
|
135 |
)
|
136 |
|
137 |
result = llm_chain.predict(data=data_description)
|
@@ -139,3 +126,4 @@ def generate_insights_langchain(rewards, rolling_average, epsilons, solved_thres
|
|
139 |
|
140 |
|
141 |
|
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
|
|
6 |
import streamlit as st
|
7 |
from DDQN import DoubleDeepQNetwork
|
8 |
from antiJamEnv import AntiJamEnv
|
|
|
107 |
st.subheader("Graph Explanation")
|
108 |
st.write(insights)
|
109 |
|
|
|
|
|
|
|
110 |
plt.close(fig) # Close the figure to release resources
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return DDQN_agent
|
113 |
|
114 |
|
115 |
def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
|
116 |
data_description = (
|
117 |
f"The graph represents training rewards over episodes. "
|
118 |
+
f"The actual rewards range from {min(rewards):.2f} to {max(rewards):.2f} with an average of {np.mean(rewards):.2f}. "
|
119 |
+
f"The rolling average values range from {min(rolling_average):.2f} to {max(rolling_average):.2f} with an average of {np.mean(rolling_average):.2f}. "
|
120 |
+
f"The epsilon values range from {min(epsilons):.2f} to {max(epsilons):.2f} with an average exploration rate of {np.mean(epsilons):.2f}. "
|
121 |
+
f"The solved threshold is set at {solved_threshold:.2f}."
|
122 |
)
|
123 |
|
124 |
result = llm_chain.predict(data=data_description)
|
|
|
126 |
|
127 |
|
128 |
|
129 |
+
|