Amitontheweb commited on
Commit
0cab724
·
verified ·
1 Parent(s): 12a5174

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -14,8 +14,8 @@ token = os.environ.get("HF_TOKEN")
14
  #tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
15
  #model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
16
 
17
- tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
18
- model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
19
 
20
  #tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
21
  #model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
@@ -138,9 +138,9 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
138
 
139
  def load_model(model_selected):
140
 
141
- if model_selected == "GPT2":
142
- tokenizer = tokenizer_gpt2
143
- model = model_gpt2
144
  #print (model_selected + " loaded")
145
 
146
  #if model_selected == "Gemma 2":
@@ -304,8 +304,8 @@ with gr.Blocks() as demo:
304
 
305
  No_beam_group_list = [2]
306
 
307
- tokenizer = tokenizer_gpt2
308
- model = model_gpt2
309
 
310
  with gr.Row():
311
 
 
14
  #tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
15
  #model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
16
 
17
+ #tokenizer_gpt2 = AutoTokenizer.from_pretrained("openai-community/gpt2")
18
+ #model_gpt2 = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
19
 
20
  #tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-2b")
21
  #model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=token)
 
138
 
139
  def load_model(model_selected):
140
 
141
+ #if model_selected == "GPT2":
142
+ #tokenizer = tokenizer_gpt2
143
+ #model = model_gpt2
144
  #print (model_selected + " loaded")
145
 
146
  #if model_selected == "Gemma 2":
 
304
 
305
  No_beam_group_list = [2]
306
 
307
+ tokenizer = tokenizer_qwen
308
+ model = model_qwen
309
 
310
  with gr.Row():
311