3v324v23 commited on
Commit
c29bac5
·
1 Parent(s): 408e06f

Deploy Space: full FAISS recommend + advanced slogan generator (Refined v2) with vector_store

Browse files
Files changed (3) hide show
  1. .gitattributes +2 -34
  2. app.py +330 -92
  3. requirements.txt +1 -0
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.faiss filter=lfs diff=lfs merge=lfs -text
2
+ *.npy filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  *.parquet filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,153 +1,391 @@
1
 
2
- import gradio as gr
3
- import pandas as pd
4
  import numpy as np
5
- import faiss, re, torch
 
 
 
 
6
  from sentence_transformers import SentenceTransformer, CrossEncoder
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
- # ------------------ Models ------------------
10
- GEN_TOK = AutoTokenizer.from_pretrained("google/flan-t5-large")
11
- GEN_MODEL = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
 
 
 
 
 
 
 
 
 
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- GEN_MODEL = GEN_MODEL.to(DEVICE)
14
-
15
- EMBED_MODEL = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
16
- RERANKER = CrossEncoder("cross-encoder/stsb-roberta-base")
17
-
18
- # ------------------ Dummy dataset (for demo) ------------------
19
- data = pd.DataFrame({
20
- "name": ["HowDidIDo", "Museotainment", "Movitr"],
21
- "tagline": ["Online evaluation platform", "PacMan & Louvre meet", "Crowdsourced video translation"],
22
- "description": [
23
- "Public speaking, Presentation skills and interview practice",
24
- "Interactive AR museum tours",
25
- "Video translation with voice and subtitles"
26
- ]
27
- })
28
-
29
- # Build FAISS index
30
- data_vecs = EMBED_MODEL.encode(data["description"].tolist())
31
- faiss.normalize_L2(data_vecs)
32
- index = faiss.IndexFlatIP(data_vecs.shape[1])
33
- index.add(data_vecs)
34
-
35
- def recommend(query, top_k=3):
36
- query_vec = EMBED_MODEL.encode([query])
37
- faiss.normalize_L2(query_vec)
38
- scores, idx = index.search(query_vec, top_k)
39
- results = data.iloc[idx[0]].copy()
40
- results["score"] = scores[0]
41
- return results[["name", "tagline", "description", "score"]]
42
-
43
- # ------------------ Helpers ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  BLOCK_PATTERNS = [
45
  r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$",
46
  r"^[A-Z][a-z]+ [A-Z][a-z]+$",
47
  r"^[A-Z][a-z]+$",
48
  ]
49
-
50
- HARD_BLOCK_WORDS = {"platform","solution","system","application","marketplace",
51
  "ai-powered","ai powered","empower","empowering",
52
- "artificial intelligence","machine learning","augmented reality","virtual reality"}
 
53
  GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"}
54
- MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn"}
55
- BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def _clean_slogan(text: str, max_words: int = 8) -> str:
58
  text = text.strip().split("\n")[0]
59
  text = re.sub(r"[\"“”‘’]", "", text)
60
  text = re.sub(r"\s+", " ", text).strip()
 
61
  words = text.split()
62
  if len(words) > max_words:
63
  text = " ".join(words[:max_words])
64
  return text
65
 
66
- def _is_blocked_slogan(s: str) -> bool:
67
- s_low = s.lower()
68
- if any(w in s_low for w in HARD_BLOCK_WORDS):
69
- return True
70
- for pat in BLOCK_PATTERNS:
71
- if re.match(pat, s.strip()):
72
- return True
73
- return False
74
-
75
- def _score_candidates(query: str, cands: list) -> list:
76
  if not cands:
77
  return []
78
- ce_scores = np.asarray(RERANKER.predict([(query, s) for s in cands]), dtype=np.float32) / 5.0
 
79
  results = []
 
 
 
 
 
 
 
 
 
80
  for i, s in enumerate(cands):
81
  words = s.split()
82
- brevity = 1.0 - min(1.0, abs(len(words) - 5) / 5.0)
83
- marketing = 0.2*len(set(words) & MARKETING_VERBS) + 0.2*len(set(words) & BENEFIT_WORDS)
84
- score = 0.6*float(ce_scores[i]) + 0.2*brevity + 0.2*marketing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  results.append((s, float(score)))
86
  return results
87
 
88
- # ------------------ Generator ------------------
89
- def generate_slogan(query_text: str, n_samples: int = 16) -> str:
 
