dejanseo commited on
Commit
6fe4e82
·
verified ·
1 Parent(s): f17abc6

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +84 -86
src/streamlit_app.py CHANGED
@@ -108,14 +108,12 @@ def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5
108
  max_new_tokens=MAX_TARGET_LENGTH,
109
  )
110
 
111
- # Decode and simple post-filter
112
  expansions: List[str] = []
113
  for seq in outputs:
114
  s = tok.decode(seq, skip_special_tokens=True)
115
  if s and normalize_text(s) != normalize_text(query):
116
  expansions.append(s)
117
 
118
- # Deduplicate preserve order
119
  seen = set()
120
  uniq = []
121
  for s in expansions:
@@ -141,95 +139,95 @@ tok, model, device = load_model()
141
  st.title("Query Fanout Generator")
142
  st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
143
 
 
144
  col1, col2 = st.columns(2)
145
  with col1:
146
  url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.")
147
- b1, b2 = st.columns(2)
148
- with b1:
149
- deep_btn = st.button("Deep Analysis", use_container_width=True)
150
- with b2:
151
- quick_btn = st.button("Quick Fan-Out", use_container_width=True)
152
-
153
  with col2:
154
  query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.")
155
 
156
- # ---- Deep Analysis path (sampling, large batches) ----
157
- if deep_btn:
158
- cfg = GENERATION_CONFIG
159
- with st.spinner("Generating queries..."):
160
- start_ts = time.time()
161
- inputs = build_inputs(tok, url, query, device)
162
-
163
- all_texts, all_scores = [], []
164
- seen_texts_for_bad_words = set()
165
-
166
- num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
167
- progress_bar = st.progress(0)
168
-
169
- for i in range(num_batches):
170
- current_seed = cfg["seed"] + i
171
- torch.manual_seed(current_seed)
172
- if torch.cuda.is_available():
173
- torch.cuda.manual_seed_all(current_seed)
174
-
175
- bad_words_ids = None
176
- if seen_texts_for_bad_words:
177
- bad_words_ids = tok(
178
- list(seen_texts_for_bad_words),
179
- add_special_tokens=False,
180
- padding=True,
181
- truncation=True
182
- )["input_ids"]
183
-
184
- batch_texts, batch_scores = sampling_generate(
185
- tok, model, device, inputs,
186
- top_n=GENERATION_BATCH_SIZE,
187
- temperature=float(cfg["temperature"]),
188
- top_p=float(cfg["top_p"]),
189
- no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
190
- repetition_penalty=float(cfg["repetition_penalty"]),
191
- bad_words_ids=bad_words_ids
192
- )
193
-
194
- all_texts.extend(batch_texts)
195
- all_scores.extend(batch_scores)
196
- for txt in batch_texts:
197
- if txt:
198
- seen_texts_for_bad_words.add(txt)
199
-
200
- progress_bar.progress((i + 1) / num_batches)
201
-
202
- # Deduplicate and finalize
203
- final_enriched = []
204
- final_seen_normalized = set()
205
- for txt, sc in zip(all_texts, all_scores):
206
- norm = normalize_text(txt)
207
- if norm and norm not in final_seen_normalized and norm != query.lower():
208
- final_seen_normalized.add(norm)
209
- final_enriched.append({"logp/len": sc, "text": txt})
210
-
211
- if cfg["sort_by"] == "logp/len":
212
- final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
213
-
214
- final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
215
-
216
- if not final_enriched:
217
- st.warning("No queries were generated. Try a different input.")
218
- else:
219
- output_texts = [item['text'] for item in final_enriched]
220
- df = pd.DataFrame(output_texts, columns=["Generated Query"])
221
- df.index = range(1, len(df) + 1)
222
- st.dataframe(df, use_container_width=True)
223
-
224
- # ---- Quick Fan-Out path (beam-based, small and simple) ----
225
- if quick_btn:
226
- with st.spinner("Generating quick fan-out..."):
227
- start_time = time.time()
228
- expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10)
229
 
230
- if expansions:
231
- df = pd.DataFrame(expansions, columns=["Generated Query"])
232
- df.index = range(1, len(df) + 1)
233
- st.dataframe(df, use_container_width=True)
234
  else:
235
- st.warning("No valid fan-outs generated. Try a different query.")
 
 
 
 
 
 
 
 
 
 
 
