asataura commited on
Commit
8b7d658
·
1 Parent(s): c91ca50

Updating the Tokenizer to TF version

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. trainer.py +2 -2
app.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  from trainer import train
7
  from tester import test
8
  import transformers
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
 
12
  def main():
@@ -34,9 +34,10 @@ def main():
34
 
35
  def perform_training(jammer_type, channel_switching_cost):
36
  agent = train(jammer_type, channel_switching_cost)
 
37
  model_name = "tiiuae/falcon-7b-instruct"
38
- model = AutoModelForCausalLM.from_pretrained(model_name)
39
- tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  pipeline = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100, temperature=0.7)
41
  text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
42
  st.write(text)
 
6
  from trainer import train
7
  from tester import test
8
  import transformers
9
+ from transformers import TFAutoModelForCausalLM, TFAutoTokenizer
10
 
11
 
12
  def main():
 
34
 
35
  def perform_training(jammer_type, channel_switching_cost):
36
  agent = train(jammer_type, channel_switching_cost)
37
+ st.subheader("Generating Insights of the DRL-Training")
38
  model_name = "tiiuae/falcon-7b-instruct"
39
+ model = TFAutoModelForCausalLM.from_pretrained(model_name)
40
+ tokenizer = TFAutoTokenizer.from_pretrained(model_name)
41
  pipeline = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=100, temperature=0.7)
42
  text = pipeline("Discuss this topic: Integrating LLMs to DRL-based anti-jamming.")
43
  st.write(text)
trainer.py CHANGED
@@ -10,7 +10,7 @@ from antiJamEnv import AntiJamEnv
10
 
11
 
12
  def train(jammer_type, channel_switching_cost):
13
- st.subheader("Training Progress")
14
  progress_bar = st.progress(0)
15
  status_text = st.empty()
16
 
@@ -62,7 +62,7 @@ def train(jammer_type, channel_switching_cost):
62
  if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
63
  break
64
 
65
- st.sidebar.success("Training completed!")
66
 
67
  # Plotting
68
  rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')
 
10
 
11
 
12
  def train(jammer_type, channel_switching_cost):
13
+ st.subheader("DRL Training Progress")
14
  progress_bar = st.progress(0)
15
  status_text = st.empty()
16
 
 
62
  if len(rewards) > 10 and np.average(rewards[-10:]) >= max_env_steps - 0.10 * max_env_steps:
63
  break
64
 
65
+ st.sidebar.success("DRL Training completed!")
66
 
67
  # Plotting
68
  rolling_average = np.convolve(rewards, np.ones(10) / 10, mode='valid')