3v324v23 commited on
Commit
73f5c98
·
1 Parent(s): bb8fa77

Deploy Space: robust paths + FAISS-or-CSV fallback + Refined v2

Browse files
Files changed (3) hide show
  1. app.py +300 -273
  2. cleaned_data.csv +0 -0
  3. requirements.txt +7 -8
app.py CHANGED
@@ -1,277 +1,304 @@
1
- \
2
- import os, json, numpy as np, pandas as pd
3
- import gradio as gr
4
- import faiss
5
- import re
6
- from sentence_transformers import SentenceTransformer
7
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
-
9
- from logic.cleaning import clean_dataframe
10
- from logic.search import SloganSearcher
11
-
12
- # -------------------- Config --------------------
13
- ASSETS_DIR = "assets"
14
- DATA_PATH = "data/slogan.csv"
15
- PROMPT_PATH = "data/prompt.txt"
16
-
17
- MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
18
- NORMALIZE = True
19
-
20
- GEN_MODEL = "google/flan-t5-base"
21
- NUM_GEN_CANDIDATES = 12
22
- MAX_NEW_TOKENS = 18
23
- TEMPERATURE = 0.7
24
- TOP_P = 0.9
25
- REPETITION_PENALTY = 1.15
26
-
27
- # choose the most relevant yet non-duplicate candidate
28
- RELEVANCE_WEIGHT = 0.7
29
- NOVELTY_WEIGHT = 0.3
30
- DUPLICATE_MAX_SIM = 0.92
31
- NOVELTY_SIM_THRESHOLD = 0.80 # keep some distance from retrieved
32
-
33
- META_PATH = os.path.join(ASSETS_DIR, "meta.json")
34
- PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet")
35
- INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index")
36
- EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy")
37
-
38
- def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True)
39
-
40
- # -------------------- Asset build --------------------
41
- def _build_assets():
42
- if not os.path.exists(DATA_PATH):
43
- raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').")
44
- os.makedirs(ASSETS_DIR, exist_ok=True)
45
-
46
- _log(f"Loading dataset: {DATA_PATH}")
47
- df = pd.read_csv(DATA_PATH)
48
-
49
- _log(f"Rows before cleaning: {len(df)}")
50
- df = clean_dataframe(df)
51
- _log(f"Rows after cleaning: {len(df)}")
52
 
53
- if "description" in df.columns and df["description"].notna().any():
54
- texts = df["description"].fillna(df["tagline"]).astype(str).tolist()
55
- text_col, fallback_col = "description", "tagline"
56
- else:
57
- texts = df["tagline"].astype(str).tolist()
58
- text_col, fallback_col = "tagline", "tagline"
59
-
60
- _log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) …")
61
- encoder = SentenceTransformer(MODEL_NAME)
62
- emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE)
63
-
64
- dim = emb.shape[1]
65
- index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim)
66
- index.add(emb)
67
-
68
- _log("Persisting assets …")
69
- df.to_parquet(PARQUET_PATH, index=False)
70
- faiss.write_index(index, INDEX_PATH)
71
- np.save(EMB_PATH, emb)
72
-
73
- meta = {
74
- "model_name": MODEL_NAME,
75
- "dim": int(dim),
76
- "normalized": NORMALIZE,
77
- "metric": "ip" if NORMALIZE else "l2",
78
- "row_count": int(len(df)),
79
- "text_col": text_col,
80
- "fallback_col": fallback_col,
81
- }
82
- with open(META_PATH, "w") as f:
83
- json.dump(meta, f, indent=2)
84
- _log("Assets built successfully.")
85
 
