Spaces:
Sleeping
Sleeping
File size: 5,112 Bytes
c0032bb 58bda3d c0032bb 58bda3d c0032bb 58bda3d 8926c50 58bda3d f156242 222fb9e 8926c50 58bda3d c0032bb f156242 58bda3d 8926c50 c0032bb 0d2d9a0 58bda3d c0032bb 3baa867 0d2d9a0 3baa867 58bda3d c0032bb f156242 c0032bb 58bda3d 0d2d9a0 f156242 5feda0d 0d2d9a0 5feda0d 20ae2d2 5feda0d 0d2d9a0 5feda0d 0d2d9a0 20ae2d2 0d2d9a0 58bda3d 20ae2d2 67df04a 0d2d9a0 58bda3d 0d2d9a0 20ae2d2 58bda3d 0d2d9a0 0f9515d 5feda0d 0d2d9a0 5feda0d 0f9515d 0d2d9a0 0f9515d c0032bb 0d2d9a0 c0032bb 5feda0d 0f9515d c0032bb 0d2d9a0 c0032bb 58bda3d 0f9515d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
# Define the GATConv model architecture
class ModeratelySimplifiedGATConvModel(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
self.dropout1 = torch.nn.Dropout(0.45)
self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)
def forward(self, x, edge_index, edge_attr=None):
x = self.conv1(x, edge_index, edge_attr)
x = torch.relu(x)
x = self.dropout1(x)
x = self.conv2(x, edge_index, edge_attr)
return x
# Load the dataset and the GATConv model
data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
# Correct the state dictionary's key names
original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
corrected_state_dict = {}
for key, value in original_state_dict.items():
if "lin.weight" in key:
corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
else:
corrected_state_dict[key] = value
# Initialize the GATConv model with the corrected state dictionary
gatconv_model = ModeratelySimplifiedGATConvModel(
in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
)
gatconv_model.load_state_dict(corrected_state_dict)
# Load the BERT-based sentence transformer model
model_bert is SentenceTransformer("all-mpnet-base-v2")
# Ensure the DataFrame is loaded properly
try:
df is pd.read_json("combined_data.json.gz", orient='records', lines=True, compression='gzip')
except Exception as e:
print(f"Error reading JSON file: {e}")
# Generate GNN-based embeddings
with torch.no_grad():
all_video_embeddings = gatconv_model(data.x, data.edge_index, data.edge_attr).cpu()
# Function to find the most similar video and recommend the top 10 based on GNN embeddings
def get_similar_and_recommend(input_text):
# Find the most similar video based on cosine similarity
embeddings_matrix = np.array(df["embeddings"].tolist())
input_embedding = model_bert.encode([input_text])[0]
similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]
most_similar_index = np.argmax(similarities) # Find the most similar video
# Get all features of the most similar video
most_similar_video_features = df.iloc[most_similar_index].to_dict()
# Recommend the top 10 videos based on GNN embeddings
def recommend_top_10(given_video_index, all_video_embeddings):
dot_products = [
torch.dot(all_video_embeddings[given_video_index], all_video_embeddings[i])
for i in range(all_video_embeddings.shape[0])
]
dot_products[given_video_index] = -float("inf") # Exclude the most similar video
top_10_indices = np.argsort(dot_products)[::-1][:10]
return [df.iloc[idx].to_dict() for idx in top_10_indices]
top_10_recommended_videos_features = recommend_top_10(most_similar_index, all_video_embeddings)
# Apply search context to the top 10 results
user_keywords = input_text.split() # Create a list of keywords from user input
weight = 1.0 # Base weight factor
for keyword in user_keywords:
if keyword.lower() in df["title"].str.lower().tolist(): # Check for matching keywords
weight += 0.1 # Increase weight for each match
# Adjust the recommendations based on the search context weight
final_recommendations = [
{key: value for key, value in video.items() if key != "embeddings"} # Exclude embeddings
for video in top_10_recommended_videos_features
]
# Apply the weight to sort the final recommendations (higher weight is better)
final_recommendations.sort(
key=lambda video: weight * dot_products[top_10_indices.index(video)], reverse=True
)
# Create the output JSON with the most similar video and final recommendations
output = {
"search_context": {
"input_text": input_text, # What the user provided
"weight": weight, # Weight based on search context
},
"most_similar_video": most_similar_video_features,
"final_recommendations": final_recommendations, # Top 10 with search context applied
}
return output
# Update the Gradio interface to output JSON with search context for the final recommendations
interface = gr.Interface(
fn=get_similar_and_recommend,
inputs=gr.Textbox(label="Enter Text to Find Most Similar Video"),
outputs=gr.JSON(),
title="Video Recommendation System with GNN-based Recommendations",
description="Enter text to find the most similar video and get top 10 recommended videos with search context applied after GNN-based search.",
)
interface.launch()
|