File size: 4,042 Bytes
d67dca9
 
 
 
 
 
a5f2527
9ca06b7
c4ae7c9
d67dca9
 
c4ae7c9
 
d67dca9
 
a5f2527
 
d67dca9
 
 
 
c0bb1b2
d67dca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f2527
d67dca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f2527
 
 
 
 
 
 
c4ae7c9
a5f2527
d67dca9
 
a5f2527
 
 
 
 
 
 
 
 
 
 
d67dca9
 
 
a5f2527
d67dca9
 
c4ae7c9
d67dca9
a5f2527
d67dca9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import json
import streamlit as st
from DDQN import DoubleDeepQNetwork
from antiJamEnv import AntiJamEnv


def train(jammer_type, channel_switching_cost):
    env = AntiJamEnv(jammer_type, channel_switching_cost)
    ob_space = env.observation_space
    ac_space = env.action_space
    st.write(f"Observation space: , {ob_space}")
    st.write(f"Action space: {ac_space}")

    s_size = ob_space.shape[0]
    a_size = ac_space.n
    max_env_steps = 100
    TRAIN_Episodes = 10
    env._max_episode_steps = max_env_steps

    epsilon = 1.0  # exploration rate
    epsilon_min = 0.01
    epsilon_decay = 0.999
    discount_rate = 0.95
    lr = 0.001
    batch_size = 32

    DDQN_agent = DoubleDeepQNetwork(s_size, a_size, lr, discount_rate, epsilon, epsilon_min, epsilon_decay)
    rewards = []  # Store rewards for graphing
    epsilons = []  # Store the Explore/Exploit

    # Training agent
    for e in range(TRAIN_Episodes):
        state = env.reset()
        # print(f"Initial state is: {state}")
        state = np.reshape(state, [1, s_size])  # Resize to store in memory to pass to .predict
        tot_rewards = 0
        for time in range(max_env_steps):  # 200 is when you "solve" the game. This can continue forever as far as I know
            action = DDQN_agent.action(state)
            next_state, reward, done, _ = env.step(action)
            # print(f'The next state is: {next_state}')
            # done: Three collisions occurred in the last 10 steps.
            # time == max_env_steps - 1 : No collisions occurred
            if done or time == max_env_steps - 1:
                rewards.append(tot_rewards)
                epsilons.append(DDQN_agent.epsilon)
                st.write(f"episode: {e}/{TRAIN_Episodes}, score: {tot_rewards}, e: {DDQN_agent.epsilon}")
                break
            # Applying channel switching cost
            next_state = np.reshape(next_state, [1, s_size])
            tot_rewards += reward
            DDQN_agent.store(state, action, reward, next_state, done)  # Resize to store in memory to pass to .predict
            state = next_state

            # Experience Replay
            if len(DDQN_agent.memory) > batch_size:
                DDQN_agent.experience_replay(batch_size)
        # Update the weights after each episode (You can configure this for x steps as well
        DDQN_agent.update_target_from_model()
        # If our current NN passes we are done
        # Early stopping criteria: I am going to use the last 10 runs within 1% of the max
        if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
            break

    # Plotting
    rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')

    # Create a new Streamlit figure
    fig = plt.figure()
    plt.plot(rewards, label='Rewards')
    plt.plot(rolling_average, color='black', label='Rolling Average')
    plt.axhline(y=max_env_steps - 0.10 * max_env_steps, color='r', linestyle='-', label='Solved Line')
    eps_graph = [100 * x for x in epsilons]
    plt.plot(eps_graph, color='g', linestyle='-', label='Epsilons')
    plt.xlabel('Episodes')
    plt.ylabel('Rewards')
    plt.title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
    plt.legend()

    # Display the Streamlit figure using streamlit.pyplot
    st.set_option('deprecation.showPyplotGlobalUse', False)
    st.pyplot(fig)

    # Save the figure
    plot_name = f'train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
    plt.savefig(plot_name, bbox_inches='tight')
    plt.close(fig)  # Close the figure to release resources

    # Save Results
    # Rewards
    fileName = f'train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
    with open(fileName, 'w') as f:
        json.dump(rewards, f)

    # Save the agent as a SavedAgent.
    agentName = f'DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
    DDQN_agent.save_model(agentName)