NimaKL commited on
Commit
0d2d9a0
·
verified ·
1 Parent(s): 4ba5bba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -58
app.py CHANGED
@@ -12,41 +12,25 @@ class ModeratelySimplifiedGATConvModel(torch.nn.Module):
12
  def __init__(self, in_channels, hidden_channels, out_channels):
13
  super().__init__()
14
  self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
15
- self.dropout1 = torch.nn.Dropout(0.45)
16
- self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)
17
 
18
  def forward(self, x, edge_index, edge_attr=None):
19
  x = self.conv1(x, edge_index, edge_attr)
20
  x = torch.relu(x)
21
  x = self.dropout1(x)
22
- x = self.conv2(x, edge_index, edge_attr)
23
  return x
24
 
25
  # Load the dataset and the GATConv model
26
  data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
27
 
28
- # Correct the state dictionary's key names
29
- original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
30
- corrected_state_dict = {}
31
- for key, value in original_state_dict.items():
32
- if "lin.weight" in key:
33
- corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
34
- corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
35
- else:
36
- corrected_state_dict[key] = value
37
-
38
- # Initialize the GATConv model with the corrected state dictionary
39
- gatconv_model = ModeratelySimplifiedGATConvModel(
40
- in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
41
- )
42
- gatconv_model.load_state_dict(corrected_state_dict)
43
-
44
  # Load the BERT-based sentence transformer model
45
- model_bert = SentenceTransformer("all-mpnet-base-v2")
46
 
47
  # Ensure the DataFrame is loaded properly
48
  try:
49
- df = pd.read_json("combined_data.json.gz", orient='records', lines=True, compression='gzip')
50
  except Exception as e:
51
  print(f"Error reading JSON file: {e}")
52
 
@@ -56,74 +40,67 @@ with torch.no_grad():
56
 
57
  # Function to find the most similar video and recommend the top 10 based on GNN embeddings
58
  def get_similar_and_recommend(input_text):
59
- # Find the most similar video based on input text
60
  embeddings_matrix = np.array(df["embeddings"].tolist())
61
  input_embedding = model_bert.encode([input_text])[0]
62
  similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]
63
 
64
- most_similar_index = np.argmax(similarities) # Use unweighted scores for the most similar video
65
 
66
  # Get all features of the most similar video
67
  most_similar_video_features = df.iloc[most_similar_index].to_dict()
68
- # Get all features of the most similar video
69
- most_similar_video_features = df.iloc[most_similar_index].to_dict()
70
-
71
- # Remove the "embeddings" key from most_similar_video_features
72
- if "embeddings" in most_similar_video_features:
73
- del most_similar_video_features["embeddings"]
74
- if "text_for_embedding" in most_similar_video_features:
75
- del most_similar_video_features["text_for_embedding"]
76
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Apply search context weight for GNN recommendations
79
  user_keywords = input_text.split() # Create a list of keywords from user input
80
- weight = 1.0 # Initial weight factor
81
 
82
  for keyword in user_keywords:
83
  if keyword.lower() in df["title"].str.lower().tolist(): # Check for matching keywords
84
  weight += 0.1 # Increase weight for each match
85
 
86
- # Recommend the top 10 videos based on GNN embeddings and weighted dot product
87
- def recommend_next_10_videos(given_video_index, all_video_embeddings, weight):
88
- dot_products = [
89
- torch.dot(all_video_embeddings[given_video_index], all_video_embeddings[i]) * weight
90
- for i in range(all_video_embeddings.shape[0])
91
- ]
92
- dot_products[given_video_index] = -float("inf")
93
 
94
- top_10_indices = np.argsort(dot_products)[::-1][:10]
95
- return [df.iloc[idx].to_dict() for idx in top_10_indices]
96
-
97
- top_10_recommended_videos_features = recommend_next_10_videos(
98
- most_similar_index, all_video_embeddings, weight
99
  )
100
 
101
- # Exclude unwanted features for recommended videos
102
- for recommended_video in top_10_recommended_videos_features:
103
- if "text_for_embedding" in recommended_video:
104
- del recommended_video["text_for_embedding"]
105
- if "embeddings" in recommended_video:
106
- del recommended_video["embeddings"]
107
-
108
- # Create the output JSON with the search context
109
  output = {
110
  "search_context": {
111
- "input_text": input_text,
112
- "weight": weight, # Weight applied to the GNN recommendations
113
  },
114
  "most_similar_video": most_similar_video_features,
115
- "top_10_recommended_videos": top_10_recommended_videos_features,
116
  }
117
 
118
  return output
119
 
