3v324v23 commited on
Commit
c17e99d
·
1 Parent(s): 73f5c98

Deploy refined v2 slogan generator with Gradio UI

Browse files
Files changed (2) hide show
  1. app.py +119 -270
  2. requirements.txt +3 -4
app.py CHANGED
@@ -1,304 +1,153 @@
1
 
2
- import os, re, json, numpy as np, pandas as pd, gradio as gr, faiss, torch
3
- from typing import List
 
 
4
  from sentence_transformers import SentenceTransformer, CrossEncoder
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
- # =========================
8
- # Config
9
- # =========================
10
- FLAN_PRIMARY = os.getenv("FLAN_PRIMARY", "google/flan-t5-large")
11
- FLAN_FALLBACK = "google/flan-t5-base"
12
- EMBED_NAME = "sentence-transformers/all-mpnet-base-v2"
13
- RERANK_NAME = "cross-encoder/stsb-roberta-base"
14
-
15
- NUM_SLOGAN_SAMPLES = int(os.getenv("NUM_SLOGAN_SAMPLES", "16"))
16
- INDEX_ROOT = os.path.join(os.path.dirname(__file__), "vector_store")
17
- DEFAULT_MODEL_FOR_INDEX = EMBED_NAME
18
- CSV_PATH = os.path.join(os.path.dirname(__file__), "cleaned_data.csv")
19
-
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
- # =========================
23
- # Lazy models
24
- # =========================
25
- _GEN_TOK = None
26
- _GEN_MODEL = None
27
- _EMBED_MODEL = None
28
- _RERANKER = None
29
-
30
- def _ensure_models():
31
- global _GEN_TOK, _GEN_MODEL, _EMBED_MODEL, _RERANKER
32
- if _EMBED_MODEL is None:
33
- _EMBED_MODEL = SentenceTransformer(EMBED_NAME)
34
- if _RERANKER is None:
35
- _RERANKER = CrossEncoder(RERANK_NAME)
36
- if _GEN_MODEL is None:
37
- try:
38
- tok = AutoTokenizer.from_pretrained(FLAN_PRIMARY)
39
- mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PRIMARY)
40
- _GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE)
41
- print(f"[INFO] Loaded generator: {FLAN_PRIMARY}")
42
- except Exception as e:
43
- print(f"[WARN] {e}; fallback to {FLAN_FALLBACK}")
44
- tok = AutoTokenizer.from_pretrained(FLAN_FALLBACK)
45
- mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_FALLBACK)
46
- _GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE)
47
-
48
- # =========================
49
- # Index cache
50
- # =========================
51
- _INDEX_CACHE = {} # mkey -> (faiss_index, meta_df)
52
-
53
- def _model_key(name: str) -> str:
54
- return name.replace("/", "_")
55
-
56
- def _format_for_e5(texts, as_query=False):
57
- prefix = "query: " if as_query else "passage: "
58
- return [prefix + str(t) for t in texts]
59
-
60
- def _build_memory_index_from_csv(model_name: str):
61
- if not os.path.exists(CSV_PATH):
62
- return None
63
- df = pd.read_csv(CSV_PATH)
64
- for col in ("name","tagline","description"):
65
- if col not in df.columns: df[col] = ""
66
- texts = df["description"].astype(str).tolist()
67
- embedder = SentenceTransformer(model_name) if model_name != EMBED_NAME else _EMBED_MODEL
68
- if model_name.startswith("intfloat/e5"):
69
- texts = _format_for_e5(texts, as_query=False)
70
- vecs = embedder.encode(texts, normalize_embeddings=True)
71
- vecs = np.asarray(vecs, dtype=np.float32)
72
- idx = faiss.IndexFlatIP(vecs.shape[1])
73
- idx.add(vecs)
74
- return idx, df[["name","tagline","description"]].copy()
75
-
76
- def _load_index_for_model(model_name: str = DEFAULT_MODEL_FOR_INDEX):
77
- mkey = _model_key(model_name)
78
- if mkey in _INDEX_CACHE: return _INDEX_CACHE[mkey]
79
-
80
- base = os.path.join(INDEX_ROOT, mkey)
81
- idx_path = os.path.join(base, "index.faiss")
82
- meta_path = os.path.join(base, "meta.parquet")
83
-
84
- if os.path.exists(idx_path) and os.path.exists(meta_path):
85
- index = faiss.read_index(idx_path)
86
- meta = pd.read_parquet(meta_path)
87
- _INDEX_CACHE[mkey] = (index, meta)
88
- return _INDEX_CACHE[mkey]
89
-
90
- # fallback: build from CSV if available
91
- built = _build_memory_index_from_csv(model_name)
92
- if built is not None:
93
- _INDEX_CACHE[mkey] = built
94
- return built
95
-
96
- # last fallback: tiny demo
97
- print("[WARN] FAISS & CSV missing — using tiny demo index")
98
- demo = pd.DataFrame({
99
- "name":["HowDidIDo","Museotainment","Movitr"],
100
- "tagline":["Online evaluation platform","PacMan & Louvre meet","Crowdsourced video translation"],
101
- "description":[
102
- "Public speaking, Presentation skills and interview practice",
103
- "Interactive AR museum tours",
104
- "Video translation with voice and subtitles"
105
- ]
106
- })
107
- embedder = SentenceTransformer(model_name) if model_name != EMBED_NAME else _EMBED_MODEL
108
- vecs = embedder.encode(demo["description"].tolist(), normalize_embeddings=True)
109
- vecs = np.asarray(vecs, dtype=np.float32)
110
- idx = faiss.IndexFlatIP(vecs.shape[1]); idx.add(vecs)
111
- _INDEX_CACHE[mkey] = (idx, demo)
112
- return _INDEX_CACHE[mkey]
113
-
114
- # =========================
115
- # Recommend
116
- # =========================
117
- def recommend(query_text: str, model_name: str = DEFAULT_MODEL_FOR_INDEX, top_k: int = 3) -> pd.DataFrame:
118
- _ensure_models()
119
- index, meta = _load_index_for_model(model_name)
120
- q_inp = _format_for_e5([query_text], as_query=True) if model_name.startswith("intfloat/e5") else [query_text]
121
- q_vec = _EMBED_MODEL.encode(q_inp, normalize_embeddings=True)
122
- q_vec = np.asarray(q_vec, dtype=np.float32)
123
- scores, idxs = index.search(q_vec, top_k)
124
- out = meta.iloc[idxs[0]].copy()
125
- out["score"] = scores[0]
126
- for col in ("name","tagline","description"):
127
- if col not in out.columns: out[col] = ""
128
- cols = ["name","tagline","description","score"]
129
- return out[cols]
130
-
131
- # =========================
132
- # Refined v2 – helpers
133
- # =========================
134
  BLOCK_PATTERNS = [
135
  r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$",
136
  r"^[A-Z][a-z]+ [A-Z][a-z]+$",
137
  r"^[A-Z][a-z]+$",
138
  ]
