Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +84 -86
src/streamlit_app.py
CHANGED
@@ -108,14 +108,12 @@ def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5
|
|
108 |
max_new_tokens=MAX_TARGET_LENGTH,
|
109 |
)
|
110 |
|
111 |
-
# Decode and simple post-filter
|
112 |
expansions: List[str] = []
|
113 |
for seq in outputs:
|
114 |
s = tok.decode(seq, skip_special_tokens=True)
|
115 |
if s and normalize_text(s) != normalize_text(query):
|
116 |
expansions.append(s)
|
117 |
|
118 |
-
# Deduplicate preserve order
|
119 |
seen = set()
|
120 |
uniq = []
|
121 |
for s in expansions:
|
@@ -141,95 +139,95 @@ tok, model, device = load_model()
|
|
141 |
st.title("Query Fanout Generator")
|
142 |
st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
|
143 |
|
|
|
144 |
col1, col2 = st.columns(2)
|
145 |
with col1:
|
146 |
url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.")
|
147 |
-
b1, b2 = st.columns(2)
|
148 |
-
with b1:
|
149 |
-
deep_btn = st.button("Deep Analysis", use_container_width=True)
|
150 |
-
with b2:
|
151 |
-
quick_btn = st.button("Quick Fan-Out", use_container_width=True)
|
152 |
-
|
153 |
with col2:
|
154 |
query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.")
|
155 |
|
156 |
-
#
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10)
|
229 |
|
230 |
-
if expansions:
|
231 |
-
df = pd.DataFrame(expansions, columns=["Generated Query"])
|
232 |
-
df.index = range(1, len(df) + 1)
|
233 |
-
st.dataframe(df, use_container_width=True)
|
234 |
else:
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
max_new_tokens=MAX_TARGET_LENGTH,
|
109 |
)
|
110 |
|
|
|
111 |
expansions: List[str] = []
|
112 |
for seq in outputs:
|
113 |
s = tok.decode(seq, skip_special_tokens=True)
|
114 |
if s and normalize_text(s) != normalize_text(query):
|
115 |
expansions.append(s)
|
116 |
|
|
|
117 |
seen = set()
|
118 |
uniq = []
|
119 |
for s in expansions:
|
|
|
139 |
st.title("Query Fanout Generator")
|
140 |
st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
|
141 |
|
142 |
+
# Inputs
|
143 |
col1, col2 = st.columns(2)
|
144 |
with col1:
|
145 |
url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.")
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
with col2:
|
147 |
query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.")
|
148 |
|
149 |
+
# Mode + single Run button
|
150 |
+
mode_high_effort = st.toggle("High Effort", value=False, help="On = Deep Analysis (stochastic sampling, large batch). Off = Quick Fan-Out (beam-based).")
|
151 |
+
run_btn = st.button("Generate", type="primary")
|
152 |
+
|
153 |
+
if run_btn:
|
154 |
+
if mode_high_effort:
|
155 |
+
# ---- Deep Analysis path (sampling, large batches) ----
|
156 |
+
cfg = GENERATION_CONFIG
|
157 |
+
with st.spinner("Generating queries..."):
|
158 |
+
start_ts = time.time()
|
159 |
+
inputs = build_inputs(tok, url, query, device)
|
160 |
+
|
161 |
+
all_texts, all_scores = [], []
|
162 |
+
seen_texts_for_bad_words = set()
|
163 |
+
|
164 |
+
num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
|
165 |
+
progress_bar = st.progress(0)
|
166 |
+
|
167 |
+
for i in range(num_batches):
|
168 |
+
current_seed = cfg["seed"] + i
|
169 |
+
torch.manual_seed(current_seed)
|
170 |
+
if torch.cuda.is_available():
|
171 |
+
torch.cuda.manual_seed_all(current_seed)
|
172 |
+
|
173 |
+
bad_words_ids = None
|
174 |
+
if seen_texts_for_bad_words:
|
175 |
+
bad_words_ids = tok(
|
176 |
+
list(seen_texts_for_bad_words),
|
177 |
+
add_special_tokens=False,
|
178 |
+
padding=True,
|
179 |
+
truncation=True
|
180 |
+
)["input_ids"]
|
181 |
+
|
182 |
+
batch_texts, batch_scores = sampling_generate(
|
183 |
+
tok, model, device, inputs,
|
184 |
+
top_n=GENERATION_BATCH_SIZE,
|
185 |
+
temperature=float(cfg["temperature"]),
|
186 |
+
top_p=float(cfg["top_p"]),
|
187 |
+
no_repeat_ngram_size=int(cfg["no_repeat_ngram_size"]),
|
188 |
+
repetition_penalty=float(cfg["repetition_penalty"]),
|
189 |
+
bad_words_ids=bad_words_ids
|
190 |
+
)
|
191 |
+
|
192 |
+
all_texts.extend(batch_texts)
|
193 |
+
all_scores.extend(batch_scores)
|
194 |
+
for txt in batch_texts:
|
195 |
+
if txt:
|
196 |
+
seen_texts_for_bad_words.add(txt)
|
197 |
+
|
198 |
+
progress_bar.progress((i + 1) / num_batches)
|
199 |
+
|
200 |
+
# Deduplicate and finalize
|
201 |
+
final_enriched = []
|
202 |
+
final_seen_normalized = set()
|
203 |
+
for txt, sc in zip(all_texts, all_scores):
|
204 |
+
norm = normalize_text(txt)
|
205 |
+
if norm and norm not in final_seen_normalized and norm != query.lower():
|
206 |
+
final_seen_normalized.add(norm)
|
207 |
+
final_enriched.append({"logp/len": sc, "text": txt})
|
208 |
+
|
209 |
+
if cfg["sort_by"] == "logp/len":
|
210 |
+
final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
|
211 |
+
|
212 |
+
final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
|
213 |
+
|
214 |
+
if not final_enriched:
|
215 |
+
st.warning("No queries were generated. Try a different input.")
|
216 |
+
else:
|
217 |
+
output_texts = [item['text'] for item in final_enriched]
|
218 |
+
df = pd.DataFrame(output_texts, columns=["Generated Query"])
|
219 |
+
df.index = range(1, len(df) + 1)
|
220 |
+
st.dataframe(df, use_container_width=True)
|
|
|
221 |
|
|
|
|
|
|
|
|
|
222 |
else:
|
223 |
+
# ---- Quick Fan-Out path (beam-based, small and simple) ----
|
224 |
+
with st.spinner("Generating quick fan-out..."):
|
225 |
+
start_time = time.time()
|
226 |
+
expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10)
|
227 |
+
|
228 |
+
if expansions:
|
229 |
+
df = pd.DataFrame(expansions, columns=["Generated Query"])
|
230 |
+
df.index = range(1, len(df) + 1)
|
231 |
+
st.dataframe(df, use_container_width=True)
|
232 |
+
else:
|
233 |
+
st.warning("No valid fan-outs generated. Try a different query.")
|