GATConvTest / app.py
NimaKL's picture
Update app.py
8aaf909 verified
raw
history blame
3.74 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sklearn.metrics.pairwise import cosine_similarity
# FastAPI App
app = FastAPI()
# Data and Model Initialization
data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
# Corrected state dictionary for GATConv model
original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
corrected_state_dict = {}
for key, value in original_state_dict.items():
if "lin.weight" in key:
corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
else:
corrected_state_dict[key] = value
# Define GATConv Model
class ModeratelySimplifiedGATConvModel(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
self.dropout1 = torch.nn.Dropout(0.45)
self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)
def forward(self, x, edge_index, edge_attr=None):
x = self.conv1(x, edge_index, edge_attr)
x = torch.relu(x)
x = self.dropout1(x)
x = self.conv2(x, edge_index, edge_attr)
return x
# Initialize GATConv model and BERT-based sentence transformer model
gatconv_model = ModeratelySimplifiedGATConvModel(
in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
)
gatconv_model.load_state_dict(corrected_state_dict)
model_bert = SentenceTransformer("all-mpnet-base-v2")
# Ensure DataFrame is loaded properly
df = pd.read_feather("EmbeddedCombined.feather")
# Function to get most similar video and recommend top 10 based on GNN embeddings
def get_similar_and_recommend(input_text):
embeddings_matrix = np.array(df["embeddings"].tolist())
input_embedding = model_bert.encode([input_text])[0]
similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]
most_similar_index = np.argmax(similarities)
most_similar_video = {
"title": df["title"].iloc[most_similar_index],
"description": df["description"].iloc[most_similar_index],
"similarity_score": similarities[most_similar_index],
}
# Function to recommend top 10 videos based on GNN embeddings
def recommend_next_10_videos(given_video_index, all_video_embeddings):
dot_products = [
torch.dot(all_video_embeddings[given_video_index].cpu(), all_video_embeddings[i].cpu())
for i in range(all_video_embeddings.shape[0])
]
dot_products[given_video_index] = -float("inf")
top_10_indices = np.argsort(dot_products)[::-1][:10]
recommendations = [df["title"].iloc[idx] for idx in top_10_indices]
return recommendations
top_10_recommendations = recommend_next_10_videos(
most_similar_index, gatconv_model(data.x, data.edge_index, data.edge_attr).cpu()
)
return {
"most_similar_video_title": most_similar_video["title"],
"top_10_recommendations": top_10_recommendations,
}
# Define the endpoint for FastAPI to get video title and recommendations
class UserInput(BaseModel):
text: str # The string input from the user
@app.post("/recommendations")
def recommend_videos(user_input: UserInput):
if not user_input.text:
raise HTTPException(status_code=400, detail="Input text cannot be empty.")
result = get_similar_and_recommend(user_input.text)
return result