139
- HARD_BLOCK_WORDS = {
140
- "platform","solution","system","application","marketplace",
141
  "ai-powered","ai powered","empower","empowering",
142
- "artificial intelligence","machine learning","augmented reality","virtual reality",
143
- }
144
  GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"}
145
- MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn","move","clarify"}
146
- BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer","stronger","together","confidently","simply","instantly"}
147
- GOOD_SLOGANS_TO_AVOID_DUP = {
148
- "smarter care, faster decisions","checkout built for small brands","less guessing. more healing.",
149
- "built to grow with your cart.","stand tall. feel better.","train your brain to win.",
150
- "your body. your algorithm.","play smarter. grow brighter.","style that thinks with you."
151
- }
 
 
 
 
152
 
153
- def _tokens(s: str) -> List[str]: return re.findall(r"[a-z0-9]{3,}", s.lower())
154
- def _jaccard(a: List[str], b: List[str]) -> float:
155
- A,B=set(a),set(b); return 0.0 if not A or not B else len(A&B)/len(A|B)
156
- def _titlecase_soft(s: str) -> str:
157
- return " ".join(w if w.isupper() else w.capitalize() for w in s.split())
158
  def _is_blocked_slogan(s: str) -> bool:
159
- if not s: return True
160
- s_strip=s.strip()
 
