asaf1602 commited on
Commit
06a5663
·
verified ·
1 Parent(s): c17e99d

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +269 -145
  2. requirements.txt +8 -6
app.py CHANGED
@@ -1,153 +1,277 @@
1
-
 
2
  import gradio as gr
3
- import pandas as pd
4
- import numpy as np
5
- import faiss, re, torch
6
- from sentence_transformers import SentenceTransformer, CrossEncoder
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
- # ------------------ Models ------------------
10
- GEN_TOK = AutoTokenizer.from_pretrained("google/flan-t5-large")
11
- GEN_MODEL = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
12
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- GEN_MODEL = GEN_MODEL.to(DEVICE)
14
-
15
- EMBED_MODEL = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
16
- RERANKER = CrossEncoder("cross-encoder/stsb-roberta-base")
17
-
18
- # ------------------ Dummy dataset (for demo) ------------------
19
- data = pd.DataFrame({
20
- "name": ["HowDidIDo", "Museotainment", "Movitr"],
21
- "tagline": ["Online evaluation platform", "PacMan & Louvre meet", "Crowdsourced video translation"],
22
- "description": [
23
- "Public speaking, Presentation skills and interview practice",
24
- "Interactive AR museum tours",
25
- "Video translation with voice and subtitles"
26
- ]
27
- })
28
-
29
- # Build FAISS index
30
- data_vecs = EMBED_MODEL.encode(data["description"].tolist())
31
- faiss.normalize_L2(data_vecs)
32
- index = faiss.IndexFlatIP(data_vecs.shape[1])
33
- index.add(data_vecs)
34
-
35
- def recommend(query, top_k=3):
36
- query_vec = EMBED_MODEL.encode([query])
37
- faiss.normalize_L2(query_vec)
38
- scores, idx = index.search(query_vec, top_k)
39
- results = data.iloc[idx[0]].copy()
40
- results["score"] = scores[0]
41
- return results[["name", "tagline", "description", "score"]]
42
-
43
- # ------------------ Helpers ------------------
44
- BLOCK_PATTERNS = [
45
- r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$",
46
- r"^[A-Z][a-z]+ [A-Z][a-z]+$",
47
- r"^[A-Z][a-z]+$",
48
- ]
49
-
50
- HARD_BLOCK_WORDS = {"platform","solution","system","application","marketplace",
51
- "ai-powered","ai powered","empower","empowering",
52
- "artificial intelligence","machine learning","augmented reality","virtual reality"}
53
- GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"}
54
- MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn"}
55
- BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer"}
56
-
57
- def _clean_slogan(text: str, max_words: int = 8) -> str:
58
- text = text.strip().split("\n")[0]
59
- text = re.sub(r"[\"“”‘’]", "", text)
60
- text = re.sub(r"\s+", " ", text).strip()
61
- words = text.split()
62
- if len(words) > max_words:
63
- text = " ".join(words[:max_words])
64
- return text
65
-
66
- def _is_blocked_slogan(s: str) -> bool:
67
- s_low = s.lower()
68
- if any(w in s_low for w in HARD_BLOCK_WORDS):
69
- return True
70
- for pat in BLOCK_PATTERNS:
71
- if re.match(pat, s.strip()):
72
- return True
73
- return False
74
-
75
- def _score_candidates(query: str, cands: list) -> list:
76
- if not cands:
77
- return []
78
- ce_scores = np.asarray(RERANKER.predict([(query, s) for s in cands]), dtype=np.float32) / 5.0
79
- results = []
80
- for i, s in enumerate(cands):
81
- words = s.split()
82
- brevity = 1.0 - min(1.0, abs(len(words) - 5) / 5.0)
83
- marketing = 0.2*len(set(words) & MARKETING_VERBS) + 0.2*len(set(words) & BENEFIT_WORDS)
84
- score = 0.6*float(ce_scores[i]) + 0.2*brevity + 0.2*marketing
85
- results.append((s, float(score)))
86
- return results
87
-
88
- # ------------------ Generator ------------------
89
- def generate_slogan(query_text: str, n_samples: int = 16) -> str:
90
- prompt = (
91
- "You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n"
92
- "Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n"
93
- "Focus on benefits and vivid verbs. Do not copy the description.\n\n"
94
- f"Description: {query_text}\nSlogans:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
96
 
97
- input_ids = GEN_TOK(prompt, return_tensors="pt").input_ids.to(DEVICE)
98
- outputs = GEN_MODEL.generate(
99
- input_ids,
100
- max_new_tokens=24,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  do_sample=True,
102
- top_k=60,
103
- top_p=0.92,
104
- temperature=1.2,
105
- num_return_sequences=n_samples
 
 
 
 
106
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- raw_cands = [GEN_TOK.decode(o, skip_special_tokens=True) for o in outputs]
109
-
110
- cand_set = set()
111
- for txt in raw_cands:
112
- for line in txt.split("\n"):
113
- s = _clean_slogan(line)
114
- if not s: continue
115
- if len(s.split()) < 2 or len(s.split()) > 8: continue
116
- if _is_blocked_slogan(s): continue
117
- cand_set.add(s.capitalize())
118
-
119
- if not cand_set:
120
- return "Fresh Ideas, Built To Scale"
121
-
122
- scored = _score_candidates(query_text, sorted(cand_set))
123
- scored.sort(key=lambda x: x[1], reverse=True)
124
- return scored[0][0] if scored else "Fresh Ideas, Built To Scale"
125
-
126
- # ------------------ Pipeline ------------------
127
- def pipeline(user_input):
128
- recs = recommend(user_input, top_k=3)
129
- slogan = generate_slogan(user_input)
130
- recs = recs.reset_index(drop=True)
131
- recs.loc[len(recs)] = ["Generated Slogan", slogan, user_input, np.nan]
132
- return recs
133
-
134
- # ------------------ Gradio UI ------------------
135
- examples = [
136
- "AI coach for improving public speaking skills",
137
- "Augmented reality app for interactive museum tours",
138
- "Voice-controlled task manager for remote teams",
139
- "Machine learning system for predicting crop yields",
140
- "Platform for AI-assisted interior design suggestions"
141
- ]
142
-
143
- demo = gr.Interface(
144
- fn=pipeline,
145
- inputs=gr.Textbox(label="Enter a startup description"),
146
- outputs=gr.Dataframe(headers=["Name", "Tagline", "Description", "Score"]),
147
- examples=examples,
148
- title="SloganAI – Startup Recommendation & Slogan Generator",
149
- description="Enter a startup idea and get top-3 similar startups + 1 generated slogan."
150
- )
151
-
152
- if __name__ == "__main__":
153
- demo.launch()
 
1
+ \
2
+ import os, json, numpy as np, pandas as pd
3
  import gradio as gr
4
+ import faiss
5
+ import re
6
+ from sentence_transformers import SentenceTransformer
 
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
+ from logic.cleaning import clean_dataframe
10
+ from logic.search import SloganSearcher
11
+
12
+ # -------------------- Config --------------------
13
+ ASSETS_DIR = "assets"
14
+ DATA_PATH = "data/slogan.csv"
15
+ PROMPT_PATH = "data/prompt.txt"
16
+
17
+ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
18
+ NORMALIZE = True
19
+
20
+ GEN_MODEL = "google/flan-t5-base"
21
+ NUM_GEN_CANDIDATES = 12
22
+ MAX_NEW_TOKENS = 18
23
+ TEMPERATURE = 0.7
24
+ TOP_P = 0.9
25
+ REPETITION_PENALTY = 1.15
26
+
27
+ # choose the most relevant yet non-duplicate candidate
28
+ RELEVANCE_WEIGHT = 0.7
29
+ NOVELTY_WEIGHT = 0.3
30
+ DUPLICATE_MAX_SIM = 0.92
31
+ NOVELTY_SIM_THRESHOLD = 0.80 # keep some distance from retrieved
32
+
33
+ META_PATH = os.path.join(ASSETS_DIR, "meta.json")
34
+ PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet")
35
+ INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index")
36
+ EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy")
37
+
38
+ def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True)
39
+
40
+ # -------------------- Asset build --------------------
41
+ def _build_assets():
42
+ if not os.path.exists(DATA_PATH):
43
+ raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').")
44
+ os.makedirs(ASSETS_DIR, exist_ok=True)
45
+
46
+ _log(f"Loading dataset: {DATA_PATH}")
47
+ df = pd.read_csv(DATA_PATH)
48
+
49
+ _log(f"Rows before cleaning: {len(df)}")
50
+ df = clean_dataframe(df)
51
+ _log(f"Rows after cleaning: {len(df)}")
52
+
53
+ if "description" in df.columns and df["description"].notna().any():
54
+ texts = df["description"].fillna(df["tagline"]).astype(str).tolist()
55
+ text_col, fallback_col = "description", "tagline"
56
+ else:
57
+ texts = df["tagline"].astype(str).tolist()
58
+ text_col, fallback_col = "tagline", "tagline"
59
+
60
+ _log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) …")
61
+ encoder = SentenceTransformer(MODEL_NAME)
62
+ emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE)
63
+
64
+ dim = emb.shape[1]
65
+ index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim)
66
+ index.add(emb)
67
+
68
+ _log("Persisting assets …")
69
+ df.to_parquet(PARQUET_PATH, index=False)
70
+ faiss.write_index(index, INDEX_PATH)
71
+ np.save(EMB_PATH, emb)
72
+
73
+ meta = {
74
+ "model_name": MODEL_NAME,
75
+ "dim": int(dim),
76
+ "normalized": NORMALIZE,
77
+ "metric": "ip" if NORMALIZE else "l2",
78
+ "row_count": int(len(df)),
79
+ "text_col": text_col,
80
+ "fallback_col": fallback_col,
81
+ }
82
+ with open(META_PATH, "w") as f:
83
+ json.dump(meta, f, indent=2)
84
+ _log("Assets built successfully.")
85
+
86
+ def _ensure_assets():
87
+ need = False
88
+ for p in (META_PATH, PARQUET_PATH, INDEX_PATH):
89
+ if not os.path.exists(p):
90
+ _log(f"Missing asset: {p}")
91
+ need = True
92
+ if need:
93
+ _log("Building assets from scratch ")
94
+ _build_assets()
95
+ return
96
+ try:
97
+ pd.read_parquet(PARQUET_PATH)
98
+ except Exception as e:
99
+ _log(f"Parquet read failed ({e}); rebuilding assets.")
100
+ _build_assets()
101
+
102
+ # Build before UI
103
+ _ensure_assets()
104
+
105
+ # -------------------- Retrieval --------------------
106
+ searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False)
107
+ meta = json.load(open(META_PATH))
108
+ _encoder = SentenceTransformer(meta["model_name"])
109
+
110
+ # -------------------- Generator --------------------
111
+ _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
112
+ _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)
113
+
114
+ # keep this list small so we don't nuke relevant outputs
115
+ _BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"}
116
+ _PUNCT_CHARS = ":;—–-,.!?“”\"'`"
117
+ _PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]")
118
+
119
+ _MIN_WORDS, _MAX_WORDS = 2, 8
120
+
121
+ def _load_prompt():
122
+ if os.path.exists(PROMPT_PATH):
123
+ with open(PROMPT_PATH, "r", encoding="utf-8") as f:
124
+ return f.read()
125
+ return (
126
+ "You are a professional slogan writer.\n"
127
+ "Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
128
+ "Do not copy examples.\n"
129
+ "Description:\n{description}\nSlogan:"
130
  )
