Spaces:
Sleeping
Sleeping
File size: 4,515 Bytes
40fa5b9 f128fe5 40fa5b9 ca764d6 40fa5b9 ca764d6 40fa5b9 f128fe5 40fa5b9 ba1c7a0 ca764d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
# Path manipulation
from pathlib import Path
from huggingface_hub import hf_hub_download
# Custom and other imports
import project_config
# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()
# Header
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
# Main content
# st.markdown(f"Hello, {st.session_state.name}!")
st.subheader("Model Predictions", divider = "green")
# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node']} ➡️ {st.session_state.query['relation']} ➡️ {st.session_state.query['target_node_type']}")
with st.spinner('Loading knowledge graph...'):
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
# Get paths to embeddings, relation weights, and edge types
with st.spinner('Downloading AI model...'):
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
token=st.secrets["HF_TOKEN"])
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
token=st.secrets["HF_TOKEN"])
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
token=st.secrets["HF_TOKEN"])
# Load embeddings, relation weights, and edge types
with st.spinner('Loading AI model...'):
embeddings = torch.load(embed_path)
relation_weights = torch.load(relation_weights_path)
edge_types = torch.load(edge_types_path)
# # Print source node type
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
# # Print source node
# st.write(f"Source Node: {st.session_state.query['source_node']}")
# # Print relation
# st.write(f"Edge Type: {st.session_state.query['relation']}")
# # Print target node type
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
# Compute predictions
with st.spinner('Computing predictions...'):
source_node_type = st.session_state.query['source_node_type']
source_node = st.session_state.query['source_node']
relation = st.session_state.query['relation']
target_node_type = st.session_state.query['target_node_type']
# Get source node index
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
# Get relation index
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
# Get target nodes indices
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
dst_indices = target_nodes.node_index.values
src_indices = np.repeat(src_index, len(dst_indices))
# Retrieve cached embeddings
src_embeddings = embeddings[src_indices]
dst_embeddings = embeddings[dst_indices]
# Apply activation function
src_embeddings = F.leaky_relu(src_embeddings)
dst_embeddings = F.leaky_relu(dst_embeddings)
# Get relation weights
rel_weights = relation_weights[edge_type_index]
# Compute weighted dot product
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
scores = torch.sigmoid(scores)
# Add scores to dataframe
target_nodes['score'] = scores.detach().numpy()
# Rank target nodes by score
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
# Add rank to dataframe
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
# Show top ranked nodes
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50)
# Rename columns
display_data = target_nodes[['rank', 'node_id', 'node_name', 'node_source', 'score']].iloc[:top_k].copy()
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'})
st.dataframe(display_data, use_container_width = True)
|