Stefan Dumitrescu commited on
Commit
19c9e19
·
1 Parent(s): 2f0ed55
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -84,13 +84,10 @@ with col1:
84
 
85
  st.markdown("**Step 2: Adjust specific text generation parameters**")
86
 
87
- tab_greedy, tab_beamsearch, tab_sampling, tab_typical = st.tabs(["Greedy", "Beam-search", "Sampling", "Typical Sampling"])
88
 
89
  with tab_greedy:
90
- st.write("as")
91
-
92
- with tab_beamsearch:
93
- num_beams = st.slider("Num beams", min_value=1, max_value=30, step=5, value=5)
94
 
95
  with tab_sampling:
96
  top_p = st.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
@@ -104,12 +101,13 @@ with col1:
104
  st.markdown("**Step 3: Adjust common text generation parameters**")
105
 
106
  no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3)
107
- temperature = st.slider("Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05)
108
  max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
109
 
110
  st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
111
 
112
 
 
113
  def update_prompt():
114
  st.session_state['text'] = prompt
115
 
@@ -138,7 +136,7 @@ if button_greedy or button_sampling or button_typical:
138
  if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
139
  keep_last = 512 - max_length
140
  print(f"keep last: {keep_last}")
141
- input_ids, attention_mask = tokenized_text.input_ids[0][:-keep_last], tokenized_text.attention_mask[0][:-keep_last]
142
  previous_ids = tokenized_text.input_ids[0][:keep_last]
143
  st.warning(f"kept last {keep_last}")
144
  else:
@@ -149,7 +147,9 @@ if button_greedy or button_sampling or button_typical:
149
  output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
150
 
151
  if previous_ids is not None:
152
- new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=1), skip_special_tokens=True)
 
 
153
  else:
154
  new_text = tokenizer.decode(output[0], skip_special_tokens=True)
155
 
@@ -199,8 +199,8 @@ text_element = col2.text_area('Text:', height=400, key="text")
199
  col2.markdown("""---""")
200
  col2.text("Statistics and details:")
201
  if details != "":
202
- col2.caption("\tGeneration details: " + details)
203
  if tokenized_text is None:
204
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
205
  tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt")
206
- col2.caption(f"\tText length is {len(text_element)} characters, {len(tt.input_ids[0])} tokens.")
 
84
 
85
  st.markdown("**Step 2: Adjust specific text generation parameters**")
86
 
87
+ tab_greedy, tab_sampling, tab_typical = st.tabs(["Greedy", "Sampling", "Typical Sampling"])
88
 
89
  with tab_greedy:
90
+ st.caption("Greedy decoding does not have any special parameters.")
 
 
 
91
 
92
  with tab_sampling:
93
  top_p = st.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
 
101
  st.markdown("**Step 3: Adjust common text generation parameters**")
102
 
103
  no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3)
104
+ temperature = st.slider("Temperature", value=1.0, min_value=0.1, max_value=1.0, step=0.1)
105
  max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
106
 
107
  st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
108
 
109
 
110
+
111
  def update_prompt():
112
  st.session_state['text'] = prompt
113
 
 
136
  if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
137
  keep_last = 512 - max_length
138
  print(f"keep last: {keep_last}")
139
+ input_ids, attention_mask = tokenized_text.input_ids[0][-keep_last:], tokenized_text.attention_mask[0][-keep_last:]
140
  previous_ids = tokenized_text.input_ids[0][:keep_last]
141
  st.warning(f"kept last {keep_last}")
142
  else:
 
147
  output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
148
 
149
  if previous_ids is not None:
150
+ print(f"\nConcat prev id: "+tokenizer.decode(previous_ids, skip_special_tokens=True))
151
+ print(f"\nWith current decode: " + tokenizer.decode(output[0], skip_special_tokens=True))
152
+ new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=-1), skip_special_tokens=True)
153
  else:
154
  new_text = tokenizer.decode(output[0], skip_special_tokens=True)
155
 
 
199
  col2.markdown("""---""")
200
  col2.text("Statistics and details:")
201
  if details != "":
202
+ col2.caption("   Generation details: " + details)
203
  if tokenized_text is None:
204
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
205
  tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt")
206
+ col2.caption(f"   Text length is {len(text_element)} characters, {len(tt.input_ids[0])} tokens.")