161
  for pat in BLOCK_PATTERNS:
162
- if re.match(pat, s_strip): return True
163
- s_low=s_strip.lower()
164
- if any(w in s_low for w in HARD_BLOCK_WORDS): return True
165
- return s_low in GOOD_SLOGANS_TO_AVOID_DUP
166
-
167
- def _generic_penalty(s: str) -> float:
168
- hits=sum(1 for w in GENERIC_WORDS if w in s.lower()); return min(1.0, 0.25*hits)
169
- def _for_penalty(s: str) -> float: return 0.3 if re.search(r"\bfor\b", s.lower()) else 0.0
170
-
171
- def _neighbor_context(neighbors_df: pd.DataFrame) -> str:
172
- if neighbors_df is None or neighbors_df.empty: return ""
173
- ex=[]
174
- for _,row in neighbors_df.head(3).iterrows():
175
- tg=str(row.get("tagline","")).strip()
176
- if 5<=len(tg)<=70: ex.append(f"- {tg}")
177
- return "\n".join(ex)
178
-
179
- def _copies_neighbor(s: str, neighbors_df: pd.DataFrame) -> bool:
180
- if neighbors_df is None or neighbors_df.empty: return False
181
- s_low=s.lower(); s_toks=_tokens(s_low)
182
- for _,row in neighbors_df.iterrows():
183
- t=str(row.get("tagline","")).strip()
184
- if not t: continue
185
- t_low=t.lower()
186
- if s_low==t_low: return True
187
- if _jaccard(s_toks,_tokens(t_low))>=0.7: return True
188
- try:
189
- em=SentenceTransformer(EMBED_NAME)
190
- s_vec=em.encode([s])[0]; s_vec=s_vec/np.linalg.norm(s_vec)
191
- for _,row in neighbors_df.head(3).iterrows():
192
- t=str(row.get("tagline","")).strip()
193
- if not t: continue
194
- t_vec=em.encode([t])[0]; t_vec=t_vec/np.linalg.norm(t_vec)
195
- if float(np.dot(s_vec,t_vec))>=0.85: return True
196
- except: pass
197
  return False
198
 
199
- def _clean_slogan(text: str, max_words: int = 8) -> str:
200
- text=text.strip().split("\n")[0]
201
- text=re.sub(r"[\"“”‘’]","",text); text=re.sub(r"\s+"," ",text).strip()
202
- words=text.split()
203
- return " ".join(words[:max_words]) if len(words)>max_words else text
204
-
205
- def _score_candidates(query: str, cands: List[str], neighbors_df: pd.DataFrame) -> List[tuple]:
206
- if not cands: return []
207
- ce_scores=np.asarray(CrossEncoder(RERANK_NAME).predict([(query,s) for s in cands]),dtype=np.float32)/5.0
208
- q_toks=_tokens(query); results=[]
209
-
210
- em=SentenceTransformer(EMBED_NAME)
211
- neighbor_vecs=[]
212
- if neighbors_df is not None and not neighbors_df.empty:
213
- for _,row in neighbors_df.head(3).iterrows():
214
- t=str(row.get("tagline","")).strip()
215
- if t:
216
- v=em.encode([t])[0]; neighbor_vecs.append(v/np.linalg.norm(v))
217
-
218
- for i,s in enumerate(cands):
219
- words=s.split()
220
- brev=1.0-min(1.0,abs(len(words)-5)/5.0)
221
- wl=set(w.lower() for w in words)
222
- m_hits=len(wl & MARKETING_VERBS); b_hits=len(wl & BENEFIT_WORDS)
223
- marketing=min(1.0,0.2*m_hits+0.2*b_hits)
224
- g_pen=_generic_penalty(s); f_pen=_for_penalty(s)
225
- n_pen=0.0
226
- if neighbor_vecs:
227
- try:
228
- s_vec=em.encode([s])[0]; s_vec=s_vec/np.linalg.norm(s_vec)
229
- sim_max=max(float(np.dot(s_vec,nv)) for nv in neighbor_vecs) if neighbor_vecs else 0.0
230
- n_pen=sim_max
231
- except: n_pen=0.0
232
- overlap=_jaccard(q_toks,_tokens(s)); anti_copy=1.0-overlap
233
- score=0.55*float(ce_scores[i])+0.20*brev+0.15*marketing+0.03*anti_copy-0.07*g_pen-0.03*f_pen-0.10*n_pen
234
- results.append((s,float(score)))
235
  return results