90
  prompt = (
91
  "You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n"
92
  "Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n"
93
- "Focus on benefits and vivid verbs. Do not copy the description.\n\n"
94
- f"Description: {query_text}\nSlogans:"
 
 
 
 
 
 
 
 
 
95
  )
 
 
 
96
 
97
- input_ids = GEN_TOK(prompt, return_tensors="pt").input_ids.to(DEVICE)
98
- outputs = GEN_MODEL.generate(
99
  input_ids,
100
  max_new_tokens=24,
101
  do_sample=True,
102
  top_k=60,
103
  top_p=0.92,
104
  temperature=1.2,
105
- num_return_sequences=n_samples
 
106
  )
107
-
108
- raw_cands = [GEN_TOK.decode(o, skip_special_tokens=True) for o in outputs]
109
 
110
  cand_set = set()
111
  for txt in raw_cands:
112
  for line in txt.split("\n"):
113
  s = _clean_slogan(line)
114
- if not s: continue
115
- if len(s.split()) < 2 or len(s.split()) > 8: continue
116
- if _is_blocked_slogan(s): continue
117
- cand_set.add(s.capitalize())
 
 
 
 
 
118
 
119
  if not cand_set:
120
- return "Fresh Ideas, Built To Scale"
121
 
122
- scored = _score_candidates(query_text, sorted(cand_set))
123
- scored.sort(key=lambda x: x[1], reverse=True)
124
- return scored[0][0] if scored else "Fresh Ideas, Built To Scale"
125
 
126
- # ------------------ Pipeline ------------------
127
- def pipeline(user_input):
128
- recs = recommend(user_input, top_k=3)
129
- slogan = generate_slogan(user_input)
130
- recs = recs.reset_index(drop=True)
131
- recs.loc[len(recs)] = ["Generated Slogan", slogan, user_input, np.nan]
132
- return recs
133
 
134
- # ------------------ Gradio UI ------------------
135
- examples = [
 
 
136
  "AI coach for improving public speaking skills",
137
  "Augmented reality app for interactive museum tours",
138
  "Voice-controlled task manager for remote teams",
139
  "Machine learning system for predicting crop yields",
140
- "Platform for AI-assisted interior design suggestions"
141
  ]
142
 
