Spaces:
Sleeping
Sleeping
File size: 4,523 Bytes
c0032bb 58bda3d c0032bb 58bda3d c0032bb 58bda3d c0032bb 58bda3d c0032bb 58bda3d c0032bb 3baa867 58bda3d c0032bb 58bda3d c0032bb 58bda3d 0f9515d 67df04a e89f25d 58bda3d c0032bb 58bda3d 0f9515d 58bda3d e89f25d 0f9515d 58bda3d 0f9515d c0032bb 58bda3d e89f25d 0f9515d c0032bb e89f25d c0032bb 0f9515d c0032bb e89f25d c0032bb 58bda3d 0f9515d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import torch
import pandas as pd
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
# Define the GATConv model architecture
class ModeratelySimplifiedGATConvModel(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
self.dropout1 = torch.nn.Dropout(0.45)
self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)
def forward(self, x, edge_index, edge_attr=None):
x = self.conv1(x, edge_index, edge_attr)
x = torch.relu(x)
x = self.dropout1(x)
x = self.conv2(x, edge_index, edge_attr)
return x
# Load the dataset and the GATConv model
data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
# Correct the state dictionary's key names
original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
corrected_state_dict = {}
for key, value in original_state_dict.items():
if "lin.weight" in key:
corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
else:
corrected_state_dict[key] = value
# Initialize the GATConv model with the corrected state dictionary
gatconv_model = ModeratelySimplifiedGATConvModel(
in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
)
gatconv_model.load_state_dict(corrected_state_dict)
# Load the BERT-based sentence transformer model
model_bert = SentenceTransformer("all-mpnet-base-v2")
# Ensure the DataFrame is loaded properly
try:
df = pd.read_json("combined_data.json.gz", orient='records', lines=True, compression='gzip')
except Exception as e:
print(f"Error reading JSON file: {e}")
# Generate GNN-based embeddings
with torch.no_grad():
all_video_embeddings = gatconv_model(data.x, data.edge_index, data.edge_attr).cpu()
# Function to find the most similar video and recommend the top 10 based on GNN embeddings
def get_similar_and_recommend(input_text):
# Find the most similar video based on input text
embeddings_matrix = np.array(df["embeddings"].tolist())
input_embedding = model_bert.encode([input_text])[0]
similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]
most_similar_index = np.argmax(similarities)
# Get all features of the most similar video
most_similar_video_features = df.iloc[most_similar_index].to_dict()
# Exclude unwanted features
unwanted_keys = ["text_for_embedding", "embeddings"]
for key in unwanted_keys:
if key in most_similar_video_features:
del most_similar_video_features[key]
# Recommend the top 10 videos based on GNN embeddings and dot product
def recommend_next_10_videos(given_video_index, all_video_embeddings):
dot_products = [
torch.dot(all_video_embeddings[given_video_index], all_video_embeddings[i])
for i in range(all_video_embeddings.shape[0])
]
dot_products[given_video_index] = -float("inf")
top_10_indices = np.argsort(dot_products)[[::-1][:10]
return [df.iloc[idx].to_dict() for idx in top_10_indices]
top_10_recommended_videos_features = recommend_next_10_videos(
most_similar_index, all_video_embeddings
)
# Exclude unwanted features for recommended videos
for recommended_video in top_10_recommended_videos_features:
for key in unwanted_keys:
if key in recommended_video:
del recommended_video[key]
# Create the output JSON with all features except the unwanted ones
output = {
"most_similar_video": most_similar_video_features,
"top_10_recommended_videos": top_10_recommended_videos_features,
}
return output
# Update the Gradio interface to output a JSON object without unwanted features
interface = gr.Interface(
fn=get_similar_and_recommend,
inputs=gr.components.Textbox(label="Enter Text to Find Most Similar Video"),
outputs=gr.JSON(),
title="Video Recommendation System with GNN-based Recommendations",
description="Enter text to find the most similar video and get the top 10 recommended videos with all features except embeddings-related fields in a JSON object.",
)
interface.launch()
|