Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +95 -22
src/streamlit_app.py
CHANGED
@@ -26,11 +26,11 @@ GENERATION_CONFIG: Dict[str, Any] = {
|
|
26 |
# ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------
|
27 |
@st.cache_resource
|
28 |
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
|
29 |
-
# Avoid CUDA initialization if no driver; select device explicitly.
|
30 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
31 |
tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
|
32 |
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
|
33 |
model.to(device)
|
|
|
34 |
return tok, model, device
|
35 |
|
36 |
# ------------------ GENERATION HELPERS ------------------
|
@@ -61,18 +61,68 @@ def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
|
|
61 |
count = torch.where(count.eq(0), torch.ones_like(count), count)
|
62 |
return [(lp / c).item() for lp, c in zip(sum_logp, count)]
|
63 |
|
64 |
-
# --- UPDATED sampling_generate function ---
|
65 |
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
gen = model.generate(**inputs, **kwargs)
|
73 |
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
74 |
|
75 |
-
def normalize_text(s: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# ------------------ STREAMLIT APP ------------------
|
78 |
st.set_page_config(
|
@@ -92,22 +142,25 @@ st.title("Query Fanout Generator")
|
|
92 |
st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
|
93 |
|
94 |
col1, col2 = st.columns(2)
|
95 |
-
|
96 |
with col1:
|
97 |
-
url = st.text_input("URL", value="dejan.ai")
|
98 |
with col2:
|
99 |
-
query = st.text_input("Query", value="ai seo agency")
|
100 |
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
|
|
104 |
cfg = GENERATION_CONFIG
|
105 |
-
|
106 |
with st.spinner("Generating queries..."):
|
107 |
start_ts = time.time()
|
108 |
inputs = build_inputs(tok, url, query, device)
|
109 |
|
110 |
-
# --- UPDATED BATCHING LOGIC WITH `bad_words_ids` ---
|
111 |
all_texts, all_scores = [], []
|
112 |
seen_texts_for_bad_words = set()
|
113 |
|
@@ -117,11 +170,17 @@ if run_button:
|
|
117 |
for i in range(num_batches):
|
118 |
current_seed = cfg["seed"] + i
|
119 |
torch.manual_seed(current_seed)
|
120 |
-
if torch.cuda.is_available():
|
|
|
121 |
|
122 |
bad_words_ids = None
|
123 |
if seen_texts_for_bad_words:
|
124 |
-
bad_words_ids = tok(
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
batch_texts, batch_scores = sampling_generate(
|
127 |
tok, model, device, inputs,
|
@@ -132,15 +191,16 @@ if run_button:
|
|
132 |
repetition_penalty=float(cfg["repetition_penalty"]),
|
133 |
bad_words_ids=bad_words_ids
|
134 |
)
|
135 |
-
|
136 |
all_texts.extend(batch_texts)
|
137 |
all_scores.extend(batch_scores)
|
138 |
for txt in batch_texts:
|
139 |
-
if txt:
|
140 |
-
|
|
|
141 |
progress_bar.progress((i + 1) / num_batches)
|
142 |
|
143 |
-
# Deduplicate and finalize
|
144 |
final_enriched = []
|
145 |
final_seen_normalized = set()
|
146 |
for txt, sc in zip(all_texts, all_scores):
|
@@ -151,7 +211,7 @@ if run_button:
|
|
151 |
|
152 |
if cfg["sort_by"] == "logp/len":
|
153 |
final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
|
154 |
-
|
155 |
final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
|
156 |
|
157 |
if not final_enriched:
|
@@ -161,3 +221,16 @@ if run_button:
|
|
161 |
df = pd.DataFrame(output_texts, columns=["Generated Query"])
|
162 |
df.index = range(1, len(df) + 1)
|
163 |
st.dataframe(df, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------
|
27 |
@st.cache_resource
|
28 |
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
|
|
|
29 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
30 |
tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
|
31 |
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
|
32 |
model.to(device)
|
33 |
+
model.eval()
|
34 |
return tok, model, device
|
35 |
|
36 |
# ------------------ GENERATION HELPERS ------------------
|
|
|
61 |
count = torch.where(count.eq(0), torch.ones_like(count), count)
|
62 |
return [(lp / c).item() for lp, c in zip(sum_logp, count)]
|
63 |
|
64 |
+
# --- UPDATED sampling_generate function (Deep Analysis) ---
|
65 |
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
|
66 |
+
kwargs = dict(
|
67 |
+
max_length=MAX_TARGET_LENGTH,
|
68 |
+
do_sample=True,
|
69 |
+
temperature=temperature,
|
70 |
+
top_p=top_p,
|
71 |
+
num_return_sequences=top_n,
|
72 |
+
return_dict_in_generate=True,
|
73 |
+
output_scores=True
|
74 |
+
)
|
75 |
+
if int(no_repeat_ngram_size) > 0:
|
76 |
+
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
77 |
+
if float(repetition_penalty) != 1.0:
|
78 |
+
kwargs["repetition_penalty"] = float(repetition_penalty)
|
79 |
+
if bad_words_ids:
|
80 |
+
kwargs["bad_words_ids"] = bad_words_ids
|
81 |
|
82 |
gen = model.generate(**inputs, **kwargs)
|
83 |
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
84 |
|
85 |
+
def normalize_text(s: str) -> str:
|
86 |
+
return " ".join(s.strip().lower().split())
|
87 |
+
|
88 |
+
# --- Beam-based quick function (from old script) ---
|
89 |
+
def generate_expansions_beam(url: str, query: str, tok: MT5Tokenizer, model: MT5ForConditionalGeneration, device: torch.device, num_return_sequences: int = 10) -> List[str]:
|
90 |
+
input_text = f"For URL: {url} diversify query: {query}"
|
91 |
+
inputs = tok(input_text, max_length=MAX_INPUT_LENGTH, truncation=True, return_tensors="pt")
|
92 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
outputs = model.generate(
|
96 |
+
**inputs,
|
97 |
+
max_length=MAX_TARGET_LENGTH,
|
98 |
+
num_return_sequences=num_return_sequences,
|
99 |
+
num_beams=num_return_sequences * 2,
|
100 |
+
num_beam_groups=num_return_sequences,
|
101 |
+
diversity_penalty=0.5,
|
102 |
+
temperature=0.8,
|
103 |
+
do_sample=False,
|
104 |
+
early_stopping=True,
|
105 |
+
pad_token_id=tok.pad_token_id,
|
106 |
+
eos_token_id=tok.eos_token_id,
|
107 |
+
forced_eos_token_id=tok.eos_token_id,
|
108 |
+
max_new_tokens=MAX_TARGET_LENGTH,
|
109 |
+
)
|
110 |
+
|
111 |
+
# Decode and simple post-filter
|
112 |
+
expansions: List[str] = []
|
113 |
+
for seq in outputs:
|
114 |
+
s = tok.decode(seq, skip_special_tokens=True)
|
115 |
+
if s and normalize_text(s) != normalize_text(query):
|
116 |
+
expansions.append(s)
|
117 |
+
|
118 |
+
# Deduplicate preserve order
|
119 |
+
seen = set()
|
120 |
+
uniq = []
|
121 |
+
for s in expansions:
|
122 |
+
if s not in seen:
|
123 |
+
seen.add(s)
|
124 |
+
uniq.append(s)
|
125 |
+
return uniq
|
126 |
|
127 |
# ------------------ STREAMLIT APP ------------------
|
128 |
st.set_page_config(
|
|
|
142 |
st.markdown("Enter a URL and a query to generate a diverse set of related queries.")
|
143 |
|
144 |
col1, col2 = st.columns(2)
|
|
|
145 |
with col1:
|
146 |
+
url = st.text_input("URL", value="dejan.ai", help="Target URL that provides context for the query.")
|
147 |
with col2:
|
148 |
+
query = st.text_input("Query", value="ai seo agency", help="The search query you want to expand.")
|
149 |
|
150 |
+
# --- Two actions side by side ---
|
151 |
+
bcol1, bcol2 = st.columns(2)
|
152 |
+
with bcol1:
|
153 |
+
deep_btn = st.button("Deep Analysis")
|
154 |
+
with bcol2:
|
155 |
+
quick_btn = st.button("Quick Fan-Out")
|
156 |
|
157 |
+
# ---- Deep Analysis path (sampling, large batches) ----
|
158 |
+
if deep_btn:
|
159 |
cfg = GENERATION_CONFIG
|
|
|
160 |
with st.spinner("Generating queries..."):
|
161 |
start_ts = time.time()
|
162 |
inputs = build_inputs(tok, url, query, device)
|
163 |
|
|
|
164 |
all_texts, all_scores = [], []
|
165 |
seen_texts_for_bad_words = set()
|
166 |
|
|
|
170 |
for i in range(num_batches):
|
171 |
current_seed = cfg["seed"] + i
|
172 |
torch.manual_seed(current_seed)
|
173 |
+
if torch.cuda.is_available():
|
174 |
+
torch.cuda.manual_seed_all(current_seed)
|
175 |
|
176 |
bad_words_ids = None
|
177 |
if seen_texts_for_bad_words:
|
178 |
+
bad_words_ids = tok(
|
179 |
+
list(seen_texts_for_bad_words),
|
180 |
+
add_special_tokens=False,
|
181 |
+
padding=True,
|
182 |
+
truncation=True
|
183 |
+
)["input_ids"]
|
184 |
|
185 |
batch_texts, batch_scores = sampling_generate(
|
186 |
tok, model, device, inputs,
|
|
|
191 |
repetition_penalty=float(cfg["repetition_penalty"]),
|
192 |
bad_words_ids=bad_words_ids
|
193 |
)
|
194 |
+
|
195 |
all_texts.extend(batch_texts)
|
196 |
all_scores.extend(batch_scores)
|
197 |
for txt in batch_texts:
|
198 |
+
if txt:
|
199 |
+
seen_texts_for_bad_words.add(txt)
|
200 |
+
|
201 |
progress_bar.progress((i + 1) / num_batches)
|
202 |
|
203 |
+
# Deduplicate and finalize
|
204 |
final_enriched = []
|
205 |
final_seen_normalized = set()
|
206 |
for txt, sc in zip(all_texts, all_scores):
|
|
|
211 |
|
212 |
if cfg["sort_by"] == "logp/len":
|
213 |
final_enriched.sort(key=lambda x: x["logp/len"], reverse=True)
|
214 |
+
|
215 |
final_enriched = final_enriched[:TOTAL_DESIRED_CANDIDATES]
|
216 |
|
217 |
if not final_enriched:
|
|
|
221 |
df = pd.DataFrame(output_texts, columns=["Generated Query"])
|
222 |
df.index = range(1, len(df) + 1)
|
223 |
st.dataframe(df, use_container_width=True)
|
224 |
+
|
225 |
+
# ---- Quick Fan-Out path (beam-based, small and simple) ----
|
226 |
+
if quick_btn:
|
227 |
+
with st.spinner("Generating quick fan-out..."):
|
228 |
+
start_time = time.time()
|
229 |
+
expansions = generate_expansions_beam(url, query, tok, model, device, num_return_sequences=10)
|
230 |
+
|
231 |
+
if expansions:
|
232 |
+
df = pd.DataFrame(expansions, columns=["Generated Query"])
|
233 |
+
df.index = range(1, len(df) + 1)
|
234 |
+
st.dataframe(df, use_container_width=True)
|
235 |
+
else:
|
236 |
+
st.warning("No valid fan-outs generated. Try a different query.")
|