236
 
237
- def generate_slogan(query_text: str, neighbors_df: pd.DataFrame = None, n_samples: int = NUM_SLOGAN_SAMPLES) -> str:
238
- _ensure_models()
239
- ctx=_neighbor_context(neighbors_df)
240
- prompt=(
241
  "You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n"
242
  "Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n"
243
- "Focus on clear benefits and vivid verbs. Do not copy the description. Return ONLY a list, one slogan per line.\n\n"
244
- "Good Examples:\nDescription: AI assistant for doctors to prioritize patient cases\nSlogan: Less Guessing. More Healing.\n\n"
245
- "Description: Payments for small online stores\nSlogan: Built to Grow with Your Cart.\n\n"
246
- "Description: Neurotech headset to boost focus\nSlogan: Train Your Brain to Win.\n\n"
247
- "Description: Interior design suggestions with AI\nSlogan: Style That Thinks With You.\n\n"
248
- "Bad Examples (avoid these): Innovative AI Platform / Smart App for Everyone / Empowering Small Businesses\n\n"
249
  )
250
- if ctx: prompt+=f"Similar taglines (style only):\n{ctx}\n\n"
251
- prompt+=f"Description: {query_text}\nSlogans:"
252
- input_ids=_GEN_TOK(prompt,return_tensors="pt").input_ids.to(DEVICE)
253
- outputs=_GEN_MODEL.generate(input_ids,max_new_tokens=24,do_sample=True,top_k=60,top_p=0.92,temperature=1.2,num_return_sequences=n_samples,repetition_penalty=1.08)
254
- raw=[_GEN_TOK.decode(o,skip_special_tokens=True) for o in outputs]
255
- cand=set()
256
- for txt in raw:
 
 
 
 
 
 
 
 
 
257
  for line in txt.split("\n"):
258
- s=_clean_slogan(line)
259
  if not s: continue
260
- if len(s.split())<2 or len(s.split())>8: continue
261
  if _is_blocked_slogan(s): continue
262
- if _copies_neighbor(s,neighbors_df): continue
263
- cand.add(_titlecase_soft(s))
264
- if not cand: return _clean_slogan(_GEN_TOK.decode(outputs[0],skip_special_tokens=True))
265
- scored=_score_candidates(query_text,sorted(cand),neighbors_df)
266
- if not scored: return _clean_slogan(_GEN_TOK.decode(outputs[0],skip_special_tokens=True))
267
- scored.sort(key=lambda x:x[1],reverse=True)
268
- return scored[0][0]
269
 
270
- # =========================
271
- # Gradio
272
- # =========================
273
- EXAMPLES=[
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  "AI coach for improving public speaking skills",
275
  "Augmented reality app for interactive museum tours",
276
  "Voice-controlled task manager for remote teams",
277
  "Machine learning system for predicting crop yields",
278
- "Platform for AI-assisted interior design suggestions",
279
  ]
280
 