108
  max_new_tokens=MAX_TARGET_LENGTH,
109
  )
110
 
 
111
  expansions: List[str] = []
112
  for seq in outputs:
113
  s = tok.decode(seq, skip_special_tokens=True)
114
  if s and normalize_text(s) != normalize_text(query):
115
  expansions.append(s)
116
 
 
117
  seen = set()
118
  uniq = []
119
  for s in expansions:
 
139
  st.title("Query Fanout Generator")
140
  st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
141
 
142
+ # Inputs
143
  col1, col2 = st.columns(2)
144
  with col1:
145
  url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.")
 
 
 
 
 
 
146
  with col2:
147
  query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.")
148
 
149
+ # Mode + single Run button
150
+ mode_high_effort = st.toggle("High Effort", value=False, help="On = Deep Analysis (stochastic sampling, large batch). Off = Quick Fan-Out (beam-based).")
151
+ run_btn = st.button("Generate", type="primary")
152
+
153
+ if run_btn:
154
+ if mode_high_effort:
155
+ # ---- Deep Analysis path (sampling, large batches) ----
156
+ cfg = GENERATION_CONFIG
157
+ with st.spinner("Generating queries..."):
158
+ start_ts = time.time()
159
+ inputs = build_inputs(tok, url, query, device)
160
+
161
+ all_texts, all_scores = [], []
162
+ seen_texts_for_bad_words = set()
163
+
164
+ num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
165
+ progress_bar = st.progress(0)
166
+
167
+ for i in range(num_batches):
168
+ current_seed = cfg["seed"] + i
169
+ torch.manual_seed(current_seed)
170
+ if torch.cuda.is_available():
171
+ torch.cuda.manual_seed_all(current_seed)
172
+
173
+ bad_words_ids = None
174
+ if seen_texts_for_bad_words:
175
+ bad_words_ids = tok(
176
+ list(seen_texts_for_bad_words),
177
+ add_special_tokens=False,
178
+ padding=True,
179
+ truncation=True
180
+ )["input_ids"]
181
+
182
+ batch_texts, batch_scores = sampling_generate(
183
+ tok, model, device, inputs,
184
+ top_n=GENERATION_BATCH_SIZE,
185
+ temperature=float(cfg["temperature"]),
186
+ top_p=float(cfg["top_p"]),
187
+ no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
188
+ repetition_penalty=float(cfg["repetition_penalty"]),
189
+ bad_words_ids=bad_words_ids
190
+ )
191
+
192
+ all_texts.extend(batch_texts)
193
+ all_scores.extend(batch_scores)
194
+ for txt in batch_texts:
195
+ if txt:
196
+ seen_texts_for_bad_words.add(txt)
197
+
198
+ progress_bar.progress((i + 1) / num_batches)
199
+
200
+ # Deduplicate and finalize
201
+ final_enriched = []
202
+ final_seen_normalized = set()
203
+ for txt, sc in zip(all_texts, all_scores):
204
+ norm = normalize_text(txt)
205
+ if norm and norm not in final_seen_normalized and norm != query.lower():
206
+ final_seen_normalized.add(norm)
207
+ final_enriched.append({"logp/len": sc, "text": txt})
208
+
209
+ if cfg["sort_by"] == "logp/len":
210
+ final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
211
+
212
+ final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
213
+
214
+ if not final_enriched:
215
+ st.warning("No queries were generated. Try a different input.")
216
+ else:
217
+ output_texts = [item['text'] for item in final_enriched]
218
+ df = pd.DataFrame(output_texts, columns=["Generated Query"])
219
+ df.index = range(1, len(df) + 1)
220
+ st.dataframe(df, use_container_width=True)
 
221
 
 
 
 
 
222
  else:
223
+ # ---- Quick Fan-Out path (beam-based, small and simple) ----
224
+ with st.spinner("Generating quick fan-out..."):
225
+ start_time = time.time()
226
+ expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10)
227
+
228
+ if expansions:
229
+ df = pd.DataFrame(expansions, columns=["Generated Query"])
230
+ df.index = range(1, len(df) + 1)
231
+ st.dataframe(df, use_container_width=True)
232
+ else:
233
+ st.warning("No valid fan-outs generated. Try a different query.")