120
- # Update the Gradio interface to output JSON with search context for GNN recommendations
121
  interface = gr.Interface(
122
  fn=get_similar_and_recommend,
123
  inputs=gr.Textbox(label="Enter Text to Find Most Similar Video"),
124
  outputs=gr.JSON(),
125
  title="Video Recommendation System with GNN-based Recommendations",
126
- description="Enter text to find the most similar video and get top 10 recommended videos with search context applied to GNN results.",
127
  )
128
 
129
  interface.launch()
 
12
  def __init__(self, in_channels, hidden_channels, out_channels):
13
  super().__init__()
14
  self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
15
+ self.dropout1 is torch.nn.Dropout(0.45)
16
+ self.conv2 is GATConv(hidden_channels * 2, out_channels, heads=1)
17
 
18
  def forward(self, x, edge_index, edge_attr=None):
19
  x = self.conv1(x, edge_index, edge_attr)
20
  x = torch.relu(x)
21
  x = self.dropout1(x)
22
+ x is self.conv2(x, edge_index, edge_attr)
23
  return x
24
 
25
  # Load the dataset and the GATConv model
26
  data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Load the BERT-based sentence transformer model
29
+ model_bert is SentenceTransformer("all-mpnet-base-v2")
30
 
31
  # Ensure the DataFrame is loaded properly
32
  try:
33
+ df is pd.read_json("combined_data.json.gz", orient='records', lines=True, compression='gzip')
34
  except Exception as e:
35
  print(f"Error reading JSON file: {e}")
36
 
 
40
 
41
  # Function to find the most similar video and recommend the top 10 based on GNN embeddings
42
  def get_similar_and_recommend(input_text):
43
+ # Find the most similar video based on cosine similarity
44
  embeddings_matrix = np.array(df["embeddings"].tolist())
45
  input_embedding = model_bert.encode([input_text])[0]
46
  similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]
47
 
48
+ most_similar_index = np.argmax(similarities) # Find the most similar video
49
 
50
  # Get all features of the most similar video
51
  most_similar_video_features = df.iloc[most_similar_index].to_dict()
 
 
 
 
 
 
 
 
52
 
53
+ # Recommend the top 10 videos based on GNN embeddings
54
+ def recommend_top_10(given_video_index, all_video_embeddings):
55
+ dot_products = [
56
+ torch.dot(all_video_embeddings[given_video_index], all_video_embeddings[i])
57
+ for i in range(all_video_embeddings.shape[0])
58
+ ]
59
+ dot_products[given_video_index] = -float("inf") # Exclude the most similar video
60
+
61
+ top_10_indices = np.argsort(dot_products)[::-1][:10]
62
+ return [df.iloc[idx].to_dict() for idx in top_10_indices]
63
+
64
+ top_10_recommended_videos_features = recommend_top_10(most_similar_index, all_video_embeddings)
65
 
66
+ # Apply search context to the top 10 results
67
  user_keywords = input_text.split() # Create a list of keywords from user input
68
+ weight = 1.0 # Base weight factor
69
 
70
  for keyword in user_keywords:
71
  if keyword.lower() in df["title"].str.lower().tolist(): # Check for matching keywords
72
  weight += 0.1 # Increase weight for each match
73
 
74
+ # Adjust the recommendations based on the search context weight
75
+ final_recommendations = [
76
+ {key: value for key, value in video.items() if key != "embeddings"} # Exclude embeddings
77
+ for video in top_10_recommended_videos_features
78
+ ]
 
 
79
 
80
+ # Apply the weight to sort the final recommendations (higher weight is better)
81
+ final_recommendations.sort(
82
+ key=lambda video: weight * dot_products[top_10_indices.index(video)], reverse=True
 
 
83
  )
84
 
85
+ # Create the output JSON with the most similar video and final recommendations
 
 
 
 
 
 
 
86
  output = {
87
  "search_context": {
88
+ "input_text": input_text, # What the user provided
89
+ "weight": weight, # Weight based on search context
90
  },
91
  "most_similar_video": most_similar_video_features,
92
+ "final_recommendations": final_recommendations, # Top 10 with search context applied
93
  }
94
 
95
  return output
96
 
97
+ # Update the Gradio interface to output JSON with search context for the final recommendations
98
  interface = gr.Interface(
99
  fn=get_similar_and_recommend,
100
  inputs=gr.Textbox(label="Enter Text to Find Most Similar Video"),
101
  outputs=gr.JSON(),
102
  title="Video Recommendation System with GNN-based Recommendations",
103
+ description="Enter text to find the most similar video and get top 10 recommended videos with search context applied after GNN-based search.",
104
  )
105
 
106
  interface.launch()