Yeb Havinga commited on
Commit
5cf4ee2
·
1 Parent(s): a9f2b23

Make seed configurable

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -87,12 +87,6 @@ def instantiate_models():
87
  p["pipeline"].load()
88
 
89
 
90
- def set_new_seed():
91
- seed = randint(0, 2**32 - 1)
92
- set_seed(seed)
93
- return seed
94
-
95
-
96
  def main():
97
  st.set_page_config( # Alternate names: setup_page, page, layout
98
  page_title="Netherator", # String or None. Strings get appended with "• Streamlit".
@@ -122,9 +116,6 @@ def main():
122
 
123
  st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
124
 
125
- # min_length = st.sidebar.number_input(
126
- # "Min length", min_value=10, max_value=150, value=75
127
- # )
128
  max_length = st.sidebar.number_input(
129
  "Lengte van de tekst",
130
  value=200,
@@ -140,8 +131,28 @@ def main():
140
  "Num return sequences", min_value=1, max_value=5, value=1
141
  )
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  if sampling_mode := st.sidebar.selectbox(
144
- "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
145
  ):
146
  if sampling_mode == "Beam Search":
147
  num_beams = st.sidebar.number_input(
@@ -200,7 +211,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
200
  estimate = int(estimate)
201
 
202
  with st.spinner(
203
- text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
204
  ):
205
  memory = psutil.virtual_memory()
206
  generator = next(
@@ -211,7 +222,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
211
  ),
212
  None,
213
  )
214
- seed = set_new_seed()
215
  time_start = time.time()
216
  result = generator.get_text(text=st.session_state.text, **params)
217
  time_end = time.time()
 
87
  p["pipeline"].load()
88
 
89
 
 
 
 
 
 
 
90
  def main():
91
  st.set_page_config( # Alternate names: setup_page, page, layout
92
  page_title="Netherator", # String or None. Strings get appended with "• Streamlit".
 
116
 
117
  st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
118
 
 
 
 
119
  max_length = st.sidebar.number_input(
120
  "Lengte van de tekst",
121
  value=200,
 
131
  "Num return sequences", min_value=1, max_value=5, value=1
132
  )
133
 
134
+ seed_placeholder = st.sidebar.empty()
135
+ if "seed" not in st.session_state:
136
+ print(f"Session state {st.session_state} does not contain seed")
137
+ st.session_state["seed"] = 4162549114
138
+ print(f"Seed is set to: {st.session_state['seed']}")
139
+
140
+ seed = seed_placeholder.number_input(
141
+ "Seed", min_value=0, max_value=2 ** 32 - 1, value=st.session_state["seed"]
142
+ )
143
+
144
+ def set_random_seed():
145
+ st.session_state["seed"] = randint(0, 2 ** 32 - 1)
146
+ seed = seed_placeholder.number_input(
147
+ "Seed", min_value=0, max_value=2 ** 32 - 1, value=st.session_state["seed"]
148
+ )
149
+ print(f"New random seed set to: {seed}")
150
+
151
+ if st.button("New random seed?"):
152
+ set_random_seed()
153
+
154
  if sampling_mode := st.sidebar.selectbox(
155
+ "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
156
  ):
157
  if sampling_mode == "Beam Search":
158
  num_beams = st.sidebar.number_input(
 
211
  estimate = int(estimate)
212
 
213
  with st.spinner(
214
+ text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
215
  ):
216
  memory = psutil.virtual_memory()
217
  generator = next(
 
222
  ),
223
  None,
224
  )
225
+ set_seed(seed)
226
  time_start = time.time()
227
  result = generator.get_text(text=st.session_state.text, **params)
228
  time_end = time.time()