143
- demo = gr.Interface(
144
- fn=pipeline,
145
- inputs=gr.Textbox(label="Enter a startup description"),
146
- outputs=gr.Dataframe(headers=["Name", "Tagline", "Description", "Score"]),
147
- examples=examples,
148
- title="SloganAI Startup Recommendation & Slogan Generator",
149
- description="Enter a startup idea and get top-3 similar startups + 1 generated slogan."
150
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  if __name__ == "__main__":
153
- demo.launch()
 
 
1
 
2
+ import os, re, json
 
3
  import numpy as np
4
+ import pandas as pd
5
+ import gradio as gr
6
+ import faiss
7
+ import torch
8
+ from typing import List
9
  from sentence_transformers import SentenceTransformer, CrossEncoder
10
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
11
 
12
+ # =========================
13
+ # Global Config
14
+ # =========================
15
+ # מודלים (אותו סטינג כמו במחברת; יש Fallback ל-base אם ה-Large לא נכנס לזיכרון)
16
+ FLAN_PRIMARY = os.getenv("FLAN_PRIMARY", "google/flan-t5-large")
17
+ FLAN_FALLBACK = "google/flan-t5-base"
18
+ EMBED_NAME = "sentence-transformers/all-mpnet-base-v2"
19
+ RERANK_NAME = "cross-encoder/stsb-roberta-base"
20
+
21
+ NUM_SLOGAN_SAMPLES = int(os.getenv("NUM_SLOGAN_SAMPLES", "16")) # אפשר להעלות ל-32 אם יש GPU
22
+ INDEX_ROOT = os.path.join(os.path.dirname(__file__), "vector_store") # איפה ששמנו את האינדקסים
23
+ DEFAULT_MODEL_FOR_INDEX = EMBED_NAME
24
+
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # =========================
28
+ # Lazy model loading (first call only)
29
+ # =========================
30
+ _GEN_TOK = None
31
+ _GEN_MODEL = None
32
+ _EMBED_MODEL = None
33
+ _RERANKER = None
34
+
35
+ def _ensure_models():
36
+ global _GEN_TOK, _GEN_MODEL, _EMBED_MODEL, _RERANKER
37
+ if _EMBED_MODEL is None:
38
+ _EMBED_MODEL = SentenceTransformer(EMBED_NAME)
39
+ if _RERANKER is None:
40
+ _RERANKER = CrossEncoder(RERANK_NAME)
41
+
42
+ if _GEN_MODEL is None:
43
+ try:
44
+ tok = AutoTokenizer.from_pretrained(FLAN_PRIMARY)
45
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PRIMARY)
46
+ _GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE)
47
+ print(f"[INFO] Loaded generator: {FLAN_PRIMARY}")
48
+ except Exception as e:
49
+ print(f"[WARN] Failed to load {FLAN_PRIMARY}. Falling back to {FLAN_FALLBACK}. Error: {e}")
50
+ tok = AutoTokenizer.from_pretrained(FLAN_FALLBACK)
51
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_FALLBACK)
52
+ _GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE)
53
+ print(f"[INFO] Loaded generator: {FLAN_FALLBACK}")
54
+
55
+ # =========================
56
+ # Index cache (so we don't read multiple times)
57
+ # =========================
58
+ _INDEX_CACHE = {} # model_key -> (faiss_index, meta_df)
59
+
60
+ def _model_key(name: str) -> str:
61
+ return name.replace("/", "_")
62
+
63
+ def _format_for_e5(texts, as_query=False):
64
+ prefix = "query: " if as_query else "passage: "
65
+ return [prefix + str(t) for t in texts]
66
+
67
+ def _load_index_for_model(model_name: str = DEFAULT_MODEL_FOR_INDEX):
68
+ """Load FAISS index + meta once for a given model."""
69
+ mkey = _model_key(model_name)
70
+ if mkey in _INDEX_CACHE:
71
+ return _INDEX_CACHE[mkey]
72
+
73
+ base = os.path.join(INDEX_ROOT, mkey)
74
+ idx_path = os.path.join(base, "index.faiss")
75
+ meta_path = os.path.join(base, "meta.parquet")
76
+
77
+ if not (os.path.exists(idx_path) and os.path.exists(meta_path)):
78
+ # fallback: tiny demo index (3 rows) if user didn't push vector_store
79
+ print(f"[WARN] Missing index for {model_name}. Using tiny demo in-memory index.")
80
+ demo = pd.DataFrame({
81
+ "name": ["HowDidIDo", "Museotainment", "Movitr"],
82
+ "tagline": ["Online evaluation platform", "PacMan & Louvre meet", "Crowdsourced video translation"],
83
+ "description": [
84
+ "Public speaking, Presentation skills and interview practice",
85
+ "Interactive AR museum tours",
86
+ "Video translation with voice and subtitles"
87
+ ]
88
+ })
89
+ model = SentenceTransformer(model_name)
90
+ vecs = model.encode(demo["description"].tolist(), normalize_embeddings=True)
91
+ dim = vecs.shape[1]
92
+ index = faiss.IndexFlatIP(dim)
93
+ index.add(np.asarray(vecs, dtype=np.float32))
94
+ _INDEX_CACHE[mkey] = (index, demo)
95
+ return _INDEX_CACHE[mkey]
96
+
97
+ index = faiss.read_index(idx_path)
98
+ meta_df = pd.read_parquet(meta_path)
99
+ _INDEX_CACHE[mkey] = (index, meta_df)
100
+ return _INDEX_CACHE[mkey]
101
+
102
+ # =========================
103
+ # Recommendation (top-3) using FAISS index you generated
104
+ # =========================
105
+ def recommend(query_text: str, model_name: str = DEFAULT_MODEL_FOR_INDEX, top_k: int = 3) -> pd.DataFrame:
106
+ _ensure_models()
107
+ index, meta = _load_index_for_model(model_name)
108
+
109
+ # format for E5 if needed
110
+ if model_name.startswith("intfloat/e5"):
111
+ q_inp = _format_for_e5([query_text], as_query=True)
112
+ else:
113
+ q_inp = [query_text]
114
+
115
+ q_vec = _EMBED_MODEL.encode(q_inp, normalize_embeddings=True)
116
+ q_vec = np.asarray(q_vec, dtype=np.float32)
117
+ scores, idxs = index.search(q_vec, top_k)
118
+ scores, idxs = scores[0], idxs[0]
119
+ out = meta.iloc[idxs].copy()
120
+ out["score"] = scores
121
+ # make sure columns exist in output (name, tagline, description)
122
+ cols = [c for c in ["row_id","name","tagline","description","score"] if c in out.columns or c=="score"]
123
+ return out[cols] if "score" in out.columns else out
124
+
125
+ # =========================
126
+ # Advanced Slogan Generator (your Refined v2 logic)
127
+ # =========================
128
  BLOCK_PATTERNS = [
129
  r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$",
130
  r"^[A-Z][a-z]+ [A-Z][a-z]+$",
131
  r"^[A-Z][a-z]+$",
132
  ]
