Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,25 @@
|
|
1 |
# app.py โ GIftyPlus (lean)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os, re, json, hashlib, pathlib, random
|
3 |
from typing import Dict, List, Tuple, Optional, Any
|
4 |
import numpy as np, pandas as pd, gradio as gr, torch
|
@@ -13,6 +34,8 @@ MAX_ROWS = int(os.getenv("MAX_ROWS", "12000"))
|
|
13 |
EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "sentence-transformers/all-MiniLM-L12-v2")
|
14 |
|
15 |
def resolve_cache_dir():
|
|
|
|
|
16 |
for p in [os.getenv("EMBED_CACHE_DIR"), os.path.join(os.getcwd(), ".gifty_cache"), "/tmp/.gifty_cache"]:
|
17 |
if not p: continue
|
18 |
pathlib.Path(p).mkdir(parents=True, exist_ok=True)
|
@@ -22,6 +45,7 @@ def resolve_cache_dir():
|
|
22 |
return os.getcwd()
|
23 |
EMBED_CACHE_DIR = resolve_cache_dir()
|
24 |
|
|
|
25 |
INTEREST_OPTIONS = ["Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion","Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food","Home decor","Science"]
|
26 |
OCCASION_UI = ["Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming","Retirement","Holidays","Valentineโs Day","Promotion / New job","Get well soon"]
|
27 |
OCCASION_CANON = {"Birthday":"birthday","Wedding / Engagement":"wedding","Anniversary":"anniversary","Graduation":"graduation","New baby":"new_baby","Housewarming":"housewarming","Retirement":"retirement","Holidays":"holidays","Valentineโs Day":"valentines","Promotion / New job":"promotion","Get well soon":"get_well"}
|
@@ -30,9 +54,11 @@ MESSAGE_TONES = ["Formal","Casual","Funny","Heartfelt","Inspirational","Playful"
|
|
30 |
AGE_OPTIONS = {"any":"any","kid (3โ12)":"kids","teen (13โ17)":"teens","adult (18โ64)":"adult","senior (65+)":"senior"}
|
31 |
GENDER_OPTIONS = ["any","female","male","nonbinary"]
|
32 |
|
|
|
33 |
SYNONYMS = {"sports":["fitness","outdoor","training","yoga","run"],"travel":["luggage","passport","map","trip","vacation"],"cooking":["kitchen","cookware","chef","baking"],"technology":["electronics","gadgets","device","smart","computer"],"music":["audio","headphones","earbuds","speaker","vinyl"],"art":["painting","drawing","sketch","canvas"],"reading":["book","novel","literature"],"gardening":["plants","planter","seeds","garden","indoor"],"fashion":["style","accessory","jewelry"],"gaming":["board game","puzzle","video game","controller"],"photography":["camera","lens","tripod","film"],"hiking":["outdoor","camping","backpack","trek"],"movies":["film","cinema","blu-ray","poster"],"crafts":["diy","handmade","kit","knitting"],"pets":["dog","cat","pet"],"wellness":["relaxation","spa","aromatherapy","self-care"],"collecting":["display","collector","limited edition"],"food":["gourmet","snack","treats","chocolate"],"home decor":["home","decor","wall art","candle"],"science":["lab","experiment","STEM","microscope"]}
|
34 |
REL_TO_TOKENS = {"Family - Parent":["parent","family"],"Family - Sibling":["sibling","family"],"Family - Child":["kids","play","family"],"Family - Other relative":["family","relative"],"Friend":["friendly"],"Colleague":["office","work","professional"],"Boss":["executive","professional","premium"],"Romantic partner":["romantic","couple"],"Teacher / Mentor":["teacher","mentor","thank_you"],"Neighbor":["neighbor","housewarming"],"Client / Business partner":["professional","thank_you","premium"]}
|
35 |
|
|
|
36 |
_CURRENCY_RE = re.compile(r"[^\d.,\-]+"); _NUM_RE = re.compile(r"(\d+(?:[.,]\d+)?)"); _RANGE_SEP = re.compile(r"\s*(?:-|โ|โ|to)\s*")
|
37 |
def _to_price_usd(x):
|
38 |
if pd.isna(x): return np.nan
|
@@ -42,6 +68,7 @@ def _to_price_usd(x):
|
|
42 |
return float(m.group(1)) if m else np.nan
|
43 |
|
44 |
def _first_present(df, cands):
|
|
|
45 |
lower = {c.lower(): c for c in df.columns}
|
46 |
for c in cands:
|
47 |
if c in df.columns: return c
|
@@ -49,6 +76,7 @@ def _first_present(df, cands):
|
|
49 |
return None
|
50 |
|
51 |
def _auto_price_col(df):
|
|
|
52 |
for c in df.columns:
|
53 |
s = df[c]
|
54 |
if pd.api.types.is_numeric_dtype(s) and not s.dropna().empty and (s.dropna().between(0.5, 10000)).mean() > .6: return c
|
@@ -57,21 +85,25 @@ def _auto_price_col(df):
|
|
57 |
return None
|
58 |
|
59 |
def map_amazon_to_schema(raw: pd.DataFrame) -> pd.DataFrame:
|
|
|
60 |
name_c=_first_present(raw,["product name","title","name","product_title"]); desc_c=_first_present(raw,["description","product_description","feature","about"])
|
61 |
cat_c=_first_present(raw,["category","categories","main_cat","product_category"]); price_c=_first_present(raw,["selling price","price","current_price","list_price","price_amount","actual_price","price_usd"]) or _auto_price_col(raw)
|
62 |
img_c=_first_present(raw,["image","image_url","imageurl","imUrl","img","img_url"])
|
63 |
df=pd.DataFrame({"name":raw.get(name_c,""),"short_desc":raw.get(desc_c,""),"tags":raw.get(cat_c,""),"price_usd":raw.get(price_c,np.nan),"image_url":raw.get(img_c,"")})
|
|
|
64 |
df["price_usd"]=df["price_usd"].map(_to_price_usd); df["name"]=df["name"].astype(str).str.strip().str.slice(0,160)
|
65 |
df["short_desc"]=df["short_desc"].astype(str).str.strip().str.slice(0,600); df["tags"]=df["tags"].astype(str).str.replace("|",", ").str.lower()
|
66 |
return df
|
67 |
|
68 |
def extract_top_cat(tags:str)->str:
|
|
|
69 |
s=(tags or "").lower()
|
70 |
for sep in ["|",">"]:
|
71 |
if sep in s: return s.split(sep,1)[0].strip()
|
72 |
return s.strip().split(",")[0] if s else ""
|
73 |
|
74 |
def load_catalog()->pd.DataFrame:
|
|
|
75 |
df=map_amazon_to_schema(load_dataset(DATASET_ID, split=DATASET_SPLIT).to_pandas()).drop_duplicates(subset=["name","short_desc"])
|
76 |
df=df[pd.notna(df["price_usd"])]; df=df[(df["price_usd"]>0)&(df["price_usd"]<=500)].reset_index(drop=True)
|
77 |
if len(df)>MAX_ROWS: df=df.sample(n=MAX_ROWS,random_state=42).reset_index(drop=True)
|
@@ -81,6 +113,9 @@ def load_catalog()->pd.DataFrame:
|
|
81 |
return df
|
82 |
CATALOG=load_catalog()
|
83 |
|
|
|
|
|
|
|
84 |
class EmbeddingBank:
|
85 |
def __init__(s, docs, model_id, dataset_tag):
|
86 |
s.model_id=model_id; s.dataset_tag=dataset_tag; s.model=SentenceTransformer(model_id); s.embs=s._load_or_build(docs)
|
@@ -95,10 +130,12 @@ class EmbeddingBank:
|
|
95 |
def query_vec(s,text): return s.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0]
|
96 |
EMB=EmbeddingBank(CATALOG["doc"].tolist(), EMBED_MODEL_ID, DATASET_ID)
|
97 |
|
|
|
98 |
_tok_rx = re.compile(r"[a-z0-9][a-z0-9\-']*")
|
99 |
if "tok_set" not in CATALOG.columns:
|
100 |
CATALOG["tok_set"]=(CATALOG["name"].fillna("")+" "+CATALOG["tags"].fillna("")+" "+CATALOG["short_desc"].fillna("")).map(lambda t:set(_tok_rx.findall(str(t).lower())))
|
101 |
|
|
|
102 |
try:
|
103 |
from sentence_transformers import CrossEncoder
|
104 |
except:
|
@@ -111,6 +148,7 @@ def _load_cross_encoder():
|
|
111 |
_CE_MODEL=CrossEncoder(RERANK_MODEL_ID, device="cpu")
|
112 |
return _CE_MODEL
|
113 |
|
|
|
114 |
OCCASION_PRIORS={"valentines":[("jewelry",.12),("chocolate",.10),("candle",.08),("romantic",.08),("couple",.08),("heart",.06)],
|
115 |
"birthday":[("fun",.06),("game",.06),("personalized",.06),("gift set",.05),("surprise",.04)],
|
116 |
"anniversary":[("couple",.10),("jewelry",.10),("photo",.08),("frame",.06),("memory",.06),("candle",.06)],
|
@@ -123,6 +161,7 @@ OCCASION_PRIORS={"valentines":[("jewelry",.12),("chocolate",.10),("candle",.08),
|
|
123 |
"get_well":[("cozy",.10),("tea",.08),("soothing",.06),("care",.06)]}
|
124 |
|
125 |
def expand_with_synonyms(tokens: List[str])->List[str]:
|
|
|
126 |
out=[];
|
127 |
for t in tokens:
|
128 |
t=t.strip().lower()
|
@@ -130,12 +169,14 @@ def expand_with_synonyms(tokens: List[str])->List[str]:
|
|
130 |
return out
|
131 |
|
132 |
def profile_to_query(p:Dict)->str:
|
|
|
133 |
inter=[i.lower() for i in p.get("interests",[]) if i]; expanded=expand_with_synonyms(inter)*3
|
134 |
parts=[", ".join(expanded) if expanded else "", ", ".join(REL_TO_TOKENS.get(p.get("relationship","Friend"),[])), OCCASION_CANON.get(p.get("occ_ui","Birthday"),"birthday")]
|
135 |
tail=f"gift ideas for a {p.get('relationship','Friend')} for {parts[-1]}; likes {', '.join(inter) or 'general'}"
|
136 |
return " | ".join([x for x in parts if x])+" | "+tail
|
137 |
|
138 |
def _gender_ok_mask(g:str)->np.ndarray:
|
|
|
139 |
g=(g or "any").lower(); bl=CATALOG["blob"]
|
140 |
has_m=bl.str.contains(r"\b(men|man's|mens|male|for men)\b",regex=True,na=False)
|
141 |
has_f=bl.str.contains(r"\b(women|woman's|womens|female|for women|dress)\b",regex=True,na=False)
|
@@ -145,6 +186,7 @@ def _gender_ok_mask(g:str)->np.ndarray:
|
|
145 |
return np.ones(len(bl),bool)
|
146 |
|
147 |
def _mask_by_age(age:str, blob:pd.Series)->np.ndarray:
|
|
|
148 |
kids=blob.str.contains(r"\b(?:kid|kids|child|children|toddler|baby|boys?|girls?|kid's|children's)\b",regex=True,na=False)
|
149 |
teen=blob.str.contains(r"\b(?:teen|teens|young adult|ya)\b",regex=True,na=False)
|
150 |
if age in ("adult","senior"): return (~kids).to_numpy()
|
@@ -153,11 +195,13 @@ def _mask_by_age(age:str, blob:pd.Series)->np.ndarray:
|
|
153 |
return np.ones(len(blob),bool)
|
154 |
|
155 |
def _interest_bonus(p:Dict, idx:np.ndarray)->np.ndarray:
|
|
|
156 |
ints=[i.lower() for i in p.get("interests",[]) if i]; syns=[s for it in ints for s in SYNONYMS.get(it,[])]; vocab=set(ints+syns)
|
157 |
if not vocab or idx.size==0: return np.zeros(len(idx),"float32")
|
158 |
counts=np.array([len(CATALOG["tok_set"].iat[i] & vocab) for i in idx],"float32"); return .10*np.clip(counts,0,6)
|
159 |
|
160 |
def _occasion_bonus(idx:np.ndarray, occ_ui:str)->np.ndarray:
|
|
|
161 |
pri=OCCASION_PRIORS.get(OCCASION_CANON.get(occ_ui or "Birthday","birthday"),[])
|
162 |
if not pri or idx.size==0: return np.zeros(len(idx),"float32")
|
163 |
bl=CATALOG["blob"].to_numpy(); out=np.zeros(len(idx),"float32")
|
@@ -166,11 +210,13 @@ def _occasion_bonus(idx:np.ndarray, occ_ui:str)->np.ndarray:
|
|
166 |
return out
|
167 |
|
168 |
def _minmax(x:np.ndarray)->np.ndarray:
|
|
|
169 |
if x.size==0: return x
|
170 |
lo,hi=float(np.min(x)),float(np.max(x));
|
171 |
return np.zeros_like(x) if hi<=lo+1e-9 else (x-lo)/(hi-lo)
|
172 |
|
173 |
def _mmr_select(cand_idx:np.ndarray, scores:np.ndarray, k:int, lambda_:float=.7)->np.ndarray:
|
|
|
174 |
if cand_idx.size<=k: return cand_idx[np.argsort(-scores)][:k]
|
175 |
picked=[]; rest=list(range(len(cand_idx))); rel=_minmax(scores)
|
176 |
V=np.asarray(EMB.embs,"float32")[cand_idx]; V/=np.linalg.norm(V,axis=1,keepdims=True)+1e-8
|
@@ -180,44 +226,154 @@ def _mmr_select(cand_idx:np.ndarray, scores:np.ndarray, k:int, lambda_:float=.7)
|
|
180 |
j=int(np.argmax(lambda_*rel[rest]-(1-lambda_)*sim_to_sel)); picked.append(rest.pop(j))
|
181 |
return cand_idx[np.array(picked,int)]
|
182 |
|
183 |
-
def recommend_top3_budget_first(
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
q=profile_to_query(p); qv=EMB.query_vec(q).astype("float32")
|
195 |
emb_sims=np.asarray(EMB.embs,"float32")[idx]@qv
|
196 |
target=(lo+hi)/2.0 if hi>lo else hi; prices=CATALOG.iloc[idx]["price_usd"].to_numpy()
|
|
|
197 |
price_bonus=np.clip(.12-np.abs(prices-target)/max(target,1.0),0,.12).astype("float32")
|
198 |
int_bonus=_interest_bonus(p,idx); occ_bonus=_occasion_bonus(idx,p.get("occ_ui","Birthday"))
|
199 |
pre=emb_sims+price_bonus+int_bonus+occ_bonus
|
|
|
200 |
K1=min(48,idx.size); top_local=np.argpartition(-pre,K1-1)[:K1]; cand_idx=idx[top_local]
|
201 |
emb_n=_minmax(emb_sims[top_local]); price_n=_minmax(price_bonus[top_local]); int_n=_minmax(int_bonus[top_local]); occ_n=_minmax(occ_bonus[top_local])
|
202 |
ce=_load_cross_encoder();
|
203 |
if ce is not None:
|
|
|
204 |
docs=CATALOG.loc[cand_idx,"doc"].tolist(); pairs=[(q,d) for d in docs]
|
205 |
k_ce=min(24,len(pairs)); tl=np.argpartition(-emb_n,k_ce-1)[:k_ce]; ce_raw=np.array(ce.predict([pairs[i] for i in tl]),"float32"); ce_n=np.zeros_like(emb_n); ce_n[tl]=_minmax(ce_raw)
|
206 |
else:
|
207 |
ce_n=np.zeros_like(emb_n)
|
|
|
208 |
final=(.56*emb_n+.26*ce_n+.10*int_n+.05*occ_n+.03*price_n).astype("float32")
|
209 |
pick=_mmr_select(cand_idx,final,k=min(3,cand_idx.size))
|
210 |
res=CATALOG.loc[pick].copy(); pos={int(cand_idx[i]):i for i in range(len(cand_idx))}; res["similarity"]=[float(final[pos[int(i)]]) for i in pick]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True)
|
212 |
|
213 |
# ===== DIY (FLAN-only) =====
|
214 |
DIY_MODEL_ID=os.getenv("DIY_MODEL_ID","google/flan-t5-small"); DIY_DEVICE=torch.device("cpu")
|
215 |
MAX_INPUT_TOKENS=int(os.getenv("MAX_INPUT_TOKENS","384")); DIY_MAX_NEW_TOKENS=int(os.getenv("DIY_MAX_NEW_TOKENS","120"))
|
|
|
216 |
INTEREST_ALIASES={"Reading":["book","novel","literary"],"Fashion":["style","chic","silk"],"Home decor":["candle","wall","jar"],"Technology":["tech","gadget","usb"],"Movies":["film","cinema","poster"]}
|
217 |
FALLBACK_NOUNS=["Kit","Set","Bundle","Box","Pack"]
|
218 |
|
219 |
_diy_cache_model={}
|
220 |
def _load_flan(mid:str):
|
|
|
221 |
if mid in _diy_cache_model: return _diy_cache_model[mid]
|
222 |
tok=AutoTokenizer.from_pretrained(mid, use_fast=True, trust_remote_code=True)
|
223 |
mdl=AutoModelForSeq2SeqLM.from_pretrained(mid, trust_remote_code=True, use_safetensors=True).to(DIY_DEVICE).eval()
|
@@ -225,6 +381,7 @@ def _load_flan(mid:str):
|
|
225 |
|
226 |
@torch.inference_mode()
|
227 |
def _gen(tok, mdl, prompt, max_new_tokens=64, do_sample=False, temperature=.9, top_p=.95, seed=None):
|
|
|
228 |
if seed is None: seed=random.randint(1,10_000_000)
|
229 |
random.seed(seed); torch.manual_seed(seed)
|
230 |
enc=tok(prompt, truncation=True, max_length=MAX_INPUT_TOKENS, return_tensors="pt"); enc={k:v.to(DIY_DEVICE) for k,v in enc.items()}
|
@@ -232,11 +389,13 @@ def _gen(tok, mdl, prompt, max_new_tokens=64, do_sample=False, temperature=.9, t
|
|
232 |
return tok.decode(out[0], skip_special_tokens=True).strip()
|
233 |
|
234 |
def _choose_interest_token(interests):
|
|
|
235 |
for it in interests:
|
236 |
if INTEREST_ALIASES.get(it): return random.choice(INTEREST_ALIASES[it])
|
237 |
return (interests[0].split()[0].lower() if interests else "gift")
|
238 |
def _title_case(s): s=re.sub(r'\s+',' ',s).strip(); s=re.sub(r'["โโโโ]+','',s); return " ".join([w.capitalize() for w in s.split()])
|
239 |
def _sanitize_name(name, interests):
|
|
|
240 |
for b in [r"^the name\b",r"\bmember of the family\b",r"^name\b",r"^title\b"]: name=re.sub(b,"",name,flags=re.I).strip()
|
241 |
name=re.sub(r'[:\-โโ]+$',"",name).strip(); alias=_choose_interest_token(interests)
|
242 |
if alias not in name.lower():
|
@@ -247,6 +406,7 @@ def _sanitize_name(name, interests):
|
|
247 |
return name
|
248 |
|
249 |
def _split_list_text(s,seps):
|
|
|
250 |
s=s.strip()
|
251 |
for sep in seps:
|
252 |
if sep in s:
|
@@ -255,6 +415,7 @@ def _split_list_text(s,seps):
|
|
255 |
return [p.strip(" -โข*.,;:") for p in re.split(r"[\n\r;]+", s) if p.strip(" -โข*.,;:")]
|
256 |
|
257 |
def _coerce_materials(items):
|
|
|
258 |
out=[]
|
259 |
for it in items:
|
260 |
it=re.sub(r'\s+',' ',it).strip(" -โข*.,;:");
|
@@ -271,6 +432,7 @@ def _coerce_materials(items):
|
|
271 |
return out[:8]
|
272 |
|
273 |
def _coerce_steps(items):
|
|
|
274 |
out=[]
|
275 |
for it in items:
|
276 |
it=it.strip(" -โข*.,;:");
|
@@ -284,10 +446,12 @@ def _coerce_steps(items):
|
|
284 |
|
285 |
def _only_int(s): m=re.search(r"-?\d+",s); return int(m.group()) if m else None
|
286 |
def _clamp_num(v,lo,hi,default):
|
|
|
287 |
try: x=float(v); return int(min(max(x,lo),hi))
|
288 |
except: return int((lo+hi)/2 if default is None else default)
|
289 |
|
290 |
def diy_generate(profile:Dict)->Tuple[dict,str]:
|
|
|
291 |
tok,mdl=_load_flan(DIY_MODEL_ID)
|
292 |
p={"recipient_name":profile.get("recipient_name","Recipient"),"relationship":profile.get("relationship","Friend"),
|
293 |
"occ_ui":profile.get("occ_ui","Birthday"),"occasion":profile.get("occ_ui","Birthday"),"interests":profile.get("interests",[]),
|
@@ -311,8 +475,33 @@ def diy_generate(profile:Dict)->Tuple[dict,str]:
|
|
311 |
"estimated_cost_usd":_clamp_num(cost,p["budget_min"],p["budget_max"],None),"estimated_time_minutes":_clamp_num(minutes,20,180,60)}
|
312 |
return idea,"ok"
|
313 |
|
314 |
-
|
315 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
MSG_MODEL_ID = "google/flan-t5-small"
|
317 |
MSG_DEVICE = "cpu"
|
318 |
TEMP_RANGE = (0.88, 1.10)
|
@@ -414,6 +603,7 @@ CLOSERS = [
|
|
414 |
]
|
415 |
|
416 |
def _msg_load():
|
|
|
417 |
global _msg_tok, _msg_mdl
|
418 |
if _msg_tok is None or _msg_mdl is None:
|
419 |
_msg_tok = AutoTokenizer.from_pretrained(MSG_MODEL_ID)
|
@@ -422,16 +612,20 @@ def _msg_load():
|
|
422 |
return _msg_tok, _msg_mdl
|
423 |
|
424 |
def _norm(s: str) -> str:
|
|
|
425 |
return re.sub(r"\s+", " ", s or "").strip()
|
426 |
|
427 |
def _sentences_n(s: str) -> int:
|
|
|
428 |
return len([p for p in re.split(r"(?<=[.!?])\s+", s.strip()) if p])
|
429 |
|
430 |
def _contains_any(text: str, terms: List[str]) -> bool:
|
|
|
431 |
t = text.lower()
|
432 |
return any(term for term in terms if term) and any((term or "").lower() in t for term in terms)
|
433 |
|
434 |
def _too_similar(a: str, b: str, n=3, thr=0.85) -> bool:
|
|
|
435 |
def ngrams(txt):
|
436 |
toks = re.findall(r"[a-zA-Z']+", txt.lower())
|
437 |
return set(tuple(toks[i:i+n]) for i in range(max(0, len(toks)-n+1)))
|
@@ -441,9 +635,11 @@ def _too_similar(a: str, b: str, n=3, thr=0.85) -> bool:
|
|
441 |
return j >= thr
|
442 |
|
443 |
def _clean_occasion(occ: str) -> str:
|
|
|
444 |
return (occ or "").replace("โ","'").strip()
|
445 |
|
446 |
def _build_prompt(profile: Dict[str, Any]) -> Tuple[str, Dict[str,str]]:
|
|
|
447 |
name = profile.get("recipient_name", "Friend")
|
448 |
rel = profile.get("relationship", "Friend")
|
449 |
occ = _clean_occasion(profile.get("occ_ui") or profile.get("occasion") or "Birthday")
|
@@ -476,6 +672,7 @@ def _build_prompt(profile: Dict[str, Any]) -> Tuple[str, Dict[str,str]]:
|
|
476 |
|
477 |
@torch.inference_mode()
|
478 |
def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None, previous_message: Optional[str]=None) -> Dict[str, Any]:
|
|
|
479 |
global _last_msg
|
480 |
tok, mdl = _msg_load()
|
481 |
if seed is None:
|
@@ -501,7 +698,7 @@ def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None,
|
|
501 |
)
|
502 |
text = _norm(tok.decode(out_ids[0], skip_special_tokens=True))
|
503 |
|
504 |
-
# ===== Validators (
|
505 |
ok_len = 1 <= _sentences_n(text) <= 3
|
506 |
name_ok = _contains_any(text, [need["name"].lower()])
|
507 |
occ_ok = _contains_any(text, [need["occ"].lower(), need["occ"].split()[0].lower()])
|
@@ -516,6 +713,7 @@ def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None,
|
|
516 |
"seed": seed, "attempt": attempt, "model": MSG_MODEL_ID}}
|
517 |
tried.append({"text": text}); seed += 17
|
518 |
|
|
|
519 |
fallback = tried[-1]["text"] if tried else f"Happy {(_clean_occasion(profile.get('occ_ui') or 'day')).lower()}, {profile.get('recipient_name','Friend')}!"
|
520 |
_last_msg = fallback
|
521 |
return {"message": fallback, "meta": {"failed": True, "model": MSG_MODEL_ID, "tone": profile.get("tone","Heartfelt")}}
|
@@ -524,23 +722,26 @@ def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None,
|
|
524 |
|
525 |
# ===== Rendering & UI =====
|
526 |
def first_sentence(s,max_chars=140):
|
|
|
527 |
s=(s or "").strip();
|
528 |
if not s: return ""
|
529 |
cut=s.split(". ")[0];
|
530 |
return cut if len(cut)<=max_chars else cut[:max_chars-1]+"โฆ"
|
531 |
|
532 |
def render_top3_html(df, age_label):
|
|
|
533 |
if df is None or df.empty: return "<em>No results found within the current filters.</em>"
|
534 |
rows=[]
|
535 |
-
for
|
536 |
name=str(r.get("name","")).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
537 |
desc=str(first_sentence(r.get("short_desc",""))).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
538 |
price=r.get("price_usd"); sim=r.get("similarity"); img=r.get("image_url","") or ""
|
539 |
price_str=f"${price:.0f}" if pd.notna(price) else "N/A"; sim_str=f"{sim:.3f}" if pd.notna(sim) else "โ"
|
540 |
img_html=f'<img src="{img}" alt="" style="width:84px;height:84px;object-fit:cover;border-radius:10px;margin-left:12px;" />' if img else ""
|
|
|
541 |
rows.append(f"""
|
542 |
<div style="display:flex;align-items:flex-start;justify-content:space-between;gap:10px;padding:10px;border:1px solid #eee;border-radius:12px;margin-bottom:8px;background:#fff;">
|
543 |
-
<div style="flex:1;min-width:0;"><div style="font-weight:700;">{name}</div>
|
544 |
<div style="font-size:0.95em;margin-top:4px;">{desc}</div>
|
545 |
<div style="font-size:0.9em;margin-top:6px;opacity:0.8;">Price: <b>{price_str}</b> ยท Age: <code>{age_label}</code> ยท Score: <code>{sim_str}</code></div>
|
546 |
</div>{img_html}
|
@@ -586,17 +787,20 @@ with gr.Blocks(title="๐ GIfty โ Recommender + DIY", css="""
|
|
586 |
tone=gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Funny")
|
587 |
|
588 |
go=gr.Button("Get GIfty!")
|
|
|
589 |
gr.Markdown("### ๐ฏ Recommendations"); out_top3=gr.HTML()
|
590 |
gr.Markdown("### ๐ ๏ธ DIY Gift"); out_diy_md=gr.Markdown()
|
591 |
gr.Markdown("### ๐ Personalized Message"); out_msg=gr.Markdown()
|
592 |
run_token=gr.State(0)
|
593 |
|
594 |
def _on_example_select(evt: gr.SelectData):
|
|
|
595 |
r=int(evt.index[0] if isinstance(evt.index,(list,tuple)) else evt.index); row=EX_DF.iloc[r]; ints=[s.strip() for s in str(row["Interests"]).split("+")]
|
596 |
return (ints,row["Occasion"],int(row["Min $"]),int(row["Max $"]),row["Recipient"],row["Relationship"],row["Age group"],row["Gender"],row["Tone"])
|
597 |
ex_df.select(_on_example_select, outputs=[interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone])
|
598 |
|
599 |
def render_diy_md(j:dict)->str:
|
|
|
600 |
if not j: return "_DIY generation failed._"
|
601 |
steps=j.get('step_by_step_instructions', j.get('steps', []))
|
602 |
parts = [
|
@@ -607,26 +811,105 @@ with gr.Blocks(title="๐ GIfty โ Recommender + DIY", css="""
|
|
607 |
f"**Estimated cost:** ${j.get('estimated_cost_usd','?')} ยท **Time:** {j.get('estimated_time_minutes','?')} min"
|
608 |
]
|
609 |
return "\n".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
|
611 |
def _build_profile(ints, occ, bmin, bmax, name, rel, age_label, gender_val, tone_val):
|
|
|
612 |
try: bmin=float(bmin); bmax=float(bmax)
|
613 |
except: bmin,bmax=5.0,500.0
|
614 |
if bmin>bmax: bmin,bmax=bmax,bmin
|
615 |
return {"recipient_name":name or "Friend","relationship":rel or "Friend","interests":ints or [],"occ_ui":occ or "Birthday","budget_min":bmin,"budget_max":bmax,"age_range":AGE_OPTIONS.get(age_label,"any"),"gender":(gender_val or "any").lower(),"tone":tone_val or "Heartfelt"}
|
616 |
|
617 |
-
def start_run(curr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
|
619 |
def predict_recs_only(rt, *args):
|
620 |
-
p=_build_profile(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
def predict_diy_only(rt, *args):
|
622 |
-
p=_build_profile(*args)
|
623 |
-
|
624 |
-
|
625 |
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
|
631 |
if __name__=="__main__":
|
632 |
demo.launch()
|
|
|
1 |
# app.py โ GIftyPlus (lean)
|
2 |
+
# -----------------------------------------------------------------------------
|
3 |
+
# High-level overview
|
4 |
+
# -----------------------------------------------------------------------------
|
5 |
+
# GIftyPlus is a lightweight gift recommender + DIY generator.
|
6 |
+
# Pipeline:
|
7 |
+
# 1) Load & normalize an Amazon-like product dataset (name/desc/tags/price/img).
|
8 |
+
# 2) Build sentence embeddings for semantic retrieval (cached to .npy).
|
9 |
+
# 3) Rank items with a weighted score (embeddings + optional cross-encoder +
|
10 |
+
# interest/occasion/price bonuses) and diversify with MMR.
|
11 |
+
# 4) Generate a DIY gift idea (FLAN-T5), then embed 10 candidates and append
|
12 |
+
# the best one as a "Generated" #4 result.
|
13 |
+
# 5) Generate a short personalized message (FLAN-T5) with basic validators.
|
14 |
+
# 6) Gradio UI: input form, input summary, top-3 + generated #4, DIY section,
|
15 |
+
# and personalized message section.
|
16 |
+
#
|
17 |
+
# Env vars you can override:
|
18 |
+
# DATASET_ID, DATASET_SPLIT, MAX_ROWS,
|
19 |
+
# EMBED_MODEL_ID, RERANK_MODEL_ID,
|
20 |
+
# DIY_MODEL_ID, MAX_INPUT_TOKENS, DIY_MAX_NEW_TOKENS.
|
21 |
+
# -----------------------------------------------------------------------------
|
22 |
+
|
23 |
import os, re, json, hashlib, pathlib, random
|
24 |
from typing import Dict, List, Tuple, Optional, Any
|
25 |
import numpy as np, pandas as pd, gradio as gr, torch
|
|
|
34 |
EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "sentence-transformers/all-MiniLM-L12-v2")
|
35 |
|
36 |
def resolve_cache_dir():
|
37 |
+
# Choose the first writable cache directory:
|
38 |
+
# 1) EMBED_CACHE_DIR env, 2) project .gifty_cache, 3) /tmp/.gifty_cache
|
39 |
for p in [os.getenv("EMBED_CACHE_DIR"), os.path.join(os.getcwd(), ".gifty_cache"), "/tmp/.gifty_cache"]:
|
40 |
if not p: continue
|
41 |
pathlib.Path(p).mkdir(parents=True, exist_ok=True)
|
|
|
45 |
return os.getcwd()
|
46 |
EMBED_CACHE_DIR = resolve_cache_dir()
|
47 |
|
48 |
+
# UI vocab / options
|
49 |
INTEREST_OPTIONS = ["Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion","Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food","Home decor","Science"]
|
50 |
OCCASION_UI = ["Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming","Retirement","Holidays","Valentineโs Day","Promotion / New job","Get well soon"]
|
51 |
OCCASION_CANON = {"Birthday":"birthday","Wedding / Engagement":"wedding","Anniversary":"anniversary","Graduation":"graduation","New baby":"new_baby","Housewarming":"housewarming","Retirement":"retirement","Holidays":"holidays","Valentineโs Day":"valentines","Promotion / New job":"promotion","Get well soon":"get_well"}
|
|
|
54 |
AGE_OPTIONS = {"any":"any","kid (3โ12)":"kids","teen (13โ17)":"teens","adult (18โ64)":"adult","senior (65+)":"senior"}
|
55 |
GENDER_OPTIONS = ["any","female","male","nonbinary"]
|
56 |
|
57 |
+
# Light synonym expansion for interests; used to enrich queries and "hit" checks
|
58 |
SYNONYMS = {"sports":["fitness","outdoor","training","yoga","run"],"travel":["luggage","passport","map","trip","vacation"],"cooking":["kitchen","cookware","chef","baking"],"technology":["electronics","gadgets","device","smart","computer"],"music":["audio","headphones","earbuds","speaker","vinyl"],"art":["painting","drawing","sketch","canvas"],"reading":["book","novel","literature"],"gardening":["plants","planter","seeds","garden","indoor"],"fashion":["style","accessory","jewelry"],"gaming":["board game","puzzle","video game","controller"],"photography":["camera","lens","tripod","film"],"hiking":["outdoor","camping","backpack","trek"],"movies":["film","cinema","blu-ray","poster"],"crafts":["diy","handmade","kit","knitting"],"pets":["dog","cat","pet"],"wellness":["relaxation","spa","aromatherapy","self-care"],"collecting":["display","collector","limited edition"],"food":["gourmet","snack","treats","chocolate"],"home decor":["home","decor","wall art","candle"],"science":["lab","experiment","STEM","microscope"]}
|
59 |
REL_TO_TOKENS = {"Family - Parent":["parent","family"],"Family - Sibling":["sibling","family"],"Family - Child":["kids","play","family"],"Family - Other relative":["family","relative"],"Friend":["friendly"],"Colleague":["office","work","professional"],"Boss":["executive","professional","premium"],"Romantic partner":["romantic","couple"],"Teacher / Mentor":["teacher","mentor","thank_you"],"Neighbor":["neighbor","housewarming"],"Client / Business partner":["professional","thank_you","premium"]}
|
60 |
|
61 |
+
# --- Price parsing helpers (robust to currency symbols and ranges) ---
|
62 |
_CURRENCY_RE = re.compile(r"[^\d.,\-]+"); _NUM_RE = re.compile(r"(\d+(?:[.,]\d+)?)"); _RANGE_SEP = re.compile(r"\s*(?:-|โ|โ|to)\s*")
|
63 |
def _to_price_usd(x):
|
64 |
if pd.isna(x): return np.nan
|
|
|
68 |
return float(m.group(1)) if m else np.nan
|
69 |
|
70 |
def _first_present(df, cands):
|
71 |
+
# Return the first column name that exists in df out of candidates (case-insensitive)
|
72 |
lower = {c.lower(): c for c in df.columns}
|
73 |
for c in cands:
|
74 |
if c in df.columns: return c
|
|
|
76 |
return None
|
77 |
|
78 |
def _auto_price_col(df):
|
79 |
+
# Heuristics for price column detection when column name is unknown
|
80 |
for c in df.columns:
|
81 |
s = df[c]
|
82 |
if pd.api.types.is_numeric_dtype(s) and not s.dropna().empty and (s.dropna().between(0.5, 10000)).mean() > .6: return c
|
|
|
85 |
return None
|
86 |
|
87 |
def map_amazon_to_schema(raw: pd.DataFrame) -> pd.DataFrame:
|
88 |
+
# Map arbitrary Amazon-like columns into a compact schema suitable for retrieval
|
89 |
name_c=_first_present(raw,["product name","title","name","product_title"]); desc_c=_first_present(raw,["description","product_description","feature","about"])
|
90 |
cat_c=_first_present(raw,["category","categories","main_cat","product_category"]); price_c=_first_present(raw,["selling price","price","current_price","list_price","price_amount","actual_price","price_usd"]) or _auto_price_col(raw)
|
91 |
img_c=_first_present(raw,["image","image_url","imageurl","imUrl","img","img_url"])
|
92 |
df=pd.DataFrame({"name":raw.get(name_c,""),"short_desc":raw.get(desc_c,""),"tags":raw.get(cat_c,""),"price_usd":raw.get(price_c,np.nan),"image_url":raw.get(img_c,"")})
|
93 |
+
# Light normalization / truncation to keep UI compact
|
94 |
df["price_usd"]=df["price_usd"].map(_to_price_usd); df["name"]=df["name"].astype(str).str.strip().str.slice(0,160)
|
95 |
df["short_desc"]=df["short_desc"].astype(str).str.strip().str.slice(0,600); df["tags"]=df["tags"].astype(str).str.replace("|",", ").str.lower()
|
96 |
return df
|
97 |
|
98 |
def extract_top_cat(tags:str)->str:
|
99 |
+
# Extract a "top-level" category token for quick grouping/labeling
|
100 |
s=(tags or "").lower()
|
101 |
for sep in ["|",">"]:
|
102 |
if sep in s: return s.split(sep,1)[0].strip()
|
103 |
return s.strip().split(",")[0] if s else ""
|
104 |
|
105 |
def load_catalog()->pd.DataFrame:
|
106 |
+
# Load dataset โ normalize schema โ filter โ light feature engineering
|
107 |
df=map_amazon_to_schema(load_dataset(DATASET_ID, split=DATASET_SPLIT).to_pandas()).drop_duplicates(subset=["name","short_desc"])
|
108 |
df=df[pd.notna(df["price_usd"])]; df=df[(df["price_usd"]>0)&(df["price_usd"]<=500)].reset_index(drop=True)
|
109 |
if len(df)>MAX_ROWS: df=df.sample(n=MAX_ROWS,random_state=42).reset_index(drop=True)
|
|
|
113 |
return df
|
114 |
CATALOG=load_catalog()
|
115 |
|
116 |
+
# -----------------------------------------------------------------------------
|
117 |
+
# Embedding bank with on-disk caching
|
118 |
+
# -----------------------------------------------------------------------------
|
119 |
class EmbeddingBank:
|
120 |
def __init__(s, docs, model_id, dataset_tag):
|
121 |
s.model_id=model_id; s.dataset_tag=dataset_tag; s.model=SentenceTransformer(model_id); s.embs=s._load_or_build(docs)
|
|
|
130 |
def query_vec(s,text): return s.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0]
|
131 |
EMB=EmbeddingBank(CATALOG["doc"].tolist(), EMBED_MODEL_ID, DATASET_ID)
|
132 |
|
133 |
+
# Token set for light lexical checks (used by interest Hit@k)
|
134 |
_tok_rx = re.compile(r"[a-z0-9][a-z0-9\-']*")
|
135 |
if "tok_set" not in CATALOG.columns:
|
136 |
CATALOG["tok_set"]=(CATALOG["name"].fillna("")+" "+CATALOG["tags"].fillna("")+" "+CATALOG["short_desc"].fillna("")).map(lambda t:set(_tok_rx.findall(str(t).lower())))
|
137 |
|
138 |
+
# Optional cross-encoder for re-ranking (small CPU-friendly model by default)
|
139 |
try:
|
140 |
from sentence_transformers import CrossEncoder
|
141 |
except:
|
|
|
148 |
_CE_MODEL=CrossEncoder(RERANK_MODEL_ID, device="cpu")
|
149 |
return _CE_MODEL
|
150 |
|
151 |
+
# Occasion-specific keyword priors (light bonus shaping)
|
152 |
OCCASION_PRIORS={"valentines":[("jewelry",.12),("chocolate",.10),("candle",.08),("romantic",.08),("couple",.08),("heart",.06)],
|
153 |
"birthday":[("fun",.06),("game",.06),("personalized",.06),("gift set",.05),("surprise",.04)],
|
154 |
"anniversary":[("couple",.10),("jewelry",.10),("photo",.08),("frame",.06),("memory",.06),("candle",.06)],
|
|
|
161 |
"get_well":[("cozy",.10),("tea",.08),("soothing",.06),("care",.06)]}
|
162 |
|
163 |
def expand_with_synonyms(tokens: List[str])->List[str]:
|
164 |
+
# Expand user-provided interests with synonyms to enrich the query
|
165 |
out=[];
|
166 |
for t in tokens:
|
167 |
t=t.strip().lower()
|
|
|
169 |
return out
|
170 |
|
171 |
def profile_to_query(p:Dict)->str:
|
172 |
+
# Construct a dense query string from profile information
|
173 |
inter=[i.lower() for i in p.get("interests",[]) if i]; expanded=expand_with_synonyms(inter)*3
|
174 |
parts=[", ".join(expanded) if expanded else "", ", ".join(REL_TO_TOKENS.get(p.get("relationship","Friend"),[])), OCCASION_CANON.get(p.get("occ_ui","Birthday"),"birthday")]
|
175 |
tail=f"gift ideas for a {p.get('relationship','Friend')} for {parts[-1]}; likes {', '.join(inter) or 'general'}"
|
176 |
return " | ".join([x for x in parts if x])+" | "+tail
|
177 |
|
178 |
def _gender_ok_mask(g:str)->np.ndarray:
|
179 |
+
# Gender-aware filter: exclude items explicitly labeled for the opposite gender unless unisex
|
180 |
g=(g or "any").lower(); bl=CATALOG["blob"]
|
181 |
has_m=bl.str.contains(r"\b(men|man's|mens|male|for men)\b",regex=True,na=False)
|
182 |
has_f=bl.str.contains(r"\b(women|woman's|womens|female|for women|dress)\b",regex=True,na=False)
|
|
|
186 |
return np.ones(len(bl),bool)
|
187 |
|
188 |
def _mask_by_age(age:str, blob:pd.Series)->np.ndarray:
|
189 |
+
# Age-aware filter: crude regex to separate kids/teens/adults
|
190 |
kids=blob.str.contains(r"\b(?:kid|kids|child|children|toddler|baby|boys?|girls?|kid's|children's)\b",regex=True,na=False)
|
191 |
teen=blob.str.contains(r"\b(?:teen|teens|young adult|ya)\b",regex=True,na=False)
|
192 |
if age in ("adult","senior"): return (~kids).to_numpy()
|
|
|
195 |
return np.ones(len(blob),bool)
|
196 |
|
197 |
def _interest_bonus(p:Dict, idx:np.ndarray)->np.ndarray:
|
198 |
+
# Soft bonus if catalog tokens overlap with interest vocabulary (synonyms included)
|
199 |
ints=[i.lower() for i in p.get("interests",[]) if i]; syns=[s for it in ints for s in SYNONYMS.get(it,[])]; vocab=set(ints+syns)
|
200 |
if not vocab or idx.size==0: return np.zeros(len(idx),"float32")
|
201 |
counts=np.array([len(CATALOG["tok_set"].iat[i] & vocab) for i in idx],"float32"); return .10*np.clip(counts,0,6)
|
202 |
|
203 |
def _occasion_bonus(idx:np.ndarray, occ_ui:str)->np.ndarray:
|
204 |
+
# Soft bonus based on occasion priors (keywords found in item blob)
|
205 |
pri=OCCASION_PRIORS.get(OCCASION_CANON.get(occ_ui or "Birthday","birthday"),[])
|
206 |
if not pri or idx.size==0: return np.zeros(len(idx),"float32")
|
207 |
bl=CATALOG["blob"].to_numpy(); out=np.zeros(len(idx),"float32")
|
|
|
210 |
return out
|
211 |
|
212 |
def _minmax(x:np.ndarray)->np.ndarray:
|
213 |
+
# Normalize to [0,1] with safe guard for constant vectors
|
214 |
if x.size==0: return x
|
215 |
lo,hi=float(np.min(x)),float(np.max(x));
|
216 |
return np.zeros_like(x) if hi<=lo+1e-9 else (x-lo)/(hi-lo)
|
217 |
|
218 |
def _mmr_select(cand_idx:np.ndarray, scores:np.ndarray, k:int, lambda_:float=.7)->np.ndarray:
|
219 |
+
# MMR selection to maintain diversity in the final top-k
|
220 |
if cand_idx.size<=k: return cand_idx[np.argsort(-scores)][:k]
|
221 |
picked=[]; rest=list(range(len(cand_idx))); rel=_minmax(scores)
|
222 |
V=np.asarray(EMB.embs,"float32")[cand_idx]; V/=np.linalg.norm(V,axis=1,keepdims=True)+1e-8
|
|
|
226 |
j=int(np.argmax(lambda_*rel[rest]-(1-lambda_)*sim_to_sel)); picked.append(rest.pop(j))
|
227 |
return cand_idx[np.array(picked,int)]
|
228 |
|
229 |
+
def recommend_top3_budget_first(
|
230 |
+
p: Dict,
|
231 |
+
include_synth: bool = True,
|
232 |
+
synth_n: int = 10,
|
233 |
+
widen_budget_frac: float = 0.5
|
234 |
+
) -> pd.DataFrame:
|
235 |
+
"""
|
236 |
+
Retrieve โ score โ diversify. Always returns semantically-ranked results
|
237 |
+
from the catalog (no โcheapest-3โ fallback). If strict filters empty the
|
238 |
+
pool, we progressively relax them but still rank by embeddings + bonuses.
|
239 |
+
Optionally appends a 4th 'Generated' item (DIY) when include_synth=True.
|
240 |
+
"""
|
241 |
+
# ---------- Filters (progressive relaxations) ----------
|
242 |
+
lo, hi = float(p.get("budget_min", 0)), float(p.get("budget_max", 1e9))
|
243 |
+
blob = CATALOG["blob"]
|
244 |
+
price = CATALOG["price_usd"].values
|
245 |
+
age_ok = _mask_by_age(p.get("age_range", "any"), blob)
|
246 |
+
gen_ok = _gender_ok_mask(p.get("gender", "any"))
|
247 |
+
price_ok_strict = (price >= lo) & (price <= hi)
|
248 |
+
price_ok_wide = (price >= max(0, lo * (1 - widen_budget_frac))) & \
|
249 |
+
(price <= (hi * (1 + widen_budget_frac) if hi < 1e8 else hi))
|
250 |
+
|
251 |
+
mask_chain = [
|
252 |
+
price_ok_strict & age_ok & gen_ok, # ืืื ืงืฉืื
|
253 |
+
price_ok_strict & gen_ok, # ืืื ืืื
|
254 |
+
price_ok_wide & gen_ok, # ืืจืืืช ืืืื ืชืงืฆืื
|
255 |
+
age_ok & gen_ok, # ืืื ืชืงืฆืื
|
256 |
+
gen_ok, # ืจืง ืืืืจ
|
257 |
+
np.ones(len(CATALOG), bool), # ืืื
|
258 |
+
]
|
259 |
+
idx = np.array([], dtype=int)
|
260 |
+
for m in mask_chain:
|
261 |
+
cand = np.where(m)[0]
|
262 |
+
if cand.size:
|
263 |
+
idx = cand
|
264 |
+
break
|
265 |
+
|
266 |
+
# ---------- Query & base similarities ----------
|
267 |
+
q = profile_to_query(p)
|
268 |
+
qv = EMB.query_vec(q).astype("float32")
|
269 |
+
embs = np.asarray(EMB.embs, "float32")
|
270 |
+
emb_sims = embs[idx] @ qv
|
271 |
+
|
272 |
+
# ---------- Bonuses (ืขืืืื ืืืืฉืืื ืขื ืืืืขืืืื ืฉื ืืืจื) ----------
|
273 |
+
target = (lo + hi) / 2.0 if hi > lo else hi
|
274 |
+
prices = CATALOG.iloc[idx]["price_usd"].to_numpy()
|
275 |
+
price_bonus = np.clip(.12 - np.abs(prices - target) / max(target, 1.0), 0, .12).astype("float32")
|
276 |
+
int_bonus = _interest_bonus(p, idx)
|
277 |
+
occ_bonus = _occasion_bonus(idx, p.get("occ_ui", "Birthday"))
|
278 |
+
|
279 |
+
# Pre-score ืขื ืืื ืืช ื-NaN/Inf
|
280 |
+
pre = np.nan_to_num(emb_sims + price_bonus + int_bonus + occ_bonus, nan=0.0, posinf=0.0, neginf=0.0)
|
281 |
+
|
282 |
+
# ---------- Local candidate pool ----------
|
283 |
+
K1 = max(1, min(48, idx.size))
|
284 |
+
try:
|
285 |
+
top_local = np.argpartition(-pre, K1 - 1)[:K1]
|
286 |
+
except Exception:
|
287 |
+
top_local = np.argsort(-pre)[:K1]
|
288 |
+
cand_idx = idx[top_local]
|
289 |
+
|
290 |
+
# ---------- Feature normalization ----------
|
291 |
+
emb_n = _minmax(np.nan_to_num(emb_sims[top_local], nan=0.0))
|
292 |
+
price_n = _minmax(np.nan_to_num(price_bonus[top_local],nan=0.0))
|
293 |
+
int_n = _minmax(np.nan_to_num(int_bonus[top_local], nan=0.0))
|
294 |
+
occ_n = _minmax(np.nan_to_num(occ_bonus[top_local], nan=0.0))
|
295 |
+
|
296 |
+
# ---------- Optional cross-encoder ----------
|
297 |
+
ce = _load_cross_encoder()
|
298 |
+
if ce is not None:
|
299 |
+
docs = CATALOG.loc[cand_idx, "doc"].tolist()
|
300 |
+
pairs = [(q, d) for d in docs]
|
301 |
+
k_ce = min(24, len(pairs))
|
302 |
+
tl = np.argpartition(-emb_n, k_ce - 1)[:k_ce]
|
303 |
+
ce_raw = np.array(ce.predict([pairs[i] for i in tl]), "float32")
|
304 |
+
ce_n = np.zeros_like(emb_n)
|
305 |
+
ce_n[tl] = _minmax(ce_raw)
|
306 |
+
else:
|
307 |
+
ce_n = np.zeros_like(emb_n)
|
308 |
+
|
309 |
+
# ---------- Final score ----------
|
310 |
+
final = np.nan_to_num(.56*emb_n + .26*ce_n + .10*int_n + .05*occ_n + .03*price_n, nan=0.0)
|
311 |
+
|
312 |
+
# ---------- Select top-3 with diversity ----------
|
313 |
+
k = int(min(3, cand_idx.size))
|
314 |
+
pick = _mmr_select(cand_idx, final, k=k) if k > 0 else np.array([], dtype=int)
|
315 |
+
if pick.size == 0:
|
316 |
+
pick = cand_idx[np.argsort(-final)[:min(3, cand_idx.size)]]
|
317 |
+
|
318 |
+
# ---------- Build result ----------
|
319 |
+
res = CATALOG.loc[pick].copy()
|
320 |
+
pos = {int(cand_idx[i]): i for i in range(len(cand_idx))}
|
321 |
+
res["similarity"] = [float(final[pos[int(i)]]) if int(i) in pos else np.nan for i in pick]
|
322 |
+
|
323 |
+
# ---------- Optional synthetic #4 ----------
|
324 |
+
if include_synth:
|
325 |
+
try:
|
326 |
+
synth = pick_best_synthetic(p, qv, generate_synthetic_candidates(p, n=int(max(1, synth_n))))
|
327 |
+
if synth is not None:
|
328 |
+
res = pd.concat(
|
329 |
+
[res, pd.DataFrame([synth])[["name","short_desc","price_usd","image_url","similarity"]]],
|
330 |
+
ignore_index=True
|
331 |
+
)
|
332 |
+
except Exception:
|
333 |
+
pass # ืื ืฉืืืจืื ืืช ื-UI ืื ื-DIY ื ืืฉื
|
334 |
+
|
335 |
+
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True)
|
336 |
+
|
337 |
q=profile_to_query(p); qv=EMB.query_vec(q).astype("float32")
|
338 |
emb_sims=np.asarray(EMB.embs,"float32")[idx]@qv
|
339 |
target=(lo+hi)/2.0 if hi>lo else hi; prices=CATALOG.iloc[idx]["price_usd"].to_numpy()
|
340 |
+
# Small bonus for being close to the budget mid-point
|
341 |
price_bonus=np.clip(.12-np.abs(prices-target)/max(target,1.0),0,.12).astype("float32")
|
342 |
int_bonus=_interest_bonus(p,idx); occ_bonus=_occasion_bonus(idx,p.get("occ_ui","Birthday"))
|
343 |
pre=emb_sims+price_bonus+int_bonus+occ_bonus
|
344 |
+
# Keep a local candidate pool for cost/quality tradeoff
|
345 |
K1=min(48,idx.size); top_local=np.argpartition(-pre,K1-1)[:K1]; cand_idx=idx[top_local]
|
346 |
emb_n=_minmax(emb_sims[top_local]); price_n=_minmax(price_bonus[top_local]); int_n=_minmax(int_bonus[top_local]); occ_n=_minmax(occ_bonus[top_local])
|
347 |
ce=_load_cross_encoder();
|
348 |
if ce is not None:
|
349 |
+
# Optional cross-encoder re-ranking on a smaller slice
|
350 |
docs=CATALOG.loc[cand_idx,"doc"].tolist(); pairs=[(q,d) for d in docs]
|
351 |
k_ce=min(24,len(pairs)); tl=np.argpartition(-emb_n,k_ce-1)[:k_ce]; ce_raw=np.array(ce.predict([pairs[i] for i in tl]),"float32"); ce_n=np.zeros_like(emb_n); ce_n[tl]=_minmax(ce_raw)
|
352 |
else:
|
353 |
ce_n=np.zeros_like(emb_n)
|
354 |
+
# Final weighted score (tuned manually)
|
355 |
final=(.56*emb_n+.26*ce_n+.10*int_n+.05*occ_n+.03*price_n).astype("float32")
|
356 |
pick=_mmr_select(cand_idx,final,k=min(3,cand_idx.size))
|
357 |
res=CATALOG.loc[pick].copy(); pos={int(cand_idx[i]):i for i in range(len(cand_idx))}; res["similarity"]=[float(final[pos[int(i)]]) for i in pick]
|
358 |
+
# === NEW: synthetic #4 ===
|
359 |
+
synth = pick_best_synthetic(p, qv, generate_synthetic_candidates(p, n=10))
|
360 |
+
if synth is not None:
|
361 |
+
res = pd.concat(
|
362 |
+
[res, pd.DataFrame([synth])[["name","short_desc","price_usd","image_url","similarity"]]],
|
363 |
+
ignore_index=True
|
364 |
+
)
|
365 |
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True)
|
366 |
|
367 |
# ===== DIY (FLAN-only) =====
|
368 |
DIY_MODEL_ID=os.getenv("DIY_MODEL_ID","google/flan-t5-small"); DIY_DEVICE=torch.device("cpu")
|
369 |
MAX_INPUT_TOKENS=int(os.getenv("MAX_INPUT_TOKENS","384")); DIY_MAX_NEW_TOKENS=int(os.getenv("DIY_MAX_NEW_TOKENS","120"))
|
370 |
+
# Light aliases to seed the DIY gift title with an interest token
|
371 |
INTEREST_ALIASES={"Reading":["book","novel","literary"],"Fashion":["style","chic","silk"],"Home decor":["candle","wall","jar"],"Technology":["tech","gadget","usb"],"Movies":["film","cinema","poster"]}
|
372 |
FALLBACK_NOUNS=["Kit","Set","Bundle","Box","Pack"]
|
373 |
|
374 |
_diy_cache_model={}
|
375 |
def _load_flan(mid:str):
|
376 |
+
# Lazy-load and cache FLAN-T5 on CPU
|
377 |
if mid in _diy_cache_model: return _diy_cache_model[mid]
|
378 |
tok=AutoTokenizer.from_pretrained(mid, use_fast=True, trust_remote_code=True)
|
379 |
mdl=AutoModelForSeq2SeqLM.from_pretrained(mid, trust_remote_code=True, use_safetensors=True).to(DIY_DEVICE).eval()
|
|
|
381 |
|
382 |
@torch.inference_mode()
|
383 |
def _gen(tok, mdl, prompt, max_new_tokens=64, do_sample=False, temperature=.9, top_p=.95, seed=None):
|
384 |
+
# Small wrapper for deterministic/non-deterministic generation
|
385 |
if seed is None: seed=random.randint(1,10_000_000)
|
386 |
random.seed(seed); torch.manual_seed(seed)
|
387 |
enc=tok(prompt, truncation=True, max_length=MAX_INPUT_TOKENS, return_tensors="pt"); enc={k:v.to(DIY_DEVICE) for k,v in enc.items()}
|
|
|
389 |
return tok.decode(out[0], skip_special_tokens=True).strip()
|
390 |
|
391 |
def _choose_interest_token(interests):
|
392 |
+
# Pick a representative token to inject into the DIY name
|
393 |
for it in interests:
|
394 |
if INTEREST_ALIASES.get(it): return random.choice(INTEREST_ALIASES[it])
|
395 |
return (interests[0].split()[0].lower() if interests else "gift")
|
396 |
def _title_case(s): s=re.sub(r'\s+',' ',s).strip(); s=re.sub(r'["โโโโ]+','',s); return " ".join([w.capitalize() for w in s.split()])
|
397 |
def _sanitize_name(name, interests):
|
398 |
+
# Clean LLM-proposed name and enforce a short, interest-infused title
|
399 |
for b in [r"^the name\b",r"\bmember of the family\b",r"^name\b",r"^title\b"]: name=re.sub(b,"",name,flags=re.I).strip()
|
400 |
name=re.sub(r'[:\-โโ]+$',"",name).strip(); alias=_choose_interest_token(interests)
|
401 |
if alias not in name.lower():
|
|
|
406 |
return name
|
407 |
|
408 |
def _split_list_text(s,seps):
|
409 |
+
# Parse list-like text returned by LLM into clean items (fallback across separators)
|
410 |
s=s.strip()
|
411 |
for sep in seps:
|
412 |
if sep in s:
|
|
|
415 |
return [p.strip(" -โข*.,;:") for p in re.split(r"[\n\r;]+", s) if p.strip(" -โข*.,;:")]
|
416 |
|
417 |
def _coerce_materials(items):
|
418 |
+
# Normalize materials list: dedupe, keep short, ensure quantities, pad with basics
|
419 |
out=[]
|
420 |
for it in items:
|
421 |
it=re.sub(r'\s+',' ',it).strip(" -โข*.,;:");
|
|
|
432 |
return out[:8]
|
433 |
|
434 |
def _coerce_steps(items):
|
435 |
+
# Normalize step list: trim, remove numbering, enforce sentence case, pad to 6+
|
436 |
out=[]
|
437 |
for it in items:
|
438 |
it=it.strip(" -โข*.,;:");
|
|
|
446 |
|
447 |
def _only_int(s): m=re.search(r"-?\d+",s); return int(m.group()) if m else None
|
448 |
def _clamp_num(v,lo,hi,default):
|
449 |
+
# Clamp numeric values into a valid range; fallback to default or midpoint
|
450 |
try: x=float(v); return int(min(max(x,lo),hi))
|
451 |
except: return int((lo+hi)/2 if default is None else default)
|
452 |
|
453 |
def diy_generate(profile:Dict)->Tuple[dict,str]:
|
454 |
+
# Generate a DIY gift object (name, overview, materials, steps, cost, time)
|
455 |
tok,mdl=_load_flan(DIY_MODEL_ID)
|
456 |
p={"recipient_name":profile.get("recipient_name","Recipient"),"relationship":profile.get("relationship","Friend"),
|
457 |
"occ_ui":profile.get("occ_ui","Birthday"),"occasion":profile.get("occ_ui","Birthday"),"interests":profile.get("interests",[]),
|
|
|
475 |
"estimated_cost_usd":_clamp_num(cost,p["budget_min"],p["budget_max"],None),"estimated_time_minutes":_clamp_num(minutes,20,180,60)}
|
476 |
return idea,"ok"
|
477 |
|
478 |
+
def generate_synthetic_candidates(profile, n=10):
|
479 |
+
# Use FLAN-based DIY generator to create N lightweight candidates (name/overview/price)
|
480 |
+
cands = []
|
481 |
+
lo, hi = int(float(profile.get("budget_min", 10))), int(float(profile.get("budget_max", 100)))
|
482 |
+
for _ in range(n):
|
483 |
+
idea, _ = diy_generate(profile) # Already returns name/overview/estimated_cost
|
484 |
+
price = int(idea.get("estimated_cost_usd") or random.randint(lo, hi))
|
485 |
+
name = idea.get("gift_name", "Custom DIY Gift")[:160]
|
486 |
+
desc = (idea.get("overview", "") or "").strip()[:300]
|
487 |
+
doc = f"{name} | custom | {desc}".lower()
|
488 |
+
cands.append({"name": name, "short_desc": desc, "price_usd": price, "image_url": "", "doc": doc})
|
489 |
+
return cands
|
490 |
+
|
491 |
+
def pick_best_synthetic(profile, qv, candidates):
|
492 |
+
# Embed synthetic candidates and pick the one most similar to the query vector
|
493 |
+
if not candidates: return None
|
494 |
+
docs = [c["doc"] for c in candidates]
|
495 |
+
vecs = EMB.model.encode(docs, convert_to_numpy=True, normalize_embeddings=True)
|
496 |
+
sims = vecs @ qv
|
497 |
+
j = int(np.argmax(sims))
|
498 |
+
best = candidates[j].copy()
|
499 |
+
best["similarity"] = float(sims[j])
|
500 |
+
return best
|
501 |
+
|
502 |
+
|
503 |
+
# --------------------- Personalized Message (FLAN + validation) ---------------------
|
504 |
+
# Implementation ported from the Colab; tone-specific constraints + simple checks.
|
505 |
MSG_MODEL_ID = "google/flan-t5-small"
|
506 |
MSG_DEVICE = "cpu"
|
507 |
TEMP_RANGE = (0.88, 1.10)
|
|
|
603 |
]
|
604 |
|
605 |
def _msg_load():
|
606 |
+
# Lazy-load FLAN for message generation (CPU)
|
607 |
global _msg_tok, _msg_mdl
|
608 |
if _msg_tok is None or _msg_mdl is None:
|
609 |
_msg_tok = AutoTokenizer.from_pretrained(MSG_MODEL_ID)
|
|
|
612 |
return _msg_tok, _msg_mdl
|
613 |
|
614 |
def _norm(s: str) -> str:
|
615 |
+
# Collapse whitespace for more reliable validators
|
616 |
return re.sub(r"\s+", " ", s or "").strip()
|
617 |
|
618 |
def _sentences_n(s: str) -> int:
|
619 |
+
# Count sentences via punctuation boundaries
|
620 |
return len([p for p in re.split(r"(?<=[.!?])\s+", s.strip()) if p])
|
621 |
|
622 |
def _contains_any(text: str, terms: List[str]) -> bool:
|
623 |
+
# Case-insensitive containment check for any of the given terms
|
624 |
t = text.lower()
|
625 |
return any(term for term in terms if term) and any((term or "").lower() in t for term in terms)
|
626 |
|
627 |
def _too_similar(a: str, b: str, n=3, thr=0.85) -> bool:
|
628 |
+
# Approximate de-duplication via n-gram Jaccard similarity
|
629 |
def ngrams(txt):
|
630 |
toks = re.findall(r"[a-zA-Z']+", txt.lower())
|
631 |
return set(tuple(toks[i:i+n]) for i in range(max(0, len(toks)-n+1)))
|
|
|
635 |
return j >= thr
|
636 |
|
637 |
def _clean_occasion(occ: str) -> str:
|
638 |
+
# Normalize typographic apostrophes to ASCII and trim
|
639 |
return (occ or "").replace("โ","'").strip()
|
640 |
|
641 |
def _build_prompt(profile: Dict[str, Any]) -> Tuple[str, Dict[str,str]]:
|
642 |
+
# Compose a guided prompt (tone + micro-rules) for the message LLM
|
643 |
name = profile.get("recipient_name", "Friend")
|
644 |
rel = profile.get("relationship", "Friend")
|
645 |
occ = _clean_occasion(profile.get("occ_ui") or profile.get("occasion") or "Birthday")
|
|
|
672 |
|
673 |
@torch.inference_mode()
|
674 |
def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None, previous_message: Optional[str]=None) -> Dict[str, Any]:
|
675 |
+
# Sample multiple generations with slight sampling variance, validate, and return best
|
676 |
global _last_msg
|
677 |
tok, mdl = _msg_load()
|
678 |
if seed is None:
|
|
|
698 |
)
|
699 |
text = _norm(tok.decode(out_ids[0], skip_special_tokens=True))
|
700 |
|
701 |
+
# ===== Validators (mirrors the Colab logic) =====
|
702 |
ok_len = 1 <= _sentences_n(text) <= 3
|
703 |
name_ok = _contains_any(text, [need["name"].lower()])
|
704 |
occ_ok = _contains_any(text, [need["occ"].lower(), need["occ"].split()[0].lower()])
|
|
|
713 |
"seed": seed, "attempt": attempt, "model": MSG_MODEL_ID}}
|
714 |
tried.append({"text": text}); seed += 17
|
715 |
|
716 |
+
# Fallback if all attempts failed validation
|
717 |
fallback = tried[-1]["text"] if tried else f"Happy {(_clean_occasion(profile.get('occ_ui') or 'day')).lower()}, {profile.get('recipient_name','Friend')}!"
|
718 |
_last_msg = fallback
|
719 |
return {"message": fallback, "meta": {"failed": True, "model": MSG_MODEL_ID, "tone": profile.get("tone","Heartfelt")}}
|
|
|
722 |
|
723 |
# ===== Rendering & UI =====
|
724 |
def first_sentence(s,max_chars=140):
|
725 |
+
# Extract the first sentence or truncate; keeps the HTML cards compact
|
726 |
s=(s or "").strip();
|
727 |
if not s: return ""
|
728 |
cut=s.split(". ")[0];
|
729 |
return cut if len(cut)<=max_chars else cut[:max_chars-1]+"โฆ"
|
730 |
|
731 |
def render_top3_html(df, age_label):
|
732 |
+
# Render the 3 catalog picks plus the optional 4th "Generated" item
|
733 |
if df is None or df.empty: return "<em>No results found within the current filters.</em>"
|
734 |
rows=[]
|
735 |
+
for i, r in df.iterrows():
|
736 |
name=str(r.get("name","")).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
737 |
desc=str(first_sentence(r.get("short_desc",""))).replace("|","\\|").replace("*","\\*").replace("_","\\_")
|
738 |
price=r.get("price_usd"); sim=r.get("similarity"); img=r.get("image_url","") or ""
|
739 |
price_str=f"${price:.0f}" if pd.notna(price) else "N/A"; sim_str=f"{sim:.3f}" if pd.notna(sim) else "โ"
|
740 |
img_html=f'<img src="{img}" alt="" style="width:84px;height:84px;object-fit:cover;border-radius:10px;margin-left:12px;" />' if img else ""
|
741 |
+
tag = "Generated" if i==3 else f"#{i+1}"
|
742 |
rows.append(f"""
|
743 |
<div style="display:flex;align-items:flex-start;justify-content:space-between;gap:10px;padding:10px;border:1px solid #eee;border-radius:12px;margin-bottom:8px;background:#fff;">
|
744 |
+
<div style="flex:1;min-width:0;"><div style="font-weight:700;">{name} <span style="font-size:.8em;opacity:.7;">({tag})</span></div>
|
745 |
<div style="font-size:0.95em;margin-top:4px;">{desc}</div>
|
746 |
<div style="font-size:0.9em;margin-top:6px;opacity:0.8;">Price: <b>{price_str}</b> ยท Age: <code>{age_label}</code> ยท Score: <code>{sim_str}</code></div>
|
747 |
</div>{img_html}
|
|
|
787 |
tone=gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Funny")
|
788 |
|
789 |
go=gr.Button("Get GIfty!")
|
790 |
+
gr.Markdown("### ๐ Input summary"); out_summary = gr.HTML(visible=False)
|
791 |
gr.Markdown("### ๐ฏ Recommendations"); out_top3=gr.HTML()
|
792 |
gr.Markdown("### ๐ ๏ธ DIY Gift"); out_diy_md=gr.Markdown()
|
793 |
gr.Markdown("### ๐ Personalized Message"); out_msg=gr.Markdown()
|
794 |
run_token=gr.State(0)
|
795 |
|
796 |
def _on_example_select(evt: gr.SelectData):
|
797 |
+
# Clicking a row fills the input widgets with that example
|
798 |
r=int(evt.index[0] if isinstance(evt.index,(list,tuple)) else evt.index); row=EX_DF.iloc[r]; ints=[s.strip() for s in str(row["Interests"]).split("+")]
|
799 |
return (ints,row["Occasion"],int(row["Min $"]),int(row["Max $"]),row["Recipient"],row["Relationship"],row["Age group"],row["Gender"],row["Tone"])
|
800 |
ex_df.select(_on_example_select, outputs=[interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone])
|
801 |
|
802 |
def render_diy_md(j:dict)->str:
|
803 |
+
# Nicely format the DIY object as markdown
|
804 |
if not j: return "_DIY generation failed._"
|
805 |
steps=j.get('step_by_step_instructions', j.get('steps', []))
|
806 |
parts = [
|
|
|
811 |
f"**Estimated cost:** ${j.get('estimated_cost_usd','?')} ยท **Time:** {j.get('estimated_time_minutes','?')} min"
|
812 |
]
|
813 |
return "\n".join(parts)
|
814 |
+
def input_summary_html(p, age_label):
|
815 |
+
# Render a compact summary of the current input above the results
|
816 |
+
ints = ", ".join(p.get("interests", [])) or "โ"
|
817 |
+
budget = f"${int(float(p.get('budget_min',0)))}โ${int(float(p.get('budget_max',0)))}"
|
818 |
+
name = p.get("recipient_name","Friend"); rel = p.get("relationship","Friend")
|
819 |
+
occ = p.get("occ_ui", "Birthday"); gender = (p.get("gender","any") or "any").capitalize()
|
820 |
+
return f"""
|
821 |
+
<div style="padding:10px 12px;border:1px solid #e2e8f0;border-radius:12px;background:#f8fafc;margin-bottom:8px;">
|
822 |
+
<div style="display:flex;flex-wrap:wrap;gap:10px;align-items:center;">
|
823 |
+
<div><b>Recipient:</b> {name} ({rel})</div>
|
824 |
+
<div><b>Occasion:</b> {occ}</div>
|
825 |
+
<div><b>Age:</b> {age_label}</div>
|
826 |
+
<div><b>Gender:</b> {gender}</div>
|
827 |
+
<div><b>Budget:</b> {budget}</div>
|
828 |
+
<div style="flex-basis:100%;height:0;"></div>
|
829 |
+
<div><b>Interests:</b> {ints}</div>
|
830 |
+
</div>
|
831 |
+
</div>
|
832 |
+
"""
|
833 |
|
834 |
def _build_profile(ints, occ, bmin, bmax, name, rel, age_label, gender_val, tone_val):
|
835 |
+
# Convert UI widget values into an internal profile dict
|
836 |
try: bmin=float(bmin); bmax=float(bmax)
|
837 |
except: bmin,bmax=5.0,500.0
|
838 |
if bmin>bmax: bmin,bmax=bmax,bmin
|
839 |
return {"recipient_name":name or "Friend","relationship":rel or "Friend","interests":ints or [],"occ_ui":occ or "Birthday","budget_min":bmin,"budget_max":bmax,"age_range":AGE_OPTIONS.get(age_label,"any"),"gender":(gender_val or "any").lower(),"tone":tone_val or "Heartfelt"}
|
840 |
|
841 |
+
def start_run(curr):
|
842 |
+
# Simple monotonic counter to tie together chained events
|
843 |
+
return int(curr or 0) + 1
|
844 |
+
|
845 |
+
def predict_summary_only(rt, *args):
|
846 |
+
# args mapping:
|
847 |
+
# 0: interests, 1: occasion, 2: budget_min, 3: budget_max,
|
848 |
+
# 4: recipient_name, 5: relationship, 6: age_label, 7: gender, 8: tone
|
849 |
+
p = _build_profile(*args)
|
850 |
+
return gr.update(value=input_summary_html(p, args[6]), visible=True), rt
|
851 |
|
852 |
def predict_recs_only(rt, *args):
|
853 |
+
p = _build_profile(*args)
|
854 |
+
top3 = recommend_top3_budget_first(p, include_synth=False) # ืืืืจ
|
855 |
+
return gr.update(value=render_top3_html(top3, args[6]), visible=True), rt
|
856 |
+
|
857 |
+
def predict_recs_with_synth(rt, *args):
|
858 |
+
p = _build_profile(*args)
|
859 |
+
synth_n = int(os.getenv("SYNTH_N", "2"))
|
860 |
+
df = recommend_top3_budget_first(p, include_synth=True, synth_n=synth_n)
|
861 |
+
return gr.update(value=render_top3_html(df, args[6]), visible=True), rt
|
862 |
+
|
863 |
def predict_diy_only(rt, *args):
|
864 |
+
p = _build_profile(*args)
|
865 |
+
diy_json, _ = diy_generate(p)
|
866 |
+
return gr.update(value=render_diy_md(diy_json), visible=True), rt
|
867 |
|
868 |
+
def predict_msg_only(rt, *args):
|
869 |
+
p = _build_profile(*args)
|
870 |
+
msg_obj = generate_personal_message(p)
|
871 |
+
return gr.update(value=msg_obj["message"], visible=True), rt
|
872 |
+
|
873 |
+
ev_start = go.click(start_run, inputs=[run_token], outputs=[run_token], queue=True)
|
874 |
+
|
875 |
+
# 1) ืกืืืื ืงืื (ืืืืื)
|
876 |
+
ev_start.then(
|
877 |
+
predict_summary_only,
|
878 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
879 |
+
outputs=[out_summary, run_token],
|
880 |
+
queue=True,
|
881 |
+
)
|
882 |
+
|
883 |
+
# 2) ืืืืฆืืช ืืืืจืืช (Top-3 ืืื ืกืื ืชืื)
|
884 |
+
recs_fast = ev_start.then(
|
885 |
+
predict_recs_only,
|
886 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
887 |
+
outputs=[out_top3, run_token],
|
888 |
+
queue=True,
|
889 |
+
)
|
890 |
+
|
891 |
+
# 3) ืืืฉืื ืกืื ืชืื ืืฉืื ืืืฉื โ ืืจืขื ื ืืช ืืืชื out_top3 ืืฉืืืื
|
892 |
+
recs_fast.then(
|
893 |
+
predict_recs_with_synth,
|
894 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
895 |
+
outputs=[out_top3, run_token],
|
896 |
+
queue=True,
|
897 |
+
)
|
898 |
+
|
899 |
+
# 4) DIY ืึพMessage ืืืืืื ืืจืืฅ ืืืงืืื ืึพ(3)
|
900 |
+
ev_start.then(
|
901 |
+
predict_diy_only,
|
902 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
903 |
+
outputs=[out_diy_md, run_token],
|
904 |
+
queue=True,
|
905 |
+
)
|
906 |
+
ev_start.then(
|
907 |
+
predict_msg_only,
|
908 |
+
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone],
|
909 |
+
outputs=[out_msg, run_token],
|
910 |
+
queue=True,
|
911 |
+
)
|
912 |
+
|
913 |
|
914 |
if __name__=="__main__":
|
915 |
demo.launch()
|