dejanseo commited on
Commit
9198c88
·
verified ·
1 Parent(s): f10444d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +95 -22
src/streamlit_app.py CHANGED
@@ -26,11 +26,11 @@ GENERATION_CONFIG: Dict[str, Any] = {
26
  # ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------
27
  @st.cache_resource
28
  def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
29
- # Avoid CUDA initialization if no driver; select device explicitly.
30
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
31
  tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
32
  model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
33
  model.to(device)
 
34
  return tok, model, device
35
 
36
  # ------------------ GENERATION HELPERS ------------------
@@ -61,18 +61,68 @@ def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
61
  count = torch.where(count.eq(0), torch.ones_like(count), count)
62
  return [(lp / c).item() for lp, c in zip(sum_logp, count)]
63
 
64
- # --- UPDATED sampling_generate function ---
65
  def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
66
- """Now accepts a list of 'bad_words_ids' to forbid certain sequences."""
67
- kwargs = dict(max_length=MAX_TARGET_LENGTH, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=top_n, return_dict_in_generate=True, output_scores=True)
68
- if no_repeat_ngram_size > 0: kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
69
- if repetition_penalty != 1.0: kwargs["repetition_penalty"] = float(repetition_penalty)
70
- if bad_words_ids: kwargs["bad_words_ids"] = bad_words_ids
 
 
 
 
 
 
 
 
 
 
71
 
72
  gen = model.generate(**inputs, **kwargs)
73
  return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
74
 
75
- def normalize_text(s: str) -> str: return " ".join(s.strip().lower().split())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # ------------------ STREAMLIT APP ------------------
78
  st.set_page_config(
@@ -92,22 +142,25 @@ st.title("Query Fanout Generator")
92
  st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
93
 
94
  col1, col2 = st.columns(2)
95
-
96
  with col1:
97
- url = st.text_input("URL", value="dejan.ai")
98
  with col2:
99
- query = st.text_input("Query", value="ai seo agency")
100
 
101
- run_button = st.button("Generate Fan-out Queries")
 
 
 
 
 
102
 
103
- if run_button:
 
104
  cfg = GENERATION_CONFIG
105
-
106
  with st.spinner("Generating queries..."):
107
  start_ts = time.time()
108
  inputs = build_inputs(tok, url, query, device)
109
 
110
- # --- UPDATED BATCHING LOGIC WITH `bad_words_ids` ---
111
  all_texts, all_scores = [], []
112
  seen_texts_for_bad_words = set()
113
 
@@ -117,11 +170,17 @@ if run_button:
117
  for i in range(num_batches):
118
  current_seed = cfg["seed"] + i
119
  torch.manual_seed(current_seed)
120
- if torch.cuda.is_available(): torch.cuda.manual_seed_all(current_seed)
 
121
 
122
  bad_words_ids = None
123
  if seen_texts_for_bad_words:
124
- bad_words_ids = tok(list(seen_texts_for_bad_words), add_special_tokens=False, padding=True, truncation=True)["input_ids"]
 
 
 
 
 
125
 
126
  batch_texts, batch_scores = sampling_generate(
127
  tok, model, device, inputs,
@@ -132,15 +191,16 @@ if run_button:
132
  repetition_penalty=float(cfg["repetition_penalty"]),
133
  bad_words_ids=bad_words_ids
134
  )
135
-
136
  all_texts.extend(batch_texts)
137
  all_scores.extend(batch_scores)
138
  for txt in batch_texts:
139
- if txt: seen_texts_for_bad_words.add(txt)
140
-
 
141
  progress_bar.progress((i + 1) / num_batches)
142
 
143
- # Deduplicate and finalize the list
144
  final_enriched = []
145
  final_seen_normalized = set()
146
  for txt, sc in zip(all_texts, all_scores):
@@ -151,7 +211,7 @@ if run_button:
151
 
152
  if cfg["sort_by"] == "logp/len":
153
  final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
154
-
155
  final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
156
 
157
  if not final_enriched:
@@ -161,3 +221,16 @@ if run_button:
161
  df = pd.DataFrame(output_texts, columns=["Generated Query"])
162
  df.index = range(1, len(df) + 1)
163
  st.dataframe(df, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------
27
  @st.cache_resource
28
  def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
 
29
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
  tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
31
  model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
32
  model.to(device)
33
+ model.eval()
34
  return tok, model, device
35
 
36
  # ------------------ GENERATION HELPERS ------------------
 
61
  count = torch.where(count.eq(0), torch.ones_like(count), count)
62
  return [(lp / c).item() for lp, c in zip(sum_logp, count)]
63
 
64
+ # --- UPDATED sampling_generate function (Deep Analysis) ---
65
  def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
66
+ kwargs = dict(
67
+ max_length=MAX_TARGET_LENGTH,
68
+ do_sample=True,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ num_return_sequences=top_n,
72
+ return_dict_in_generate=True,
73
+ output_scores=True
74
+ )
75
+ if int(no_repeat_ngram_size) > 0:
76
+ kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
77
+ if float(repetition_penalty) != 1.0:
78
+ kwargs["repetition_penalty"] = float(repetition_penalty)
79
+ if bad_words_ids:
80
+ kwargs["bad_words_ids"] = bad_words_ids
81
 
82
  gen = model.generate(**inputs, **kwargs)
83
  return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
84
 
85
+ def normalize_text(s: str) -> str:
86
+ return " ".join(s.strip().lower().split())
87
+
88
+ # --- Beam-based quick function (from old script) ---
89
+ def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device, num_return_sequences: int = 10) -> List[str]:
90
+ input_text = f"For URL: {url} diversify query: {query}"
91
+ inputs = tok(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt")
92
+ inputs = {k: v.to(device) for k, v in inputs.items()}
93
+
94
+ with torch.no_grad():
95
+ outputs = model.generate(
96
+ **inputs,
97
+ max_length=MAX_TARGET_LENGTH,
98
+ num_return_sequences=num_return_sequences,
99
+ num_beams=num_return_sequences * 2,
100
+ num_beam_groups=num_return_sequences,
101
+ diversity_penalty=0.5,
102
+ temperature=0.8,
103
+ do_sample=False,
104
+ early_stopping=True,
105
+ pad_token_id=tok.pad_token_id,
106
+ eos_token_id=tok.eos_token_id,
107
+ forced_eos_token_id=tok.eos_token_id,
108
+ max_new_tokens=MAX_TARGET_LENGTH,
109
+ )
110
+
111
+ # Decode and simple post-filter
112
+ expansions: List[str] = []
113
+ for seq in outputs:
114
+ s = tok.decode(seq, skip_special_tokens=True)
115
+ if s and normalize_text(s) != normalize_text(query):
116
+ expansions.append(s)
117
+
118
+ # Deduplicate preserve order
119
+ seen = set()
120
+ uniq = []
121
+ for s in expansions:
122
+ if s not in seen:
123
+ seen.add(s)
124
+ uniq.append(s)
125
+ return uniq
126
 
127
  # ------------------ STREAMLIT APP ------------------
128
  st.set_page_config(
 
142
  st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
143
 
144
  col1, col2 = st.columns(2)
 
145
  with col1:
146
+ url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.")
147
  with col2:
148
+ query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.")
149
 
150
+ # --- Two actions side by side ---
151
+ bcol1, bcol2 = st.columns(2)
152
+ with bcol1:
153
+ deep_btn = st.button("Deep Analysis")
154
+ with bcol2:
155
+ quick_btn = st.button("Quick Fan-Out")
156
 
157
+ # ---- Deep Analysis path (sampling, large batches) ----
158
+ if deep_btn:
159
  cfg = GENERATION_CONFIG
 
160
  with st.spinner("Generating queries..."):
161
  start_ts = time.time()
162
  inputs = build_inputs(tok, url, query, device)
163
 
 
164
  all_texts, all_scores = [], []
165
  seen_texts_for_bad_words = set()
166
 
 
170
  for i in range(num_batches):
171
  current_seed = cfg["seed"] + i
172
  torch.manual_seed(current_seed)
173
+ if torch.cuda.is_available():
174
+ torch.cuda.manual_seed_all(current_seed)
175
 
176
  bad_words_ids = None
177
  if seen_texts_for_bad_words:
178
+ bad_words_ids = tok(
179
+ list(seen_texts_for_bad_words),
180
+ add_special_tokens=False,
181
+ padding=True,
182
+ truncation=True
183
+ )["input_ids"]
184
 
185
  batch_texts, batch_scores = sampling_generate(
186
  tok, model, device, inputs,
 
191
  repetition_penalty=float(cfg["repetition_penalty"]),
192
  bad_words_ids=bad_words_ids
193
  )
194
+
195
  all_texts.extend(batch_texts)
196
  all_scores.extend(batch_scores)
197
  for txt in batch_texts:
198
+ if txt:
199
+ seen_texts_for_bad_words.add(txt)
200
+
201
  progress_bar.progress((i + 1) / num_batches)
202
 
203
+ # Deduplicate and finalize
204
  final_enriched = []
205
  final_seen_normalized = set()
206
  for txt, sc in zip(all_texts, all_scores):
 
211
 
212
  if cfg["sort_by"] == "logp/len":
213
  final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
214
+
215
  final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
216
 
217
  if not final_enriched:
 
221
  df = pd.DataFrame(output_texts, columns=["Generated Query"])
222
  df.index = range(1, len(df) + 1)
223
  st.dataframe(df, use_container_width=True)
224
+
225
+ # ---- Quick Fan-Out path (beam-based, small and simple) ----
226
+ if quick_btn:
227
+ with st.spinner("Generating quick fan-out..."):
228
+ start_time = time.time()
229
+ expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10)
230
+
231
+ if expansions:
232
+ df = pd.DataFrame(expansions, columns=["Generated Query"])
233
+ df.index = range(1, len(df) + 1)
234
+ st.dataframe(df, use_container_width=True)
235
+ else:
236
+ st.warning("No valid fan-outs generated. Try a different query.")