133
+ HARD_BLOCK_WORDS = {
134
+ "platform","solution","system","application","marketplace",
135
  "ai-powered","ai powered","empower","empowering",
136
+ "artificial intelligence","machine learning","augmented reality","virtual reality",
137
+ }
138
  GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"}
139
+ MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn","move","clarify"}
140
+ BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer","stronger","together","confidently","simply","instantly"}
141
+ GOOD_SLOGANS_TO_AVOID_DUP = {
142
+ "smarter care, faster decisions",
143
+ "checkout built for small brands",
144
+ "less guessing. more healing.",
145
+ "built to grow with your cart.",
146
+ "stand tall. feel better.",
147
+ "train your brain to win.",
148
+ "your body. your algorithm.",
149
+ "play smarter. grow brighter.",
150
+ "style that thinks with you."
151
+ }
152
+
153
+ def _tokens(s: str) -> List[str]:
154
+ return re.findall(r"[a-z0-9]{3,}", s.lower())
155
+
156
+ def _jaccard(a: List[str], b: List[str]) -> float:
157
+ A, B = set(a), set(b)
158
+ return 0.0 if not A or not B else len(A & B) / len(A | B)
159
+
160
+ def _titlecase_soft(s: str) -> str:
161
+ out = []
162
+ for w in s.split():
163
+ out.append(w if w.isupper() else w.capitalize())
164
+ return " ".join(out)
165
+
166
+ def _is_blocked_slogan(s: str) -> bool:
167
+ if not s: return True
168
+ s_strip = s.strip()
169
+ for pat in BLOCK_PATTERNS:
170
+ if re.match(pat, s_strip):
171
+ return True
172
+ s_low = s_strip.lower()
173
+ for w in HARD_BLOCK_WORDS:
174
+ if w in s_low:
175
+ return True
176
+ if s_low in GOOD_SLOGANS_TO_AVOID_DUP:
177
+ return True
178
+ return False
179
+
180
+ def _generic_penalty(s: str) -> float:
181
+ hits = sum(1 for w in GENERIC_WORDS if w in s.lower())
182
+ return min(1.0, 0.25 * hits)
183
+
184
+ def _for_penalty(s: str) -> float:
185
+ return 0.3 if re.search(r"\bfor\b", s.lower()) else 0.0
186
+
187
+ def _neighbor_context(neighbors_df: pd.DataFrame) -> str:
188
+ if neighbors_df is None or neighbors_df.empty:
189
+ return ""
190
+ examples = []
191
+ for _, row in neighbors_df.head(3).iterrows():
192
+ tg = str(row.get("tagline", "")).strip()
193
+ if 5 <= len(tg) <= 70:
194
+ examples.append(f"- {tg}")
195
+ return "\n".join(examples)
196
+
197
+ def _copies_neighbor(s: str, neighbors_df: pd.DataFrame) -> bool:
198
+ if neighbors_df is None or neighbors_df.empty:
199
+ return False
200
+ s_low = s.lower()
201
+ s_toks = _tokens(s_low)
202
+ for _, row in neighbors_df.iterrows():
203
+ t = str(row.get("tagline", "")).strip()
204
+ if not t:
205
+ continue
206
+ t_low = t.lower()
207
+ if s_low == t_low:
208
+ return True
209
+ if _jaccard(s_toks, _tokens(t_low)) >= 0.7:
210
+ return True
211
+ try:
212
+ s_vec = _EMBED_MODEL.encode([s])[0]; s_vec = s_vec / np.linalg.norm(s_vec)
213
+ for _, row in neighbors_df.head(3).iterrows():
214
+ t = str(row.get("tagline", "")).strip()
215
+ if not t: continue
216
+ t_vec = _EMBED_MODEL.encode([t])[0]; t_vec = t_vec / np.linalg.norm(t_vec)
217
+ if float(np.dot(s_vec, t_vec)) >= 0.85:
218
+ return True
219
+ except Exception:
220
+ pass
221
+ return False
222
 
