Spaces:
Sleeping
Sleeping
File size: 3,470 Bytes
d67dca9 a5f2527 9ca06b7 c4ae7c9 d67dca9 c4ae7c9 90f9ad1 6b43fcc c4ae7c9 d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 6b43fcc d67dca9 a5f2527 6b43fcc c4ae7c9 6b43fcc a5f2527 039e45e a5f2527 6b43fcc a5f2527 d67dca9 6b43fcc d67dca9 c4ae7c9 d67dca9 6b43fcc d67dca9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import json
import streamlit as st
from DDQN import DoubleDeepQNetwork
from antiJamEnv import AntiJamEnv
def train(jammer_type, channel_switching_cost):
st.subheader("Training Progress")
progress_bar = st.progress(0)
status_text = st.empty()
env = AntiJamEnv(jammer_type, channel_switching_cost)
ob_space = env.observation_space
ac_space = env.action_space
s_size = ob_space.shape[0]
a_size = ac_space.n
max_env_steps = 100
TRAIN_Episodes = 20
env._max_episode_steps = max_env_steps
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.999
discount_rate = 0.95
lr = 0.001
batch_size = 32
DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
rewards = []
epsilons = []
for e in range(TRAIN_Episodes):
state = env.reset()
state = np.reshape(state, [1, s_size])
tot_rewards = 0
for time in range(max_env_steps):
action = DDQN_agent.action(state)
next_state, reward, done, _ = env.step(action)
next_state = np.reshape(next_state, [1, s_size])
tot_rewards += reward
DDQN_agent.store(state, action, reward, next_state, done)
state = next_state
if len(DDQN_agent.memory) > batch_size:
DDQN_agent.experience_replay(batch_size)
if done or time == max_env_steps - 1:
rewards.append(tot_rewards)
epsilons.append(DDQN_agent.epsilon)
status_text.text(f"Episode: {e+1}/{TRAIN_Episodes}, Reward: {tot_rewards}, Epsilon: {DDQN_agent.epsilon:.3f}")
progress_bar.progress((e + 1) / TRAIN_Episodes)
break
DDQN_agent.update_target_from_model()
if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
break
st.sidebar.success("Training completed!")
# Plotting
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
# Create a new Streamlit figure for the training graph
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(rewards, label='Rewards')
ax.plot(rolling_average, color='black', label='Rolling Average')
ax.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
eps_graph = [100 * x for x in epsilons]
ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
ax.set_xlabel('Episodes')
ax.set_ylabel('Rewards')
ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
ax.legend()
# Display the Streamlit figure using streamlit.pyplot
st.set_option('deprecation.showPyplotGlobalUse', False)
st.subheader("Training Graph")
st.pyplot(fig)
# Save the figure
plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
plt.savefig(plot_name, bbox_inches='tight')
plt.close(fig) # Close the figure to release resources
# Save Results
# Rewards
fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
with open(fileName, 'w') as f:
json.dump(rewards, f)
# Save the agent as a SavedAgent.
agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
DDQN_agent.save_model(agentName)
|