281
- def pipeline(user_input: str):
282
- recs=recommend(user_input, model_name=DEFAULT_MODEL_FOR_INDEX, top_k=3)
283
- slogan=generate_slogan(user_input, neighbors_df=recs, n_samples=NUM_SLOGAN_SAMPLES)
284
- recs=recs.reset_index(drop=True)
285
- for col in ("name","tagline","description"):
286
- if col not in recs.columns: recs[col]=""
287
- recs.loc[len(recs)]={"name":"Synthetic Example","tagline":slogan,"description":user_input,"score":np.nan}
288
- return recs[["name","tagline","description","score"]], slogan
289
-
290
- with gr.Blocks(title="SloganAI — Recommendations + Slogan Generator") as demo:
291
- gr.Markdown("## SloganAI — Top-3 Recommendations + A High-Quality Generated Slogan")
292
- with gr.Row():
293
- with gr.Column(scale=1):
294
- inp=gr.Textbox(label="Enter a startup description", lines=3, placeholder="e.g., AI coach for improving public speaking skills")
295
- gr.Examples(EXAMPLES, inputs=inp, label="One-click examples")
296
- btn=gr.Button("Submit", variant="primary")
297
- with gr.Column(scale=2):
298
- out_df=gr.Dataframe(headers=["Name","Tagline","Description","Score"], label="Top 3 + Generated")
299
- out_sg=gr.Textbox(label="Generated Slogan", interactive=False)
300
- btn.click(fn=pipeline, inputs=inp, outputs=[out_df, out_sg])
301
 
302
  if __name__ == "__main__":
303
- _ensure_models()
304
- demo.queue().launch()
 
1
 
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import faiss, re, torch
6
  from sentence_transformers import SentenceTransformer, CrossEncoder
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
+ # ------------------ Models ------------------
10
+ GEN_TOK = AutoTokenizer.from_pretrained("google/flan-t5-large")
11
+ GEN_MODEL = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
 
 
 
 
 
 
 
 
 
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ GEN_MODEL = GEN_MODEL.to(DEVICE)
14
+
15
+ EMBED_MODEL = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
16
+ RERANKER = CrossEncoder("cross-encoder/stsb-roberta-base")
17
+
18
+ # ------------------ Dummy dataset (for demo) ------------------
19
+ data = pd.DataFrame({
20
+ "name": ["HowDidIDo", "Museotainment", "Movitr"],
21
+ "tagline": ["Online evaluation platform", "PacMan & Louvre meet", "Crowdsourced video translation"],
22
+ "description": [
23
+ "Public speaking, Presentation skills and interview practice",
24
+ "Interactive AR museum tours",
25
+ "Video translation with voice and subtitles"
26
+ ]
27
+ })
28
+
29
+ # Build FAISS index
30
+ data_vecs = EMBED_MODEL.encode(data["description"].tolist())
31
+ faiss.normalize_L2(data_vecs)
32
+ index = faiss.IndexFlatIP(data_vecs.shape[1])
33
+ index.add(data_vecs)
34
+
35
+ def recommend(query, top_k=3):
36
+ query_vec = EMBED_MODEL.encode([query])
37
+ faiss.normalize_L2(query_vec)
38
+ scores, idx = index.search(query_vec, top_k)
39
+ results = data.iloc[idx[0]].copy()
40
+ results["score"] = scores[0]
41
+ return results[["name", "tagline", "description", "score"]]
42
+
43
+ # ------------------ Helpers ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  BLOCK_PATTERNS = [
45
  r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$",
46
  r"^[A-Z][a-z]+ [A-Z][a-z]+$",
47
  r"^[A-Z][a-z]+$",
48
  ]
49
+
50
+ HARD_BLOCK_WORDS = {"platform","solution","system","application","marketplace",
51
  "ai-powered","ai powered","empower","empowering",
52
+ "artificial intelligence","machine learning","augmented reality","virtual reality"}
 
53
  GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"}
54
+ MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn"}
55
+ BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer"}
56
+
57
+ def _clean_slogan(text: str, max_words: int = 8) -> str:
58
+ text = text.strip().split("\n")[0]
59
+ text = re.sub(r"[\"“”‘’]", "", text)
60
+ text = re.sub(r"\s+", " ", text).strip()
61
+ words = text.split()
62
+ if len(words) > max_words:
63
+ text = " ".join(words[:max_words])
64
+ return text
65
 
 
 
 
 
 
66
  def _is_blocked_slogan(s: str) -> bool:
67
+ s_low = s.lower()
68
+ if any(w in s_low for w in HARD_BLOCK_WORDS):
69
+ return True
70
  for pat in BLOCK_PATTERNS:
71
+ if re.match(pat, s.strip()):
72
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return False
74
 
