LinDee commited on
Commit
bfa5376
·
verified ·
1 Parent(s): d54e819

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -14
app.py CHANGED
@@ -3,34 +3,68 @@ import pickle
3
  import pandas as pd
4
  import numpy as np
5
  from sklearn.metrics.pairwise import cosine_similarity
 
6
 
7
- # Load model and data with error handling
8
  try:
 
9
  with open("recommender_model.pkl", "rb") as f:
10
  model = pickle.load(f)
11
- posts_df = pd.read_csv("posts_cleaned.csv")
 
 
12
  post_texts = posts_df["post_text"].astype(str).tolist()
13
- post_embeddings = np.load("post_embeddings.npy") # Precomputed
14
  except Exception as e:
15
  raise gr.Error(f"Error loading files: {str(e)}")
16
 
17
- # Predict function
 
 
 
 
 
 
 
 
 
 
 
18
  def recommend_from_input(user_text):
19
  if not user_text.strip():
20
- return "Please enter valid text."
 
 
 
 
 
21
  user_vec = model.encode([user_text])
22
- sims = cosine_similarity(user_vec, post_embeddings)[0]
23
- top_idxs = sims.argsort()[-5:][::-1] # Top 5 posts
 
 
 
 
24
  return posts_df.iloc[top_idxs]["post_text"].tolist()
25
 
26
- # Gradio UI
27
  interface = gr.Interface(
28
  fn=recommend_from_input,
29
- inputs=gr.Textbox(label="Describe your interests"),
30
- outputs=gr.Dataframe(headers=["Recommended Posts"]),
31
- title="AI Content Recommender",
32
- description="Enter a topic or post to get recommendations",
33
- examples=[["Web3 security"], ["Machine learning trends"]]
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
- interface.launch()
 
 
3
  import pandas as pd
4
  import numpy as np
5
  from sklearn.metrics.pairwise import cosine_similarity
6
+ from sentence_transformers import SentenceTransformer # Explicit import
7
 
8
+ # Load model and data
9
  try:
10
+ # Load your trained SentenceTransformer model
11
  with open("recommender_model.pkl", "rb") as f:
12
  model = pickle.load(f)
13
+
14
+ # Load posts dataset
15
+ posts_df = pd.read_csv("posts_cleaned.csv")
16
  post_texts = posts_df["post_text"].astype(str).tolist()
17
+
18
  except Exception as e:
19
  raise gr.Error(f"Error loading files: {str(e)}")
20
 
21
+ # Cache embeddings in memory after first computation
22
+ post_embeddings = None
23
+
24
+ def get_embeddings():
25
+ global post_embeddings
26
+ if post_embeddings is None:
27
+ print("Computing embeddings for all posts...")
28
+ post_embeddings = model.encode(post_texts, convert_to_tensor=False)
29
+ print("Embeddings computed!")
30
+ return post_embeddings
31
+
32
+ # Prediction function
33
  def recommend_from_input(user_text):
34
  if not user_text.strip():
35
+ return []
36
+
37
+ # Get embeddings (computes only once)
38
+ embeddings = get_embeddings()
39
+
40
+ # Encode user input
41
  user_vec = model.encode([user_text])
42
+
43
+ # Calculate similarities
44
+ sims = cosine_similarity(user_vec, embeddings)[0]
45
+ top_idxs = sims.argsort()[-5:][::-1] # Top 5 most similar posts
46
+
47
+ # Return as list of strings
48
  return posts_df.iloc[top_idxs]["post_text"].tolist()
49
 
50
+ # Gradio Interface
51
  interface = gr.Interface(
52
  fn=recommend_from_input,
53
+ inputs=gr.Textbox(label="What are you interested in?", placeholder="e.g. Web3 security tips"),
54
+ outputs=gr.Dataframe(
55
+ headers=["Recommended Posts"],
56
+ datatype=["str"],
57
+ col_count=(1, "fixed")
58
+ ),
59
+ title="🔍 AI Content Recommender",
60
+ description="Enter a topic or interest to get personalized post recommendations",
61
+ examples=[
62
+ ["Blockchain scalability solutions"],
63
+ ["Latest breakthroughs in AI"],
64
+ ["How to write smart contracts"]
65
+ ],
66
+ allow_flagging="never"
67
  )
68
 
69
+ # Launch with queue for stability
70
+ interface.launch(share=False)