86
- def _ensure_assets():
87
- need = False
88
- for p in (META_PATH, PARQUET_PATH, INDEX_PATH):
89
- if not os.path.exists(p):
90
- _log(f"Missing asset: {p}")
91
- need = True
92
- if need:
93
- _log("Building assets from scratch …")
94
- _build_assets()
95
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
97
- pd.read_parquet(PARQUET_PATH)
98
- except Exception as e:
99
- _log(f"Parquet read failed ({e}); rebuilding assets.")
100
- _build_assets()
101
-
102
- # Build before UI
103
- _ensure_assets()
104
-
105
- # -------------------- Retrieval --------------------
106
- searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False)
107
- meta = json.load(open(META_PATH))
108
- _encoder = SentenceTransformer(meta["model_name"])
109
-
110
- # -------------------- Generator --------------------
111
- _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
112
- _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)
113
-
114
- # keep this list small so we don't nuke relevant outputs
115
- _BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"}
116
- _PUNCT_CHARS = ":;—–-,.!?“”\"'`"
117
- _PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]")
118
-
119
- _MIN_WORDS, _MAX_WORDS = 2, 8
120
-
121
- def _load_prompt():
122
- if os.path.exists(PROMPT_PATH):
123
- with open(PROMPT_PATH, "r", encoding="utf-8") as f:
124
- return f.read()
125
- return (
126
- "You are a professional slogan writer.\n"
127
- "Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
128
- "Do not copy examples.\n"
129
- "Description:\n{description}\nSlogan:"
130
- )
131
-
132
- def _render_prompt(description: str, retrieved=None) -> str:
133
- tmpl = _load_prompt()
134
- if "{description}" in tmpl:
135
- prompt = tmpl.replace("{description}", description)
136
- else:
137
- prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
138
- if retrieved:
139
- prompt += "\n\nDo NOT copy these existing slogans:\n"
140
- for s in retrieved[:3]:
141
- prompt += f"- {s}\n"
142
- return prompt
143
-
144
- def _title_case(s: str) -> str:
145
- small = {"and","or","for","of","the","to","in","on","with","a","an"}
146
- words = [w for w in s.split() if w]
147
- out = []
148
- for i,w in enumerate(words):
149
- lw = w.lower()
150
- if i>0 and lw in small: out.append(lw)
151
- else: out.append(lw.capitalize())
152
- return " ".join(out)
153
-
154
- def _strip_punct(s: str) -> str:
155
- return _PUNCT_RE.sub("", s)
156
-
157
- def _strict_ok(s: str) -> bool:
158
- if not s: return False
159
- wc = len(s.split())
160
- if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
161
- lo = s.lower()
162
- if any(term in lo for term in _BANNED_TERMS): return False
163
- if lo in {"the","a","an"}: return False
164
- return True
165
-
166
- def _postprocess_strict(texts):
167
- cleaned, seen = [], set()
168
- for t in texts:
169
- s = t.replace("Slogan:", "").strip().strip('"').strip("'")
170
- s = " ".join(s.split())
171
- s = _strip_punct(s) # remove punctuation instead of rejecting
172
- s = _title_case(s)
173
- if _strict_ok(s):
174
- k = s.lower()
175
- if k not in seen:
176
- seen.add(k); cleaned.append(s)
177
- return cleaned
178
-
179
- def _postprocess_relaxed(texts):
180
- # fallback if strict returns nothing: keep 2–8 words, strip punctuation, Title Case
181
- cleaned, seen = [], set()
182
- for t in texts:
183
- s = t.strip().strip('"').strip("'")
184
- s = _strip_punct(s)
185
- s = " ".join(s.split())
186
- wc = len(s.split())
187
- if _MIN_WORDS <= wc <= _MAX_WORDS:
188
- s = _title_case(s)
189
- k = s.lower()
190
- if k not in seen:
191
- seen.add(k); cleaned.append(s)
192
- return cleaned
193
-
194
- def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
195
- prompt = _render_prompt(description, retrieved_texts)
196
-
197
- # only block very generic junk at decode time
198
- bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids
199
-
200
- inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
201
- outputs = _gen_model.generate(
202
- **inputs,
203
- do_sample=True,
204
- temperature=TEMPERATURE,
205
- top_p=TOP_P,
206
- num_return_sequences=n,
207
- max_new_tokens=MAX_NEW_TOKENS,
208
- no_repeat_ngram_size=3,
209
- repetition_penalty=REPETITION_PENALTY,
210
- bad_words_ids=bad_ids if bad_ids else None,
211
- eos_token_id=_gen_tokenizer.eos_token_id,
212
  )