223
  def _clean_slogan(text: str, max_words: int = 8) -> str:
224
  text = text.strip().split("\n")[0]
225
  text = re.sub(r"[\"“”‘’]", "", text)
226
  text = re.sub(r"\s+", " ", text).strip()
227
+ text = re.sub(r"^\W+|\W+$", "", text)
228
  words = text.split()
229
  if len(words) > max_words:
230
  text = " ".join(words[:max_words])
231
  return text
232
 
233
+ def _score_candidates(query: str, cands: List[str], neighbors_df: pd.DataFrame) -> List[tuple]:
 
 
 
 
 
 
 
 
 
234
  if not cands:
235
  return []
236
+ ce_scores = np.asarray(_RERANKER.predict([(query, s) for s in cands]), dtype=np.float32) / 5.0
237
+ q_toks = _tokens(query)
238
  results = []
239
+
240
+ neighbor_vecs = []
241
+ if neighbors_df is not None and not neighbors_df.empty:
242
+ for _, row in neighbors_df.head(3).iterrows():
243
+ t = str(row.get("tagline","")).strip()
244
+ if t:
245
+ v = _EMBED_MODEL.encode([t])[0]
246
+ neighbor_vecs.append(v / np.linalg.norm(v))
247
+
248
  for i, s in enumerate(cands):
249
  words = s.split()
250
+ brevity = 1.0 - min(1.0, abs(len(words) - 5) / 5.0) # best ~5 words
251
+ wl = set(w.lower() for w in words)
252
+ m_hits = len(wl & MARKETING_VERBS)
253
+ b_hits = len(wl & BENEFIT_WORDS)
254
+ marketing = min(1.0, 0.2*m_hits + 0.2*b_hits)
255
+ g_pen = _generic_penalty(s)
256
+ f_pen = _for_penalty(s)
257
+
258
+ n_pen = 0.0
259
+ if neighbor_vecs:
260
+ try:
261
+ s_vec = _EMBED_MODEL.encode([s])[0]; s_vec = s_vec / np.linalg.norm(s_vec)
262
+ sim_max = max(float(np.dot(s_vec, nv)) for nv in neighbor_vecs) if neighbor_vecs else 0.0
263
+ n_pen = sim_max
264
+ except Exception:
265
+ n_pen = 0.0
266
+
267
+ overlap = _jaccard(q_toks, _tokens(s))
268
+ anti_copy = 1.0 - overlap
269
+
270
+ score = (
271
+ 0.55*float(ce_scores[i]) +
272
+ 0.20*brevity +
273
+ 0.15*marketing +
274
+ 0.03*anti_copy -
275
+ 0.07*g_pen -
276
+ 0.03*f_pen -
277
+ 0.10*n_pen
278
+ )
279
  results.append((s, float(score)))
280
  return results
281
 
282
+ def generate_slogan(query_text: str, neighbors_df: pd.DataFrame = None, n_samples: int = NUM_SLOGAN_SAMPLES) -> str:
283
+ _ensure_models()
284
+ ctx = _neighbor_context(neighbors_df)
285
  prompt = (
286
  "You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n"
287
  "Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n"
288
+ "Focus on clear benefits and vivid verbs. Do not copy the description. Return ONLY a list, one slogan per line.\n\n"
289
+ "Good Examples:\n"
290
+ "Description: AI assistant for doctors to prioritize patient cases\n"
291
+ "Slogan: Less Guessing. More Healing.\n\n"
292
+ "Description: Payments for small online stores\n"
293
+ "Slogan: Built to Grow with Your Cart.\n\n"
294
+ "Description: Neurotech headset to boost focus\n"
295
+ "Slogan: Train Your Brain to Win.\n\n"
296
+ "Description: Interior design suggestions with AI\n"
297
+ "Slogan: Style That Thinks With You.\n\n"
298
+ "Bad Examples (avoid these): Innovative AI Platform / Smart App for Everyone / Empowering Small Businesses\n\n"
299
  )
300
+ if ctx:
301
+ prompt += f"Similar taglines (style only):\n{ctx}\n\n"
302
+ prompt += f"Description: {query_text}\nSlogans:"
303
 
