File size: 4,515 Bytes
40fa5b9
f128fe5
40fa5b9
ca764d6
 
 
 
 
 
 
40fa5b9
 
ca764d6
40fa5b9
 
 
 
f128fe5
 
40fa5b9
ba1c7a0
 
 
 
ca764d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
from menu import menu_with_redirect

# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Path manipulation
from pathlib import Path
from huggingface_hub import hf_hub_download

# Custom and other imports
import project_config

# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()

# Header
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)

# Main content
# st.markdown(f"Hello, {st.session_state.name}!")

st.subheader("Model Predictions", divider = "green")

# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node']} ➡️ {st.session_state.query['relation']} ➡️ {st.session_state.query['target_node_type']}")

with st.spinner('Loading knowledge graph...'):
    kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)

# Get paths to embeddings, relation weights, and edge types
with st.spinner('Downloading AI model...'):
    embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
                                filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
                                token=st.secrets["HF_TOKEN"])
    relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
                                            filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
                                            token=st.secrets["HF_TOKEN"])
    edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
                                        filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
                                        token=st.secrets["HF_TOKEN"])

# Load embeddings, relation weights, and edge types
with st.spinner('Loading AI model...'):
    embeddings = torch.load(embed_path)
    relation_weights = torch.load(relation_weights_path)
    edge_types = torch.load(edge_types_path)

# # Print source node type
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")

# # Print source node
# st.write(f"Source Node: {st.session_state.query['source_node']}")

# # Print relation
# st.write(f"Edge Type: {st.session_state.query['relation']}")

# # Print target node type
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")

# Compute predictions
with st.spinner('Computing predictions...'):

    source_node_type = st.session_state.query['source_node_type']
    source_node = st.session_state.query['source_node']
    relation = st.session_state.query['relation']
    target_node_type = st.session_state.query['target_node_type']

    # Get source node index
    src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]

    # Get relation index
    edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]

    # Get target nodes indices
    target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
    dst_indices = target_nodes.node_index.values
    src_indices = np.repeat(src_index, len(dst_indices))

    # Retrieve cached embeddings
    src_embeddings = embeddings[src_indices]
    dst_embeddings = embeddings[dst_indices]

    # Apply activation function
    src_embeddings = F.leaky_relu(src_embeddings)
    dst_embeddings = F.leaky_relu(dst_embeddings)

    # Get relation weights
    rel_weights = relation_weights[edge_type_index]

    # Compute weighted dot product
    scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
    scores = torch.sigmoid(scores)

    # Add scores to dataframe
    target_nodes['score'] = scores.detach().numpy()

    # Rank target nodes by score
    target_nodes = target_nodes.sort_values(by = 'score', ascending = False)

    # Add rank to dataframe
    target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)

    # Show top ranked nodes
    top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50) 

    # Rename columns
    display_data = target_nodes[['rank', 'node_id', 'node_name', 'node_source', 'score']].iloc[:top_k].copy()
    display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'})
    st.dataframe(display_data, use_container_width = True)