Spaces:
Sleeping
Sleeping
Updating the trainer UI
Browse files- app.py +24 -25
- trainer.py +18 -5
app.py
CHANGED
@@ -9,30 +9,6 @@ import transformers
|
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
import torch
|
11 |
|
12 |
-
#
|
13 |
-
# model = "tiiuae/falcon-7b-instruct"
|
14 |
-
#
|
15 |
-
# tokenizer = AutoTokenizer.from_pretrained(model)
|
16 |
-
# pipeline = transformers.pipeline(
|
17 |
-
# "text-generation",
|
18 |
-
# model=model,
|
19 |
-
# tokenizer=tokenizer,
|
20 |
-
# torch_dtype=torch.bfloat16,
|
21 |
-
# trust_remote_code=True,
|
22 |
-
# device_map="auto",
|
23 |
-
# )
|
24 |
-
# sequences = pipeline(
|
25 |
-
# "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
26 |
-
# max_length=200,
|
27 |
-
# do_sample=True,
|
28 |
-
# top_k=10,
|
29 |
-
# num_return_sequences=1,
|
30 |
-
# eos_token_id=tokenizer.eos_token_id,
|
31 |
-
# )
|
32 |
-
# st.title("Beyond the Anti-Jam: Integration of DRL with LLM")
|
33 |
-
# for seq in sequences:
|
34 |
-
# st.write(f"Result: {seq['generated_text']}")
|
35 |
-
|
36 |
|
37 |
def perform_training(jammer_type, channel_switching_cost):
|
38 |
agent = train(jammer_type, channel_switching_cost)
|
@@ -68,8 +44,31 @@ st.sidebar.write(f"Channel Switching Cost: {channel_switching_cost}")
|
|
68 |
start_button = st.sidebar.button('Start')
|
69 |
|
70 |
if start_button:
|
71 |
-
agent = perform_training(jammer_type, channel_switching_cost)
|
72 |
st.subheader("Generating Insights of the DRL-Training")
|
73 |
# text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
|
74 |
# st.write(text)
|
75 |
test(agent, jammer_type, channel_switching_cost)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
import torch
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def perform_training(jammer_type, channel_switching_cost):
|
14 |
agent = train(jammer_type, channel_switching_cost)
|
|
|
44 |
start_button = st.sidebar.button('Start')
|
45 |
|
46 |
if start_button:
|
47 |
+
agent, rewards = perform_training(jammer_type, channel_switching_cost)
|
48 |
st.subheader("Generating Insights of the DRL-Training")
|
49 |
# text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
|
50 |
# st.write(text)
|
51 |
test(agent, jammer_type, channel_switching_cost)
|
52 |
+
|
53 |
+
# model = "tiiuae/falcon-7b-instruct"
|
54 |
+
#
|
55 |
+
# tokenizer = AutoTokenizer.from_pretrained(model)
|
56 |
+
# pipeline = transformers.pipeline(
|
57 |
+
# "text-generation",
|
58 |
+
# model=model,
|
59 |
+
# tokenizer=tokenizer,
|
60 |
+
# torch_dtype=torch.bfloat16,
|
61 |
+
# trust_remote_code=True,
|
62 |
+
# device_map="auto",
|
63 |
+
# )
|
64 |
+
# sequences = pipeline(
|
65 |
+
# "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
66 |
+
# max_length=200,
|
67 |
+
# do_sample=True,
|
68 |
+
# top_k=10,
|
69 |
+
# num_return_sequences=1,
|
70 |
+
# eos_token_id=tokenizer.eos_token_id,
|
71 |
+
# )
|
72 |
+
# st.title("Beyond the Anti-Jam: Integration of DRL with LLM")
|
73 |
+
# for seq in sequences:
|
74 |
+
# st.write(f"Result: {seq['generated_text']}")
|
trainer.py
CHANGED
@@ -79,10 +79,23 @@ def train(jammer_type, channel_switching_cost):
|
|
79 |
ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
|
80 |
ax.legend()
|
81 |
|
82 |
-
#
|
83 |
-
st.
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Save the figure
|
88 |
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
@@ -98,4 +111,4 @@ def train(jammer_type, channel_switching_cost):
|
|
98 |
# # Save the agent as a SavedAgent.
|
99 |
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
100 |
# DDQN_agent.save_model(agentName)
|
101 |
-
return DDQN_agent
|
|
|
79 |
ax.set_title(f'Training Rewards - {jammer_type}, CSC: {channel_switching_cost}')
|
80 |
ax.legend()
|
81 |
|
82 |
+
# Use Streamlit layout to create two side-by-side containers
|
83 |
+
with st.beta_container():
|
84 |
+
col1, col2 = st.beta_columns(2)
|
85 |
+
|
86 |
+
with col1:
|
87 |
+
st.subheader("Training Graph")
|
88 |
+
st.set_option('deprecation.showPyplotGlobalUse', False)
|
89 |
+
st.pyplot(fig)
|
90 |
+
|
91 |
+
with col2:
|
92 |
+
st.subheader("Graph Explanation")
|
93 |
+
st.write("""
|
94 |
+
The training graph shows the rewards received by the agent in each episode of the training process.
|
95 |
+
The blue line represents the actual reward values, while the black line represents a rolling average.
|
96 |
+
The red horizontal line indicates the threshold for considering the task solved.
|
97 |
+
The green line represents the epsilon (exploration rate) values for the agent, indicating how often it takes random actions.
|
98 |
+
""")
|
99 |
|
100 |
# Save the figure
|
101 |
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
|
|
111 |
# # Save the agent as a SavedAgent.
|
112 |
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
113 |
# DDQN_agent.save_model(agentName)
|
114 |
+
return DDQN_agent, rewards
|