File size: 3,470 Bytes
d67dca9
 
 
 
 
 
a5f2527
9ca06b7
c4ae7c9
d67dca9
 
c4ae7c9
90f9ad1
 
 
6b43fcc
c4ae7c9
d67dca9
 
 
 
 
 
6b43fcc
d67dca9
 
6b43fcc
d67dca9
 
 
 
 
 
 
6b43fcc
 
d67dca9
 
 
6b43fcc
d67dca9
6b43fcc
d67dca9
 
 
 
6b43fcc
d67dca9
 
 
 
6b43fcc
 
 
 
 
 
 
 
d67dca9
6b43fcc
d67dca9
 
 
6b43fcc
 
d67dca9
a5f2527
 
6b43fcc
 
 
 
 
c4ae7c9
6b43fcc
 
 
 
 
a5f2527
 
 
039e45e
 
a5f2527
 
6b43fcc
a5f2527
 
d67dca9
 
 
6b43fcc
d67dca9
 
c4ae7c9
d67dca9
6b43fcc
d67dca9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import json
import streamlit as st
from DDQN import DoubleDeepQNetwork
from antiJamEnv import AntiJamEnv


def train(jammer_type, channel_switching_cost):
    st.subheader("Training Progress")
    progress_bar = st.progress(0)
    status_text = st.empty()

    env = AntiJamEnv(jammer_type, channel_switching_cost)
    ob_space = env.observation_space
    ac_space = env.action_space

    s_size = ob_space.shape[0]
    a_size = ac_space.n
    max_env_steps = 100
    TRAIN_Episodes = 20
    env._max_episode_steps = max_env_steps

    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.999
    discount_rate = 0.95
    lr = 0.001
    batch_size = 32

    DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
    rewards = []
    epsilons = []

    for e in range(TRAIN_Episodes):
        state = env.reset()
        state = np.reshape(state, [1, s_size])
        tot_rewards = 0
        for time in range(max_env_steps):
            action = DDQN_agent.action(state)
            next_state, reward, done, _ = env.step(action)
            next_state = np.reshape(next_state, [1, s_size])
            tot_rewards += reward
            DDQN_agent.store(state, action, reward, next_state, done)
            state = next_state

            if len(DDQN_agent.memory) > batch_size:
                DDQN_agent.experience_replay(batch_size)

            if done or time == max_env_steps - 1:
                rewards.append(tot_rewards)
                epsilons.append(DDQN_agent.epsilon)
                status_text.text(f"Episode: {e+1}/{TRAIN_Episodes}, Reward: {tot_rewards}, Epsilon: {DDQN_agent.epsilon:.3f}")
                progress_bar.progress((e + 1) / TRAIN_Episodes)
                break

        DDQN_agent.update_target_from_model()

        if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
            break

    st.sidebar.success("Training completed!")

    # Plotting
    rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')

    # Create a new Streamlit figure for the training graph
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(rewards, label='Rewards')
    ax.plot(rolling_average, color='black', label='Rolling Average')
    ax.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
    eps_graph = [100 * x for x in epsilons]
    ax.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
    ax.set_xlabel('Episodes')
    ax.set_ylabel('Rewards')
    ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
    ax.legend()

    # Display the Streamlit figure using streamlit.pyplot
    st.set_option('deprecation.showPyplotGlobalUse', False)
    st.subheader("Training Graph")
    st.pyplot(fig)

    # Save the figure
    plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
    plt.savefig(plot_name, bbox_inches='tight')
    plt.close(fig)  # Close the figure to release resources

    # Save Results
    # Rewards
    fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
    with open(fileName, 'w') as f:
        json.dump(rewards, f)

    # Save the agent as a SavedAgent.
    agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
    DDQN_agent.save_model(agentName)