File size: 5,379 Bytes
c0032bb
58bda3d
 
 
 
 
c0032bb
58bda3d
20ae2d2
c0032bb
58bda3d
 
 
 
 
 
 
 
 
f156242
222fb9e
f156242
58bda3d
 
c0032bb
f156242
58bda3d
222fb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0032bb
f156242
58bda3d
c0032bb
3baa867
f156242
3baa867
 
58bda3d
c0032bb
 
f156242
c0032bb
 
58bda3d
c0032bb
f156242
 
 
5feda0d
20ae2d2
5feda0d
20ae2d2
 
 
 
 
 
 
 
 
 
5feda0d
 
20ae2d2
 
 
58bda3d
20ae2d2
 
 
67df04a
20ae2d2
 
58bda3d
20ae2d2
58bda3d
 
f156242
58bda3d
20ae2d2
4215f3c
58bda3d
20ae2d2
 
 
58bda3d
e89f25d
 
5feda0d
 
 
 
e89f25d
20ae2d2
0f9515d
5feda0d
 
20ae2d2
5feda0d
0f9515d
 
 
 
 
c0032bb
20ae2d2
c0032bb
 
5feda0d
0f9515d
c0032bb
20ae2d2
c0032bb
58bda3d
0f9515d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Define the GATConv model architecture
class ModeratelySimplifiedGATConvModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
        self.dropout1 = torch.nn.Dropout(0.45)
        self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)

    def forward(self, x, edge_index, edge_attr=None):
        x = self.conv1(x, edge_index, edge_attr)
        x = torch.relu(x)
        x = self.dropout1(x)
        x = self.conv2(x, edge_index, edge_attr)
        return x

# Load the dataset and the GATConv model
data = torch.load("graph_data.pt", map_location=torch.device("cpu"))

# Correct the state dictionary's key names
original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
corrected_state_dict = {}
for key, value in original_state_dict.items():
    if "lin.weight" in key:
        corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
        corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
    else:
        corrected_state_dict[key] = value

# Initialize the GATConv model with the corrected state dictionary
gatconv_model = ModeratelySimplifiedGATConvModel(
    in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
)
gatconv_model.load_state_dict(corrected_state_dict)

# Load the BERT-based sentence transformer model
model_bert = SentenceTransformer("all-mpnet-base-v2")

# Ensure the DataFrame is loaded properly
try:
    df = pd.read_json("combined_data.json.gz", orient='records', lines=True, compression='gzip')
except Exception as e:
    print(f"Error reading JSON file: {e}")

# Generate GNN-based embeddings
with torch.no_grad():
    all_video_embeddings = gatconv_model(data.x, data.edge_index, data.edge_attr).cpu()

# Function to find the most similar video and recommend the top 10 based on GNN embeddings
def get_similar_and_recommend(input_text):
    # Find the most similar video based on input text
    embeddings_matrix = np.array(df["embeddings"].tolist())
    input_embedding = model_bert.encode([input_text])[0]
    similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]

    most_similar_index = np.argmax(similarities)  # Use unweighted scores for the most similar video

    # Get all features of the most similar video
    most_similar_video_features = df.iloc[most_similar_index].to_dict()
    # Get all features of the most similar video
    most_similar_video_features = df.iloc[most_similar_index].to_dict()
    
    # Remove the "embeddings" key from most_similar_video_features
    if "embeddings" in most_similar_video_features:
        del most_similar_video_features["embeddings"]
    if "text_for_embedding" in most_similar_video_features:
        del most_similar_video_features["text_for_embedding"]


    # Apply search context weight for GNN recommendations
    user_keywords = input_text.split()  # Create a list of keywords from user input
    weight = 1.0  # Initial weight factor

    for keyword in user_keywords:
        if keyword.lower() in df["title"].str.lower().tolist():  # Check for matching keywords
            weight += 0.1  # Increase weight for each match

    # Recommend the top 10 videos based on GNN embeddings and weighted dot product
    def recommend_next_10_videos(given_video_index, all_video_embeddings, weight):
        dot_products = [
            torch.dot(all_video_embeddings[given_video_index], all_video_embeddings[i]) * weight
            for i in range(all_video_embeddings.shape[0])
        ]
        dot_products[given_video_index] = -float("inf")

        top_10_indices = np.argsort(dot_products)[[::-1][:10]
        return [df.iloc[idx].to_dict() for idx in top_10_indices]

    top_10_recommended_videos_features = recommend_next_10_videos(
        most_similar_index, all_video_embeddings, weight
    )

    # Exclude unwanted features for recommended videos
    for recommended_video in top_10_recommended_videos_features:
        if "text_for_embedding" in recommended_video:
            del recommended_video["text_for_embedding"]
        if "embeddings" in recommended_video:
            del recommended_video["embeddings"]

    # Create the output JSON with the search context
    output = {
        "search_context": {
            "input_text": input_text,
            "weight": weight,  # Weight applied to the GNN recommendations
        },
        "most_similar_video": most_similar_video_features,
        "top_10_recommended_videos": top_10_recommended_videos_features,
    }

    return output

# Update the Gradio interface to output JSON with search context for GNN recommendations
interface = gr.Interface(
    fn=get_similar_and_recommend,
    inputs=gr.Textbox(label="Enter Text to Find Most Similar Video"),
    outputs=gr.JSON(),
    title="Video Recommendation System with GNN-based Recommendations",
    description="Enter text to find the most similar video and get top 10 recommended videos with search context applied to GNN results.",
)

interface.launch()