131
 
132
+ def _render_prompt(description: str, retrieved=None) -> str:
133
+ tmpl = _load_prompt()
134
+ if "{description}" in tmpl:
135
+ prompt = tmpl.replace("{description}", description)
136
+ else:
137
+ prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
138
+ if retrieved:
139
+ prompt += "\n\nDo NOT copy these existing slogans:\n"
140
+ for s in retrieved[:3]:
141
+ prompt += f"- {s}\n"
142
+ return prompt
143
+
144
+ def _title_case(s: str) -> str:
145
+ small = {"and","or","for","of","the","to","in","on","with","a","an"}
146
+ words = [w for w in s.split() if w]
147
+ out = []
148
+ for i,w in enumerate(words):
149
+ lw = w.lower()
150
+ if i>0 and lw in small: out.append(lw)
151
+ else: out.append(lw.capitalize())
152
+ return " ".join(out)
153
+
154
+ def _strip_punct(s: str) -> str:
155
+ return _PUNCT_RE.sub("", s)
156
+
157
+ def _strict_ok(s: str) -> bool:
158
+ if not s: return False
159
+ wc = len(s.split())
160
+ if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
161
+ lo = s.lower()
162
+ if any(term in lo for term in _BANNED_TERMS): return False
163
+ if lo in {"the","a","an"}: return False
164
+ return True
165
+
166
+ def _postprocess_strict(texts):
167
+ cleaned, seen = [], set()
168
+ for t in texts:
169
+ s = t.replace("Slogan:", "").strip().strip('"').strip("'")
170
+ s = " ".join(s.split())
171
+ s = _strip_punct(s) # remove punctuation instead of rejecting
172
+ s = _title_case(s)
173
+ if _strict_ok(s):
174
+ k = s.lower()
175
+ if k not in seen:
176
+ seen.add(k); cleaned.append(s)
177
+ return cleaned
178
+
179
+ def _postprocess_relaxed(texts):
180
+ # fallback if strict returns nothing: keep 2–8 words, strip punctuation, Title Case
181
+ cleaned, seen = [], set()
182
+ for t in texts:
183
+ s = t.strip().strip('"').strip("'")
184
+ s = _strip_punct(s)
185
+ s = " ".join(s.split())
186
+ wc = len(s.split())
187
+ if _MIN_WORDS <= wc <= _MAX_WORDS:
188
+ s = _title_case(s)
189
+ k = s.lower()
190
+ if k not in seen:
191
+ seen.add(k); cleaned.append(s)
192
+ return cleaned
193
+
194
+ def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
195
+ prompt = _render_prompt(description, retrieved_texts)
196
+
197
+ # only block very generic junk at decode time
198
+ bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids
199
+
200
+ inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
201
+ outputs = _gen_model.generate(
202
+ **inputs,
203
  do_sample=True,
204
+ temperature=TEMPERATURE,
205
+ top_p=TOP_P,
206
+ num_return_sequences=n,
207
+ max_new_tokens=MAX_NEW_TOKENS,
208
+ no_repeat_ngram_size=3,
209
+ repetition_penalty=REPETITION_PENALTY,
210
+ bad_words_ids=bad_ids if bad_ids else None,
211
+ eos_token_id=_gen_tokenizer.eos_token_id,
212
  )
