File size: 4,218 Bytes
d67dca9
 
 
 
 
 
a5f2527
9ca06b7
c4ae7c9
d67dca9
 
c4ae7c9
8b7d658
90f9ad1
 
6b43fcc
c4ae7c9
d67dca9
 
 
 
 
 
d1fe52d
d67dca9
 
6b43fcc
d67dca9
 
 
 
 
 
 
6b43fcc
 
d67dca9
 
 
6b43fcc
d67dca9
6b43fcc
d67dca9
 
 
 
6b43fcc
d67dca9
 
 
 
6b43fcc
 
 
 
 
 
 
 
d67dca9
6b43fcc
d67dca9
 
 
8b7d658
6b43fcc
d67dca9
a5f2527
 
6b43fcc
 
 
 
 
c4ae7c9
6b43fcc
 
 
 
 
a5f2527
58a0e00
2b93f0a
 
58a0e00
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f2527
 
897f151
 
a5f2527
d67dca9
 
 
897f151
 
 
 
 
 
 
58a0e00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import json
import streamlit as st
from DDQN import DoubleDeepQNetwork
from antiJamEnv import AntiJamEnv


def train(jammer_type, channel_switching_cost):
    st.subheader("DRL Training Progress")
    progress_bar = st.progress(0)
    status_text = st.empty()

    env = AntiJamEnv(jammer_type, channel_switching_cost)
    ob_space = env.observation_space
    ac_space = env.action_space

    s_size = ob_space.shape[0]
    a_size = ac_space.n
    max_env_steps = 100
    TRAIN_Episodes = 5
    env._max_episode_steps = max_env_steps

    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.999
    discount_rate = 0.95
    lr = 0.001
    batch_size = 32

    DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
    rewards = []
    epsilons = []

    for e in range(TRAIN_Episodes):
        state = env.reset()
        state = np.reshape(state, [1, s_size])
        tot_rewards = 0
        for time in range(max_env_steps):
            action = DDQN_agent.action(state)
            next_state, reward, done, _ = env.step(action)
            next_state = np.reshape(next_state, [1, s_size])
            tot_rewards += reward
            DDQN_agent.store(state, action, reward, next_state, done)
            state = next_state

            if len(DDQN_agent.memory) > batch_size:
                DDQN_agent.experience_replay(batch_size)

            if done or time == max_env_steps - 1:
                rewards.append(tot_rewards)
                epsilons.append(DDQN_agent.epsilon)
                status_text.text(f"Episode: {e+1}/{TRAIN_Episodes}, Reward: {tot_rewards}, Epsilon: {DDQN_agent.epsilon:.3f}")
                progress_bar.progress((e + 1) / TRAIN_Episodes)
                break

        DDQN_agent.update_target_from_model()

        if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
            break

    st.sidebar.success("DRL Training completed!")

    # Plotting
    rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')

    # Create a new Streamlit figure for the training graph
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(rewards, label='Rewards')
    ax.plot(rolling_average, color='black', label='Rolling Average')
    ax.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
    eps_graph = [100 * x for x in epsilons]
    ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
    ax.set_xlabel('Episodes')
    ax.set_ylabel('Rewards')
    ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
    ax.legend()

    # Use Streamlit layout to create two side-by-side containers
    with st.container():
        col1, col2 = st.columns(2)

        with col1:
            st.subheader("Training Graph")
            st.set_option('deprecation.showPyplotGlobalUse', False)
            st.pyplot(fig)

        with col2:
            st.subheader("Graph Explanation")
            st.write("""
               The training graph shows the rewards received by the agent in each episode of the training process.
               The blue line represents the actual reward values, while the black line represents a rolling average.
               The red horizontal line indicates the threshold for considering the task solved.
               The green line represents the epsilon (exploration rate) values for the agent, indicating how often it takes random actions.
               """)

    # Save the figure
    # plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
    # plt.savefig(plot_name, bbox_inches='tight')
    plt.close(fig)  # Close the figure to release resources

    # Save Results
    # Rewards
    # fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
    # with open(fileName, 'w') as f:
    #     json.dump(rewards, f)
    #
    # # Save the agent as a SavedAgent.
    # agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
    # DDQN_agent.save_model(agentName)
    return DDQN_agent, rewards