Spaces:
Sleeping
Sleeping
import streamlit as st | |
from menu import menu_with_redirect | |
# Standard imports | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Path manipulation | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
# Custom and other imports | |
import project_config | |
# Redirect to app.py if not logged in, otherwise show the navigation menu | |
menu_with_redirect() | |
# Header | |
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True) | |
# Main content | |
# st.markdown(f"Hello, {st.session_state.name}!") | |
st.subheader("Model Predictions", divider = "green") | |
# Print current query | |
st.markdown(f"**Query:** {st.session_state.query['source_node']} ➡️ {st.session_state.query['relation']} ➡️ {st.session_state.query['target_node_type']}") | |
with st.spinner('Loading knowledge graph...'): | |
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False) | |
# Get paths to embeddings, relation weights, and edge types | |
with st.spinner('Downloading AI model...'): | |
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy", | |
filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt", | |
token=st.secrets["HF_TOKEN"]) | |
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy", | |
filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt", | |
token=st.secrets["HF_TOKEN"]) | |
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy", | |
filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt", | |
token=st.secrets["HF_TOKEN"]) | |
# Load embeddings, relation weights, and edge types | |
with st.spinner('Loading AI model...'): | |
embeddings = torch.load(embed_path) | |
relation_weights = torch.load(relation_weights_path) | |
edge_types = torch.load(edge_types_path) | |
# # Print source node type | |
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}") | |
# # Print source node | |
# st.write(f"Source Node: {st.session_state.query['source_node']}") | |
# # Print relation | |
# st.write(f"Edge Type: {st.session_state.query['relation']}") | |
# # Print target node type | |
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}") | |
# Compute predictions | |
with st.spinner('Computing predictions...'): | |
source_node_type = st.session_state.query['source_node_type'] | |
source_node = st.session_state.query['source_node'] | |
relation = st.session_state.query['relation'] | |
target_node_type = st.session_state.query['target_node_type'] | |
# Get source node index | |
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0] | |
# Get relation index | |
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0] | |
# Get target nodes indices | |
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type] | |
dst_indices = target_nodes.node_index.values | |
src_indices = np.repeat(src_index, len(dst_indices)) | |
# Retrieve cached embeddings | |
src_embeddings = embeddings[src_indices] | |
dst_embeddings = embeddings[dst_indices] | |
# Apply activation function | |
src_embeddings = F.leaky_relu(src_embeddings) | |
dst_embeddings = F.leaky_relu(dst_embeddings) | |
# Get relation weights | |
rel_weights = relation_weights[edge_type_index] | |
# Compute weighted dot product | |
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1) | |
scores = torch.sigmoid(scores) | |
# Add scores to dataframe | |
target_nodes['score'] = scores.detach().numpy() | |
# Rank target nodes by score | |
target_nodes = target_nodes.sort_values(by = 'score', ascending = False) | |
# Add rank to dataframe | |
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1) | |
# Show top ranked nodes | |
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50) | |
# Rename columns | |
display_data = target_nodes[['rank', 'node_id', 'node_name', 'node_source', 'score']].iloc[:top_k].copy() | |
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'}) | |
st.dataframe(display_data, use_container_width = True) | |