Stefan Dumitrescu commited on
Commit
2f0ed55
·
1 Parent(s): 9506fdb
Files changed (1) hide show
  1. app.py +73 -20
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  st.set_page_config(
@@ -64,16 +65,24 @@ def setModel(model_checkpoint):
64
  model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
65
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
66
  return model, tokenizer
 
 
67
  #############################################
 
 
 
 
 
 
68
 
69
  col1, _, col2 = st.columns([10, 1, 16])
70
 
71
  with col1:
72
- st.write("Step 1: Select model")
73
 
74
  model_checkpoint = st.selectbox("Select model", model_list)
75
 
76
- st.write("Step 2: Adjust specific text generation parameters")
77
 
78
  tab_greedy, tab_beamsearch, tab_sampling, tab_typical = st.tabs(["Greedy", "Beam-search", "Sampling", "Typical Sampling"])
79
 
@@ -92,20 +101,19 @@ with col1:
92
 
93
  st.markdown("""---""")
94
 
95
- st.write("Step 3: Adjust common text generation parameters")
96
 
97
  no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3)
98
  temperature = st.slider("Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05)
99
- max_length = st.slider("Max Length", value=20, min_value=10, max_value=200)
100
 
 
101
 
102
- with col2:
103
- with st.container():
104
- button_greedy = st.button("Greedy")
105
- button_beam_search = st.button("Beam-search")
106
- button_sampling = st.button("Sampling")
107
- button_typical = st.button("Typical sampling")
108
 
 
 
 
 
109
 
110
  @st.cache(allow_output_mutation=True)
111
  def setModel(model_checkpoint):
@@ -114,19 +122,56 @@ def setModel(model_checkpoint):
114
  return model, tokenizer
115
 
116
  #####################################################
117
- # run-time
118
 
119
  if 'text' not in st.session_state:
120
  st.session_state['text'] = 'Acesta este un exemplu de text generat de un model de limbă.'
121
 
122
  details = ""
 
123
 
124
- if button_greedy:
125
  model, tokenizer = setModel(model_checkpoint)
 
126
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
127
- input_ids = tokenized_text.input_ids
128
- attention_mask = tokenized_text.attention_mask
129
- output = greedy_search(model, input_ids, attention_mask, no_repeat_ngrams, max_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
131
  details = "Text generated using greedy decoding"
132
 
@@ -135,19 +180,27 @@ if button_sampling:
135
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
136
  input_ids = tokenized_text.input_ids
137
  attention_mask = tokenized_text.attention_mask
138
- output = sampling(model, input_ids, attention_mask, no_repeat_ngrams, max_length, temperature, top_k, top_p)
 
139
  st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
140
  details = f"Text generated using sampling, top-p={top_p:.2f}, top-k={top_k:.2f}, temperature={temperature:.2f}"
141
 
142
  if button_typical:
143
  model, tokenizer = setModel(model_checkpoint)
144
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
145
- input_ids = tokenized_text.input_ids
146
- attention_mask = tokenized_text.attention_mask
147
- output = typical_sampling(model, input_ids, attention_mask, no_repeat_ngrams, max_length, temperature, typical_p)
148
  st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
149
  details = f"Text generated using typical sampling, typical-p={typical_p:.2f}, temperature={temperature:.2f}"
 
150
 
151
  text_element = col2.text_area('Text:', height=400, key="text")
 
 
152
  if details != "":
153
- col2.write(details)
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  st.set_page_config(
 
65
  model = AutoModelForCausalLM.from_pretrained(model_checkpoint)
66
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
67
  return model, tokenizer
68
+
69
+
70
  #############################################
71
+ col_title, _, col_b1, col_b2, col_b3, _ = st.columns([18, 1, 8, 8, 8, 1])
72
+ col_title.markdown("**Playground for text generation with Romanian models**")
73
+ button_greedy = col_b1.button("Greedy generation")
74
+ button_sampling = col_b2.button("Sampling generation")
75
+ button_typical = col_b3.button("Typical sampling generation")
76
+
77
 
78
  col1, _, col2 = st.columns([10, 1, 16])
79
 
80
  with col1:
81
+ st.markdown("**Step 1: Select model**")
82
 
83
  model_checkpoint = st.selectbox("Select model", model_list)
84
 
85
+ st.markdown("**Step 2: Adjust specific text generation parameters**")
86
 
87
  tab_greedy, tab_beamsearch, tab_sampling, tab_typical = st.tabs(["Greedy", "Beam-search", "Sampling", "Typical Sampling"])
88
 
 
101
 
102
  st.markdown("""---""")
103
 
104
+ st.markdown("**Step 3: Adjust common text generation parameters**")
105
 
106
  no_repeat_ngrams = st.slider("No repeat n-grams", value=2, min_value=0, max_value=3)
107
  temperature = st.slider("Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05)
108
+ max_length = st.slider("Number of tokens to generate", value=50, min_value=10, max_value=256)
109
 
110
+ st.markdown("**Step 4: Select a prompt or input your own text, and click generate in the left panel**")
111
 
 
 
 
 
 
 
112
 
113
+ def update_prompt():
114
+ st.session_state['text'] = prompt
115
+
116
+ prompt = st.selectbox("Select prompt", model_list, on_change=update_prompt)
117
 
118
  @st.cache(allow_output_mutation=True)
119
  def setModel(model_checkpoint):
 
122
  return model, tokenizer
123
 
124
  #####################################################
125
+ # show-time
126
 
127
  if 'text' not in st.session_state:
128
  st.session_state['text'] = 'Acesta este un exemplu de text generat de un model de limbă.'
129
 
130
  details = ""
131
+ tokenized_text = None
132
 
133
+ if button_greedy or button_sampling or button_typical:
134
  model, tokenizer = setModel(model_checkpoint)
135
+
136
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
137
+
138
+ if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
139
+ keep_last = 512 - max_length
140
+ print(f"keep last: {keep_last}")
141
+ input_ids, attention_mask = tokenized_text.input_ids[0][:-keep_last], tokenized_text.attention_mask[0][:-keep_last]
142
+ previous_ids = tokenized_text.input_ids[0][:keep_last]
143
+ st.warning(f"kept last {keep_last}")
144
+ else:
145
+ input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0]
146
+ previous_ids = None
147
+
148
+ length = min(512, len(input_ids)+max_length)
149
+ output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
150
+
151
+ if previous_ids is not None:
152
+ new_text = tokenizer.decode(torch.cat([previous_ids, output[0]], dim=1), skip_special_tokens=True)
153
+ else:
154
+ new_text = tokenizer.decode(output[0], skip_special_tokens=True)
155
+
156
+ st.session_state['text'] = new_text
157
+ details = "Text generated using greedy decoding"
158
+
159
+ """
160
+ if button_greedy:
161
+
162
+ tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
163
+ print(f"len text: {len(tokenized_text.input_ids[0])}")
164
+ print(f"max_len : {max_length}")
165
+ if len(tokenized_text.input_ids[0]) + max_length > 512: # need to keep less words
166
+ keep_last = 512 - max_length
167
+ print(f"keep last: {keep_last}")
168
+ input_ids, attention_mask = tokenized_text.input_ids[0][:-keep_last], tokenized_text.attention_mask[0][:-keep_last]
169
+ st.warning(f"kept last {keep_last}")
170
+ else:
171
+ input_ids, attention_mask = tokenized_text.input_ids[0], tokenized_text.attention_mask[0]
172
+
173
+ length = min(512, len(input_ids)+max_length)
174
+ output = greedy_search(model, input_ids.unsqueeze(dim=0), attention_mask.unsqueeze(dim=0), no_repeat_ngrams, length)
175
  st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
176
  details = "Text generated using greedy decoding"
177
 
 
180
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
181
  input_ids = tokenized_text.input_ids
182
  attention_mask = tokenized_text.attention_mask
183
+ length = min(512, len(input_ids[0]) + max_length)
184
+ output = sampling(model, input_ids, attention_mask, no_repeat_ngrams, length, temperature, top_k, top_p)
185
  st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
186
  details = f"Text generated using sampling, top-p={top_p:.2f}, top-k={top_k:.2f}, temperature={temperature:.2f}"
187
 
188
  if button_typical:
189
  model, tokenizer = setModel(model_checkpoint)
190
  tokenized_text = tokenizer(st.session_state['text'], add_special_tokens=False, return_tensors="pt")
191
+ input_ids, attention_mask = tokenized_text.input_ids, tokenized_text.attention_mask
192
+ length = min(512, len(input_ids[0]) + max_length)
193
+ output = typical_sampling(model, input_ids, attention_mask, no_repeat_ngrams, length, temperature, typical_p)
194
  st.session_state['text'] = tokenizer.decode(output[0], skip_special_tokens=True)
195
  details = f"Text generated using typical sampling, typical-p={typical_p:.2f}, temperature={temperature:.2f}"
196
+ """
197
 
198
  text_element = col2.text_area('Text:', height=400, key="text")
199
+ col2.markdown("""---""")
200
+ col2.text("Statistics and details:")
201
  if details != "":
202
+ col2.caption("\tGeneration details: " + details)
203
+ if tokenized_text is None:
204
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
205
+ tt = tokenizer(text_element, add_special_tokens=False, return_tensors="pt")
206
+ col2.caption(f"\tText length is {len(text_element)} characters, {len(tt.input_ids[0])} tokens.")