Spaces:
Sleeping
Sleeping
Updating the Tokenizer to TF version
Browse files- app.py +4 -3
- trainer.py +2 -2
app.py
CHANGED
@@ -6,7 +6,7 @@ import os
|
|
6 |
from trainer import train
|
7 |
from tester import test
|
8 |
import transformers
|
9 |
-
from transformers import
|
10 |
|
11 |
|
12 |
def main():
|
@@ -34,9 +34,10 @@ def main():
|
|
34 |
|
35 |
def perform_training(jammer_type, channel_switching_cost):
|
36 |
agent = train(jammer_type, channel_switching_cost)
|
|
|
37 |
model_name = "tiiuae/falcon-7b-instruct"
|
38 |
-
model =
|
39 |
-
tokenizer =
|
40 |
pipeline = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100, temperature=0.7)
|
41 |
text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
|
42 |
st.write(text)
|
|
|
6 |
from trainer import train
|
7 |
from tester import test
|
8 |
import transformers
|
9 |
+
from transformers import TFAutoModelForCausalLM, TFAutoTokenizer
|
10 |
|
11 |
|
12 |
def main():
|
|
|
34 |
|
35 |
def perform_training(jammer_type, channel_switching_cost):
|
36 |
agent = train(jammer_type, channel_switching_cost)
|
37 |
+
st.subheader("Generating Insights of the DRL-Training")
|
38 |
model_name = "tiiuae/falcon-7b-instruct"
|
39 |
+
model = TFAutoModelForCausalLM.from_pretrained(model_name)
|
40 |
+
tokenizer = TFAutoTokenizer.from_pretrained(model_name)
|
41 |
pipeline = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100, temperature=0.7)
|
42 |
text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
|
43 |
st.write(text)
|
trainer.py
CHANGED
@@ -10,7 +10,7 @@ from antiJamEnv import AntiJamEnv
|
|
10 |
|
11 |
|
12 |
def train(jammer_type, channel_switching_cost):
|
13 |
-
st.subheader("Training Progress")
|
14 |
progress_bar = st.progress(0)
|
15 |
status_text = st.empty()
|
16 |
|
@@ -62,7 +62,7 @@ def train(jammer_type, channel_switching_cost):
|
|
62 |
if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
|
63 |
break
|
64 |
|
65 |
-
st.sidebar.success("Training completed!")
|
66 |
|
67 |
# Plotting
|
68 |
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
|
|
|
10 |
|
11 |
|
12 |
def train(jammer_type, channel_switching_cost):
|
13 |
+
st.subheader("DRL Training Progress")
|
14 |
progress_bar = st.progress(0)
|
15 |
status_text = st.empty()
|
16 |
|
|
|
62 |
if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
|
63 |
break
|
64 |
|
65 |
+
st.sidebar.success("DRL Training completed!")
|
66 |
|
67 |
# Plotting
|
68 |
rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
|