213
- texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
214
-
215
- cands = _postprocess_strict(texts)
216
- if not cands:
217
- cands = _postprocess_relaxed(texts) # <- graceful fallback
218
- return cands
219
-
220
- def _pick_best(candidates, retrieved_texts, description):
221
- """Weighted relevance to description minus duplication vs retrieved."""
222
- if not candidates:
223
- return None
224
- c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
225
- d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
226
- rel = c_emb @ d_emb # cosine sim to description
227
-
228
- if retrieved_texts:
229
- R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
230
- dup = np.max(R @ c_emb.T, axis=0) # max sim to any retrieved
231
- else:
232
- dup = np.zeros(len(candidates), dtype=np.float32)
233
-
234
- # penalize near-duplicates outright
235
- mask = dup < DUPLICATE_MAX_SIM
236
- if mask.any():
237
- scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
238
- best_idx = np.argmax(scores)
239
- return [c for i, c in enumerate(candidates) if mask[i]][best_idx]
240
-
241
- # else: pick most relevant that still clears a basic novelty bar, else top score
242
- scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
243
- order = np.argsort(-scores)
244
- for i in order:
245
- if dup[i] < NOVELTY_SIM_THRESHOLD:
246
- return candidates[i]
247
- return candidates[order[0]]
248
-
249
- # -------------------- Inference pipeline --------------------
250
- def run_pipeline(user_description: str):
251
- if not user_description or not user_description.strip():
252
- return "Please enter a description."
253
- retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10)
254
- retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else []
255
- gens = _generate_candidates(user_description, retrieved_texts, NUM_GEN_CANDIDATES)
256
- chosen = _pick_best(gens, retrieved_texts, user_description) or (gens[0] if gens else "—")
257
- lines = []
258
- lines.append("### 🔎 Top 3 similar slogans")
259
- if retrieved_texts:
260
- for i, s in enumerate(retrieved_texts, 1):
261
- lines.append(f"{i}. {s}")
262
- else:
263
- lines.append("No similar slogans found.")
264
- lines.append("\n### ✨ AI-generated suggestion")
265
- lines.append(chosen)
266
- return "\n".join(lines)
267
-
268
- # -------------------- UI --------------------
269
- with gr.Blocks(title="Slogan Finder") as demo:
270
- gr.Markdown("# 🔎 Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.")
271
- query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...")
272
- btn = gr.Button("Get slogans", variant="primary")
273
- out = gr.Markdown()
274
- btn.click(run_pipeline, inputs=[query], outputs=out)
275
-
276
- demo.queue(max_size=64).launch()
277
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import os, re, json, numpy as np, pandas as pd, gradio as gr, faiss, torch
3
+ from typing import List
4
+ from sentence_transformers import SentenceTransformer, CrossEncoder
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # =========================
8
+ # Config
9
+ # =========================
10
+ FLAN_PRIMARY = os.getenv("FLAN_PRIMARY", "google/flan-t5-large")
11
+ FLAN_FALLBACK = "google/flan-t5-base"
12
+ EMBED_NAME = "sentence-transformers/all-mpnet-base-v2"
13
+ RERANK_NAME = "cross-encoder/stsb-roberta-base"
14
+
15
+ NUM_SLOGAN_SAMPLES = int(os.getenv("NUM_SLOGAN_SAMPLES", "16"))
16
+ INDEX_ROOT = os.path.join(os.path.dirname(__file__), "vector_store")
17
+ DEFAULT_MODEL_FOR_INDEX = EMBED_NAME
18
+ CSV_PATH = os.path.join(os.path.dirname(__file__), "cleaned_data.csv")
19
+
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ # =========================
23
+ # Lazy models
24
+ # =========================
25
+ _GEN_TOK = None
26
+ _GEN_MODEL = None
27
+ _EMBED_MODEL = None
28
+ _RERANKER = None
29
+
30
+ def _ensure_models():
31
+ global _GEN_TOK, _GEN_MODEL, _EMBED_MODEL, _RERANKER
32
+ if _EMBED_MODEL is None:
33
+ _EMBED_MODEL = SentenceTransformer(EMBED_NAME)
34
+ if _RERANKER is None:
35
+ _RERANKER = CrossEncoder(RERANK_NAME)
36
+ if _GEN_MODEL is None:
37
+ try:
38
+ tok = AutoTokenizer.from_pretrained(FLAN_PRIMARY)
39
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PRIMARY)
40
+ _GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE)
41
+ print(f"[INFO] Loaded generator: {FLAN_PRIMARY}")
42
+ except Exception as e:
43
+ print(f"[WARN] {e}; fallback to {FLAN_FALLBACK}")
44
+ tok = AutoTokenizer.from_pretrained(FLAN_FALLBACK)
45
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_FALLBACK)
46
+ _GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE)
47
+
48
+ # =========================
49
+ # Index cache
50
+ # =========================
51
+ _INDEX_CACHE = {} # mkey -> (faiss_index, meta_df)
52
+
53
+ def _model_key(name: str) -> str:
54
+ return name.replace("/", "_")
55
+
56
+ def _format_for_e5(texts, as_query=False):
57
+ prefix = "query: " if as_query else "passage: "
58
+ return [prefix + str(t) for t in texts]
59
+
60
+ def _build_memory_index_from_csv(model_name: str):
61
+ if not os.path.exists(CSV_PATH):
62
+ return None
63
+ df = pd.read_csv(CSV_PATH)
64
+ for col in ("name","tagline","description"):
65
+ if col not in df.columns: df[col] = ""
66
+ texts = df["description"].astype(str).tolist()
67
+ embedder = SentenceTransformer(model_name) if model_name != EMBED_NAME else _EMBED_MODEL
68
+ if model_name.startswith("intfloat/e5"):
69
+ texts = _format_for_e5(texts, as_query=False)
70
+ vecs = embedder.encode(texts, normalize_embeddings=True)
71
+ vecs = np.asarray(vecs, dtype=np.float32)
72
+ idx = faiss.IndexFlatIP(vecs.shape[1])
73
+ idx.add(vecs)
74
+ return idx, df[["name","tagline","description"]].copy()
75
+
76
+ def _load_index_for_model(model_name: str = DEFAULT_MODEL_FOR_INDEX):
77
+ mkey = _model_key(model_name)
78
+ if mkey in _INDEX_CACHE: return _INDEX_CACHE[mkey]
79
+
80
+ base = os.path.join(INDEX_ROOT, mkey)
81
+ idx_path = os.path.join(base, "index.faiss")
82
+ meta_path = os.path.join(base, "meta.parquet")
83
+
84
+ if os.path.exists(idx_path) and os.path.exists(meta_path):
85
+ index = faiss.read_index(idx_path)
86
+ meta = pd.read_parquet(meta_path)
87
+ _INDEX_CACHE[mkey] = (index, meta)
88
+ return _INDEX_CACHE[mkey]
89
+
90
+ # fallback: build from CSV if available
91
+ built = _build_memory_index_from_csv(model_name)
92
+ if built is not None:
93
+ _INDEX_CACHE[mkey] = built
94
+ return built
95
+
96
+ # last fallback: tiny demo
97
+ print("[WARN] FAISS & CSV missing — using tiny demo index")
98
+ demo = pd.DataFrame({
99
+ "name":["HowDidIDo","Museotainment","Movitr"],
100
+ "tagline":["Online evaluation platform","PacMan & Louvre meet","Crowdsourced video translation"],
101
+ "description":[
102
+ "Public speaking, Presentation skills and interview practice",
103
+ "Interactive AR museum tours",
104
+ "Video translation with voice and subtitles"
105
+ ]
106
+ })
107
+ embedder = SentenceTransformer(model_name) if model_name != EMBED_NAME else _EMBED_MODEL
108
+ vecs = embedder.encode(demo["description"].tolist(), normalize_embeddings=True)
109
+ vecs = np.asarray(vecs, dtype=np.float32)
110
+ idx = faiss.IndexFlatIP(vecs.shape[1]); idx.add(vecs)
111
+ _INDEX_CACHE[mkey] = (idx, demo)
112
+ return _INDEX_CACHE[mkey]
113
+
114
+ # =========================
115
+ # Recommend
116
+ # =========================
117
+ def recommend(query_text: str, model_name: str = DEFAULT_MODEL_FOR_INDEX, top_k: int = 3) -> pd.DataFrame:
118
+ _ensure_models()
119
+ index, meta = _load_index_for_model(model_name)
120
+ q_inp = _format_for_e5([query_text], as_query=True) if model_name.startswith("intfloat/e5") else [query_text]
121
+ q_vec = _EMBED_MODEL.encode(q_inp, normalize_embeddings=True)
122
+ q_vec = np.asarray(q_vec, dtype=np.float32)
123
+ scores, idxs = index.search(q_vec, top_k)
124
+ out = meta.iloc[idxs[0]].copy()
125
+ out["score"] = scores[0]
126
+ for col in ("name","tagline","description"):
127
+ if col not in out.columns: out[col] = ""
128
+ cols = ["name","tagline","description","score"]
129
+ return out[cols]
130
+
131
+ # =========================
132
+ # Refined v2 – helpers
133
+ # =========================
134
+ BLOCK_PATTERNS = [
135
+ r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$",
136
+ r"^[A-Z][a-z]+ [A-Z][a-z]+$",
137
+ r"^[A-Z][a-z]+$",
138
+ ]
139
+ HARD_BLOCK_WORDS = {
140
+ "platform","solution","system","application","marketplace",
141
+ "ai-powered","ai powered","empower","empowering",
142
+ "artificial intelligence","machine learning","augmented reality","virtual reality",
143
+ }
144
+ GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"}
145
+ MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn","move","clarify"}
146
+ BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer","stronger","together","confidently","simply","instantly"}
147
+ GOOD_SLOGANS_TO_AVOID_DUP = {
148
+ "smarter care, faster decisions","checkout built for small brands","less guessing. more healing.",
149
+ "built to grow with your cart.","stand tall. feel better.","train your brain to win.",
150
+ "your body. your algorithm.","play smarter. grow brighter.","style that thinks with you."
151
+ }
152
+
153
+ def _tokens(s: str) -> List[str]: return re.findall(r"[a-z0-9]{3,}", s.lower())
154
+ def _jaccard(a: List[str], b: List[str]) -> float:
155
+ A,B=set(a),set(b); return 0.0 if not A or not B else len(A&B)/len(A|B)
156
+ def _titlecase_soft(s: str) -> str:
157
+ return " ".join(w if w.isupper() else w.capitalize() for w in s.split())
158
+ def _is_blocked_slogan(s: str) -> bool:
159
+ if not s: return True
160
+ s_strip=s.strip()
161
+ for pat in BLOCK_PATTERNS:
162
+ if re.match(pat, s_strip): return True
163
+ s_low=s_strip.lower()
164
+ if any(w in s_low for w in HARD_BLOCK_WORDS): return True
165
+ return s_low in GOOD_SLOGANS_TO_AVOID_DUP
166
+
167
+ def _generic_penalty(s: str) -> float:
168
+ hits=sum(1 for w in GENERIC_WORDS if w in s.lower()); return min(1.0, 0.25*hits)
169
+ def _for_penalty(s: str) -> float: return 0.3 if re.search(r"\bfor\b", s.lower()) else 0.0
170
+
171
+ def _neighbor_context(neighbors_df: pd.DataFrame) -> str:
172
+ if neighbors_df is None or neighbors_df.empty: return ""
173
+ ex=[]
174
+ for _,row in neighbors_df.head(3).iterrows():
175
+ tg=str(row.get("tagline","")).strip()
176
+ if 5<=len(tg)<=70: ex.append(f"- {tg}")
177
+ return "\n".join(ex)
178
+
179
+ def _copies_neighbor(s: str, neighbors_df: pd.DataFrame) -> bool:
180
+ if neighbors_df is None or neighbors_df.empty: return False
181
+ s_low=s.lower(); s_toks=_tokens(s_low)
182
+ for _,row in neighbors_df.iterrows():
183
+ t=str(row.get("tagline","")).strip()
184
+ if not t: continue
185
+ t_low=t.lower()
186
+ if s_low==t_low: return True
187
+ if _jaccard(s_toks,_tokens(t_low))>=0.7: return True
188
  try:
