Spaces:
Sleeping
Sleeping
File size: 4,218 Bytes
d67dca9 a5f2527 9ca06b7 c4ae7c9 d67dca9 c4ae7c9 8b7d658 90f9ad1 6b43fcc c4ae7c9 d67dca9 d1fe52d d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 8b7d658 6b43fcc d67dca9 a5f2527 6b43fcc c4ae7c9 6b43fcc a5f2527 58a0e00 2b93f0a 58a0e00 a5f2527 897f151 a5f2527 d67dca9 897f151 58a0e00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import json
import streamlit as st
from DDQN import DoubleDeepQNetwork
from antiJamEnv import AntiJamEnv
def train(jammer_type, channel_switching_cost):
st.subheader("DRL Training Progress")
progress_bar = st.progress(0)
status_text = st.empty()
env = AntiJamEnv(jammer_type, channel_switching_cost)
ob_space = env.observation_space
ac_space = env.action_space
s_size = ob_space.shape[0]
a_size = ac_space.n
max_env_steps = 100
TRAIN_Episodes = 5
env._max_episode_steps = max_env_steps
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.999
discount_rate = 0.95
lr = 0.001
batch_size = 32
DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
rewards = []
epsilons = []
for e in range(TRAIN_Episodes):
state = env.reset()
state = np.reshape(state, [1, s_size])
tot_rewards = 0
for time in range(max_env_steps):
action = DDQN_agent.action(state)
next_state, reward, done, _ = env.step(action)
next_state = np.reshape(next_state, [1, s_size])
tot_rewards += reward
DDQN_agent.store(state, action, reward, next_state, done)
state = next_state
if len(DDQN_agent.memory) > batch_size:
DDQN_agent.experience_replay(batch_size)
if done or time == max_env_steps - 1:
rewards.append(tot_rewards)
epsilons.append(DDQN_agent.epsilon)
status_text.text(f"Episode: {e+1}/{TRAIN_Episodes}, Reward: {tot_rewards}, Epsilon: {DDQN_agent.epsilon:.3f}")
progress_bar.progress((e + 1) / TRAIN_Episodes)
break
DDQN_agent.update_target_from_model()
if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
break
st.sidebar.success("DRL Training completed!")
# Plotting
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
# Create a new Streamlit figure for the training graph
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(rewards, label='Rewards')
ax.plot(rolling_average, color='black', label='Rolling Average')
ax.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
eps_graph = [100 * x for x in epsilons]
ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
ax.set_xlabel('Episodes')
ax.set_ylabel('Rewards')
ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
ax.legend()
# Use Streamlit layout to create two side-by-side containers
with st.container():
col1, col2 = st.columns(2)
with col1:
st.subheader("Training Graph")
st.set_option('deprecation.showPyplotGlobalUse', False)
st.pyplot(fig)
with col2:
st.subheader("Graph Explanation")
st.write("""
The training graph shows the rewards received by the agent in each episode of the training process.
The blue line represents the actual reward values, while the black line represents a rolling average.
The red horizontal line indicates the threshold for considering the task solved.
The green line represents the epsilon (exploration rate) values for the agent, indicating how often it takes random actions.
""")
# Save the figure
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
# plt.savefig(plot_name, bbox_inches='tight')
plt.close(fig) # Close the figure to release resources
# Save Results
# Rewards
# fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
# with open(fileName, 'w') as f:
# json.dump(rewards, f)
#
# # Save the agent as a SavedAgent.
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
# DDQN_agent.save_model(agentName)
return DDQN_agent, rewards |