304
+ input_ids = _GEN_TOK(prompt, return_tensors="pt").input_ids.to(DEVICE)
305
+ outputs = _GEN_MODEL.generate(
306
  input_ids,
307
  max_new_tokens=24,
308
  do_sample=True,
309
  top_k=60,
310
  top_p=0.92,
311
  temperature=1.2,
312
+ num_return_sequences=n_samples,
313
+ repetition_penalty=1.08
314
  )
315
+ raw_cands = [_GEN_TOK.decode(o, skip_special_tokens=True) for o in outputs]
 
316
 
317
  cand_set = set()
318
  for txt in raw_cands:
319
  for line in txt.split("\n"):
320
  s = _clean_slogan(line)
321
+ if not s:
322
+ continue
323
+ if len(s.split()) < 2 or len(s.split()) > 8:
324
+ continue
325
+ if _is_blocked_slogan(s):
326
+ continue
327
+ if _copies_neighbor(s, neighbors_df):
328
+ continue
329
+ cand_set.add(_titlecase_soft(s))
330
 
331
  if not cand_set:
332
+ return _clean_slogan(_GEN_TOK.decode(outputs[0], skip_special_tokens=True))
333
 
334
+ scored = _score_candidates(query_text, sorted(cand_set), neighbors_df)
335
+ if not scored:
336
+ return _clean_slogan(_GEN_TOK.decode(outputs[0], skip_special_tokens=True))
337
 
338
+ scored.sort(key=lambda x: x[1], reverse=True)
339
+ return scored[0][0]
 
 
 
 
 
340
 
341
+ # =========================
342
+ # Gradio Pipeline
343
+ # =========================
344
+ EXAMPLES = [
345
  "AI coach for improving public speaking skills",
346
  "Augmented reality app for interactive museum tours",
347
  "Voice-controlled task manager for remote teams",
348
  "Machine learning system for predicting crop yields",
349
+ "Platform for AI-assisted interior design suggestions",
350
  ]
351
 
352
+ def pipeline(user_input: str):
353
+ # 1) Top-3 recommendations from your FAISS index (mpnet by default)
354
+ recs = recommend(user_input, model_name=DEFAULT_MODEL_FOR_INDEX, top_k=3)
355
+
356
+ # 2) Generate slogan using the neighbors as style context
357
+ slogan = generate_slogan(user_input, neighbors_df=recs, n_samples=NUM_SLOGAN_SAMPLES)
358
+
359
+ # 3) Append the generated item as the 4th row
360
+ recs = recs.reset_index(drop=True)
361
+ # Ensure columns exist
362
+ if "name" not in recs.columns: recs["name"] = ""
363
+ if "tagline" not in recs.columns: recs["tagline"] = ""
364
+ if "description" not in recs.columns: recs["description"] = ""
365
+
366
+ recs.loc[len(recs)] = {
367
+ "row_id": np.nan,
368
+ "name": "Synthetic Example",
369
+ "tagline": slogan,
370
+ "description": user_input,
371
+ "score": np.nan
372
+ }
373
+ # Second output: the slogan itself (visible headline)
374
+ return recs[["name","tagline","description","score"]], slogan
375
+
376
+ with gr.Blocks(title="SloganAI — Recommendations + Slogan Generator") as demo:
377
+ gr.Markdown("## SloganAI — Top-3 Recommendations + A High-Quality Generated Slogan\nEnter a startup idea, click **Submit**, or try an example.")
378
+ with gr.Row():
379
+ with gr.Column(scale=1):
380
+ inp = gr.Textbox(label="Enter a startup description", lines=3, placeholder="e.g., AI coach for improving public speaking skills")
381
+ ex = gr.Examples(EXAMPLES, inputs=inp, label="One‑click examples")
382
+ btn = gr.Button("Submit", variant="primary")
383
+ with gr.Column(scale=2):
384
+ out_df = gr.Dataframe(headers=["Name","Tagline","Description","Score"], label="Top 3 + Generated")
385
+ out_sg = gr.Textbox(label="Generated Slogan", interactive=False)
386
+
387
+ btn.click(fn=pipeline, inputs=inp, outputs=[out_df, out_sg])
388
 
389
  if __name__ == "__main__":
390
+ _ensure_models()
391
+ demo.queue().launch()
requirements.txt CHANGED
@@ -5,3 +5,4 @@ faiss-cpu
5
  pandas
6
  numpy
7
  torch
 
 
5
  pandas
6
  numpy
7
  torch
8
+ pyarrow