189
+ em=SentenceTransformer(EMBED_NAME)
190
+ s_vec=em.encode([s])[0]; s_vec=s_vec/np.linalg.norm(s_vec)
191
+ for _,row in neighbors_df.head(3).iterrows():
192
+ t=str(row.get("tagline","")).strip()
193
+ if not t: continue
194
+ t_vec=em.encode([t])[0]; t_vec=t_vec/np.linalg.norm(t_vec)
195
+ if float(np.dot(s_vec,t_vec))>=0.85: return True
196
+ except: pass
197
+ return False
198
+
199
+ def _clean_slogan(text: str, max_words: int = 8) -> str:
200
+ text=text.strip().split("\n")[0]
201
+ text=re.sub(r"[\"“”‘’]","",text); text=re.sub(r"\s+"," ",text).strip()
202
+ words=text.split()
203
+ return " ".join(words[:max_words]) if len(words)>max_words else text
204
+
205
+ def _score_candidates(query: str, cands: List[str], neighbors_df: pd.DataFrame) -> List[tuple]:
206
+ if not cands: return []
207
+ ce_scores=np.asarray(CrossEncoder(RERANK_NAME).predict([(query,s) for s in cands]),dtype=np.float32)/5.0
208
+ q_toks=_tokens(query); results=[]
209
+
210
+ em=SentenceTransformer(EMBED_NAME)
211
+ neighbor_vecs=[]
212
+ if neighbors_df is not None and not neighbors_df.empty:
213
+ for _,row in neighbors_df.head(3).iterrows():
214
+ t=str(row.get("tagline","")).strip()
215
+ if t:
216
+ v=em.encode([t])[0]; neighbor_vecs.append(v/np.linalg.norm(v))
217
+
218
+ for i,s in enumerate(cands):
219
+ words=s.split()
220
+ brev=1.0-min(1.0,abs(len(words)-5)/5.0)
221
+ wl=set(w.lower() for w in words)
222
+ m_hits=len(wl & MARKETING_VERBS); b_hits=len(wl & BENEFIT_WORDS)
223
+ marketing=min(1.0,0.2*m_hits+0.2*b_hits)
224
+ g_pen=_generic_penalty(s); f_pen=_for_penalty(s)
225
+ n_pen=0.0
226
+ if neighbor_vecs:
227
+ try:
228
+ s_vec=em.encode([s])[0]; s_vec=s_vec/np.linalg.norm(s_vec)
229
+ sim_max=max(float(np.dot(s_vec,nv)) for nv in neighbor_vecs) if neighbor_vecs else 0.0
230
+ n_pen=sim_max
231
+ except: n_pen=0.0
232
+ overlap=_jaccard(q_toks,_tokens(s)); anti_copy=1.0-overlap
233
+ score=0.55*float(ce_scores[i])+0.20*brev+0.15*marketing+0.03*anti_copy-0.07*g_pen-0.03*f_pen-0.10*n_pen
234
+ results.append((s,float(score)))
235
+ return results
236
+
237
+ def generate_slogan(query_text: str, neighbors_df: pd.DataFrame = None, n_samples: int = NUM_SLOGAN_SAMPLES) -> str:
238
+ _ensure_models()
239
+ ctx=_neighbor_context(neighbors_df)
240
+ prompt=(
241
+ "You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n"
242
+ "Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n"
243
+ "Focus on clear benefits and vivid verbs. Do not copy the description. Return ONLY a list, one slogan per line.\n\n"
244
+ "Good Examples:\nDescription: AI assistant for doctors to prioritize patient cases\nSlogan: Less Guessing. More Healing.\n\n"
245
+ "Description: Payments for small online stores\nSlogan: Built to Grow with Your Cart.\n\n"
246
+ "Description: Neurotech headset to boost focus\nSlogan: Train Your Brain to Win.\n\n"
247
+ "Description: Interior design suggestions with AI\nSlogan: Style That Thinks With You.\n\n"
248
+ "Bad Examples (avoid these): Innovative AI Platform / Smart App for Everyone / Empowering Small Businesses\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  )
250
+ if ctx: prompt+=f"Similar taglines (style only):\n{ctx}\n\n"
251
+ prompt+=f"Description: {query_text}\nSlogans:"
252
+ input_ids=_GEN_TOK(prompt,return_tensors="pt").input_ids.to(DEVICE)
253
+ outputs=_GEN_MODEL.generate(input_ids,max_new_tokens=24,do_sample=True,top_k=60,top_p=0.92,temperature=1.2,num_return_sequences=n_samples,repetition_penalty=1.08)
254
+ raw=[_GEN_TOK.decode(o,skip_special_tokens=True) for o in outputs]
255
+ cand=set()
256
+ for txt in raw:
257
+ for line in txt.split("\n"):
258
+ s=_clean_slogan(line)
259
+ if not s: continue
260
+ if len(s.split())<2 or len(s.split())>8: continue
261
+ if _is_blocked_slogan(s): continue
262
+ if _copies_neighbor(s,neighbors_df): continue
263
+ cand.add(_titlecase_soft(s))
264
+ if not cand: return _clean_slogan(_GEN_TOK.decode(outputs[0],skip_special_tokens=True))
265
+ scored=_score_candidates(query_text,sorted(cand),neighbors_df)
266
+ if not scored: return _clean_slogan(_GEN_TOK.decode(outputs[0],skip_special_tokens=True))
267
+ scored.sort(key=lambda x:x[1],reverse=True)
268
+ return scored[0][0]
269
+
270
+ # =========================
271
+ # Gradio
272
+ # =========================
273
+ EXAMPLES=[
274
+ "AI coach for improving public speaking skills",
275
+ "Augmented reality app for interactive museum tours",
276
+ "Voice-controlled task manager for remote teams",
277
+ "Machine learning system for predicting crop yields",
278
+ "Platform for AI-assisted interior design suggestions",
279
+ ]
280
+
281
+ def pipeline(user_input: str):
282
+ recs=recommend(user_input, model_name=DEFAULT_MODEL_FOR_INDEX, top_k=3)
283
+ slogan=generate_slogan(user_input, neighbors_df=recs, n_samples=NUM_SLOGAN_SAMPLES)
284
+ recs=recs.reset_index(drop=True)
285
+ for col in ("name","tagline","description"):
286
+ if col not in recs.columns: recs[col]=""
287
+ recs.loc[len(recs)]={"name":"Synthetic Example","tagline":slogan,"description":user_input,"score":np.nan}
288
+ return recs[["name","tagline","description","score"]], slogan
289
+
290
+ with gr.Blocks(title="SloganAI — Recommendations + Slogan Generator") as demo:
291
+ gr.Markdown("## SloganAI Top-3 Recommendations + A High-Quality Generated Slogan")
292
+ with gr.Row():
293
+ with gr.Column(scale=1):
294
+ inp=gr.Textbox(label="Enter a startup description", lines=3, placeholder="e.g., AI coach for improving public speaking skills")
295
+ gr.Examples(EXAMPLES, inputs=inp, label="One-click examples")
296
+ btn=gr.Button("Submit", variant="primary")
297
+ with gr.Column(scale=2):
298
+ out_df=gr.Dataframe(headers=["Name","Tagline","Description","Score"], label="Top 3 + Generated")
299
+ out_sg=gr.Textbox(label="Generated Slogan", interactive=False)
300
+ btn.click(fn=pipeline, inputs=inp, outputs=[out_df, out_sg])
301
+
302
+ if __name__ == "__main__":
303
+ _ensure_models()
304
+ demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
cleaned_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
- gradio==5.43.1
2
- huggingface_hub>=0.23.0
3
- sentence-transformers>=2.6.0
4
- faiss-cpu>=1.8.0
5
- pandas>=2.1.0
6
- numpy>=1.26.0
7
- pyarrow>=14.0.1
8
  torch
9
- transformers>=4.40.0
 
1
+ gradio>=4.36.1,<5
2
+ transformers>=4.42,<5
3
+ sentence-transformers>=2.3.1
4
+ faiss-cpu
5
+ pandas
6
+ numpy
 
7
  torch
8
+ pyarrow