75
+ def _score_candidates(query: str, cands: list) -> list:
76
+ if not cands:
77
+ return []
78
+ ce_scores = np.asarray(RERANKER.predict([(query, s) for s in cands]), dtype=np.float32) / 5.0
79
+ results = []
80
+ for i, s in enumerate(cands):
81
+ words = s.split()
82
+ brevity = 1.0 - min(1.0, abs(len(words) - 5) / 5.0)
83
+ marketing = 0.2*len(set(words) & MARKETING_VERBS) + 0.2*len(set(words) & BENEFIT_WORDS)
84
+ score = 0.6*float(ce_scores[i]) + 0.2*brevity + 0.2*marketing
85
+ results.append((s, float(score)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return results
87
 
88
+ # ------------------ Generator ------------------
89
+ def generate_slogan(query_text: str, n_samples: int = 16) -> str:
90
+ prompt = (
 
91
  "You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n"
92
  "Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n"
93
+ "Focus on benefits and vivid verbs. Do not copy the description.\n\n"
94
+ f"Description: {query_text}\nSlogans:"
 
 
 
 
95
  )
96
+
97
+ input_ids = GEN_TOK(prompt, return_tensors="pt").input_ids.to(DEVICE)
98
+ outputs = GEN_MODEL.generate(
99
+ input_ids,
100
+ max_new_tokens=24,
101
+ do_sample=True,
102
+ top_k=60,
103
+ top_p=0.92,
104
+ temperature=1.2,
105
+ num_return_sequences=n_samples
106
+ )
107
+
108
+ raw_cands = [GEN_TOK.decode(o, skip_special_tokens=True) for o in outputs]
109
+
110
+ cand_set = set()
111
+ for txt in raw_cands:
112
  for line in txt.split("\n"):
113
+ s = _clean_slogan(line)
114
  if not s: continue
115
+ if len(s.split()) < 2 or len(s.split()) > 8: continue
116
  if _is_blocked_slogan(s): continue
117
+ cand_set.add(s.capitalize())
 
 
 
 
 
 
118
 
119
+ if not cand_set:
120
+ return "Fresh Ideas, Built To Scale"
121
+
122
+ scored = _score_candidates(query_text, sorted(cand_set))
123
+ scored.sort(key=lambda x: x[1], reverse=True)
124
+ return scored[0][0] if scored else "Fresh Ideas, Built To Scale"
125
+
126
+ # ------------------ Pipeline ------------------
127
+ def pipeline(user_input):
128
+ recs = recommend(user_input, top_k=3)
129
+ slogan = generate_slogan(user_input)
130
+ recs = recs.reset_index(drop=True)
131
+ recs.loc[len(recs)] = ["Generated Slogan", slogan, user_input, np.nan]
132
+ return recs
133
+
134
+ # ------------------ Gradio UI ------------------
135
+ examples = [
136
  "AI coach for improving public speaking skills",
137
  "Augmented reality app for interactive museum tours",
138
  "Voice-controlled task manager for remote teams",
139
  "Machine learning system for predicting crop yields",
140
+ "Platform for AI-assisted interior design suggestions"
141
  ]
142
 
143
+ demo = gr.Interface(
144
+ fn=pipeline,
145
+ inputs=gr.Textbox(label="Enter a startup description"),
146
+ outputs=gr.Dataframe(headers=["Name", "Tagline", "Description", "Score"]),
147
+ examples=examples,
148
+ title="SloganAI Startup Recommendation & Slogan Generator",
149
+ description="Enter a startup idea and get top-3 similar startups + 1 generated slogan."
150
+ )
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  if __name__ == "__main__":
153
+ demo.launch()
 
requirements.txt CHANGED
@@ -1,8 +1,7 @@
1
- gradio>=4.36.1,<5
2
- transformers>=4.42,<5
3
- sentence-transformers>=2.3.1
4
  faiss-cpu
5
  pandas
6
  numpy
7
  torch
8
- pyarrow
 
1
+ gradio
2
+ transformers
3
+ sentence-transformers
4
  faiss-cpu
5
  pandas
6
  numpy
7
  torch