213
+ texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
214
+
215
+ cands = _postprocess_strict(texts)
216
+ if not cands:
217
+ cands = _postprocess_relaxed(texts) # <- graceful fallback
218
+ return cands
219
+
220
+ def _pick_best(candidates, retrieved_texts, description):
221
+ """Weighted relevance to description minus duplication vs retrieved."""
222
+ if not candidates:
223
+ return None
224
+ c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
225
+ d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
226
+ rel = c_emb @ d_emb # cosine sim to description
227
+
228
+ if retrieved_texts:
229
+ R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
230
+ dup = np.max(R @ c_emb.T, axis=0) # max sim to any retrieved
231
+ else:
232
+ dup = np.zeros(len(candidates), dtype=np.float32)
233
+
234
+ # penalize near-duplicates outright
235
+ mask = dup < DUPLICATE_MAX_SIM
236
+ if mask.any():
237
+ scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
238
+ best_idx = np.argmax(scores)
239
+ return [c for i, c in enumerate(candidates) if mask[i]][best_idx]
240
+
241
+ # else: pick most relevant that still clears a basic novelty bar, else top score
242
+ scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
243
+ order = np.argsort(-scores)
244
+ for i in order:
245
+ if dup[i] < NOVELTY_SIM_THRESHOLD:
246
+ return candidates[i]
247
+ return candidates[order[0]]
248
+
249
+ # -------------------- Inference pipeline --------------------
250
+ def run_pipeline(user_description: str):
251
+ if not user_description or not user_description.strip():
252
+ return "Please enter a description."
253
+ retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10)
254
+ retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else []
255
+ gens = _generate_candidates(user_description, retrieved_texts, NUM_GEN_CANDIDATES)
256
+ chosen = _pick_best(gens, retrieved_texts, user_description) or (gens[0] if gens else "—")
257
+ lines = []
258
+ lines.append("### 🔎 Top 3 similar slogans")
259
+ if retrieved_texts:
260
+ for i, s in enumerate(retrieved_texts, 1):
261
+ lines.append(f"{i}. {s}")
262
+ else:
263
+ lines.append("No similar slogans found.")
264
+ lines.append("\n### ✨ AI-generated suggestion")
265
+ lines.append(chosen)
266
+ return "\n".join(lines)
267
+
268
+ # -------------------- UI --------------------
269
+ with gr.Blocks(title="Slogan Finder") as demo:
270
+ gr.Markdown("# 🔎 Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.")
271
+ query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...")
272
+ btn = gr.Button("Get slogans", variant="primary")
273
+ out = gr.Markdown()
274
+ btn.click(run_pipeline, inputs=[query], outputs=out)
275
+
276
+ demo.queue(max_size=64).launch()
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
- gradio
2
- transformers
3
- sentence-transformers
4
- faiss-cpu
5
- pandas
6
- numpy
 
7
  torch
 
 
1
+ gradio==5.43.1
2
+ huggingface_hub>=0.23.0
3
+ sentence-transformers>=2.6.0
4
+ faiss-cpu>=1.8.0
5
+ pandas>=2.1.0
6
+ numpy>=1.26.0
7
+ pyarrow>=14.0.1
8
  torch
9
+ transformers>=4.40.0